mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support ltx2.3 inference (#1332)
This commit is contained in:
@@ -718,6 +718,66 @@ ltx2_series = [
|
||||
"model_name": "ltx2_latent_upsampler",
|
||||
"model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
|
||||
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
|
||||
"model_name": "ltx2_dit",
|
||||
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
||||
"extra_kwargs": {"apply_gated_attention": True, "cross_attention_adaln": True, "caption_channels": None},
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
|
||||
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
|
||||
"model_name": "ltx2_video_vae_encoder",
|
||||
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
||||
"extra_kwargs": {"encoder_version": "ltx-2.3"},
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
|
||||
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
|
||||
"model_name": "ltx2_video_vae_decoder",
|
||||
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
||||
"extra_kwargs": {"decoder_version": "ltx-2.3"},
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
|
||||
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
|
||||
"model_name": "ltx2_audio_vae_decoder",
|
||||
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
|
||||
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
|
||||
"model_name": "ltx2_audio_vocoder",
|
||||
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2VocoderWithBWE",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
|
||||
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
|
||||
"model_name": "ltx2_audio_vae_encoder",
|
||||
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
|
||||
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
|
||||
"model_name": "ltx2_text_encoder_post_modules",
|
||||
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
||||
"extra_kwargs": {"seperated_audio_video": True, "embedding_dim_gemma": 3840, "num_layers_gemma": 49, "video_attetion_heads": 32, "video_attention_head_dim": 128, "audio_attention_heads": 32, "audio_attention_head_dim": 64, "num_connetor_layers": 8, "apply_gated_attention": True},
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
|
||||
"model_hash": "aed408774d694a2452f69936c32febb5",
|
||||
"model_name": "ltx2_latent_upsampler",
|
||||
"model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
|
||||
"extra_kwargs": {"rational_resampler": False},
|
||||
},
|
||||
]
|
||||
anima_series = [
|
||||
{
|
||||
|
||||
@@ -1279,9 +1279,268 @@ class LTX2AudioDecoder(torch.nn.Module):
|
||||
return torch.tanh(h) if self.tanh_out else h
|
||||
|
||||
|
||||
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2
|
||||
# Adopted from https://github.com/NVIDIA/BigVGAN
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _sinc(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.where(
|
||||
x == 0,
|
||||
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
||||
torch.sin(math.pi * x) / math.pi / x,
|
||||
)
|
||||
|
||||
|
||||
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor:
|
||||
even = kernel_size % 2 == 0
|
||||
half_size = kernel_size // 2
|
||||
delta_f = 4 * half_width
|
||||
amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||
if amplitude > 50.0:
|
||||
beta = 0.1102 * (amplitude - 8.7)
|
||||
elif amplitude >= 21.0:
|
||||
beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)
|
||||
else:
|
||||
beta = 0.0
|
||||
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||
time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size
|
||||
if cutoff == 0:
|
||||
filter_ = torch.zeros_like(time)
|
||||
else:
|
||||
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
|
||||
filter_ /= filter_.sum()
|
||||
return filter_.view(1, 1, kernel_size)
|
||||
|
||||
|
||||
class LowPassFilter1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cutoff: float = 0.5,
|
||||
half_width: float = 0.6,
|
||||
stride: int = 1,
|
||||
padding: bool = True,
|
||||
padding_mode: str = "replicate",
|
||||
kernel_size: int = 12,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if cutoff < -0.0:
|
||||
raise ValueError("Minimum cutoff must be larger than zero.")
|
||||
if cutoff > 0.5:
|
||||
raise ValueError("A cutoff above 0.5 does not make sense.")
|
||||
self.kernel_size = kernel_size
|
||||
self.even = kernel_size % 2 == 0
|
||||
self.pad_left = kernel_size // 2 - int(self.even)
|
||||
self.pad_right = kernel_size // 2
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.padding_mode = padding_mode
|
||||
self.register_buffer("filter", kaiser_sinc_filter1d(cutoff, half_width, kernel_size))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
_, n_channels, _ = x.shape
|
||||
if self.padding:
|
||||
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
||||
return F.conv1d(x, self.filter.expand(n_channels, -1, -1), stride=self.stride, groups=n_channels)
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ratio: int = 2,
|
||||
kernel_size: int | None = None,
|
||||
persistent: bool = True,
|
||||
window_type: str = "kaiser",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.stride = ratio
|
||||
|
||||
if window_type == "hann":
|
||||
# Hann-windowed sinc filter equivalent to torchaudio.functional.resample
|
||||
rolloff = 0.99
|
||||
lowpass_filter_width = 6
|
||||
width = math.ceil(lowpass_filter_width / rolloff)
|
||||
self.kernel_size = 2 * width * ratio + 1
|
||||
self.pad = width
|
||||
self.pad_left = 2 * width * ratio
|
||||
self.pad_right = self.kernel_size - ratio
|
||||
time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff
|
||||
time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width)
|
||||
window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2
|
||||
sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1)
|
||||
else:
|
||||
# Kaiser-windowed sinc filter (BigVGAN default).
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
sinc_filter = kaiser_sinc_filter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
kernel_size=self.kernel_size,
|
||||
)
|
||||
|
||||
self.register_buffer("filter", sinc_filter, persistent=persistent)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
_, n_channels, _ = x.shape
|
||||
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
||||
filt = self.filter.to(dtype=x.dtype, device=x.device).expand(n_channels, -1, -1)
|
||||
x = self.ratio * F.conv_transpose1d(x, filt, stride=self.stride, groups=n_channels)
|
||||
return x[..., self.pad_left : -self.pad_right]
|
||||
|
||||
|
||||
class DownSample1d(nn.Module):
|
||||
def __init__(self, ratio: int = 2, kernel_size: int | None = None) -> None:
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.lowpass = LowPassFilter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
stride=ratio,
|
||||
kernel_size=self.kernel_size,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.lowpass(x)
|
||||
|
||||
|
||||
class Activation1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
activation: nn.Module,
|
||||
up_ratio: int = 2,
|
||||
down_ratio: int = 2,
|
||||
up_kernel_size: int = 12,
|
||||
down_kernel_size: int = 12,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.act = activation
|
||||
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.upsample(x)
|
||||
x = self.act(x)
|
||||
return self.downsample(x)
|
||||
|
||||
|
||||
class Snake(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
alpha: float = 1.0,
|
||||
alpha_trainable: bool = True,
|
||||
alpha_logscale: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.eps = 1e-9
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
return x + (1.0 / (alpha + self.eps)) * torch.sin(x * alpha).pow(2)
|
||||
|
||||
|
||||
class SnakeBeta(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
alpha: float = 1.0,
|
||||
alpha_trainable: bool = True,
|
||||
alpha_logscale: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.beta = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
|
||||
self.beta.requires_grad = alpha_trainable
|
||||
self.eps = 1e-9
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
beta = torch.exp(beta)
|
||||
return x + (1.0 / (beta + self.eps)) * torch.sin(x * alpha).pow(2)
|
||||
|
||||
|
||||
class AMPBlock1(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int = 3,
|
||||
dilation: tuple[int, int, int] = (1, 3, 5),
|
||||
activation: str = "snake",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
),
|
||||
nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
),
|
||||
nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
|
||||
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
|
||||
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
|
||||
]
|
||||
)
|
||||
|
||||
self.acts1 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs1))])
|
||||
self.acts2 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs2))])
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2, strict=True):
|
||||
xt = a1(x)
|
||||
xt = c1(xt)
|
||||
xt = a2(xt)
|
||||
xt = c2(xt)
|
||||
x = x + xt
|
||||
return x
|
||||
|
||||
|
||||
class LTX2Vocoder(torch.nn.Module):
|
||||
"""
|
||||
Vocoder model for synthesizing audio from Mel spectrograms.
|
||||
LTX2Vocoder model for synthesizing audio from Mel spectrograms.
|
||||
Args:
|
||||
resblock_kernel_sizes: List of kernel sizes for the residual blocks.
|
||||
This value is read from the checkpoint at `config.vocoder.resblock_kernel_sizes`.
|
||||
@@ -1293,28 +1552,33 @@ class LTX2Vocoder(torch.nn.Module):
|
||||
This value is read from the checkpoint at `config.vocoder.resblock_dilation_sizes`.
|
||||
upsample_initial_channel: Initial number of channels for the upsampling layers.
|
||||
This value is read from the checkpoint at `config.vocoder.upsample_initial_channel`.
|
||||
stereo: Whether to use stereo output.
|
||||
This value is read from the checkpoint at `config.vocoder.stereo`.
|
||||
resblock: Type of residual block to use.
|
||||
resblock: Type of residual block to use ("1", "2", or "AMP1").
|
||||
This value is read from the checkpoint at `config.vocoder.resblock`.
|
||||
output_sample_rate: Waveform sample rate.
|
||||
This value is read from the checkpoint at `config.vocoder.output_sample_rate`.
|
||||
output_sampling_rate: Waveform sample rate.
|
||||
This value is read from the checkpoint at `config.vocoder.output_sampling_rate`.
|
||||
activation: Activation type for BigVGAN v2 ("snake" or "snakebeta"). Only used when resblock="AMP1".
|
||||
use_tanh_at_final: Apply tanh at the output (when apply_final_activation=True).
|
||||
apply_final_activation: Whether to apply the final tanh/clamp activation.
|
||||
use_bias_at_final: Whether to use bias in the final conv layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
def __init__( # noqa: PLR0913
|
||||
self,
|
||||
resblock_kernel_sizes: List[int] | None = [3, 7, 11],
|
||||
upsample_rates: List[int] | None = [6, 5, 2, 2, 2],
|
||||
upsample_kernel_sizes: List[int] | None = [16, 15, 8, 4, 4],
|
||||
resblock_dilation_sizes: List[List[int]] | None = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
upsample_initial_channel: int = 1024,
|
||||
stereo: bool = True,
|
||||
resblock: str = "1",
|
||||
output_sample_rate: int = 24000,
|
||||
):
|
||||
output_sampling_rate: int = 24000,
|
||||
activation: str = "snake",
|
||||
use_tanh_at_final: bool = True,
|
||||
apply_final_activation: bool = True,
|
||||
use_bias_at_final: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Initialize default values if not provided. Note that mutable default values are not supported.
|
||||
# Mutable default values are not supported as default arguments.
|
||||
if resblock_kernel_sizes is None:
|
||||
resblock_kernel_sizes = [3, 7, 11]
|
||||
if upsample_rates is None:
|
||||
@@ -1324,36 +1588,60 @@ class LTX2Vocoder(torch.nn.Module):
|
||||
if resblock_dilation_sizes is None:
|
||||
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
||||
|
||||
self.output_sample_rate = output_sample_rate
|
||||
self.output_sampling_rate = output_sampling_rate
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
in_channels = 128 if stereo else 64
|
||||
self.conv_pre = nn.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
||||
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
|
||||
self.use_tanh_at_final = use_tanh_at_final
|
||||
self.apply_final_activation = apply_final_activation
|
||||
self.is_amp = resblock == "AMP1"
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True)):
|
||||
self.ups.append(
|
||||
nn.ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=(kernel_size - stride) // 2,
|
||||
)
|
||||
# All production checkpoints are stereo: 128 input channels (2 stereo channels x 64 mel
|
||||
# bins each), 2 output channels.
|
||||
self.conv_pre = nn.Conv1d(
|
||||
in_channels=128,
|
||||
out_channels=upsample_initial_channel,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=3,
|
||||
)
|
||||
resblock_cls = ResBlock1 if resblock == "1" else AMPBlock1
|
||||
|
||||
self.ups = nn.ModuleList(
|
||||
nn.ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=(kernel_size - stride) // 2,
|
||||
)
|
||||
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True))
|
||||
)
|
||||
|
||||
final_channels = upsample_initial_channel // (2 ** len(upsample_rates))
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i, _ in enumerate(self.ups):
|
||||
|
||||
for i in range(len(upsample_rates)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes, strict=True):
|
||||
self.resblocks.append(resblock_class(ch, kernel_size, dilations))
|
||||
if self.is_amp:
|
||||
self.resblocks.append(resblock_cls(ch, kernel_size, dilations, activation=activation))
|
||||
else:
|
||||
self.resblocks.append(resblock_cls(ch, kernel_size, dilations))
|
||||
|
||||
out_channels = 2 if stereo else 1
|
||||
final_channels = upsample_initial_channel // (2**self.num_upsamples)
|
||||
self.conv_post = nn.Conv1d(final_channels, out_channels, 7, 1, padding=3)
|
||||
if self.is_amp:
|
||||
self.act_post: nn.Module = Activation1d(SnakeBeta(final_channels))
|
||||
else:
|
||||
self.act_post = nn.LeakyReLU()
|
||||
|
||||
self.upsample_factor = math.prod(layer.stride[0] for layer in self.ups)
|
||||
# All production checkpoints are stereo: this final conv maps `final_channels` to 2 output channels (stereo).
|
||||
self.conv_post = nn.Conv1d(
|
||||
in_channels=final_channels,
|
||||
out_channels=2,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=3,
|
||||
bias=use_bias_at_final,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
@@ -1374,7 +1662,8 @@ class LTX2Vocoder(torch.nn.Module):
|
||||
x = self.conv_pre(x)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
if not self.is_amp:
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
start = i * self.num_kernels
|
||||
end = start + self.num_kernels
|
||||
@@ -1386,23 +1675,198 @@ class LTX2Vocoder(torch.nn.Module):
|
||||
[self.resblocks[idx](x) for idx in range(start, end)],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
x = block_outputs.mean(dim=0)
|
||||
|
||||
x = self.conv_post(F.leaky_relu(x))
|
||||
return torch.tanh(x)
|
||||
x = self.act_post(x)
|
||||
x = self.conv_post(x)
|
||||
|
||||
if self.apply_final_activation:
|
||||
x = torch.tanh(x) if self.use_tanh_at_final else torch.clamp(x, -1, 1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def decode_audio(latent: torch.Tensor, audio_decoder: "LTX2AudioDecoder", vocoder: "LTX2Vocoder") -> torch.Tensor:
|
||||
class _STFTFn(nn.Module):
|
||||
"""Implements STFT as a convolution with precomputed DFT x Hann-window bases.
|
||||
The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal
|
||||
Hann window are stored as buffers and loaded from the checkpoint. Using the exact
|
||||
bfloat16 bases from training ensures the mel values fed to the BWE generator are
|
||||
bit-identical to what it was trained on.
|
||||
"""
|
||||
Decode an audio latent representation using the provided audio decoder and vocoder.
|
||||
Args:
|
||||
latent: Input audio latent tensor.
|
||||
audio_decoder: Model to decode the latent to waveform features.
|
||||
vocoder: Model to convert decoded features to audio waveform.
|
||||
Returns:
|
||||
Decoded audio as a float tensor.
|
||||
|
||||
def __init__(self, filter_length: int, hop_length: int, win_length: int) -> None:
|
||||
super().__init__()
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
n_freqs = filter_length // 2 + 1
|
||||
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
||||
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
||||
|
||||
def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute magnitude and phase spectrogram from a batch of waveforms.
|
||||
Applies causal (left-only) padding of win_length - hop_length samples so that
|
||||
each output frame depends only on past and present input — no lookahead.
|
||||
Args:
|
||||
y: Waveform tensor of shape (B, T).
|
||||
Returns:
|
||||
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
||||
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
||||
"""
|
||||
if y.dim() == 2:
|
||||
y = y.unsqueeze(1) # (B, 1, T)
|
||||
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
|
||||
y = F.pad(y, (left_pad, 0))
|
||||
spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
|
||||
n_freqs = spec.shape[1] // 2
|
||||
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
|
||||
magnitude = torch.sqrt(real**2 + imag**2)
|
||||
phase = torch.atan2(imag.float(), real.float()).to(real.dtype)
|
||||
return magnitude, phase
|
||||
|
||||
|
||||
class MelSTFT(nn.Module):
|
||||
"""Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.
|
||||
Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input
|
||||
waveform and projecting the linear magnitude spectrum onto the mel filterbank.
|
||||
The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint
|
||||
(mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).
|
||||
"""
|
||||
decoded_audio = audio_decoder(latent)
|
||||
decoded_audio = vocoder(decoded_audio).squeeze(0).float()
|
||||
return decoded_audio
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter_length: int,
|
||||
hop_length: int,
|
||||
win_length: int,
|
||||
n_mel_channels: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.stft_fn = _STFTFn(filter_length, hop_length, win_length)
|
||||
|
||||
# Initialized to zeros; load_state_dict overwrites with the checkpoint's
|
||||
# exact bfloat16 filterbank (vocoder.mel_stft.mel_basis, shape [n_mels, n_freqs]).
|
||||
n_freqs = filter_length // 2 + 1
|
||||
self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs))
|
||||
|
||||
def mel_spectrogram(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Compute log-mel spectrogram and auxiliary spectral quantities.
|
||||
Args:
|
||||
y: Waveform tensor of shape (B, T).
|
||||
Returns:
|
||||
log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).
|
||||
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
||||
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
||||
energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames).
|
||||
"""
|
||||
magnitude, phase = self.stft_fn(y)
|
||||
energy = torch.norm(magnitude, dim=1)
|
||||
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
|
||||
log_mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
return log_mel, magnitude, phase, energy
|
||||
|
||||
|
||||
class LTX2VocoderWithBWE(nn.Module):
|
||||
"""LTX2Vocoder with bandwidth extension (BWE) upsampling.
|
||||
Chains a mel-to-wav vocoder with a BWE module that upsamples the output
|
||||
to a higher sample rate. The BWE computes a mel spectrogram from the
|
||||
vocoder output, runs it through a second generator to predict a residual,
|
||||
and adds it to a sinc-resampled skip connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_sampling_rate: int = 16000,
|
||||
output_sampling_rate: int = 48000,
|
||||
hop_length: int = 80,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.vocoder = LTX2Vocoder(
|
||||
resblock_kernel_sizes=[3, 7, 11],
|
||||
upsample_rates=[5, 2, 2, 2, 2, 2],
|
||||
upsample_kernel_sizes=[11, 4, 4, 4, 4, 4],
|
||||
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
upsample_initial_channel=1536,
|
||||
resblock="AMP1",
|
||||
activation="snakebeta",
|
||||
use_tanh_at_final=False,
|
||||
apply_final_activation=True,
|
||||
use_bias_at_final=False,
|
||||
output_sampling_rate=input_sampling_rate,
|
||||
)
|
||||
self.bwe_generator = LTX2Vocoder(
|
||||
resblock_kernel_sizes=[3, 7, 11],
|
||||
upsample_rates=[6, 5, 2, 2, 2],
|
||||
upsample_kernel_sizes=[12, 11, 4, 4, 4],
|
||||
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
upsample_initial_channel=512,
|
||||
resblock="AMP1",
|
||||
activation="snakebeta",
|
||||
use_tanh_at_final=False,
|
||||
apply_final_activation=False,
|
||||
use_bias_at_final=False,
|
||||
output_sampling_rate=output_sampling_rate,
|
||||
)
|
||||
|
||||
self.mel_stft = MelSTFT(
|
||||
filter_length=512,
|
||||
hop_length=hop_length,
|
||||
win_length=512,
|
||||
n_mel_channels=64,
|
||||
)
|
||||
self.input_sampling_rate = input_sampling_rate
|
||||
self.output_sampling_rate = output_sampling_rate
|
||||
self.hop_length = hop_length
|
||||
# Compute the resampler on CPU so the sinc filter is materialized even when
|
||||
# the model is constructed on meta device (SingleGPUModelBuilder pattern).
|
||||
# The filter is not stored in the checkpoint (persistent=False).
|
||||
with torch.device("cpu"):
|
||||
self.resampler = UpSample1d(
|
||||
ratio=output_sampling_rate // input_sampling_rate, persistent=False, window_type="hann"
|
||||
)
|
||||
|
||||
@property
|
||||
def conv_pre(self) -> nn.Conv1d:
|
||||
return self.vocoder.conv_pre
|
||||
|
||||
@property
|
||||
def conv_post(self) -> nn.Conv1d:
|
||||
return self.vocoder.conv_post
|
||||
|
||||
def _compute_mel(self, audio: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute log-mel spectrogram from waveform using causal STFT bases.
|
||||
Args:
|
||||
audio: Waveform tensor of shape (B, C, T).
|
||||
Returns:
|
||||
mel: Log-mel spectrogram of shape (B, C, n_mels, T_frames).
|
||||
"""
|
||||
batch, n_channels, _ = audio.shape
|
||||
flat = audio.reshape(batch * n_channels, -1) # (B*C, T)
|
||||
mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
|
||||
return mel.reshape(batch, n_channels, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
|
||||
|
||||
def forward(self, mel_spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Run the full vocoder + BWE forward pass.
|
||||
Args:
|
||||
mel_spec: Mel spectrogram of shape (B, 2, T, mel_bins) for stereo
|
||||
or (B, T, mel_bins) for mono. Same format as LTX2Vocoder.forward.
|
||||
Returns:
|
||||
Waveform tensor of shape (B, out_channels, T_out) clipped to [-1, 1].
|
||||
"""
|
||||
x = self.vocoder(mel_spec)
|
||||
_, _, length_low_rate = x.shape
|
||||
output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate
|
||||
|
||||
# Pad to multiple of hop_length for exact mel frame count
|
||||
remainder = length_low_rate % self.hop_length
|
||||
if remainder != 0:
|
||||
x = F.pad(x, (0, self.hop_length - remainder))
|
||||
|
||||
# Compute mel spectrogram from vocoder output: (B, C, n_mels, T_frames)
|
||||
mel = self._compute_mel(x)
|
||||
|
||||
# LTX2Vocoder.forward expects (B, C, T, mel_bins) — transpose before calling bwe_generator
|
||||
mel_for_bwe = mel.transpose(2, 3) # (B, C, T_frames, mel_bins)
|
||||
residual = self.bwe_generator(mel_for_bwe)
|
||||
skip = self.resampler(x)
|
||||
assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
|
||||
|
||||
return torch.clamp(residual + skip, -1, 1)[..., :output_length]
|
||||
|
||||
@@ -251,11 +251,27 @@ class Modality:
|
||||
Input data for a single modality (video or audio) in the transformer.
|
||||
Bundles the latent tokens, timestep embeddings, positional information,
|
||||
and text conditioning context for processing by the diffusion transformer.
|
||||
Attributes:
|
||||
latent: Patchified latent tokens, shape ``(B, T, D)`` where *B* is
|
||||
the batch size, *T* is the total number of tokens (noisy +
|
||||
conditioning), and *D* is the input dimension.
|
||||
timesteps: Per-token timestep embeddings, shape ``(B, T)``.
|
||||
positions: Positional coordinates, shape ``(B, 3, T)`` for video
|
||||
(time, height, width) or ``(B, 1, T)`` for audio.
|
||||
context: Text conditioning embeddings from the prompt encoder.
|
||||
enabled: Whether this modality is active in the current forward pass.
|
||||
context_mask: Optional mask for the text context tokens.
|
||||
attention_mask: Optional 2-D self-attention mask, shape ``(B, T, T)``.
|
||||
Values in ``[0, 1]`` where ``1`` = full attention and ``0`` = no
|
||||
attention. ``None`` means unrestricted (full) attention between
|
||||
all tokens. Built incrementally by conditioning items; see
|
||||
:class:`~ltx_core.conditioning.types.attention_strength_wrapper.ConditioningItemAttentionStrengthWrapper`.
|
||||
"""
|
||||
|
||||
latent: (
|
||||
torch.Tensor
|
||||
) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
|
||||
sigma: torch.Tensor # Shape: (B,). Current sigma value, used for cross-attention timestep calculation.
|
||||
timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps
|
||||
positions: (
|
||||
torch.Tensor
|
||||
@@ -263,6 +279,7 @@ class Modality:
|
||||
context: torch.Tensor
|
||||
enabled: bool = True
|
||||
context_mask: torch.Tensor | None = None
|
||||
attention_mask: torch.Tensor | None = None
|
||||
|
||||
|
||||
def to_denoised(
|
||||
|
||||
@@ -225,6 +225,17 @@ class BatchedPerturbationConfig:
|
||||
return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)])
|
||||
|
||||
|
||||
|
||||
ADALN_NUM_BASE_PARAMS = 6
|
||||
# Cross-attention AdaLN adds 3 more (scale, shift, gate) for the CA norm.
|
||||
ADALN_NUM_CROSS_ATTN_PARAMS = 3
|
||||
|
||||
|
||||
def adaln_embedding_coefficient(cross_attention_adaln: bool) -> int:
|
||||
"""Total number of AdaLN parameters per block."""
|
||||
return ADALN_NUM_BASE_PARAMS + (ADALN_NUM_CROSS_ATTN_PARAMS if cross_attention_adaln else 0)
|
||||
|
||||
|
||||
class AdaLayerNormSingle(torch.nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm single (adaLN-single).
|
||||
@@ -460,6 +471,7 @@ class Attention(torch.nn.Module):
|
||||
dim_head: int = 64,
|
||||
norm_eps: float = 1e-6,
|
||||
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||
apply_gated_attention: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.rope_type = rope_type
|
||||
@@ -477,6 +489,12 @@ class Attention(torch.nn.Module):
|
||||
self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True)
|
||||
self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True)
|
||||
|
||||
# Optional per-head gating
|
||||
if apply_gated_attention:
|
||||
self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True)
|
||||
else:
|
||||
self.to_gate_logits = None
|
||||
|
||||
self.to_out = torch.nn.Sequential(torch.nn.Linear(inner_dim, query_dim, bias=True), torch.nn.Identity())
|
||||
|
||||
def forward(
|
||||
@@ -486,6 +504,8 @@ class Attention(torch.nn.Module):
|
||||
mask: torch.Tensor | None = None,
|
||||
pe: torch.Tensor | None = None,
|
||||
k_pe: torch.Tensor | None = None,
|
||||
perturbation_mask: torch.Tensor | None = None,
|
||||
all_perturbed: bool = False,
|
||||
) -> torch.Tensor:
|
||||
q = self.to_q(x)
|
||||
context = x if context is None else context
|
||||
@@ -517,6 +537,19 @@ class Attention(torch.nn.Module):
|
||||
|
||||
# Reshape back to original format
|
||||
out = out.flatten(2, 3)
|
||||
|
||||
# Apply per-head gating if enabled
|
||||
if self.to_gate_logits is not None:
|
||||
gate_logits = self.to_gate_logits(x) # (B, T, H)
|
||||
b, t, _ = out.shape
|
||||
# Reshape to (B, T, H, D) for per-head gating
|
||||
out = out.view(b, t, self.heads, self.dim_head)
|
||||
# Apply gating: 2 * sigmoid(x) so that zero-init gives identity (2 * 0.5 = 1.0)
|
||||
gates = 2.0 * torch.sigmoid(gate_logits) # (B, T, H)
|
||||
out = out * gates.unsqueeze(-1) # (B, T, H, D) * (B, T, H, 1)
|
||||
# Reshape back to (B, T, H*D)
|
||||
out = out.view(b, t, self.heads * self.dim_head)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
@@ -545,7 +578,6 @@ class PixArtAlphaTextProjection(torch.nn.Module):
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TransformerArgs:
|
||||
x: torch.Tensor
|
||||
@@ -558,7 +590,10 @@ class TransformerArgs:
|
||||
cross_scale_shift_timestep: torch.Tensor | None
|
||||
cross_gate_timestep: torch.Tensor | None
|
||||
enabled: bool
|
||||
|
||||
prompt_timestep: torch.Tensor | None = None
|
||||
self_attention_mask: torch.Tensor | None = (
|
||||
None # Additive log-space self-attention bias (B, 1, T, T), None = full attention
|
||||
)
|
||||
|
||||
|
||||
class TransformerArgsPreprocessor:
|
||||
@@ -566,7 +601,6 @@ class TransformerArgsPreprocessor:
|
||||
self,
|
||||
patchify_proj: torch.nn.Linear,
|
||||
adaln: AdaLayerNormSingle,
|
||||
caption_projection: PixArtAlphaTextProjection,
|
||||
inner_dim: int,
|
||||
max_pos: list[int],
|
||||
num_attention_heads: int,
|
||||
@@ -575,10 +609,11 @@ class TransformerArgsPreprocessor:
|
||||
double_precision_rope: bool,
|
||||
positional_embedding_theta: float,
|
||||
rope_type: LTXRopeType,
|
||||
caption_projection: torch.nn.Module | None = None,
|
||||
prompt_adaln: AdaLayerNormSingle | None = None,
|
||||
) -> None:
|
||||
self.patchify_proj = patchify_proj
|
||||
self.adaln = adaln
|
||||
self.caption_projection = caption_projection
|
||||
self.inner_dim = inner_dim
|
||||
self.max_pos = max_pos
|
||||
self.num_attention_heads = num_attention_heads
|
||||
@@ -587,18 +622,18 @@ class TransformerArgsPreprocessor:
|
||||
self.double_precision_rope = double_precision_rope
|
||||
self.positional_embedding_theta = positional_embedding_theta
|
||||
self.rope_type = rope_type
|
||||
self.caption_projection = caption_projection
|
||||
self.prompt_adaln = prompt_adaln
|
||||
|
||||
def _prepare_timestep(
|
||||
self, timestep: torch.Tensor, batch_size: int, hidden_dtype: torch.dtype
|
||||
self, timestep: torch.Tensor, adaln: AdaLayerNormSingle, batch_size: int, hidden_dtype: torch.dtype
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Prepare timestep embeddings."""
|
||||
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep, embedded_timestep = self.adaln(
|
||||
timestep.flatten(),
|
||||
timestep_scaled = timestep * self.timestep_scale_multiplier
|
||||
timestep, embedded_timestep = adaln(
|
||||
timestep_scaled.flatten(),
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
|
||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
||||
@@ -608,14 +643,12 @@ class TransformerArgsPreprocessor:
|
||||
self,
|
||||
context: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
) -> torch.Tensor:
|
||||
"""Prepare context for transformer blocks."""
|
||||
if self.caption_projection is not None:
|
||||
context = self.caption_projection(context)
|
||||
batch_size = x.shape[0]
|
||||
context = self.caption_projection(context)
|
||||
context = context.view(batch_size, -1, x.shape[-1])
|
||||
|
||||
return context, attention_mask
|
||||
return context.view(batch_size, -1, x.shape[-1])
|
||||
|
||||
def _prepare_attention_mask(self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype) -> torch.Tensor | None:
|
||||
"""Prepare attention mask."""
|
||||
@@ -626,6 +659,34 @@ class TransformerArgsPreprocessor:
|
||||
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||
) * torch.finfo(x_dtype).max
|
||||
|
||||
def _prepare_self_attention_mask(
|
||||
self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype
|
||||
) -> torch.Tensor | None:
|
||||
"""Prepare self-attention mask by converting [0,1] values to additive log-space bias.
|
||||
Input shape: (B, T, T) with values in [0, 1].
|
||||
Output shape: (B, 1, T, T) with 0.0 for full attention and a large negative value
|
||||
for masked positions.
|
||||
Positions with attention_mask <= 0 are fully masked (mapped to the dtype's minimum
|
||||
representable value). Strictly positive entries are converted via log-space for
|
||||
smooth attenuation, with small values clamped for numerical stability.
|
||||
Returns None if input is None (no masking).
|
||||
"""
|
||||
if attention_mask is None:
|
||||
return None
|
||||
|
||||
# Convert [0, 1] attention mask to additive log-space bias:
|
||||
# 1.0 -> log(1.0) = 0.0 (no bias, full attention)
|
||||
# 0.0 -> finfo.min (fully masked)
|
||||
finfo = torch.finfo(x_dtype)
|
||||
eps = finfo.tiny
|
||||
|
||||
bias = torch.full_like(attention_mask, finfo.min, dtype=x_dtype)
|
||||
positive = attention_mask > 0
|
||||
if positive.any():
|
||||
bias[positive] = torch.log(attention_mask[positive].clamp(min=eps)).to(x_dtype)
|
||||
|
||||
return bias.unsqueeze(1) # (B, 1, T, T) for head broadcast
|
||||
|
||||
def _prepare_positional_embeddings(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -653,11 +714,20 @@ class TransformerArgsPreprocessor:
|
||||
def prepare(
|
||||
self,
|
||||
modality: Modality,
|
||||
cross_modality: Modality | None = None, # noqa: ARG002
|
||||
) -> TransformerArgs:
|
||||
x = self.patchify_proj(modality.latent)
|
||||
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], modality.latent.dtype)
|
||||
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
|
||||
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
|
||||
batch_size = x.shape[0]
|
||||
timestep, embedded_timestep = self._prepare_timestep(
|
||||
modality.timesteps, self.adaln, batch_size, modality.latent.dtype
|
||||
)
|
||||
prompt_timestep = None
|
||||
if self.prompt_adaln is not None:
|
||||
prompt_timestep, _ = self._prepare_timestep(
|
||||
modality.sigma, self.prompt_adaln, batch_size, modality.latent.dtype
|
||||
)
|
||||
context = self._prepare_context(modality.context, x)
|
||||
attention_mask = self._prepare_attention_mask(modality.context_mask, modality.latent.dtype)
|
||||
pe = self._prepare_positional_embeddings(
|
||||
positions=modality.positions,
|
||||
inner_dim=self.inner_dim,
|
||||
@@ -666,6 +736,7 @@ class TransformerArgsPreprocessor:
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
x_dtype=modality.latent.dtype,
|
||||
)
|
||||
self_attention_mask = self._prepare_self_attention_mask(modality.attention_mask, modality.latent.dtype)
|
||||
return TransformerArgs(
|
||||
x=x,
|
||||
context=context,
|
||||
@@ -677,6 +748,8 @@ class TransformerArgsPreprocessor:
|
||||
cross_scale_shift_timestep=None,
|
||||
cross_gate_timestep=None,
|
||||
enabled=modality.enabled,
|
||||
prompt_timestep=prompt_timestep,
|
||||
self_attention_mask=self_attention_mask,
|
||||
)
|
||||
|
||||
|
||||
@@ -685,7 +758,6 @@ class MultiModalTransformerArgsPreprocessor:
|
||||
self,
|
||||
patchify_proj: torch.nn.Linear,
|
||||
adaln: AdaLayerNormSingle,
|
||||
caption_projection: PixArtAlphaTextProjection,
|
||||
cross_scale_shift_adaln: AdaLayerNormSingle,
|
||||
cross_gate_adaln: AdaLayerNormSingle,
|
||||
inner_dim: int,
|
||||
@@ -699,11 +771,12 @@ class MultiModalTransformerArgsPreprocessor:
|
||||
positional_embedding_theta: float,
|
||||
rope_type: LTXRopeType,
|
||||
av_ca_timestep_scale_multiplier: int,
|
||||
caption_projection: torch.nn.Module | None = None,
|
||||
prompt_adaln: AdaLayerNormSingle | None = None,
|
||||
) -> None:
|
||||
self.simple_preprocessor = TransformerArgsPreprocessor(
|
||||
patchify_proj=patchify_proj,
|
||||
adaln=adaln,
|
||||
caption_projection=caption_projection,
|
||||
inner_dim=inner_dim,
|
||||
max_pos=max_pos,
|
||||
num_attention_heads=num_attention_heads,
|
||||
@@ -712,6 +785,8 @@ class MultiModalTransformerArgsPreprocessor:
|
||||
double_precision_rope=double_precision_rope,
|
||||
positional_embedding_theta=positional_embedding_theta,
|
||||
rope_type=rope_type,
|
||||
caption_projection=caption_projection,
|
||||
prompt_adaln=prompt_adaln,
|
||||
)
|
||||
self.cross_scale_shift_adaln = cross_scale_shift_adaln
|
||||
self.cross_gate_adaln = cross_gate_adaln
|
||||
@@ -722,8 +797,22 @@ class MultiModalTransformerArgsPreprocessor:
|
||||
def prepare(
|
||||
self,
|
||||
modality: Modality,
|
||||
cross_modality: Modality | None = None,
|
||||
) -> TransformerArgs:
|
||||
transformer_args = self.simple_preprocessor.prepare(modality)
|
||||
if cross_modality is None:
|
||||
return transformer_args
|
||||
|
||||
if cross_modality.sigma.numel() > 1:
|
||||
if cross_modality.sigma.shape[0] != modality.timesteps.shape[0]:
|
||||
raise ValueError("Cross modality sigma must have the same batch size as the modality")
|
||||
if cross_modality.sigma.ndim != 1:
|
||||
raise ValueError("Cross modality sigma must be a 1D tensor")
|
||||
|
||||
cross_timestep = cross_modality.sigma.view(
|
||||
modality.timesteps.shape[0], 1, *[1] * len(modality.timesteps.shape[2:])
|
||||
)
|
||||
|
||||
cross_pe = self.simple_preprocessor._prepare_positional_embeddings(
|
||||
positions=modality.positions[:, 0:1, :],
|
||||
inner_dim=self.audio_cross_attention_dim,
|
||||
@@ -734,7 +823,7 @@ class MultiModalTransformerArgsPreprocessor:
|
||||
)
|
||||
|
||||
cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep(
|
||||
timestep=modality.timesteps,
|
||||
timestep=cross_timestep,
|
||||
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
|
||||
batch_size=transformer_args.x.shape[0],
|
||||
hidden_dtype=modality.latent.dtype,
|
||||
@@ -749,7 +838,7 @@ class MultiModalTransformerArgsPreprocessor:
|
||||
|
||||
def _prepare_cross_attention_timestep(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
timestep: torch.Tensor | None,
|
||||
timestep_scale_multiplier: int,
|
||||
batch_size: int,
|
||||
hidden_dtype: torch.dtype,
|
||||
@@ -779,6 +868,8 @@ class TransformerConfig:
|
||||
heads: int
|
||||
d_head: int
|
||||
context_dim: int
|
||||
apply_gated_attention: bool = False
|
||||
cross_attention_adaln: bool = False
|
||||
|
||||
|
||||
class BasicAVTransformerBlock(torch.nn.Module):
|
||||
@@ -801,6 +892,7 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
||||
context_dim=None,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
apply_gated_attention=video.apply_gated_attention,
|
||||
)
|
||||
self.attn2 = Attention(
|
||||
query_dim=video.dim,
|
||||
@@ -809,9 +901,11 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
||||
dim_head=video.d_head,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
apply_gated_attention=video.apply_gated_attention,
|
||||
)
|
||||
self.ff = FeedForward(video.dim, dim_out=video.dim)
|
||||
self.scale_shift_table = torch.nn.Parameter(torch.empty(6, video.dim))
|
||||
video_sst_size = adaln_embedding_coefficient(video.cross_attention_adaln)
|
||||
self.scale_shift_table = torch.nn.Parameter(torch.empty(video_sst_size, video.dim))
|
||||
|
||||
if audio is not None:
|
||||
self.audio_attn1 = Attention(
|
||||
@@ -821,6 +915,7 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
||||
context_dim=None,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
apply_gated_attention=audio.apply_gated_attention,
|
||||
)
|
||||
self.audio_attn2 = Attention(
|
||||
query_dim=audio.dim,
|
||||
@@ -829,9 +924,11 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
||||
dim_head=audio.d_head,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
apply_gated_attention=audio.apply_gated_attention,
|
||||
)
|
||||
self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim)
|
||||
self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(6, audio.dim))
|
||||
audio_sst_size = adaln_embedding_coefficient(audio.cross_attention_adaln)
|
||||
self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(audio_sst_size, audio.dim))
|
||||
|
||||
if audio is not None and video is not None:
|
||||
# Q: Video, K,V: Audio
|
||||
@@ -842,6 +939,7 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
||||
dim_head=audio.d_head,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
apply_gated_attention=video.apply_gated_attention,
|
||||
)
|
||||
|
||||
# Q: Audio, K,V: Video
|
||||
@@ -852,11 +950,21 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
||||
dim_head=audio.d_head,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
apply_gated_attention=audio.apply_gated_attention,
|
||||
)
|
||||
|
||||
self.scale_shift_table_a2v_ca_audio = torch.nn.Parameter(torch.empty(5, audio.dim))
|
||||
self.scale_shift_table_a2v_ca_video = torch.nn.Parameter(torch.empty(5, video.dim))
|
||||
|
||||
self.cross_attention_adaln = (video is not None and video.cross_attention_adaln) or (
|
||||
audio is not None and audio.cross_attention_adaln
|
||||
)
|
||||
|
||||
if self.cross_attention_adaln and video is not None:
|
||||
self.prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, video.dim))
|
||||
if self.cross_attention_adaln and audio is not None:
|
||||
self.audio_prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, audio.dim))
|
||||
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
def get_ada_values(
|
||||
@@ -876,19 +984,49 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
||||
batch_size: int,
|
||||
scale_shift_timestep: torch.Tensor,
|
||||
gate_timestep: torch.Tensor,
|
||||
scale_shift_indices: slice,
|
||||
num_scale_shift_values: int = 4,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
scale_shift_ada_values = self.get_ada_values(
|
||||
scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, slice(None, None)
|
||||
scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, scale_shift_indices
|
||||
)
|
||||
gate_ada_values = self.get_ada_values(
|
||||
scale_shift_table[num_scale_shift_values:, :], batch_size, gate_timestep, slice(None, None)
|
||||
)
|
||||
|
||||
scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values]
|
||||
gate_ada_values = [t.squeeze(2) for t in gate_ada_values]
|
||||
scale, shift = (t.squeeze(2) for t in scale_shift_ada_values)
|
||||
(gate,) = (t.squeeze(2) for t in gate_ada_values)
|
||||
|
||||
return (*scale_shift_chunks, *gate_ada_values)
|
||||
return scale, shift, gate
|
||||
|
||||
def _apply_text_cross_attention(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
attn: Attention,
|
||||
scale_shift_table: torch.Tensor,
|
||||
prompt_scale_shift_table: torch.Tensor | None,
|
||||
timestep: torch.Tensor,
|
||||
prompt_timestep: torch.Tensor | None,
|
||||
context_mask: torch.Tensor | None,
|
||||
cross_attention_adaln: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Apply text cross-attention, with optional AdaLN modulation."""
|
||||
if cross_attention_adaln:
|
||||
shift_q, scale_q, gate = self.get_ada_values(scale_shift_table, x.shape[0], timestep, slice(6, 9))
|
||||
return apply_cross_attention_adaln(
|
||||
x,
|
||||
context,
|
||||
attn,
|
||||
shift_q,
|
||||
scale_q,
|
||||
gate,
|
||||
prompt_scale_shift_table,
|
||||
prompt_timestep,
|
||||
context_mask,
|
||||
self.norm_eps,
|
||||
)
|
||||
return attn(rms_norm(x, eps=self.norm_eps), context=context, mask=context_mask)
|
||||
|
||||
def forward( # noqa: PLR0915
|
||||
self,
|
||||
@@ -896,7 +1034,11 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
||||
audio: TransformerArgs | None,
|
||||
perturbations: BatchedPerturbationConfig | None = None,
|
||||
) -> tuple[TransformerArgs | None, TransformerArgs | None]:
|
||||
batch_size = video.x.shape[0]
|
||||
if video is None and audio is None:
|
||||
raise ValueError("At least one of video or audio must be provided")
|
||||
|
||||
batch_size = (video or audio).x.shape[0]
|
||||
|
||||
if perturbations is None:
|
||||
perturbations = BatchedPerturbationConfig.empty(batch_size)
|
||||
|
||||
@@ -913,63 +1055,103 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
||||
vshift_msa, vscale_msa, vgate_msa = self.get_ada_values(
|
||||
self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)
|
||||
)
|
||||
if not perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx):
|
||||
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
|
||||
v_mask = perturbations.mask_like(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx, vx)
|
||||
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa * v_mask
|
||||
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
|
||||
del vshift_msa, vscale_msa
|
||||
|
||||
vx = vx + self.attn2(rms_norm(vx, eps=self.norm_eps), context=video.context, mask=video.context_mask)
|
||||
|
||||
del vshift_msa, vscale_msa, vgate_msa
|
||||
all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx)
|
||||
none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx)
|
||||
v_mask = (
|
||||
perturbations.mask_like(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx, vx)
|
||||
if not all_perturbed and not none_perturbed
|
||||
else None
|
||||
)
|
||||
vx = (
|
||||
vx
|
||||
+ self.attn1(
|
||||
norm_vx,
|
||||
pe=video.positional_embeddings,
|
||||
mask=video.self_attention_mask,
|
||||
perturbation_mask=v_mask,
|
||||
all_perturbed=all_perturbed,
|
||||
)
|
||||
* vgate_msa
|
||||
)
|
||||
del vgate_msa, norm_vx, v_mask
|
||||
vx = vx + self._apply_text_cross_attention(
|
||||
vx,
|
||||
video.context,
|
||||
self.attn2,
|
||||
self.scale_shift_table,
|
||||
getattr(self, "prompt_scale_shift_table", None),
|
||||
video.timesteps,
|
||||
video.prompt_timestep,
|
||||
video.context_mask,
|
||||
cross_attention_adaln=self.cross_attention_adaln,
|
||||
)
|
||||
|
||||
if run_ax:
|
||||
ashift_msa, ascale_msa, agate_msa = self.get_ada_values(
|
||||
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
|
||||
)
|
||||
|
||||
if not perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx):
|
||||
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
|
||||
a_mask = perturbations.mask_like(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx, ax)
|
||||
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa * a_mask
|
||||
|
||||
ax = ax + self.audio_attn2(rms_norm(ax, eps=self.norm_eps), context=audio.context, mask=audio.context_mask)
|
||||
|
||||
del ashift_msa, ascale_msa, agate_msa
|
||||
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
|
||||
del ashift_msa, ascale_msa
|
||||
all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx)
|
||||
none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx)
|
||||
a_mask = (
|
||||
perturbations.mask_like(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx, ax)
|
||||
if not all_perturbed and not none_perturbed
|
||||
else None
|
||||
)
|
||||
ax = (
|
||||
ax
|
||||
+ self.audio_attn1(
|
||||
norm_ax,
|
||||
pe=audio.positional_embeddings,
|
||||
mask=audio.self_attention_mask,
|
||||
perturbation_mask=a_mask,
|
||||
all_perturbed=all_perturbed,
|
||||
)
|
||||
* agate_msa
|
||||
)
|
||||
del agate_msa, norm_ax, a_mask
|
||||
ax = ax + self._apply_text_cross_attention(
|
||||
ax,
|
||||
audio.context,
|
||||
self.audio_attn2,
|
||||
self.audio_scale_shift_table,
|
||||
getattr(self, "audio_prompt_scale_shift_table", None),
|
||||
audio.timesteps,
|
||||
audio.prompt_timestep,
|
||||
audio.context_mask,
|
||||
cross_attention_adaln=self.cross_attention_adaln,
|
||||
)
|
||||
|
||||
# Audio - Video cross attention.
|
||||
if run_a2v or run_v2a:
|
||||
vx_norm3 = rms_norm(vx, eps=self.norm_eps)
|
||||
ax_norm3 = rms_norm(ax, eps=self.norm_eps)
|
||||
|
||||
(
|
||||
scale_ca_audio_hidden_states_a2v,
|
||||
shift_ca_audio_hidden_states_a2v,
|
||||
scale_ca_audio_hidden_states_v2a,
|
||||
shift_ca_audio_hidden_states_v2a,
|
||||
gate_out_v2a,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio,
|
||||
ax.shape[0],
|
||||
audio.cross_scale_shift_timestep,
|
||||
audio.cross_gate_timestep,
|
||||
)
|
||||
if run_a2v and not perturbations.all_in_batch(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx):
|
||||
scale_ca_video_a2v, shift_ca_video_a2v, gate_out_a2v = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video,
|
||||
vx.shape[0],
|
||||
video.cross_scale_shift_timestep,
|
||||
video.cross_gate_timestep,
|
||||
slice(0, 2),
|
||||
)
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_a2v) + shift_ca_video_a2v
|
||||
del scale_ca_video_a2v, shift_ca_video_a2v
|
||||
|
||||
(
|
||||
scale_ca_video_hidden_states_a2v,
|
||||
shift_ca_video_hidden_states_a2v,
|
||||
scale_ca_video_hidden_states_v2a,
|
||||
shift_ca_video_hidden_states_v2a,
|
||||
gate_out_a2v,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video,
|
||||
vx.shape[0],
|
||||
video.cross_scale_shift_timestep,
|
||||
video.cross_gate_timestep,
|
||||
)
|
||||
|
||||
if run_a2v:
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v) + shift_ca_video_hidden_states_a2v
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
|
||||
scale_ca_audio_a2v, shift_ca_audio_a2v, _ = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio,
|
||||
ax.shape[0],
|
||||
audio.cross_scale_shift_timestep,
|
||||
audio.cross_gate_timestep,
|
||||
slice(0, 2),
|
||||
)
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_a2v) + shift_ca_audio_a2v
|
||||
del scale_ca_audio_a2v, shift_ca_audio_a2v
|
||||
a2v_mask = perturbations.mask_like(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx, vx)
|
||||
vx = vx + (
|
||||
self.audio_to_video_attn(
|
||||
@@ -981,10 +1163,27 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
||||
* gate_out_a2v
|
||||
* a2v_mask
|
||||
)
|
||||
del gate_out_a2v, a2v_mask, vx_scaled, ax_scaled
|
||||
|
||||
if run_v2a:
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
|
||||
if run_v2a and not perturbations.all_in_batch(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx):
|
||||
scale_ca_audio_v2a, shift_ca_audio_v2a, gate_out_v2a = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio,
|
||||
ax.shape[0],
|
||||
audio.cross_scale_shift_timestep,
|
||||
audio.cross_gate_timestep,
|
||||
slice(2, 4),
|
||||
)
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_v2a) + shift_ca_audio_v2a
|
||||
del scale_ca_audio_v2a, shift_ca_audio_v2a
|
||||
scale_ca_video_v2a, shift_ca_video_v2a, _ = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video,
|
||||
vx.shape[0],
|
||||
video.cross_scale_shift_timestep,
|
||||
video.cross_gate_timestep,
|
||||
slice(2, 4),
|
||||
)
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_v2a) + shift_ca_video_v2a
|
||||
del scale_ca_video_v2a, shift_ca_video_v2a
|
||||
v2a_mask = perturbations.mask_like(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx, ax)
|
||||
ax = ax + (
|
||||
self.video_to_audio_attn(
|
||||
@@ -996,40 +1195,53 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
||||
* gate_out_v2a
|
||||
* v2a_mask
|
||||
)
|
||||
del gate_out_v2a, v2a_mask, ax_scaled, vx_scaled
|
||||
|
||||
del gate_out_a2v, gate_out_v2a
|
||||
del (
|
||||
scale_ca_video_hidden_states_a2v,
|
||||
shift_ca_video_hidden_states_a2v,
|
||||
scale_ca_audio_hidden_states_a2v,
|
||||
shift_ca_audio_hidden_states_a2v,
|
||||
scale_ca_video_hidden_states_v2a,
|
||||
shift_ca_video_hidden_states_v2a,
|
||||
scale_ca_audio_hidden_states_v2a,
|
||||
shift_ca_audio_hidden_states_v2a,
|
||||
)
|
||||
del vx_norm3, ax_norm3
|
||||
|
||||
if run_vx:
|
||||
vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(
|
||||
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, None)
|
||||
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, 6)
|
||||
)
|
||||
vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
|
||||
vx = vx + self.ff(vx_scaled) * vgate_mlp
|
||||
|
||||
del vshift_mlp, vscale_mlp, vgate_mlp
|
||||
del vshift_mlp, vscale_mlp, vgate_mlp, vx_scaled
|
||||
|
||||
if run_ax:
|
||||
ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(
|
||||
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, None)
|
||||
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, 6)
|
||||
)
|
||||
ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
|
||||
ax = ax + self.audio_ff(ax_scaled) * agate_mlp
|
||||
|
||||
del ashift_mlp, ascale_mlp, agate_mlp
|
||||
del ashift_mlp, ascale_mlp, agate_mlp, ax_scaled
|
||||
|
||||
return replace(video, x=vx) if video is not None else None, replace(audio, x=ax) if audio is not None else None
|
||||
|
||||
|
||||
def apply_cross_attention_adaln(
|
||||
x: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
attn: Attention,
|
||||
q_shift: torch.Tensor,
|
||||
q_scale: torch.Tensor,
|
||||
q_gate: torch.Tensor,
|
||||
prompt_scale_shift_table: torch.Tensor,
|
||||
prompt_timestep: torch.Tensor,
|
||||
context_mask: torch.Tensor | None = None,
|
||||
norm_eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
batch_size = x.shape[0]
|
||||
shift_kv, scale_kv = (
|
||||
prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)
|
||||
+ prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1)
|
||||
).unbind(dim=2)
|
||||
attn_input = rms_norm(x, eps=norm_eps) * (1 + q_scale) + q_shift
|
||||
encoder_hidden_states = context * (1 + scale_kv) + shift_kv
|
||||
return attn(attn_input, context=encoder_hidden_states, mask=context_mask) * q_gate
|
||||
|
||||
|
||||
class GELUApprox(torch.nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int) -> None:
|
||||
super().__init__()
|
||||
@@ -1094,6 +1306,8 @@ class LTXModel(torch.nn.Module):
|
||||
av_ca_timestep_scale_multiplier: int = 1000,
|
||||
rope_type: LTXRopeType = LTXRopeType.SPLIT,
|
||||
double_precision_rope: bool = True,
|
||||
apply_gated_attention: bool = False,
|
||||
cross_attention_adaln: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self._enable_gradient_checkpointing = False
|
||||
@@ -1103,6 +1317,7 @@ class LTXModel(torch.nn.Module):
|
||||
self.timestep_scale_multiplier = timestep_scale_multiplier
|
||||
self.positional_embedding_theta = positional_embedding_theta
|
||||
self.model_type = model_type
|
||||
self.cross_attention_adaln = cross_attention_adaln
|
||||
cross_pe_max_pos = None
|
||||
if model_type.is_video_enabled():
|
||||
if positional_embedding_max_pos is None:
|
||||
@@ -1145,8 +1360,13 @@ class LTXModel(torch.nn.Module):
|
||||
audio_attention_head_dim=audio_attention_head_dim if model_type.is_audio_enabled() else 0,
|
||||
audio_cross_attention_dim=audio_cross_attention_dim,
|
||||
norm_eps=norm_eps,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
)
|
||||
|
||||
@property
|
||||
def _adaln_embedding_coefficient(self) -> int:
|
||||
return adaln_embedding_coefficient(self.cross_attention_adaln)
|
||||
|
||||
def _init_video(
|
||||
self,
|
||||
in_channels: int,
|
||||
@@ -1157,14 +1377,15 @@ class LTXModel(torch.nn.Module):
|
||||
"""Initialize video-specific components."""
|
||||
# Video input components
|
||||
self.patchify_proj = torch.nn.Linear(in_channels, self.inner_dim, bias=True)
|
||||
|
||||
self.adaln_single = AdaLayerNormSingle(self.inner_dim)
|
||||
self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=self._adaln_embedding_coefficient)
|
||||
self.prompt_adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None
|
||||
|
||||
# Video caption projection
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
)
|
||||
if caption_channels is not None:
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
)
|
||||
|
||||
# Video output components
|
||||
self.scale_shift_table = torch.nn.Parameter(torch.empty(2, self.inner_dim))
|
||||
@@ -1183,15 +1404,15 @@ class LTXModel(torch.nn.Module):
|
||||
# Audio input components
|
||||
self.audio_patchify_proj = torch.nn.Linear(in_channels, self.audio_inner_dim, bias=True)
|
||||
|
||||
self.audio_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
)
|
||||
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=self._adaln_embedding_coefficient)
|
||||
self.audio_prompt_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None
|
||||
|
||||
# Audio caption projection
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
)
|
||||
if caption_channels is not None:
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
)
|
||||
|
||||
# Audio output components
|
||||
self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(2, self.audio_inner_dim))
|
||||
@@ -1233,7 +1454,6 @@ class LTXModel(torch.nn.Module):
|
||||
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
||||
patchify_proj=self.patchify_proj,
|
||||
adaln=self.adaln_single,
|
||||
caption_projection=self.caption_projection,
|
||||
cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single,
|
||||
cross_gate_adaln=self.av_ca_a2v_gate_adaln_single,
|
||||
inner_dim=self.inner_dim,
|
||||
@@ -1247,11 +1467,12 @@ class LTXModel(torch.nn.Module):
|
||||
positional_embedding_theta=self.positional_embedding_theta,
|
||||
rope_type=self.rope_type,
|
||||
av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier,
|
||||
caption_projection=getattr(self, "caption_projection", None),
|
||||
prompt_adaln=getattr(self, "prompt_adaln_single", None),
|
||||
)
|
||||
self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
||||
patchify_proj=self.audio_patchify_proj,
|
||||
adaln=self.audio_adaln_single,
|
||||
caption_projection=self.audio_caption_projection,
|
||||
cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single,
|
||||
cross_gate_adaln=self.av_ca_v2a_gate_adaln_single,
|
||||
inner_dim=self.audio_inner_dim,
|
||||
@@ -1265,12 +1486,13 @@ class LTXModel(torch.nn.Module):
|
||||
positional_embedding_theta=self.positional_embedding_theta,
|
||||
rope_type=self.rope_type,
|
||||
av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier,
|
||||
caption_projection=getattr(self, "audio_caption_projection", None),
|
||||
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
|
||||
)
|
||||
elif self.model_type.is_video_enabled():
|
||||
self.video_args_preprocessor = TransformerArgsPreprocessor(
|
||||
patchify_proj=self.patchify_proj,
|
||||
adaln=self.adaln_single,
|
||||
caption_projection=self.caption_projection,
|
||||
inner_dim=self.inner_dim,
|
||||
max_pos=self.positional_embedding_max_pos,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
@@ -1279,12 +1501,13 @@ class LTXModel(torch.nn.Module):
|
||||
double_precision_rope=self.double_precision_rope,
|
||||
positional_embedding_theta=self.positional_embedding_theta,
|
||||
rope_type=self.rope_type,
|
||||
caption_projection=getattr(self, "caption_projection", None),
|
||||
prompt_adaln=getattr(self, "prompt_adaln_single", None),
|
||||
)
|
||||
elif self.model_type.is_audio_enabled():
|
||||
self.audio_args_preprocessor = TransformerArgsPreprocessor(
|
||||
patchify_proj=self.audio_patchify_proj,
|
||||
adaln=self.audio_adaln_single,
|
||||
caption_projection=self.audio_caption_projection,
|
||||
inner_dim=self.audio_inner_dim,
|
||||
max_pos=self.audio_positional_embedding_max_pos,
|
||||
num_attention_heads=self.audio_num_attention_heads,
|
||||
@@ -1293,6 +1516,8 @@ class LTXModel(torch.nn.Module):
|
||||
double_precision_rope=self.double_precision_rope,
|
||||
positional_embedding_theta=self.positional_embedding_theta,
|
||||
rope_type=self.rope_type,
|
||||
caption_projection=getattr(self, "audio_caption_projection", None),
|
||||
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
|
||||
)
|
||||
|
||||
def _init_transformer_blocks(
|
||||
@@ -1303,6 +1528,7 @@ class LTXModel(torch.nn.Module):
|
||||
audio_attention_head_dim: int,
|
||||
audio_cross_attention_dim: int,
|
||||
norm_eps: float,
|
||||
apply_gated_attention: bool,
|
||||
) -> None:
|
||||
"""Initialize transformer blocks for LTX."""
|
||||
video_config = (
|
||||
@@ -1311,6 +1537,8 @@ class LTXModel(torch.nn.Module):
|
||||
heads=self.num_attention_heads,
|
||||
d_head=attention_head_dim,
|
||||
context_dim=cross_attention_dim,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
cross_attention_adaln=self.cross_attention_adaln,
|
||||
)
|
||||
if self.model_type.is_video_enabled()
|
||||
else None
|
||||
@@ -1321,6 +1549,8 @@ class LTXModel(torch.nn.Module):
|
||||
heads=self.audio_num_attention_heads,
|
||||
d_head=audio_attention_head_dim,
|
||||
context_dim=audio_cross_attention_dim,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
cross_attention_adaln=self.cross_attention_adaln,
|
||||
)
|
||||
if self.model_type.is_audio_enabled()
|
||||
else None
|
||||
@@ -1409,8 +1639,8 @@ class LTXModel(torch.nn.Module):
|
||||
if not self.model_type.is_audio_enabled() and audio is not None:
|
||||
raise ValueError("Audio is not enabled for this model")
|
||||
|
||||
video_args = self.video_args_preprocessor.prepare(video) if video is not None else None
|
||||
audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None
|
||||
video_args = self.video_args_preprocessor.prepare(video, audio) if video is not None else None
|
||||
audio_args = self.audio_args_preprocessor.prepare(audio, video) if audio is not None else None
|
||||
# Process transformer blocks
|
||||
video_out, audio_out = self._process_transformer_blocks(
|
||||
video=video_args,
|
||||
@@ -1441,12 +1671,12 @@ class LTXModel(torch.nn.Module):
|
||||
)
|
||||
return vx, ax
|
||||
|
||||
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False):
|
||||
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps, sigma, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False):
|
||||
cross_pe_max_pos = None
|
||||
if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled():
|
||||
cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0])
|
||||
self._init_preprocessors(cross_pe_max_pos)
|
||||
video = Modality(video_latents, video_timesteps, video_positions, video_context)
|
||||
audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context) if audio_latents is not None else None
|
||||
video = Modality(video_latents, sigma, video_timesteps, video_positions, video_context)
|
||||
audio = Modality(audio_latents, sigma, audio_timesteps, audio_positions, audio_context) if audio_latents is not None else None
|
||||
vx, ax = self._forward(video=video, audio=audio, perturbations=None, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload)
|
||||
return vx, ax
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from transformers import Gemma3ForConditionalGeneration, Gemma3Config, AutoTokenizer
|
||||
from .ltx2_dit import (LTXRopeType, generate_freq_grid_np, generate_freq_grid_pytorch, precompute_freqs_cis, Attention,
|
||||
FeedForward)
|
||||
@@ -147,14 +150,14 @@ class LTXVGemmaTokenizer:
|
||||
return out
|
||||
|
||||
|
||||
class GemmaFeaturesExtractorProjLinear(torch.nn.Module):
|
||||
class GemmaFeaturesExtractorProjLinear(nn.Module):
|
||||
"""
|
||||
Feature extractor module for Gemma models.
|
||||
This module applies a single linear projection to the input tensor.
|
||||
It expects a flattened feature tensor of shape (batch_size, 3840*49).
|
||||
The linear layer maps this to a (batch_size, 3840) embedding.
|
||||
Attributes:
|
||||
aggregate_embed (torch.nn.Linear): Linear projection layer.
|
||||
aggregate_embed (nn.Linear): Linear projection layer.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
@@ -163,26 +166,65 @@ class GemmaFeaturesExtractorProjLinear(torch.nn.Module):
|
||||
The input dimension is expected to be 3840 * 49, and the output is 3840.
|
||||
"""
|
||||
super().__init__()
|
||||
self.aggregate_embed = torch.nn.Linear(3840 * 49, 3840, bias=False)
|
||||
self.aggregate_embed = nn.Linear(3840 * 49, 3840, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the feature extractor.
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor of shape (batch_size, 3840 * 49).
|
||||
Returns:
|
||||
torch.Tensor: Output tensor of shape (batch_size, 3840).
|
||||
"""
|
||||
return self.aggregate_embed(x)
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
padding_side: str = "left",
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
encoded = torch.stack(hidden_states, dim=-1) if isinstance(hidden_states, (list, tuple)) else hidden_states
|
||||
dtype = encoded.dtype
|
||||
sequence_lengths = attention_mask.sum(dim=-1)
|
||||
normed = _norm_and_concat_padded_batch(encoded, sequence_lengths, padding_side)
|
||||
features = self.aggregate_embed(normed.to(dtype))
|
||||
return features, features
|
||||
|
||||
|
||||
class _BasicTransformerBlock1D(torch.nn.Module):
|
||||
class GemmaSeperatedFeaturesExtractorProjLinear(nn.Module):
|
||||
"""22B: per-token RMS norm → rescale → dual aggregate embeds"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int,
|
||||
embedding_dim: int,
|
||||
video_inner_dim: int,
|
||||
audio_inner_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
in_dim = embedding_dim * num_layers
|
||||
self.video_aggregate_embed = torch.nn.Linear(in_dim, video_inner_dim, bias=True)
|
||||
self.audio_aggregate_embed = torch.nn.Linear(in_dim, audio_inner_dim, bias=True)
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
padding_side: str = "left", # noqa: ARG002
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
encoded = torch.stack(hidden_states, dim=-1) if isinstance(hidden_states, (list, tuple)) else hidden_states
|
||||
normed = norm_and_concat_per_token_rms(encoded, attention_mask)
|
||||
normed = normed.to(encoded.dtype)
|
||||
v_dim = self.video_aggregate_embed.out_features
|
||||
video = self.video_aggregate_embed(_rescale_norm(normed, v_dim, self.embedding_dim))
|
||||
audio = None
|
||||
if self.audio_aggregate_embed is not None:
|
||||
a_dim = self.audio_aggregate_embed.out_features
|
||||
audio = self.audio_aggregate_embed(_rescale_norm(normed, a_dim, self.embedding_dim))
|
||||
return video, audio
|
||||
|
||||
|
||||
|
||||
class _BasicTransformerBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
heads: int,
|
||||
dim_head: int,
|
||||
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||
apply_gated_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -191,6 +233,7 @@ class _BasicTransformerBlock1D(torch.nn.Module):
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
)
|
||||
|
||||
self.ff = FeedForward(
|
||||
@@ -231,7 +274,7 @@ class _BasicTransformerBlock1D(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Embeddings1DConnector(torch.nn.Module):
|
||||
class Embeddings1DConnector(nn.Module):
|
||||
"""
|
||||
Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or
|
||||
other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can
|
||||
@@ -263,6 +306,7 @@ class Embeddings1DConnector(torch.nn.Module):
|
||||
num_learnable_registers: int | None = 128,
|
||||
rope_type: LTXRopeType = LTXRopeType.SPLIT,
|
||||
double_precision_rope: bool = True,
|
||||
apply_gated_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
@@ -274,13 +318,14 @@ class Embeddings1DConnector(torch.nn.Module):
|
||||
)
|
||||
self.rope_type = rope_type
|
||||
self.double_precision_rope = double_precision_rope
|
||||
self.transformer_1d_blocks = torch.nn.ModuleList(
|
||||
self.transformer_1d_blocks = nn.ModuleList(
|
||||
[
|
||||
_BasicTransformerBlock1D(
|
||||
dim=self.inner_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
@@ -288,7 +333,7 @@ class Embeddings1DConnector(torch.nn.Module):
|
||||
|
||||
self.num_learnable_registers = num_learnable_registers
|
||||
if self.num_learnable_registers:
|
||||
self.learnable_registers = torch.nn.Parameter(
|
||||
self.learnable_registers = nn.Parameter(
|
||||
torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0
|
||||
)
|
||||
|
||||
@@ -307,7 +352,7 @@ class Embeddings1DConnector(torch.nn.Module):
|
||||
non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :]
|
||||
non_zero_nums = non_zero_hidden_states.shape[1]
|
||||
pad_length = hidden_states.shape[1] - non_zero_nums
|
||||
adjusted_hidden_states = torch.nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0)
|
||||
adjusted_hidden_states = nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0)
|
||||
flipped_mask = torch.flip(attention_mask_binary, dims=[1])
|
||||
hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers
|
||||
|
||||
@@ -358,9 +403,147 @@ class Embeddings1DConnector(torch.nn.Module):
|
||||
return hidden_states, attention_mask
|
||||
|
||||
|
||||
class LTX2TextEncoderPostModules(torch.nn.Module):
|
||||
def __init__(self,):
|
||||
class LTX2TextEncoderPostModules(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
seperated_audio_video: bool = False,
|
||||
embedding_dim_gemma: int = 3840,
|
||||
num_layers_gemma: int = 49,
|
||||
video_attetion_heads: int = 32,
|
||||
video_attention_head_dim: int = 128,
|
||||
audio_attention_heads: int = 32,
|
||||
audio_attention_head_dim: int = 64,
|
||||
num_connetor_layers: int = 2,
|
||||
apply_gated_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear()
|
||||
self.embeddings_connector = Embeddings1DConnector()
|
||||
self.audio_embeddings_connector = Embeddings1DConnector()
|
||||
if not seperated_audio_video:
|
||||
self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear()
|
||||
self.embeddings_connector = Embeddings1DConnector()
|
||||
self.audio_embeddings_connector = Embeddings1DConnector()
|
||||
else:
|
||||
# LTX-2.3
|
||||
self.feature_extractor_linear = GemmaSeperatedFeaturesExtractorProjLinear(
|
||||
num_layers_gemma, embedding_dim_gemma, video_attetion_heads * video_attention_head_dim,
|
||||
audio_attention_heads * audio_attention_head_dim)
|
||||
self.embeddings_connector = Embeddings1DConnector(
|
||||
attention_head_dim=video_attention_head_dim,
|
||||
num_attention_heads=video_attetion_heads,
|
||||
num_layers=num_connetor_layers,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
)
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
attention_head_dim=audio_attention_head_dim,
|
||||
num_attention_heads=audio_attention_heads,
|
||||
num_layers=num_connetor_layers,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
)
|
||||
|
||||
def create_embeddings(
|
||||
self,
|
||||
video_features: torch.Tensor,
|
||||
audio_features: torch.Tensor | None,
|
||||
additive_attention_mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]:
|
||||
video_encoded, video_mask = self.embeddings_connector(video_features, additive_attention_mask)
|
||||
video_encoded, binary_mask = _to_binary_mask(video_encoded, video_mask)
|
||||
audio_encoded, _ = self.audio_embeddings_connector(audio_features, additive_attention_mask)
|
||||
|
||||
return video_encoded, audio_encoded, binary_mask
|
||||
|
||||
def process_hidden_states(
|
||||
self,
|
||||
hidden_states: tuple[torch.Tensor, ...],
|
||||
attention_mask: torch.Tensor,
|
||||
padding_side: str = "left",
|
||||
):
|
||||
video_feats, audio_feats = self.feature_extractor_linear(hidden_states, attention_mask, padding_side)
|
||||
additive_mask = _convert_to_additive_mask(attention_mask, video_feats.dtype)
|
||||
video_enc, audio_enc, binary_mask = self.create_embeddings(video_feats, audio_feats, additive_mask)
|
||||
return video_enc, audio_enc, binary_mask
|
||||
|
||||
|
||||
def _norm_and_concat_padded_batch(
|
||||
encoded_text: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
padding_side: str = "right",
|
||||
) -> torch.Tensor:
|
||||
"""Normalize and flatten multi-layer hidden states, respecting padding.
|
||||
Performs per-batch, per-layer normalization using masked mean and range,
|
||||
then concatenates across the layer dimension.
|
||||
Args:
|
||||
encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].
|
||||
sequence_lengths: Number of valid (non-padded) tokens per batch item.
|
||||
padding_side: Whether padding is on "left" or "right".
|
||||
Returns:
|
||||
Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],
|
||||
with padded positions zeroed out.
|
||||
"""
|
||||
b, t, d, l = encoded_text.shape # noqa: E741
|
||||
device = encoded_text.device
|
||||
# Build mask: [B, T, 1, 1]
|
||||
token_indices = torch.arange(t, device=device)[None, :] # [1, T]
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [B, T]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = t - sequence_lengths[:, None] # [B, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = rearrange(mask, "b t -> b t 1 1")
|
||||
eps = 1e-6
|
||||
# Compute masked mean: [B, 1, 1, L]
|
||||
masked = encoded_text.masked_fill(~mask, 0.0)
|
||||
denom = (sequence_lengths * d).view(b, 1, 1, 1)
|
||||
mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps)
|
||||
# Compute masked min/max: [B, 1, 1, L]
|
||||
x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
range_ = x_max - x_min
|
||||
# Normalize only the valid tokens
|
||||
normed = 8 * (encoded_text - mean) / (range_ + eps)
|
||||
# concat to be [Batch, T, D * L] - this preserves the original structure
|
||||
normed = normed.reshape(b, t, -1) # [B, T, D * L]
|
||||
# Apply mask to preserve original padding (set padded positions to 0)
|
||||
mask_flattened = rearrange(mask, "b t 1 1 -> b t 1").expand(-1, -1, d * l)
|
||||
normed = normed.masked_fill(~mask_flattened, 0.0)
|
||||
|
||||
return normed
|
||||
|
||||
|
||||
def _convert_to_additive_mask(attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
return (attention_mask - 1).to(dtype).reshape(
|
||||
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(dtype).max
|
||||
|
||||
def _to_binary_mask(encoded: torch.Tensor, encoded_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert connector output mask to binary mask and apply to encoded tensor."""
|
||||
binary_mask = (encoded_mask < 0.000001).to(torch.int64)
|
||||
binary_mask = binary_mask.reshape([encoded.shape[0], encoded.shape[1], 1])
|
||||
encoded = encoded * binary_mask
|
||||
return encoded, binary_mask
|
||||
|
||||
|
||||
def norm_and_concat_per_token_rms(
|
||||
encoded_text: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Per-token RMSNorm normalization for V2 models.
|
||||
Args:
|
||||
encoded_text: [B, T, D, L]
|
||||
attention_mask: [B, T] binary mask
|
||||
Returns:
|
||||
[B, T, D*L] normalized tensor with padding zeroed out.
|
||||
"""
|
||||
B, T, D, L = encoded_text.shape # noqa: N806
|
||||
variance = torch.mean(encoded_text**2, dim=2, keepdim=True) # [B,T,1,L]
|
||||
normed = encoded_text * torch.rsqrt(variance + 1e-6)
|
||||
normed = normed.reshape(B, T, D * L)
|
||||
mask_3d = attention_mask.bool().unsqueeze(-1) # [B, T, 1]
|
||||
return torch.where(mask_3d, normed, torch.zeros_like(normed))
|
||||
|
||||
|
||||
def _rescale_norm(x: torch.Tensor, target_dim: int, source_dim: int) -> torch.Tensor:
|
||||
"""Rescale normalization: x * sqrt(target_dim / source_dim)."""
|
||||
return x * math.sqrt(target_dim / source_dim)
|
||||
|
||||
@@ -555,9 +555,6 @@ class PerChannelStatistics(nn.Module):
|
||||
super().__init__()
|
||||
self.register_buffer("std-of-means", torch.empty(latent_channels))
|
||||
self.register_buffer("mean-of-means", torch.empty(latent_channels))
|
||||
self.register_buffer("mean-of-stds", torch.empty(latent_channels))
|
||||
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(latent_channels))
|
||||
self.register_buffer("channel", torch.empty(latent_channels))
|
||||
|
||||
def un_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(
|
||||
@@ -1335,27 +1332,49 @@ class LTX2VideoEncoder(nn.Module):
|
||||
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
||||
latent_log_var: LogVarianceType = LogVarianceType.UNIFORM,
|
||||
encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
encoder_version: str = "ltx-2",
|
||||
):
|
||||
super().__init__()
|
||||
encoder_blocks = [['res_x', {
|
||||
'num_layers': 4
|
||||
}], ['compress_space_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 6
|
||||
}], ['compress_time_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 6
|
||||
}], ['compress_all_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 2
|
||||
}], ['compress_all_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 2
|
||||
}]]
|
||||
if encoder_version == "ltx-2":
|
||||
encoder_blocks = [['res_x', {
|
||||
'num_layers': 4
|
||||
}], ['compress_space_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 6
|
||||
}], ['compress_time_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 6
|
||||
}], ['compress_all_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 2
|
||||
}], ['compress_all_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 2
|
||||
}]]
|
||||
else:
|
||||
encoder_blocks = [["res_x", {
|
||||
"num_layers": 4
|
||||
}], ["compress_space_res", {
|
||||
"multiplier": 2
|
||||
}], ["res_x", {
|
||||
"num_layers": 6
|
||||
}], ["compress_time_res", {
|
||||
"multiplier": 2
|
||||
}], ["res_x", {
|
||||
"num_layers": 4
|
||||
}], ["compress_all_res", {
|
||||
"multiplier": 2
|
||||
}], ["res_x", {
|
||||
"num_layers": 2
|
||||
}], ["compress_all_res", {
|
||||
"multiplier": 1
|
||||
}], ["res_x", {
|
||||
"num_layers": 2
|
||||
}]]
|
||||
self.patch_size = patch_size
|
||||
self.norm_layer = norm_layer
|
||||
self.latent_channels = out_channels
|
||||
@@ -1435,8 +1454,8 @@ class LTX2VideoEncoder(nn.Module):
|
||||
# Validate frame count
|
||||
frames_count = sample.shape[2]
|
||||
if ((frames_count - 1) % 8) != 0:
|
||||
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames "
|
||||
"(e.g., 1, 9, 17, ...). Please check your input.")
|
||||
frames_to_crop = (frames_count - 1) % 8
|
||||
sample = sample[:, :, :-frames_to_crop, ...]
|
||||
|
||||
# Initial spatial compression: trade spatial resolution for channel depth
|
||||
# This reduces H,W by patch_size and increases channels, making convolutions more efficient
|
||||
@@ -1712,17 +1731,21 @@ def _make_decoder_block(
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
out_channels = in_channels // block_config.get("multiplier", 1)
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
stride=(2, 1, 1),
|
||||
out_channels_reduction_factor=block_config.get("multiplier", 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
out_channels = in_channels // block_config.get("multiplier", 1)
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
stride=(1, 2, 2),
|
||||
out_channels_reduction_factor=block_config.get("multiplier", 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
@@ -1782,6 +1805,8 @@ class LTX2VideoDecoder(nn.Module):
|
||||
causal: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||
decoder_version: str = "ltx-2",
|
||||
base_channels: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -1790,28 +1815,49 @@ class LTX2VideoDecoder(nn.Module):
|
||||
# video inputs by a factor of 8 in the temporal dimension and 32 in
|
||||
# each spatial dimension (height and width). This parameter determines how
|
||||
# many video frames and pixels correspond to a single latent cell.
|
||||
decoder_blocks = [['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}], ['compress_all', {
|
||||
'residual': True,
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}], ['compress_all', {
|
||||
'residual': True,
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}], ['compress_all', {
|
||||
'residual': True,
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}]]
|
||||
if decoder_version == "ltx-2":
|
||||
decoder_blocks = [['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}], ['compress_all', {
|
||||
'residual': True,
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}], ['compress_all', {
|
||||
'residual': True,
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}], ['compress_all', {
|
||||
'residual': True,
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}]]
|
||||
else:
|
||||
decoder_blocks = [["res_x", {
|
||||
"num_layers": 4
|
||||
}], ["compress_space", {
|
||||
"multiplier": 2
|
||||
}], ["res_x", {
|
||||
"num_layers": 6
|
||||
}], ["compress_time", {
|
||||
"multiplier": 2
|
||||
}], ["res_x", {
|
||||
"num_layers": 4
|
||||
}], ["compress_all", {
|
||||
"multiplier": 1
|
||||
}], ["res_x", {
|
||||
"num_layers": 2
|
||||
}], ["compress_all", {
|
||||
"multiplier": 2
|
||||
}], ["res_x", {
|
||||
"num_layers": 2
|
||||
}]]
|
||||
self.video_downscale_factors = SpatioTemporalScaleFactors(
|
||||
time=8,
|
||||
width=32,
|
||||
@@ -1833,13 +1879,14 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
# Compute initial feature_channels by going through blocks in reverse
|
||||
# This determines the channel width at the start of the decoder
|
||||
feature_channels = in_channels
|
||||
for block_name, block_params in list(reversed(decoder_blocks)):
|
||||
block_config = block_params if isinstance(block_params, dict) else {}
|
||||
if block_name == "res_x_y":
|
||||
feature_channels = feature_channels * block_config.get("multiplier", 2)
|
||||
if block_name == "compress_all":
|
||||
feature_channels = feature_channels * block_config.get("multiplier", 1)
|
||||
# feature_channels = in_channels
|
||||
# for block_name, block_params in list(reversed(decoder_blocks)):
|
||||
# block_config = block_params if isinstance(block_params, dict) else {}
|
||||
# if block_name == "res_x_y":
|
||||
# feature_channels = feature_channels * block_config.get("multiplier", 2)
|
||||
# if block_name == "compress_all":
|
||||
# feature_channels = feature_channels * block_config.get("multiplier", 1)
|
||||
feature_channels = base_channels * 8
|
||||
|
||||
self.conv_in = make_conv_nd(
|
||||
dims=convolution_dimensions,
|
||||
|
||||
@@ -92,13 +92,13 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder")
|
||||
pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder")
|
||||
pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler")
|
||||
pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
|
||||
|
||||
# Stage 2
|
||||
if stage2_lora_config is not None:
|
||||
stage2_lora_config.download_if_necessary()
|
||||
pipe.stage2_lora_path = stage2_lora_config.path
|
||||
# Optional, currently not used
|
||||
pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
@@ -168,8 +168,8 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
# Shape
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 768,
|
||||
num_frames=121,
|
||||
frame_rate=24,
|
||||
num_frames: Optional[int] = 121,
|
||||
frame_rate: Optional[int] = 24,
|
||||
# Classifier-free guidance
|
||||
cfg_scale: Optional[float] = 3.0,
|
||||
# Scheduler
|
||||
@@ -238,7 +238,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float()
|
||||
return video, decoded_audio
|
||||
|
||||
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength, initial_latents=None, num_frames=121):
|
||||
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength, initial_latents=None):
|
||||
b, _, f, h, w = latents.shape
|
||||
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device)
|
||||
initial_latents = torch.zeros_like(latents) if initial_latents is None else initial_latents
|
||||
@@ -306,121 +306,20 @@ class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
|
||||
output_params=("video_context", "audio_context"),
|
||||
onload_model_names=("text_encoder", "text_encoder_post_modules"),
|
||||
)
|
||||
|
||||
def _convert_to_additive_mask(self, attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
return (attention_mask - 1).to(dtype).reshape(
|
||||
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(dtype).max
|
||||
|
||||
def _run_connectors(self, pipe, encoded_input: torch.Tensor,
|
||||
attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
connector_attention_mask = self._convert_to_additive_mask(attention_mask, encoded_input.dtype)
|
||||
|
||||
encoded, encoded_connector_attention_mask = pipe.text_encoder_post_modules.embeddings_connector(
|
||||
encoded_input,
|
||||
connector_attention_mask,
|
||||
)
|
||||
|
||||
# restore the mask values to int64
|
||||
attention_mask = (encoded_connector_attention_mask < 0.000001).to(torch.int64)
|
||||
attention_mask = attention_mask.reshape([encoded.shape[0], encoded.shape[1], 1])
|
||||
encoded = encoded * attention_mask
|
||||
|
||||
encoded_for_audio, _ = pipe.text_encoder_post_modules.audio_embeddings_connector(
|
||||
encoded_input, connector_attention_mask)
|
||||
|
||||
return encoded, encoded_for_audio, attention_mask.squeeze(-1)
|
||||
|
||||
def _norm_and_concat_padded_batch(
|
||||
self,
|
||||
encoded_text: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
padding_side: str = "right",
|
||||
) -> torch.Tensor:
|
||||
"""Normalize and flatten multi-layer hidden states, respecting padding.
|
||||
Performs per-batch, per-layer normalization using masked mean and range,
|
||||
then concatenates across the layer dimension.
|
||||
Args:
|
||||
encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].
|
||||
sequence_lengths: Number of valid (non-padded) tokens per batch item.
|
||||
padding_side: Whether padding is on "left" or "right".
|
||||
Returns:
|
||||
Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],
|
||||
with padded positions zeroed out.
|
||||
"""
|
||||
b, t, d, l = encoded_text.shape # noqa: E741
|
||||
device = encoded_text.device
|
||||
# Build mask: [B, T, 1, 1]
|
||||
token_indices = torch.arange(t, device=device)[None, :] # [1, T]
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [B, T]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = t - sequence_lengths[:, None] # [B, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = rearrange(mask, "b t -> b t 1 1")
|
||||
eps = 1e-6
|
||||
# Compute masked mean: [B, 1, 1, L]
|
||||
masked = encoded_text.masked_fill(~mask, 0.0)
|
||||
denom = (sequence_lengths * d).view(b, 1, 1, 1)
|
||||
mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps)
|
||||
# Compute masked min/max: [B, 1, 1, L]
|
||||
x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
range_ = x_max - x_min
|
||||
# Normalize only the valid tokens
|
||||
normed = 8 * (encoded_text - mean) / (range_ + eps)
|
||||
# concat to be [Batch, T, D * L] - this preserves the original structure
|
||||
normed = normed.reshape(b, t, -1) # [B, T, D * L]
|
||||
# Apply mask to preserve original padding (set padded positions to 0)
|
||||
mask_flattened = rearrange(mask, "b t 1 1 -> b t 1").expand(-1, -1, d * l)
|
||||
normed = normed.masked_fill(~mask_flattened, 0.0)
|
||||
|
||||
return normed
|
||||
|
||||
def _run_feature_extractor(self,
|
||||
pipe,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
padding_side: str = "right") -> torch.Tensor:
|
||||
encoded_text_features = torch.stack(hidden_states, dim=-1)
|
||||
encoded_text_features_dtype = encoded_text_features.dtype
|
||||
sequence_lengths = attention_mask.sum(dim=-1)
|
||||
normed_concated_encoded_text_features = self._norm_and_concat_padded_batch(encoded_text_features,
|
||||
sequence_lengths,
|
||||
padding_side=padding_side)
|
||||
|
||||
return pipe.text_encoder_post_modules.feature_extractor_linear(
|
||||
normed_concated_encoded_text_features.to(encoded_text_features_dtype))
|
||||
|
||||
def _preprocess_text(
|
||||
self,
|
||||
pipe,
|
||||
text: str,
|
||||
padding_side: str = "left",
|
||||
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Encode a given string into feature tensors suitable for downstream tasks.
|
||||
Args:
|
||||
text (str): Input string to encode.
|
||||
Returns:
|
||||
tuple[torch.Tensor, dict[str, torch.Tensor]]: Encoded features and a dictionary with attention mask.
|
||||
"""
|
||||
token_pairs = pipe.tokenizer.tokenize_with_weights(text)["gemma"]
|
||||
input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.device)
|
||||
attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.device)
|
||||
outputs = pipe.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
||||
projected = self._run_feature_extractor(pipe,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
padding_side=padding_side)
|
||||
return projected, attention_mask
|
||||
|
||||
return outputs.hidden_states, attention_mask
|
||||
def encode_prompt(self, pipe, text, padding_side="left"):
|
||||
encoded_inputs, attention_mask = self._preprocess_text(pipe, text, padding_side)
|
||||
video_encoding, audio_encoding, attention_mask = self._run_connectors(pipe, encoded_inputs, attention_mask)
|
||||
hidden_states, attention_mask = self._preprocess_text(pipe, text)
|
||||
video_encoding, audio_encoding, attention_mask = pipe.text_encoder_post_modules.process_hidden_states(
|
||||
hidden_states, attention_mask, padding_side)
|
||||
return video_encoding, audio_encoding, attention_mask
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, prompt: str):
|
||||
@@ -539,7 +438,7 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
||||
self.get_image_latent(pipe, img, stage1_height, stage1_width, tiled, tile_size_in_pixels,
|
||||
tile_overlap_in_pixels) for img in input_images
|
||||
]
|
||||
video_latents, denoise_mask_video, initial_latents = pipe.apply_input_images_to_latents(video_latents, stage1_latents, input_images_indexes, input_images_strength, num_frames=num_frames)
|
||||
video_latents, denoise_mask_video, initial_latents = pipe.apply_input_images_to_latents(video_latents, stage1_latents, input_images_indexes, input_images_strength)
|
||||
output_dicts.update({"video_latents": video_latents, "denoise_mask_video": denoise_mask_video, "input_latents_video": initial_latents})
|
||||
if use_two_stage_pipeline:
|
||||
stage2_latents = [
|
||||
@@ -649,6 +548,7 @@ def model_fn_ltx2(
|
||||
audio_positions=audio_positions,
|
||||
audio_context=audio_context,
|
||||
audio_timesteps=audio_timesteps,
|
||||
sigma=timestep,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
|
||||
@@ -27,6 +27,7 @@ def LTX2VocoderStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("vocoder."):
|
||||
new_name = name.replace("vocoder.", "")
|
||||
# new_name = name.replace("vocoder.", "")
|
||||
new_name = name[len("vocoder."):]
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
return state_dict_
|
||||
|
||||
@@ -6,7 +6,8 @@ def LTX2VideoEncoderStateDictConverter(state_dict):
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
elif name.startswith("vae.per_channel_statistics."):
|
||||
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
if new_name not in ["per_channel_statistics.channel", "per_channel_statistics.mean-of-stds", "per_channel_statistics.mean-of-stds_over_std-of-means"]:
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
return state_dict_
|
||||
|
||||
|
||||
@@ -18,5 +19,6 @@ def LTX2VideoDecoderStateDictConverter(state_dict):
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
elif name.startswith("vae.per_channel_statistics."):
|
||||
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
return state_dict_
|
||||
if new_name not in ["per_channel_statistics.channel", "per_channel_statistics.mean-of-stds", "per_channel_statistics.mean-of-stds_over_std-of-means"]:
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
return state_dict_
|
||||
|
||||
@@ -30,7 +30,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
)
|
||||
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
height, width, num_frames = 512, 768, 121
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-distilled.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
)
|
||||
|
||||
|
||||
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||
negative_prompt = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"]
|
||||
)
|
||||
image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height))
|
||||
# first frame
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
use_distilled_pipeline=True,
|
||||
input_images=[image],
|
||||
input_images_indexes=[0],
|
||||
input_images_strength=1.0,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2.3_distilled_i2av_first.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,
|
||||
)
|
||||
56
examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py
Normal file
56
examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
)
|
||||
|
||||
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
height, width, num_frames = 512, 768, 121
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"]
|
||||
)
|
||||
image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height))
|
||||
# first frame
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=False,
|
||||
input_images=[image],
|
||||
input_images_indexes=[0],
|
||||
input_images_strength=1.0,
|
||||
num_inference_steps=40,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2.3_onestage_i2av_first.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,
|
||||
)
|
||||
71
examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py
Normal file
71
examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-distilled-lora-384.safetensors"),
|
||||
)
|
||||
|
||||
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||
negative_prompt = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"]
|
||||
)
|
||||
image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height))
|
||||
# first frame
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=42,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
use_two_stage_pipeline=True,
|
||||
num_inference_steps=40,
|
||||
input_images=[image],
|
||||
input_images_indexes=[0],
|
||||
input_images_strength=1.0,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2.3_twostage_i2av_first.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,
|
||||
)
|
||||
@@ -0,0 +1,57 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-distilled.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
)
|
||||
|
||||
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||
negative_prompt = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
use_distilled_pipeline=True,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2.3_distilled.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,
|
||||
)
|
||||
42
examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py
Normal file
42
examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
)
|
||||
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
height, width, num_frames = 512, 768, 121
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2.3_onestage.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,
|
||||
)
|
||||
57
examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage.py
Normal file
57
examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-distilled-lora-384.safetensors"),
|
||||
)
|
||||
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||
negative_prompt = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
use_two_stage_pipeline=True,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2.3_twostage.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,
|
||||
)
|
||||
@@ -31,7 +31,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
)
|
||||
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
height, width, num_frames = 512, 768, 121
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
|
||||
Reference in New Issue
Block a user