support ltx2.3 inference (#1332)

This commit is contained in:
Zhongjie Duan
2026-03-06 16:24:53 +08:00
committed by GitHub
17 changed files with 1608 additions and 351 deletions

View File

@@ -718,6 +718,66 @@ ltx2_series = [
"model_name": "ltx2_latent_upsampler",
"model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
"model_name": "ltx2_dit",
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
"extra_kwargs": {"apply_gated_attention": True, "cross_attention_adaln": True, "caption_channels": None},
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
"model_name": "ltx2_video_vae_encoder",
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
"extra_kwargs": {"encoder_version": "ltx-2.3"},
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
"model_name": "ltx2_video_vae_decoder",
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
"extra_kwargs": {"decoder_version": "ltx-2.3"},
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
"model_name": "ltx2_audio_vae_decoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
"model_name": "ltx2_audio_vocoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2VocoderWithBWE",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
"model_name": "ltx2_audio_vae_encoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
"model_name": "ltx2_text_encoder_post_modules",
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
"extra_kwargs": {"seperated_audio_video": True, "embedding_dim_gemma": 3840, "num_layers_gemma": 49, "video_attetion_heads": 32, "video_attention_head_dim": 128, "audio_attention_heads": 32, "audio_attention_head_dim": 64, "num_connetor_layers": 8, "apply_gated_attention": True},
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
"model_hash": "aed408774d694a2452f69936c32febb5",
"model_name": "ltx2_latent_upsampler",
"model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
"extra_kwargs": {"rational_resampler": False},
},
]
anima_series = [
{

View File

@@ -1279,9 +1279,268 @@ class LTX2AudioDecoder(torch.nn.Module):
return torch.tanh(h) if self.tanh_out else h
def get_padding(kernel_size: int, dilation: int = 1) -> int:
return int((kernel_size * dilation - dilation) / 2)
# ---------------------------------------------------------------------------
# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2
# Adopted from https://github.com/NVIDIA/BigVGAN
# ---------------------------------------------------------------------------
def _sinc(x: torch.Tensor) -> torch.Tensor:
return torch.where(
x == 0,
torch.tensor(1.0, device=x.device, dtype=x.dtype),
torch.sin(math.pi * x) / math.pi / x,
)
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor:
even = kernel_size % 2 == 0
half_size = kernel_size // 2
delta_f = 4 * half_width
amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if amplitude > 50.0:
beta = 0.1102 * (amplitude - 8.7)
elif amplitude >= 21.0:
beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)
else:
beta = 0.0
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
filter_ /= filter_.sum()
return filter_.view(1, 1, kernel_size)
class LowPassFilter1d(nn.Module):
def __init__(
self,
cutoff: float = 0.5,
half_width: float = 0.6,
stride: int = 1,
padding: bool = True,
padding_mode: str = "replicate",
kernel_size: int = 12,
) -> None:
super().__init__()
if cutoff < -0.0:
raise ValueError("Minimum cutoff must be larger than zero.")
if cutoff > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.kernel_size = kernel_size
self.even = kernel_size % 2 == 0
self.pad_left = kernel_size // 2 - int(self.even)
self.pad_right = kernel_size // 2
self.stride = stride
self.padding = padding
self.padding_mode = padding_mode
self.register_buffer("filter", kaiser_sinc_filter1d(cutoff, half_width, kernel_size))
def forward(self, x: torch.Tensor) -> torch.Tensor:
_, n_channels, _ = x.shape
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
return F.conv1d(x, self.filter.expand(n_channels, -1, -1), stride=self.stride, groups=n_channels)
class UpSample1d(nn.Module):
def __init__(
self,
ratio: int = 2,
kernel_size: int | None = None,
persistent: bool = True,
window_type: str = "kaiser",
) -> None:
super().__init__()
self.ratio = ratio
self.stride = ratio
if window_type == "hann":
# Hann-windowed sinc filter equivalent to torchaudio.functional.resample
rolloff = 0.99
lowpass_filter_width = 6
width = math.ceil(lowpass_filter_width / rolloff)
self.kernel_size = 2 * width * ratio + 1
self.pad = width
self.pad_left = 2 * width * ratio
self.pad_right = self.kernel_size - ratio
time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff
time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width)
window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2
sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1)
else:
# Kaiser-windowed sinc filter (BigVGAN default).
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
sinc_filter = kaiser_sinc_filter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,
kernel_size=self.kernel_size,
)
self.register_buffer("filter", sinc_filter, persistent=persistent)
def forward(self, x: torch.Tensor) -> torch.Tensor:
_, n_channels, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode="replicate")
filt = self.filter.to(dtype=x.dtype, device=x.device).expand(n_channels, -1, -1)
x = self.ratio * F.conv_transpose1d(x, filt, stride=self.stride, groups=n_channels)
return x[..., self.pad_left : -self.pad_right]
class DownSample1d(nn.Module):
def __init__(self, ratio: int = 2, kernel_size: int | None = None) -> None:
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.lowpass = LowPassFilter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=self.kernel_size,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lowpass(x)
class Activation1d(nn.Module):
def __init__(
self,
activation: nn.Module,
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12,
) -> None:
super().__init__()
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.upsample(x)
x = self.act(x)
return self.downsample(x)
class Snake(nn.Module):
def __init__(
self,
in_features: int,
alpha: float = 1.0,
alpha_trainable: bool = True,
alpha_logscale: bool = True,
) -> None:
super().__init__()
self.alpha_logscale = alpha_logscale
self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.eps = 1e-9
def forward(self, x: torch.Tensor) -> torch.Tensor:
alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
return x + (1.0 / (alpha + self.eps)) * torch.sin(x * alpha).pow(2)
class SnakeBeta(nn.Module):
def __init__(
self,
in_features: int,
alpha: float = 1.0,
alpha_trainable: bool = True,
alpha_logscale: bool = True,
) -> None:
super().__init__()
self.alpha_logscale = alpha_logscale
self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
self.beta.requires_grad = alpha_trainable
self.eps = 1e-9
def forward(self, x: torch.Tensor) -> torch.Tensor:
alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
beta = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
return x + (1.0 / (beta + self.eps)) * torch.sin(x * alpha).pow(2)
class AMPBlock1(nn.Module):
def __init__(
self,
channels: int,
kernel_size: int = 3,
dilation: tuple[int, int, int] = (1, 3, 5),
activation: str = "snake",
) -> None:
super().__init__()
act_cls = SnakeBeta if activation == "snakebeta" else Snake
self.convs1 = nn.ModuleList(
[
nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
),
nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
),
nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
),
]
)
self.convs2 = nn.ModuleList(
[
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
]
)
self.acts1 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs1))])
self.acts2 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs2))])
def forward(self, x: torch.Tensor) -> torch.Tensor:
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2, strict=True):
xt = a1(x)
xt = c1(xt)
xt = a2(xt)
xt = c2(xt)
x = x + xt
return x
class LTX2Vocoder(torch.nn.Module):
"""
Vocoder model for synthesizing audio from Mel spectrograms.
LTX2Vocoder model for synthesizing audio from Mel spectrograms.
Args:
resblock_kernel_sizes: List of kernel sizes for the residual blocks.
This value is read from the checkpoint at `config.vocoder.resblock_kernel_sizes`.
@@ -1293,28 +1552,33 @@ class LTX2Vocoder(torch.nn.Module):
This value is read from the checkpoint at `config.vocoder.resblock_dilation_sizes`.
upsample_initial_channel: Initial number of channels for the upsampling layers.
This value is read from the checkpoint at `config.vocoder.upsample_initial_channel`.
stereo: Whether to use stereo output.
This value is read from the checkpoint at `config.vocoder.stereo`.
resblock: Type of residual block to use.
resblock: Type of residual block to use ("1", "2", or "AMP1").
This value is read from the checkpoint at `config.vocoder.resblock`.
output_sample_rate: Waveform sample rate.
This value is read from the checkpoint at `config.vocoder.output_sample_rate`.
output_sampling_rate: Waveform sample rate.
This value is read from the checkpoint at `config.vocoder.output_sampling_rate`.
activation: Activation type for BigVGAN v2 ("snake" or "snakebeta"). Only used when resblock="AMP1".
use_tanh_at_final: Apply tanh at the output (when apply_final_activation=True).
apply_final_activation: Whether to apply the final tanh/clamp activation.
use_bias_at_final: Whether to use bias in the final conv layer.
"""
def __init__(
def __init__( # noqa: PLR0913
self,
resblock_kernel_sizes: List[int] | None = [3, 7, 11],
upsample_rates: List[int] | None = [6, 5, 2, 2, 2],
upsample_kernel_sizes: List[int] | None = [16, 15, 8, 4, 4],
resblock_dilation_sizes: List[List[int]] | None = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
upsample_initial_channel: int = 1024,
stereo: bool = True,
resblock: str = "1",
output_sample_rate: int = 24000,
):
output_sampling_rate: int = 24000,
activation: str = "snake",
use_tanh_at_final: bool = True,
apply_final_activation: bool = True,
use_bias_at_final: bool = True,
) -> None:
super().__init__()
# Initialize default values if not provided. Note that mutable default values are not supported.
# Mutable default values are not supported as default arguments.
if resblock_kernel_sizes is None:
resblock_kernel_sizes = [3, 7, 11]
if upsample_rates is None:
@@ -1324,36 +1588,60 @@ class LTX2Vocoder(torch.nn.Module):
if resblock_dilation_sizes is None:
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
self.output_sample_rate = output_sample_rate
self.output_sampling_rate = output_sampling_rate
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
in_channels = 128 if stereo else 64
self.conv_pre = nn.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
self.use_tanh_at_final = use_tanh_at_final
self.apply_final_activation = apply_final_activation
self.is_amp = resblock == "AMP1"
self.ups = nn.ModuleList()
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True)):
self.ups.append(
nn.ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
kernel_size,
stride,
padding=(kernel_size - stride) // 2,
)
# All production checkpoints are stereo: 128 input channels (2 stereo channels x 64 mel
# bins each), 2 output channels.
self.conv_pre = nn.Conv1d(
in_channels=128,
out_channels=upsample_initial_channel,
kernel_size=7,
stride=1,
padding=3,
)
resblock_cls = ResBlock1 if resblock == "1" else AMPBlock1
self.ups = nn.ModuleList(
nn.ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
kernel_size,
stride,
padding=(kernel_size - stride) // 2,
)
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True))
)
final_channels = upsample_initial_channel // (2 ** len(upsample_rates))
self.resblocks = nn.ModuleList()
for i, _ in enumerate(self.ups):
for i in range(len(upsample_rates)):
ch = upsample_initial_channel // (2 ** (i + 1))
for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes, strict=True):
self.resblocks.append(resblock_class(ch, kernel_size, dilations))
if self.is_amp:
self.resblocks.append(resblock_cls(ch, kernel_size, dilations, activation=activation))
else:
self.resblocks.append(resblock_cls(ch, kernel_size, dilations))
out_channels = 2 if stereo else 1
final_channels = upsample_initial_channel // (2**self.num_upsamples)
self.conv_post = nn.Conv1d(final_channels, out_channels, 7, 1, padding=3)
if self.is_amp:
self.act_post: nn.Module = Activation1d(SnakeBeta(final_channels))
else:
self.act_post = nn.LeakyReLU()
self.upsample_factor = math.prod(layer.stride[0] for layer in self.ups)
# All production checkpoints are stereo: this final conv maps `final_channels` to 2 output channels (stereo).
self.conv_post = nn.Conv1d(
in_channels=final_channels,
out_channels=2,
kernel_size=7,
stride=1,
padding=3,
bias=use_bias_at_final,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
@@ -1374,7 +1662,8 @@ class LTX2Vocoder(torch.nn.Module):
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
if not self.is_amp:
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
start = i * self.num_kernels
end = start + self.num_kernels
@@ -1386,23 +1675,198 @@ class LTX2Vocoder(torch.nn.Module):
[self.resblocks[idx](x) for idx in range(start, end)],
dim=0,
)
x = block_outputs.mean(dim=0)
x = self.conv_post(F.leaky_relu(x))
return torch.tanh(x)
x = self.act_post(x)
x = self.conv_post(x)
if self.apply_final_activation:
x = torch.tanh(x) if self.use_tanh_at_final else torch.clamp(x, -1, 1)
return x
def decode_audio(latent: torch.Tensor, audio_decoder: "LTX2AudioDecoder", vocoder: "LTX2Vocoder") -> torch.Tensor:
class _STFTFn(nn.Module):
"""Implements STFT as a convolution with precomputed DFT x Hann-window bases.
The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal
Hann window are stored as buffers and loaded from the checkpoint. Using the exact
bfloat16 bases from training ensures the mel values fed to the BWE generator are
bit-identical to what it was trained on.
"""
Decode an audio latent representation using the provided audio decoder and vocoder.
Args:
latent: Input audio latent tensor.
audio_decoder: Model to decode the latent to waveform features.
vocoder: Model to convert decoded features to audio waveform.
Returns:
Decoded audio as a float tensor.
def __init__(self, filter_length: int, hop_length: int, win_length: int) -> None:
super().__init__()
self.hop_length = hop_length
self.win_length = win_length
n_freqs = filter_length // 2 + 1
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length))
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length))
def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute magnitude and phase spectrogram from a batch of waveforms.
Applies causal (left-only) padding of win_length - hop_length samples so that
each output frame depends only on past and present input — no lookahead.
Args:
y: Waveform tensor of shape (B, T).
Returns:
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
"""
if y.dim() == 2:
y = y.unsqueeze(1) # (B, 1, T)
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
y = F.pad(y, (left_pad, 0))
spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
n_freqs = spec.shape[1] // 2
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
magnitude = torch.sqrt(real**2 + imag**2)
phase = torch.atan2(imag.float(), real.float()).to(real.dtype)
return magnitude, phase
class MelSTFT(nn.Module):
"""Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.
Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input
waveform and projecting the linear magnitude spectrum onto the mel filterbank.
The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint
(mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).
"""
decoded_audio = audio_decoder(latent)
decoded_audio = vocoder(decoded_audio).squeeze(0).float()
return decoded_audio
def __init__(
self,
filter_length: int,
hop_length: int,
win_length: int,
n_mel_channels: int,
) -> None:
super().__init__()
self.stft_fn = _STFTFn(filter_length, hop_length, win_length)
# Initialized to zeros; load_state_dict overwrites with the checkpoint's
# exact bfloat16 filterbank (vocoder.mel_stft.mel_basis, shape [n_mels, n_freqs]).
n_freqs = filter_length // 2 + 1
self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs))
def mel_spectrogram(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute log-mel spectrogram and auxiliary spectral quantities.
Args:
y: Waveform tensor of shape (B, T).
Returns:
log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames).
"""
magnitude, phase = self.stft_fn(y)
energy = torch.norm(magnitude, dim=1)
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
log_mel = torch.log(torch.clamp(mel, min=1e-5))
return log_mel, magnitude, phase, energy
class LTX2VocoderWithBWE(nn.Module):
"""LTX2Vocoder with bandwidth extension (BWE) upsampling.
Chains a mel-to-wav vocoder with a BWE module that upsamples the output
to a higher sample rate. The BWE computes a mel spectrogram from the
vocoder output, runs it through a second generator to predict a residual,
and adds it to a sinc-resampled skip connection.
"""
def __init__(
self,
input_sampling_rate: int = 16000,
output_sampling_rate: int = 48000,
hop_length: int = 80,
) -> None:
super().__init__()
self.vocoder = LTX2Vocoder(
resblock_kernel_sizes=[3, 7, 11],
upsample_rates=[5, 2, 2, 2, 2, 2],
upsample_kernel_sizes=[11, 4, 4, 4, 4, 4],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
upsample_initial_channel=1536,
resblock="AMP1",
activation="snakebeta",
use_tanh_at_final=False,
apply_final_activation=True,
use_bias_at_final=False,
output_sampling_rate=input_sampling_rate,
)
self.bwe_generator = LTX2Vocoder(
resblock_kernel_sizes=[3, 7, 11],
upsample_rates=[6, 5, 2, 2, 2],
upsample_kernel_sizes=[12, 11, 4, 4, 4],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
upsample_initial_channel=512,
resblock="AMP1",
activation="snakebeta",
use_tanh_at_final=False,
apply_final_activation=False,
use_bias_at_final=False,
output_sampling_rate=output_sampling_rate,
)
self.mel_stft = MelSTFT(
filter_length=512,
hop_length=hop_length,
win_length=512,
n_mel_channels=64,
)
self.input_sampling_rate = input_sampling_rate
self.output_sampling_rate = output_sampling_rate
self.hop_length = hop_length
# Compute the resampler on CPU so the sinc filter is materialized even when
# the model is constructed on meta device (SingleGPUModelBuilder pattern).
# The filter is not stored in the checkpoint (persistent=False).
with torch.device("cpu"):
self.resampler = UpSample1d(
ratio=output_sampling_rate // input_sampling_rate, persistent=False, window_type="hann"
)
@property
def conv_pre(self) -> nn.Conv1d:
return self.vocoder.conv_pre
@property
def conv_post(self) -> nn.Conv1d:
return self.vocoder.conv_post
def _compute_mel(self, audio: torch.Tensor) -> torch.Tensor:
"""Compute log-mel spectrogram from waveform using causal STFT bases.
Args:
audio: Waveform tensor of shape (B, C, T).
Returns:
mel: Log-mel spectrogram of shape (B, C, n_mels, T_frames).
"""
batch, n_channels, _ = audio.shape
flat = audio.reshape(batch * n_channels, -1) # (B*C, T)
mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
return mel.reshape(batch, n_channels, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
def forward(self, mel_spec: torch.Tensor) -> torch.Tensor:
"""Run the full vocoder + BWE forward pass.
Args:
mel_spec: Mel spectrogram of shape (B, 2, T, mel_bins) for stereo
or (B, T, mel_bins) for mono. Same format as LTX2Vocoder.forward.
Returns:
Waveform tensor of shape (B, out_channels, T_out) clipped to [-1, 1].
"""
x = self.vocoder(mel_spec)
_, _, length_low_rate = x.shape
output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate
# Pad to multiple of hop_length for exact mel frame count
remainder = length_low_rate % self.hop_length
if remainder != 0:
x = F.pad(x, (0, self.hop_length - remainder))
# Compute mel spectrogram from vocoder output: (B, C, n_mels, T_frames)
mel = self._compute_mel(x)
# LTX2Vocoder.forward expects (B, C, T, mel_bins) — transpose before calling bwe_generator
mel_for_bwe = mel.transpose(2, 3) # (B, C, T_frames, mel_bins)
residual = self.bwe_generator(mel_for_bwe)
skip = self.resampler(x)
assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
return torch.clamp(residual + skip, -1, 1)[..., :output_length]

View File

@@ -251,11 +251,27 @@ class Modality:
Input data for a single modality (video or audio) in the transformer.
Bundles the latent tokens, timestep embeddings, positional information,
and text conditioning context for processing by the diffusion transformer.
Attributes:
latent: Patchified latent tokens, shape ``(B, T, D)`` where *B* is
the batch size, *T* is the total number of tokens (noisy +
conditioning), and *D* is the input dimension.
timesteps: Per-token timestep embeddings, shape ``(B, T)``.
positions: Positional coordinates, shape ``(B, 3, T)`` for video
(time, height, width) or ``(B, 1, T)`` for audio.
context: Text conditioning embeddings from the prompt encoder.
enabled: Whether this modality is active in the current forward pass.
context_mask: Optional mask for the text context tokens.
attention_mask: Optional 2-D self-attention mask, shape ``(B, T, T)``.
Values in ``[0, 1]`` where ``1`` = full attention and ``0`` = no
attention. ``None`` means unrestricted (full) attention between
all tokens. Built incrementally by conditioning items; see
:class:`~ltx_core.conditioning.types.attention_strength_wrapper.ConditioningItemAttentionStrengthWrapper`.
"""
latent: (
torch.Tensor
) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
sigma: torch.Tensor # Shape: (B,). Current sigma value, used for cross-attention timestep calculation.
timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps
positions: (
torch.Tensor
@@ -263,6 +279,7 @@ class Modality:
context: torch.Tensor
enabled: bool = True
context_mask: torch.Tensor | None = None
attention_mask: torch.Tensor | None = None
def to_denoised(

View File

@@ -225,6 +225,17 @@ class BatchedPerturbationConfig:
return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)])
ADALN_NUM_BASE_PARAMS = 6
# Cross-attention AdaLN adds 3 more (scale, shift, gate) for the CA norm.
ADALN_NUM_CROSS_ATTN_PARAMS = 3
def adaln_embedding_coefficient(cross_attention_adaln: bool) -> int:
"""Total number of AdaLN parameters per block."""
return ADALN_NUM_BASE_PARAMS + (ADALN_NUM_CROSS_ATTN_PARAMS if cross_attention_adaln else 0)
class AdaLayerNormSingle(torch.nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
@@ -460,6 +471,7 @@ class Attention(torch.nn.Module):
dim_head: int = 64,
norm_eps: float = 1e-6,
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
apply_gated_attention: bool = False,
) -> None:
super().__init__()
self.rope_type = rope_type
@@ -477,6 +489,12 @@ class Attention(torch.nn.Module):
self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True)
self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True)
# Optional per-head gating
if apply_gated_attention:
self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True)
else:
self.to_gate_logits = None
self.to_out = torch.nn.Sequential(torch.nn.Linear(inner_dim, query_dim, bias=True), torch.nn.Identity())
def forward(
@@ -486,6 +504,8 @@ class Attention(torch.nn.Module):
mask: torch.Tensor | None = None,
pe: torch.Tensor | None = None,
k_pe: torch.Tensor | None = None,
perturbation_mask: torch.Tensor | None = None,
all_perturbed: bool = False,
) -> torch.Tensor:
q = self.to_q(x)
context = x if context is None else context
@@ -517,6 +537,19 @@ class Attention(torch.nn.Module):
# Reshape back to original format
out = out.flatten(2, 3)
# Apply per-head gating if enabled
if self.to_gate_logits is not None:
gate_logits = self.to_gate_logits(x) # (B, T, H)
b, t, _ = out.shape
# Reshape to (B, T, H, D) for per-head gating
out = out.view(b, t, self.heads, self.dim_head)
# Apply gating: 2 * sigmoid(x) so that zero-init gives identity (2 * 0.5 = 1.0)
gates = 2.0 * torch.sigmoid(gate_logits) # (B, T, H)
out = out * gates.unsqueeze(-1) # (B, T, H, D) * (B, T, H, 1)
# Reshape back to (B, T, H*D)
out = out.view(b, t, self.heads * self.dim_head)
return self.to_out(out)
@@ -545,7 +578,6 @@ class PixArtAlphaTextProjection(torch.nn.Module):
hidden_states = self.linear_2(hidden_states)
return hidden_states
@dataclass(frozen=True)
class TransformerArgs:
x: torch.Tensor
@@ -558,7 +590,10 @@ class TransformerArgs:
cross_scale_shift_timestep: torch.Tensor | None
cross_gate_timestep: torch.Tensor | None
enabled: bool
prompt_timestep: torch.Tensor | None = None
self_attention_mask: torch.Tensor | None = (
None # Additive log-space self-attention bias (B, 1, T, T), None = full attention
)
class TransformerArgsPreprocessor:
@@ -566,7 +601,6 @@ class TransformerArgsPreprocessor:
self,
patchify_proj: torch.nn.Linear,
adaln: AdaLayerNormSingle,
caption_projection: PixArtAlphaTextProjection,
inner_dim: int,
max_pos: list[int],
num_attention_heads: int,
@@ -575,10 +609,11 @@ class TransformerArgsPreprocessor:
double_precision_rope: bool,
positional_embedding_theta: float,
rope_type: LTXRopeType,
caption_projection: torch.nn.Module | None = None,
prompt_adaln: AdaLayerNormSingle | None = None,
) -> None:
self.patchify_proj = patchify_proj
self.adaln = adaln
self.caption_projection = caption_projection
self.inner_dim = inner_dim
self.max_pos = max_pos
self.num_attention_heads = num_attention_heads
@@ -587,18 +622,18 @@ class TransformerArgsPreprocessor:
self.double_precision_rope = double_precision_rope
self.positional_embedding_theta = positional_embedding_theta
self.rope_type = rope_type
self.caption_projection = caption_projection
self.prompt_adaln = prompt_adaln
def _prepare_timestep(
self, timestep: torch.Tensor, batch_size: int, hidden_dtype: torch.dtype
self, timestep: torch.Tensor, adaln: AdaLayerNormSingle, batch_size: int, hidden_dtype: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
"""Prepare timestep embeddings."""
timestep = timestep * self.timestep_scale_multiplier
timestep, embedded_timestep = self.adaln(
timestep.flatten(),
timestep_scaled = timestep * self.timestep_scale_multiplier
timestep, embedded_timestep = adaln(
timestep_scaled.flatten(),
hidden_dtype=hidden_dtype,
)
# Second dimension is 1 or number of tokens (if timestep_per_token)
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
@@ -608,14 +643,12 @@ class TransformerArgsPreprocessor:
self,
context: torch.Tensor,
x: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
) -> torch.Tensor:
"""Prepare context for transformer blocks."""
if self.caption_projection is not None:
context = self.caption_projection(context)
batch_size = x.shape[0]
context = self.caption_projection(context)
context = context.view(batch_size, -1, x.shape[-1])
return context, attention_mask
return context.view(batch_size, -1, x.shape[-1])
def _prepare_attention_mask(self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype) -> torch.Tensor | None:
"""Prepare attention mask."""
@@ -626,6 +659,34 @@ class TransformerArgsPreprocessor:
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
) * torch.finfo(x_dtype).max
def _prepare_self_attention_mask(
self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype
) -> torch.Tensor | None:
"""Prepare self-attention mask by converting [0,1] values to additive log-space bias.
Input shape: (B, T, T) with values in [0, 1].
Output shape: (B, 1, T, T) with 0.0 for full attention and a large negative value
for masked positions.
Positions with attention_mask <= 0 are fully masked (mapped to the dtype's minimum
representable value). Strictly positive entries are converted via log-space for
smooth attenuation, with small values clamped for numerical stability.
Returns None if input is None (no masking).
"""
if attention_mask is None:
return None
# Convert [0, 1] attention mask to additive log-space bias:
# 1.0 -> log(1.0) = 0.0 (no bias, full attention)
# 0.0 -> finfo.min (fully masked)
finfo = torch.finfo(x_dtype)
eps = finfo.tiny
bias = torch.full_like(attention_mask, finfo.min, dtype=x_dtype)
positive = attention_mask > 0
if positive.any():
bias[positive] = torch.log(attention_mask[positive].clamp(min=eps)).to(x_dtype)
return bias.unsqueeze(1) # (B, 1, T, T) for head broadcast
def _prepare_positional_embeddings(
self,
positions: torch.Tensor,
@@ -653,11 +714,20 @@ class TransformerArgsPreprocessor:
def prepare(
self,
modality: Modality,
cross_modality: Modality | None = None, # noqa: ARG002
) -> TransformerArgs:
x = self.patchify_proj(modality.latent)
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], modality.latent.dtype)
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
batch_size = x.shape[0]
timestep, embedded_timestep = self._prepare_timestep(
modality.timesteps, self.adaln, batch_size, modality.latent.dtype
)
prompt_timestep = None
if self.prompt_adaln is not None:
prompt_timestep, _ = self._prepare_timestep(
modality.sigma, self.prompt_adaln, batch_size, modality.latent.dtype
)
context = self._prepare_context(modality.context, x)
attention_mask = self._prepare_attention_mask(modality.context_mask, modality.latent.dtype)
pe = self._prepare_positional_embeddings(
positions=modality.positions,
inner_dim=self.inner_dim,
@@ -666,6 +736,7 @@ class TransformerArgsPreprocessor:
num_attention_heads=self.num_attention_heads,
x_dtype=modality.latent.dtype,
)
self_attention_mask = self._prepare_self_attention_mask(modality.attention_mask, modality.latent.dtype)
return TransformerArgs(
x=x,
context=context,
@@ -677,6 +748,8 @@ class TransformerArgsPreprocessor:
cross_scale_shift_timestep=None,
cross_gate_timestep=None,
enabled=modality.enabled,
prompt_timestep=prompt_timestep,
self_attention_mask=self_attention_mask,
)
@@ -685,7 +758,6 @@ class MultiModalTransformerArgsPreprocessor:
self,
patchify_proj: torch.nn.Linear,
adaln: AdaLayerNormSingle,
caption_projection: PixArtAlphaTextProjection,
cross_scale_shift_adaln: AdaLayerNormSingle,
cross_gate_adaln: AdaLayerNormSingle,
inner_dim: int,
@@ -699,11 +771,12 @@ class MultiModalTransformerArgsPreprocessor:
positional_embedding_theta: float,
rope_type: LTXRopeType,
av_ca_timestep_scale_multiplier: int,
caption_projection: torch.nn.Module | None = None,
prompt_adaln: AdaLayerNormSingle | None = None,
) -> None:
self.simple_preprocessor = TransformerArgsPreprocessor(
patchify_proj=patchify_proj,
adaln=adaln,
caption_projection=caption_projection,
inner_dim=inner_dim,
max_pos=max_pos,
num_attention_heads=num_attention_heads,
@@ -712,6 +785,8 @@ class MultiModalTransformerArgsPreprocessor:
double_precision_rope=double_precision_rope,
positional_embedding_theta=positional_embedding_theta,
rope_type=rope_type,
caption_projection=caption_projection,
prompt_adaln=prompt_adaln,
)
self.cross_scale_shift_adaln = cross_scale_shift_adaln
self.cross_gate_adaln = cross_gate_adaln
@@ -722,8 +797,22 @@ class MultiModalTransformerArgsPreprocessor:
def prepare(
self,
modality: Modality,
cross_modality: Modality | None = None,
) -> TransformerArgs:
transformer_args = self.simple_preprocessor.prepare(modality)
if cross_modality is None:
return transformer_args
if cross_modality.sigma.numel() > 1:
if cross_modality.sigma.shape[0] != modality.timesteps.shape[0]:
raise ValueError("Cross modality sigma must have the same batch size as the modality")
if cross_modality.sigma.ndim != 1:
raise ValueError("Cross modality sigma must be a 1D tensor")
cross_timestep = cross_modality.sigma.view(
modality.timesteps.shape[0], 1, *[1] * len(modality.timesteps.shape[2:])
)
cross_pe = self.simple_preprocessor._prepare_positional_embeddings(
positions=modality.positions[:, 0:1, :],
inner_dim=self.audio_cross_attention_dim,
@@ -734,7 +823,7 @@ class MultiModalTransformerArgsPreprocessor:
)
cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep(
timestep=modality.timesteps,
timestep=cross_timestep,
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
batch_size=transformer_args.x.shape[0],
hidden_dtype=modality.latent.dtype,
@@ -749,7 +838,7 @@ class MultiModalTransformerArgsPreprocessor:
def _prepare_cross_attention_timestep(
self,
timestep: torch.Tensor,
timestep: torch.Tensor | None,
timestep_scale_multiplier: int,
batch_size: int,
hidden_dtype: torch.dtype,
@@ -779,6 +868,8 @@ class TransformerConfig:
heads: int
d_head: int
context_dim: int
apply_gated_attention: bool = False
cross_attention_adaln: bool = False
class BasicAVTransformerBlock(torch.nn.Module):
@@ -801,6 +892,7 @@ class BasicAVTransformerBlock(torch.nn.Module):
context_dim=None,
rope_type=rope_type,
norm_eps=norm_eps,
apply_gated_attention=video.apply_gated_attention,
)
self.attn2 = Attention(
query_dim=video.dim,
@@ -809,9 +901,11 @@ class BasicAVTransformerBlock(torch.nn.Module):
dim_head=video.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
apply_gated_attention=video.apply_gated_attention,
)
self.ff = FeedForward(video.dim, dim_out=video.dim)
self.scale_shift_table = torch.nn.Parameter(torch.empty(6, video.dim))
video_sst_size = adaln_embedding_coefficient(video.cross_attention_adaln)
self.scale_shift_table = torch.nn.Parameter(torch.empty(video_sst_size, video.dim))
if audio is not None:
self.audio_attn1 = Attention(
@@ -821,6 +915,7 @@ class BasicAVTransformerBlock(torch.nn.Module):
context_dim=None,
rope_type=rope_type,
norm_eps=norm_eps,
apply_gated_attention=audio.apply_gated_attention,
)
self.audio_attn2 = Attention(
query_dim=audio.dim,
@@ -829,9 +924,11 @@ class BasicAVTransformerBlock(torch.nn.Module):
dim_head=audio.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
apply_gated_attention=audio.apply_gated_attention,
)
self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim)
self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(6, audio.dim))
audio_sst_size = adaln_embedding_coefficient(audio.cross_attention_adaln)
self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(audio_sst_size, audio.dim))
if audio is not None and video is not None:
# Q: Video, K,V: Audio
@@ -842,6 +939,7 @@ class BasicAVTransformerBlock(torch.nn.Module):
dim_head=audio.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
apply_gated_attention=video.apply_gated_attention,
)
# Q: Audio, K,V: Video
@@ -852,11 +950,21 @@ class BasicAVTransformerBlock(torch.nn.Module):
dim_head=audio.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
apply_gated_attention=audio.apply_gated_attention,
)
self.scale_shift_table_a2v_ca_audio = torch.nn.Parameter(torch.empty(5, audio.dim))
self.scale_shift_table_a2v_ca_video = torch.nn.Parameter(torch.empty(5, video.dim))
self.cross_attention_adaln = (video is not None and video.cross_attention_adaln) or (
audio is not None and audio.cross_attention_adaln
)
if self.cross_attention_adaln and video is not None:
self.prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, video.dim))
if self.cross_attention_adaln and audio is not None:
self.audio_prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, audio.dim))
self.norm_eps = norm_eps
def get_ada_values(
@@ -876,19 +984,49 @@ class BasicAVTransformerBlock(torch.nn.Module):
batch_size: int,
scale_shift_timestep: torch.Tensor,
gate_timestep: torch.Tensor,
scale_shift_indices: slice,
num_scale_shift_values: int = 4,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
scale_shift_ada_values = self.get_ada_values(
scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, slice(None, None)
scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, scale_shift_indices
)
gate_ada_values = self.get_ada_values(
scale_shift_table[num_scale_shift_values:, :], batch_size, gate_timestep, slice(None, None)
)
scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values]
gate_ada_values = [t.squeeze(2) for t in gate_ada_values]
scale, shift = (t.squeeze(2) for t in scale_shift_ada_values)
(gate,) = (t.squeeze(2) for t in gate_ada_values)
return (*scale_shift_chunks, *gate_ada_values)
return scale, shift, gate
def _apply_text_cross_attention(
self,
x: torch.Tensor,
context: torch.Tensor,
attn: Attention,
scale_shift_table: torch.Tensor,
prompt_scale_shift_table: torch.Tensor | None,
timestep: torch.Tensor,
prompt_timestep: torch.Tensor | None,
context_mask: torch.Tensor | None,
cross_attention_adaln: bool = False,
) -> torch.Tensor:
"""Apply text cross-attention, with optional AdaLN modulation."""
if cross_attention_adaln:
shift_q, scale_q, gate = self.get_ada_values(scale_shift_table, x.shape[0], timestep, slice(6, 9))
return apply_cross_attention_adaln(
x,
context,
attn,
shift_q,
scale_q,
gate,
prompt_scale_shift_table,
prompt_timestep,
context_mask,
self.norm_eps,
)
return attn(rms_norm(x, eps=self.norm_eps), context=context, mask=context_mask)
def forward( # noqa: PLR0915
self,
@@ -896,7 +1034,11 @@ class BasicAVTransformerBlock(torch.nn.Module):
audio: TransformerArgs | None,
perturbations: BatchedPerturbationConfig | None = None,
) -> tuple[TransformerArgs | None, TransformerArgs | None]:
batch_size = video.x.shape[0]
if video is None and audio is None:
raise ValueError("At least one of video or audio must be provided")
batch_size = (video or audio).x.shape[0]
if perturbations is None:
perturbations = BatchedPerturbationConfig.empty(batch_size)
@@ -913,63 +1055,103 @@ class BasicAVTransformerBlock(torch.nn.Module):
vshift_msa, vscale_msa, vgate_msa = self.get_ada_values(
self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)
)
if not perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx):
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
v_mask = perturbations.mask_like(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx, vx)
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa * v_mask
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
del vshift_msa, vscale_msa
vx = vx + self.attn2(rms_norm(vx, eps=self.norm_eps), context=video.context, mask=video.context_mask)
del vshift_msa, vscale_msa, vgate_msa
all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx)
none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx)
v_mask = (
perturbations.mask_like(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx, vx)
if not all_perturbed and not none_perturbed
else None
)
vx = (
vx
+ self.attn1(
norm_vx,
pe=video.positional_embeddings,
mask=video.self_attention_mask,
perturbation_mask=v_mask,
all_perturbed=all_perturbed,
)
* vgate_msa
)
del vgate_msa, norm_vx, v_mask
vx = vx + self._apply_text_cross_attention(
vx,
video.context,
self.attn2,
self.scale_shift_table,
getattr(self, "prompt_scale_shift_table", None),
video.timesteps,
video.prompt_timestep,
video.context_mask,
cross_attention_adaln=self.cross_attention_adaln,
)
if run_ax:
ashift_msa, ascale_msa, agate_msa = self.get_ada_values(
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
)
if not perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx):
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
a_mask = perturbations.mask_like(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx, ax)
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa * a_mask
ax = ax + self.audio_attn2(rms_norm(ax, eps=self.norm_eps), context=audio.context, mask=audio.context_mask)
del ashift_msa, ascale_msa, agate_msa
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
del ashift_msa, ascale_msa
all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx)
none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx)
a_mask = (
perturbations.mask_like(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx, ax)
if not all_perturbed and not none_perturbed
else None
)
ax = (
ax
+ self.audio_attn1(
norm_ax,
pe=audio.positional_embeddings,
mask=audio.self_attention_mask,
perturbation_mask=a_mask,
all_perturbed=all_perturbed,
)
* agate_msa
)
del agate_msa, norm_ax, a_mask
ax = ax + self._apply_text_cross_attention(
ax,
audio.context,
self.audio_attn2,
self.audio_scale_shift_table,
getattr(self, "audio_prompt_scale_shift_table", None),
audio.timesteps,
audio.prompt_timestep,
audio.context_mask,
cross_attention_adaln=self.cross_attention_adaln,
)
# Audio - Video cross attention.
if run_a2v or run_v2a:
vx_norm3 = rms_norm(vx, eps=self.norm_eps)
ax_norm3 = rms_norm(ax, eps=self.norm_eps)
(
scale_ca_audio_hidden_states_a2v,
shift_ca_audio_hidden_states_a2v,
scale_ca_audio_hidden_states_v2a,
shift_ca_audio_hidden_states_v2a,
gate_out_v2a,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_audio,
ax.shape[0],
audio.cross_scale_shift_timestep,
audio.cross_gate_timestep,
)
if run_a2v and not perturbations.all_in_batch(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx):
scale_ca_video_a2v, shift_ca_video_a2v, gate_out_a2v = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_video,
vx.shape[0],
video.cross_scale_shift_timestep,
video.cross_gate_timestep,
slice(0, 2),
)
vx_scaled = vx_norm3 * (1 + scale_ca_video_a2v) + shift_ca_video_a2v
del scale_ca_video_a2v, shift_ca_video_a2v
(
scale_ca_video_hidden_states_a2v,
shift_ca_video_hidden_states_a2v,
scale_ca_video_hidden_states_v2a,
shift_ca_video_hidden_states_v2a,
gate_out_a2v,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_video,
vx.shape[0],
video.cross_scale_shift_timestep,
video.cross_gate_timestep,
)
if run_a2v:
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v) + shift_ca_video_hidden_states_a2v
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
scale_ca_audio_a2v, shift_ca_audio_a2v, _ = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_audio,
ax.shape[0],
audio.cross_scale_shift_timestep,
audio.cross_gate_timestep,
slice(0, 2),
)
ax_scaled = ax_norm3 * (1 + scale_ca_audio_a2v) + shift_ca_audio_a2v
del scale_ca_audio_a2v, shift_ca_audio_a2v
a2v_mask = perturbations.mask_like(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx, vx)
vx = vx + (
self.audio_to_video_attn(
@@ -981,10 +1163,27 @@ class BasicAVTransformerBlock(torch.nn.Module):
* gate_out_a2v
* a2v_mask
)
del gate_out_a2v, a2v_mask, vx_scaled, ax_scaled
if run_v2a:
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
if run_v2a and not perturbations.all_in_batch(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx):
scale_ca_audio_v2a, shift_ca_audio_v2a, gate_out_v2a = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_audio,
ax.shape[0],
audio.cross_scale_shift_timestep,
audio.cross_gate_timestep,
slice(2, 4),
)
ax_scaled = ax_norm3 * (1 + scale_ca_audio_v2a) + shift_ca_audio_v2a
del scale_ca_audio_v2a, shift_ca_audio_v2a
scale_ca_video_v2a, shift_ca_video_v2a, _ = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_video,
vx.shape[0],
video.cross_scale_shift_timestep,
video.cross_gate_timestep,
slice(2, 4),
)
vx_scaled = vx_norm3 * (1 + scale_ca_video_v2a) + shift_ca_video_v2a
del scale_ca_video_v2a, shift_ca_video_v2a
v2a_mask = perturbations.mask_like(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx, ax)
ax = ax + (
self.video_to_audio_attn(
@@ -996,40 +1195,53 @@ class BasicAVTransformerBlock(torch.nn.Module):
* gate_out_v2a
* v2a_mask
)
del gate_out_v2a, v2a_mask, ax_scaled, vx_scaled
del gate_out_a2v, gate_out_v2a
del (
scale_ca_video_hidden_states_a2v,
shift_ca_video_hidden_states_a2v,
scale_ca_audio_hidden_states_a2v,
shift_ca_audio_hidden_states_a2v,
scale_ca_video_hidden_states_v2a,
shift_ca_video_hidden_states_v2a,
scale_ca_audio_hidden_states_v2a,
shift_ca_audio_hidden_states_v2a,
)
del vx_norm3, ax_norm3
if run_vx:
vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, None)
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, 6)
)
vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
vx = vx + self.ff(vx_scaled) * vgate_mlp
del vshift_mlp, vscale_mlp, vgate_mlp
del vshift_mlp, vscale_mlp, vgate_mlp, vx_scaled
if run_ax:
ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, None)
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, 6)
)
ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
ax = ax + self.audio_ff(ax_scaled) * agate_mlp
del ashift_mlp, ascale_mlp, agate_mlp
del ashift_mlp, ascale_mlp, agate_mlp, ax_scaled
return replace(video, x=vx) if video is not None else None, replace(audio, x=ax) if audio is not None else None
def apply_cross_attention_adaln(
x: torch.Tensor,
context: torch.Tensor,
attn: Attention,
q_shift: torch.Tensor,
q_scale: torch.Tensor,
q_gate: torch.Tensor,
prompt_scale_shift_table: torch.Tensor,
prompt_timestep: torch.Tensor,
context_mask: torch.Tensor | None = None,
norm_eps: float = 1e-6,
) -> torch.Tensor:
batch_size = x.shape[0]
shift_kv, scale_kv = (
prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)
+ prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1)
).unbind(dim=2)
attn_input = rms_norm(x, eps=norm_eps) * (1 + q_scale) + q_shift
encoder_hidden_states = context * (1 + scale_kv) + shift_kv
return attn(attn_input, context=encoder_hidden_states, mask=context_mask) * q_gate
class GELUApprox(torch.nn.Module):
def __init__(self, dim_in: int, dim_out: int) -> None:
super().__init__()
@@ -1094,6 +1306,8 @@ class LTXModel(torch.nn.Module):
av_ca_timestep_scale_multiplier: int = 1000,
rope_type: LTXRopeType = LTXRopeType.SPLIT,
double_precision_rope: bool = True,
apply_gated_attention: bool = False,
cross_attention_adaln: bool = False,
):
super().__init__()
self._enable_gradient_checkpointing = False
@@ -1103,6 +1317,7 @@ class LTXModel(torch.nn.Module):
self.timestep_scale_multiplier = timestep_scale_multiplier
self.positional_embedding_theta = positional_embedding_theta
self.model_type = model_type
self.cross_attention_adaln = cross_attention_adaln
cross_pe_max_pos = None
if model_type.is_video_enabled():
if positional_embedding_max_pos is None:
@@ -1145,8 +1360,13 @@ class LTXModel(torch.nn.Module):
audio_attention_head_dim=audio_attention_head_dim if model_type.is_audio_enabled() else 0,
audio_cross_attention_dim=audio_cross_attention_dim,
norm_eps=norm_eps,
apply_gated_attention=apply_gated_attention,
)
@property
def _adaln_embedding_coefficient(self) -> int:
return adaln_embedding_coefficient(self.cross_attention_adaln)
def _init_video(
self,
in_channels: int,
@@ -1157,14 +1377,15 @@ class LTXModel(torch.nn.Module):
"""Initialize video-specific components."""
# Video input components
self.patchify_proj = torch.nn.Linear(in_channels, self.inner_dim, bias=True)
self.adaln_single = AdaLayerNormSingle(self.inner_dim)
self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=self._adaln_embedding_coefficient)
self.prompt_adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None
# Video caption projection
self.caption_projection = PixArtAlphaTextProjection(
in_features=caption_channels,
hidden_size=self.inner_dim,
)
if caption_channels is not None:
self.caption_projection = PixArtAlphaTextProjection(
in_features=caption_channels,
hidden_size=self.inner_dim,
)
# Video output components
self.scale_shift_table = torch.nn.Parameter(torch.empty(2, self.inner_dim))
@@ -1183,15 +1404,15 @@ class LTXModel(torch.nn.Module):
# Audio input components
self.audio_patchify_proj = torch.nn.Linear(in_channels, self.audio_inner_dim, bias=True)
self.audio_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
)
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=self._adaln_embedding_coefficient)
self.audio_prompt_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None
# Audio caption projection
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=caption_channels,
hidden_size=self.audio_inner_dim,
)
if caption_channels is not None:
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=caption_channels,
hidden_size=self.audio_inner_dim,
)
# Audio output components
self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(2, self.audio_inner_dim))
@@ -1233,7 +1454,6 @@ class LTXModel(torch.nn.Module):
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
patchify_proj=self.patchify_proj,
adaln=self.adaln_single,
caption_projection=self.caption_projection,
cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single,
cross_gate_adaln=self.av_ca_a2v_gate_adaln_single,
inner_dim=self.inner_dim,
@@ -1247,11 +1467,12 @@ class LTXModel(torch.nn.Module):
positional_embedding_theta=self.positional_embedding_theta,
rope_type=self.rope_type,
av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier,
caption_projection=getattr(self, "caption_projection", None),
prompt_adaln=getattr(self, "prompt_adaln_single", None),
)
self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor(
patchify_proj=self.audio_patchify_proj,
adaln=self.audio_adaln_single,
caption_projection=self.audio_caption_projection,
cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single,
cross_gate_adaln=self.av_ca_v2a_gate_adaln_single,
inner_dim=self.audio_inner_dim,
@@ -1265,12 +1486,13 @@ class LTXModel(torch.nn.Module):
positional_embedding_theta=self.positional_embedding_theta,
rope_type=self.rope_type,
av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier,
caption_projection=getattr(self, "audio_caption_projection", None),
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
)
elif self.model_type.is_video_enabled():
self.video_args_preprocessor = TransformerArgsPreprocessor(
patchify_proj=self.patchify_proj,
adaln=self.adaln_single,
caption_projection=self.caption_projection,
inner_dim=self.inner_dim,
max_pos=self.positional_embedding_max_pos,
num_attention_heads=self.num_attention_heads,
@@ -1279,12 +1501,13 @@ class LTXModel(torch.nn.Module):
double_precision_rope=self.double_precision_rope,
positional_embedding_theta=self.positional_embedding_theta,
rope_type=self.rope_type,
caption_projection=getattr(self, "caption_projection", None),
prompt_adaln=getattr(self, "prompt_adaln_single", None),
)
elif self.model_type.is_audio_enabled():
self.audio_args_preprocessor = TransformerArgsPreprocessor(
patchify_proj=self.audio_patchify_proj,
adaln=self.audio_adaln_single,
caption_projection=self.audio_caption_projection,
inner_dim=self.audio_inner_dim,
max_pos=self.audio_positional_embedding_max_pos,
num_attention_heads=self.audio_num_attention_heads,
@@ -1293,6 +1516,8 @@ class LTXModel(torch.nn.Module):
double_precision_rope=self.double_precision_rope,
positional_embedding_theta=self.positional_embedding_theta,
rope_type=self.rope_type,
caption_projection=getattr(self, "audio_caption_projection", None),
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
)
def _init_transformer_blocks(
@@ -1303,6 +1528,7 @@ class LTXModel(torch.nn.Module):
audio_attention_head_dim: int,
audio_cross_attention_dim: int,
norm_eps: float,
apply_gated_attention: bool,
) -> None:
"""Initialize transformer blocks for LTX."""
video_config = (
@@ -1311,6 +1537,8 @@ class LTXModel(torch.nn.Module):
heads=self.num_attention_heads,
d_head=attention_head_dim,
context_dim=cross_attention_dim,
apply_gated_attention=apply_gated_attention,
cross_attention_adaln=self.cross_attention_adaln,
)
if self.model_type.is_video_enabled()
else None
@@ -1321,6 +1549,8 @@ class LTXModel(torch.nn.Module):
heads=self.audio_num_attention_heads,
d_head=audio_attention_head_dim,
context_dim=audio_cross_attention_dim,
apply_gated_attention=apply_gated_attention,
cross_attention_adaln=self.cross_attention_adaln,
)
if self.model_type.is_audio_enabled()
else None
@@ -1409,8 +1639,8 @@ class LTXModel(torch.nn.Module):
if not self.model_type.is_audio_enabled() and audio is not None:
raise ValueError("Audio is not enabled for this model")
video_args = self.video_args_preprocessor.prepare(video) if video is not None else None
audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None
video_args = self.video_args_preprocessor.prepare(video, audio) if video is not None else None
audio_args = self.audio_args_preprocessor.prepare(audio, video) if audio is not None else None
# Process transformer blocks
video_out, audio_out = self._process_transformer_blocks(
video=video_args,
@@ -1441,12 +1671,12 @@ class LTXModel(torch.nn.Module):
)
return vx, ax
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False):
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps, sigma, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False):
cross_pe_max_pos = None
if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled():
cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0])
self._init_preprocessors(cross_pe_max_pos)
video = Modality(video_latents, video_timesteps, video_positions, video_context)
audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context) if audio_latents is not None else None
video = Modality(video_latents, sigma, video_timesteps, video_positions, video_context)
audio = Modality(audio_latents, sigma, audio_timesteps, audio_positions, audio_context) if audio_latents is not None else None
vx, ax = self._forward(video=video, audio=audio, perturbations=None, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload)
return vx, ax

View File

@@ -1,4 +1,7 @@
import math
import torch
import torch.nn as nn
from einops import rearrange
from transformers import Gemma3ForConditionalGeneration, Gemma3Config, AutoTokenizer
from .ltx2_dit import (LTXRopeType, generate_freq_grid_np, generate_freq_grid_pytorch, precompute_freqs_cis, Attention,
FeedForward)
@@ -147,14 +150,14 @@ class LTXVGemmaTokenizer:
return out
class GemmaFeaturesExtractorProjLinear(torch.nn.Module):
class GemmaFeaturesExtractorProjLinear(nn.Module):
"""
Feature extractor module for Gemma models.
This module applies a single linear projection to the input tensor.
It expects a flattened feature tensor of shape (batch_size, 3840*49).
The linear layer maps this to a (batch_size, 3840) embedding.
Attributes:
aggregate_embed (torch.nn.Linear): Linear projection layer.
aggregate_embed (nn.Linear): Linear projection layer.
"""
def __init__(self) -> None:
@@ -163,26 +166,65 @@ class GemmaFeaturesExtractorProjLinear(torch.nn.Module):
The input dimension is expected to be 3840 * 49, and the output is 3840.
"""
super().__init__()
self.aggregate_embed = torch.nn.Linear(3840 * 49, 3840, bias=False)
self.aggregate_embed = nn.Linear(3840 * 49, 3840, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the feature extractor.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, 3840 * 49).
Returns:
torch.Tensor: Output tensor of shape (batch_size, 3840).
"""
return self.aggregate_embed(x)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
padding_side: str = "left",
) -> tuple[torch.Tensor, torch.Tensor | None]:
encoded = torch.stack(hidden_states, dim=-1) if isinstance(hidden_states, (list, tuple)) else hidden_states
dtype = encoded.dtype
sequence_lengths = attention_mask.sum(dim=-1)
normed = _norm_and_concat_padded_batch(encoded, sequence_lengths, padding_side)
features = self.aggregate_embed(normed.to(dtype))
return features, features
class _BasicTransformerBlock1D(torch.nn.Module):
class GemmaSeperatedFeaturesExtractorProjLinear(nn.Module):
"""22B: per-token RMS norm → rescale → dual aggregate embeds"""
def __init__(
self,
num_layers: int,
embedding_dim: int,
video_inner_dim: int,
audio_inner_dim: int,
):
super().__init__()
in_dim = embedding_dim * num_layers
self.video_aggregate_embed = torch.nn.Linear(in_dim, video_inner_dim, bias=True)
self.audio_aggregate_embed = torch.nn.Linear(in_dim, audio_inner_dim, bias=True)
self.embedding_dim = embedding_dim
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
padding_side: str = "left", # noqa: ARG002
) -> tuple[torch.Tensor, torch.Tensor | None]:
encoded = torch.stack(hidden_states, dim=-1) if isinstance(hidden_states, (list, tuple)) else hidden_states
normed = norm_and_concat_per_token_rms(encoded, attention_mask)
normed = normed.to(encoded.dtype)
v_dim = self.video_aggregate_embed.out_features
video = self.video_aggregate_embed(_rescale_norm(normed, v_dim, self.embedding_dim))
audio = None
if self.audio_aggregate_embed is not None:
a_dim = self.audio_aggregate_embed.out_features
audio = self.audio_aggregate_embed(_rescale_norm(normed, a_dim, self.embedding_dim))
return video, audio
class _BasicTransformerBlock1D(nn.Module):
def __init__(
self,
dim: int,
heads: int,
dim_head: int,
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
apply_gated_attention: bool = False,
):
super().__init__()
@@ -191,6 +233,7 @@ class _BasicTransformerBlock1D(torch.nn.Module):
heads=heads,
dim_head=dim_head,
rope_type=rope_type,
apply_gated_attention=apply_gated_attention,
)
self.ff = FeedForward(
@@ -231,7 +274,7 @@ class _BasicTransformerBlock1D(torch.nn.Module):
return hidden_states
class Embeddings1DConnector(torch.nn.Module):
class Embeddings1DConnector(nn.Module):
"""
Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or
other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can
@@ -263,6 +306,7 @@ class Embeddings1DConnector(torch.nn.Module):
num_learnable_registers: int | None = 128,
rope_type: LTXRopeType = LTXRopeType.SPLIT,
double_precision_rope: bool = True,
apply_gated_attention: bool = False,
):
super().__init__()
self.num_attention_heads = num_attention_heads
@@ -274,13 +318,14 @@ class Embeddings1DConnector(torch.nn.Module):
)
self.rope_type = rope_type
self.double_precision_rope = double_precision_rope
self.transformer_1d_blocks = torch.nn.ModuleList(
self.transformer_1d_blocks = nn.ModuleList(
[
_BasicTransformerBlock1D(
dim=self.inner_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
rope_type=rope_type,
apply_gated_attention=apply_gated_attention,
)
for _ in range(num_layers)
]
@@ -288,7 +333,7 @@ class Embeddings1DConnector(torch.nn.Module):
self.num_learnable_registers = num_learnable_registers
if self.num_learnable_registers:
self.learnable_registers = torch.nn.Parameter(
self.learnable_registers = nn.Parameter(
torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0
)
@@ -307,7 +352,7 @@ class Embeddings1DConnector(torch.nn.Module):
non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :]
non_zero_nums = non_zero_hidden_states.shape[1]
pad_length = hidden_states.shape[1] - non_zero_nums
adjusted_hidden_states = torch.nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0)
adjusted_hidden_states = nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0)
flipped_mask = torch.flip(attention_mask_binary, dims=[1])
hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers
@@ -358,9 +403,147 @@ class Embeddings1DConnector(torch.nn.Module):
return hidden_states, attention_mask
class LTX2TextEncoderPostModules(torch.nn.Module):
def __init__(self,):
class LTX2TextEncoderPostModules(nn.Module):
def __init__(
self,
seperated_audio_video: bool = False,
embedding_dim_gemma: int = 3840,
num_layers_gemma: int = 49,
video_attetion_heads: int = 32,
video_attention_head_dim: int = 128,
audio_attention_heads: int = 32,
audio_attention_head_dim: int = 64,
num_connetor_layers: int = 2,
apply_gated_attention: bool = False,
):
super().__init__()
self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear()
self.embeddings_connector = Embeddings1DConnector()
self.audio_embeddings_connector = Embeddings1DConnector()
if not seperated_audio_video:
self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear()
self.embeddings_connector = Embeddings1DConnector()
self.audio_embeddings_connector = Embeddings1DConnector()
else:
# LTX-2.3
self.feature_extractor_linear = GemmaSeperatedFeaturesExtractorProjLinear(
num_layers_gemma, embedding_dim_gemma, video_attetion_heads * video_attention_head_dim,
audio_attention_heads * audio_attention_head_dim)
self.embeddings_connector = Embeddings1DConnector(
attention_head_dim=video_attention_head_dim,
num_attention_heads=video_attetion_heads,
num_layers=num_connetor_layers,
apply_gated_attention=apply_gated_attention,
)
self.audio_embeddings_connector = Embeddings1DConnector(
attention_head_dim=audio_attention_head_dim,
num_attention_heads=audio_attention_heads,
num_layers=num_connetor_layers,
apply_gated_attention=apply_gated_attention,
)
def create_embeddings(
self,
video_features: torch.Tensor,
audio_features: torch.Tensor | None,
additive_attention_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]:
video_encoded, video_mask = self.embeddings_connector(video_features, additive_attention_mask)
video_encoded, binary_mask = _to_binary_mask(video_encoded, video_mask)
audio_encoded, _ = self.audio_embeddings_connector(audio_features, additive_attention_mask)
return video_encoded, audio_encoded, binary_mask
def process_hidden_states(
self,
hidden_states: tuple[torch.Tensor, ...],
attention_mask: torch.Tensor,
padding_side: str = "left",
):
video_feats, audio_feats = self.feature_extractor_linear(hidden_states, attention_mask, padding_side)
additive_mask = _convert_to_additive_mask(attention_mask, video_feats.dtype)
video_enc, audio_enc, binary_mask = self.create_embeddings(video_feats, audio_feats, additive_mask)
return video_enc, audio_enc, binary_mask
def _norm_and_concat_padded_batch(
encoded_text: torch.Tensor,
sequence_lengths: torch.Tensor,
padding_side: str = "right",
) -> torch.Tensor:
"""Normalize and flatten multi-layer hidden states, respecting padding.
Performs per-batch, per-layer normalization using masked mean and range,
then concatenates across the layer dimension.
Args:
encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].
sequence_lengths: Number of valid (non-padded) tokens per batch item.
padding_side: Whether padding is on "left" or "right".
Returns:
Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],
with padded positions zeroed out.
"""
b, t, d, l = encoded_text.shape # noqa: E741
device = encoded_text.device
# Build mask: [B, T, 1, 1]
token_indices = torch.arange(t, device=device)[None, :] # [1, T]
if padding_side == "right":
# For right padding, valid tokens are from 0 to sequence_length-1
mask = token_indices < sequence_lengths[:, None] # [B, T]
elif padding_side == "left":
# For left padding, valid tokens are from (T - sequence_length) to T-1
start_indices = t - sequence_lengths[:, None] # [B, 1]
mask = token_indices >= start_indices # [B, T]
else:
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
mask = rearrange(mask, "b t -> b t 1 1")
eps = 1e-6
# Compute masked mean: [B, 1, 1, L]
masked = encoded_text.masked_fill(~mask, 0.0)
denom = (sequence_lengths * d).view(b, 1, 1, 1)
mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps)
# Compute masked min/max: [B, 1, 1, L]
x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
range_ = x_max - x_min
# Normalize only the valid tokens
normed = 8 * (encoded_text - mean) / (range_ + eps)
# concat to be [Batch, T, D * L] - this preserves the original structure
normed = normed.reshape(b, t, -1) # [B, T, D * L]
# Apply mask to preserve original padding (set padded positions to 0)
mask_flattened = rearrange(mask, "b t 1 1 -> b t 1").expand(-1, -1, d * l)
normed = normed.masked_fill(~mask_flattened, 0.0)
return normed
def _convert_to_additive_mask(attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return (attention_mask - 1).to(dtype).reshape(
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(dtype).max
def _to_binary_mask(encoded: torch.Tensor, encoded_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert connector output mask to binary mask and apply to encoded tensor."""
binary_mask = (encoded_mask < 0.000001).to(torch.int64)
binary_mask = binary_mask.reshape([encoded.shape[0], encoded.shape[1], 1])
encoded = encoded * binary_mask
return encoded, binary_mask
def norm_and_concat_per_token_rms(
encoded_text: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
"""Per-token RMSNorm normalization for V2 models.
Args:
encoded_text: [B, T, D, L]
attention_mask: [B, T] binary mask
Returns:
[B, T, D*L] normalized tensor with padding zeroed out.
"""
B, T, D, L = encoded_text.shape # noqa: N806
variance = torch.mean(encoded_text**2, dim=2, keepdim=True) # [B,T,1,L]
normed = encoded_text * torch.rsqrt(variance + 1e-6)
normed = normed.reshape(B, T, D * L)
mask_3d = attention_mask.bool().unsqueeze(-1) # [B, T, 1]
return torch.where(mask_3d, normed, torch.zeros_like(normed))
def _rescale_norm(x: torch.Tensor, target_dim: int, source_dim: int) -> torch.Tensor:
"""Rescale normalization: x * sqrt(target_dim / source_dim)."""
return x * math.sqrt(target_dim / source_dim)

View File

@@ -555,9 +555,6 @@ class PerChannelStatistics(nn.Module):
super().__init__()
self.register_buffer("std-of-means", torch.empty(latent_channels))
self.register_buffer("mean-of-means", torch.empty(latent_channels))
self.register_buffer("mean-of-stds", torch.empty(latent_channels))
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(latent_channels))
self.register_buffer("channel", torch.empty(latent_channels))
def un_normalize(self, x: torch.Tensor) -> torch.Tensor:
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(
@@ -1335,27 +1332,49 @@ class LTX2VideoEncoder(nn.Module):
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
latent_log_var: LogVarianceType = LogVarianceType.UNIFORM,
encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
encoder_version: str = "ltx-2",
):
super().__init__()
encoder_blocks = [['res_x', {
'num_layers': 4
}], ['compress_space_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 6
}], ['compress_time_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 6
}], ['compress_all_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 2
}], ['compress_all_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 2
}]]
if encoder_version == "ltx-2":
encoder_blocks = [['res_x', {
'num_layers': 4
}], ['compress_space_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 6
}], ['compress_time_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 6
}], ['compress_all_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 2
}], ['compress_all_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 2
}]]
else:
encoder_blocks = [["res_x", {
"num_layers": 4
}], ["compress_space_res", {
"multiplier": 2
}], ["res_x", {
"num_layers": 6
}], ["compress_time_res", {
"multiplier": 2
}], ["res_x", {
"num_layers": 4
}], ["compress_all_res", {
"multiplier": 2
}], ["res_x", {
"num_layers": 2
}], ["compress_all_res", {
"multiplier": 1
}], ["res_x", {
"num_layers": 2
}]]
self.patch_size = patch_size
self.norm_layer = norm_layer
self.latent_channels = out_channels
@@ -1435,8 +1454,8 @@ class LTX2VideoEncoder(nn.Module):
# Validate frame count
frames_count = sample.shape[2]
if ((frames_count - 1) % 8) != 0:
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames "
"(e.g., 1, 9, 17, ...). Please check your input.")
frames_to_crop = (frames_count - 1) % 8
sample = sample[:, :, :-frames_to_crop, ...]
# Initial spatial compression: trade spatial resolution for channel depth
# This reduces H,W by patch_size and increases channels, making convolutions more efficient
@@ -1712,17 +1731,21 @@ def _make_decoder_block(
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
out_channels = in_channels // block_config.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=convolution_dimensions,
in_channels=in_channels,
stride=(2, 1, 1),
out_channels_reduction_factor=block_config.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
out_channels = in_channels // block_config.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=convolution_dimensions,
in_channels=in_channels,
stride=(1, 2, 2),
out_channels_reduction_factor=block_config.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
@@ -1782,6 +1805,8 @@ class LTX2VideoDecoder(nn.Module):
causal: bool = False,
timestep_conditioning: bool = False,
decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
decoder_version: str = "ltx-2",
base_channels: int = 128,
):
super().__init__()
@@ -1790,28 +1815,49 @@ class LTX2VideoDecoder(nn.Module):
# video inputs by a factor of 8 in the temporal dimension and 32 in
# each spatial dimension (height and width). This parameter determines how
# many video frames and pixels correspond to a single latent cell.
decoder_blocks = [['res_x', {
'num_layers': 5,
'inject_noise': False
}], ['compress_all', {
'residual': True,
'multiplier': 2
}], ['res_x', {
'num_layers': 5,
'inject_noise': False
}], ['compress_all', {
'residual': True,
'multiplier': 2
}], ['res_x', {
'num_layers': 5,
'inject_noise': False
}], ['compress_all', {
'residual': True,
'multiplier': 2
}], ['res_x', {
'num_layers': 5,
'inject_noise': False
}]]
if decoder_version == "ltx-2":
decoder_blocks = [['res_x', {
'num_layers': 5,
'inject_noise': False
}], ['compress_all', {
'residual': True,
'multiplier': 2
}], ['res_x', {
'num_layers': 5,
'inject_noise': False
}], ['compress_all', {
'residual': True,
'multiplier': 2
}], ['res_x', {
'num_layers': 5,
'inject_noise': False
}], ['compress_all', {
'residual': True,
'multiplier': 2
}], ['res_x', {
'num_layers': 5,
'inject_noise': False
}]]
else:
decoder_blocks = [["res_x", {
"num_layers": 4
}], ["compress_space", {
"multiplier": 2
}], ["res_x", {
"num_layers": 6
}], ["compress_time", {
"multiplier": 2
}], ["res_x", {
"num_layers": 4
}], ["compress_all", {
"multiplier": 1
}], ["res_x", {
"num_layers": 2
}], ["compress_all", {
"multiplier": 2
}], ["res_x", {
"num_layers": 2
}]]
self.video_downscale_factors = SpatioTemporalScaleFactors(
time=8,
width=32,
@@ -1833,13 +1879,14 @@ class LTX2VideoDecoder(nn.Module):
# Compute initial feature_channels by going through blocks in reverse
# This determines the channel width at the start of the decoder
feature_channels = in_channels
for block_name, block_params in list(reversed(decoder_blocks)):
block_config = block_params if isinstance(block_params, dict) else {}
if block_name == "res_x_y":
feature_channels = feature_channels * block_config.get("multiplier", 2)
if block_name == "compress_all":
feature_channels = feature_channels * block_config.get("multiplier", 1)
# feature_channels = in_channels
# for block_name, block_params in list(reversed(decoder_blocks)):
# block_config = block_params if isinstance(block_params, dict) else {}
# if block_name == "res_x_y":
# feature_channels = feature_channels * block_config.get("multiplier", 2)
# if block_name == "compress_all":
# feature_channels = feature_channels * block_config.get("multiplier", 1)
feature_channels = base_channels * 8
self.conv_in = make_conv_nd(
dims=convolution_dimensions,

View File

@@ -92,13 +92,13 @@ class LTX2AudioVideoPipeline(BasePipeline):
pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder")
pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder")
pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler")
pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
# Stage 2
if stage2_lora_config is not None:
stage2_lora_config.download_if_necessary()
pipe.stage2_lora_path = stage2_lora_config.path
# Optional, currently not used
pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
@@ -168,8 +168,8 @@ class LTX2AudioVideoPipeline(BasePipeline):
# Shape
height: Optional[int] = 512,
width: Optional[int] = 768,
num_frames=121,
frame_rate=24,
num_frames: Optional[int] = 121,
frame_rate: Optional[int] = 24,
# Classifier-free guidance
cfg_scale: Optional[float] = 3.0,
# Scheduler
@@ -238,7 +238,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float()
return video, decoded_audio
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength, initial_latents=None, num_frames=121):
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength, initial_latents=None):
b, _, f, h, w = latents.shape
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device)
initial_latents = torch.zeros_like(latents) if initial_latents is None else initial_latents
@@ -306,121 +306,20 @@ class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
output_params=("video_context", "audio_context"),
onload_model_names=("text_encoder", "text_encoder_post_modules"),
)
def _convert_to_additive_mask(self, attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return (attention_mask - 1).to(dtype).reshape(
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(dtype).max
def _run_connectors(self, pipe, encoded_input: torch.Tensor,
attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
connector_attention_mask = self._convert_to_additive_mask(attention_mask, encoded_input.dtype)
encoded, encoded_connector_attention_mask = pipe.text_encoder_post_modules.embeddings_connector(
encoded_input,
connector_attention_mask,
)
# restore the mask values to int64
attention_mask = (encoded_connector_attention_mask < 0.000001).to(torch.int64)
attention_mask = attention_mask.reshape([encoded.shape[0], encoded.shape[1], 1])
encoded = encoded * attention_mask
encoded_for_audio, _ = pipe.text_encoder_post_modules.audio_embeddings_connector(
encoded_input, connector_attention_mask)
return encoded, encoded_for_audio, attention_mask.squeeze(-1)
def _norm_and_concat_padded_batch(
self,
encoded_text: torch.Tensor,
sequence_lengths: torch.Tensor,
padding_side: str = "right",
) -> torch.Tensor:
"""Normalize and flatten multi-layer hidden states, respecting padding.
Performs per-batch, per-layer normalization using masked mean and range,
then concatenates across the layer dimension.
Args:
encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].
sequence_lengths: Number of valid (non-padded) tokens per batch item.
padding_side: Whether padding is on "left" or "right".
Returns:
Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],
with padded positions zeroed out.
"""
b, t, d, l = encoded_text.shape # noqa: E741
device = encoded_text.device
# Build mask: [B, T, 1, 1]
token_indices = torch.arange(t, device=device)[None, :] # [1, T]
if padding_side == "right":
# For right padding, valid tokens are from 0 to sequence_length-1
mask = token_indices < sequence_lengths[:, None] # [B, T]
elif padding_side == "left":
# For left padding, valid tokens are from (T - sequence_length) to T-1
start_indices = t - sequence_lengths[:, None] # [B, 1]
mask = token_indices >= start_indices # [B, T]
else:
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
mask = rearrange(mask, "b t -> b t 1 1")
eps = 1e-6
# Compute masked mean: [B, 1, 1, L]
masked = encoded_text.masked_fill(~mask, 0.0)
denom = (sequence_lengths * d).view(b, 1, 1, 1)
mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps)
# Compute masked min/max: [B, 1, 1, L]
x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
range_ = x_max - x_min
# Normalize only the valid tokens
normed = 8 * (encoded_text - mean) / (range_ + eps)
# concat to be [Batch, T, D * L] - this preserves the original structure
normed = normed.reshape(b, t, -1) # [B, T, D * L]
# Apply mask to preserve original padding (set padded positions to 0)
mask_flattened = rearrange(mask, "b t 1 1 -> b t 1").expand(-1, -1, d * l)
normed = normed.masked_fill(~mask_flattened, 0.0)
return normed
def _run_feature_extractor(self,
pipe,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
padding_side: str = "right") -> torch.Tensor:
encoded_text_features = torch.stack(hidden_states, dim=-1)
encoded_text_features_dtype = encoded_text_features.dtype
sequence_lengths = attention_mask.sum(dim=-1)
normed_concated_encoded_text_features = self._norm_and_concat_padded_batch(encoded_text_features,
sequence_lengths,
padding_side=padding_side)
return pipe.text_encoder_post_modules.feature_extractor_linear(
normed_concated_encoded_text_features.to(encoded_text_features_dtype))
def _preprocess_text(
self,
pipe,
text: str,
padding_side: str = "left",
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""
Encode a given string into feature tensors suitable for downstream tasks.
Args:
text (str): Input string to encode.
Returns:
tuple[torch.Tensor, dict[str, torch.Tensor]]: Encoded features and a dictionary with attention mask.
"""
token_pairs = pipe.tokenizer.tokenize_with_weights(text)["gemma"]
input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.device)
attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.device)
outputs = pipe.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
projected = self._run_feature_extractor(pipe,
hidden_states=outputs.hidden_states,
attention_mask=attention_mask,
padding_side=padding_side)
return projected, attention_mask
return outputs.hidden_states, attention_mask
def encode_prompt(self, pipe, text, padding_side="left"):
encoded_inputs, attention_mask = self._preprocess_text(pipe, text, padding_side)
video_encoding, audio_encoding, attention_mask = self._run_connectors(pipe, encoded_inputs, attention_mask)
hidden_states, attention_mask = self._preprocess_text(pipe, text)
video_encoding, audio_encoding, attention_mask = pipe.text_encoder_post_modules.process_hidden_states(
hidden_states, attention_mask, padding_side)
return video_encoding, audio_encoding, attention_mask
def process(self, pipe: LTX2AudioVideoPipeline, prompt: str):
@@ -539,7 +438,7 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
self.get_image_latent(pipe, img, stage1_height, stage1_width, tiled, tile_size_in_pixels,
tile_overlap_in_pixels) for img in input_images
]
video_latents, denoise_mask_video, initial_latents = pipe.apply_input_images_to_latents(video_latents, stage1_latents, input_images_indexes, input_images_strength, num_frames=num_frames)
video_latents, denoise_mask_video, initial_latents = pipe.apply_input_images_to_latents(video_latents, stage1_latents, input_images_indexes, input_images_strength)
output_dicts.update({"video_latents": video_latents, "denoise_mask_video": denoise_mask_video, "input_latents_video": initial_latents})
if use_two_stage_pipeline:
stage2_latents = [
@@ -649,6 +548,7 @@ def model_fn_ltx2(
audio_positions=audio_positions,
audio_context=audio_context,
audio_timesteps=audio_timesteps,
sigma=timestep,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)

View File

@@ -27,6 +27,7 @@ def LTX2VocoderStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith("vocoder."):
new_name = name.replace("vocoder.", "")
# new_name = name.replace("vocoder.", "")
new_name = name[len("vocoder."):]
state_dict_[new_name] = state_dict[name]
return state_dict_

View File

@@ -6,7 +6,8 @@ def LTX2VideoEncoderStateDictConverter(state_dict):
state_dict_[new_name] = state_dict[name]
elif name.startswith("vae.per_channel_statistics."):
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
state_dict_[new_name] = state_dict[name]
if new_name not in ["per_channel_statistics.channel", "per_channel_statistics.mean-of-stds", "per_channel_statistics.mean-of-stds_over_std-of-means"]:
state_dict_[new_name] = state_dict[name]
return state_dict_
@@ -18,5 +19,6 @@ def LTX2VideoDecoderStateDictConverter(state_dict):
state_dict_[new_name] = state_dict[name]
elif name.startswith("vae.per_channel_statistics."):
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
state_dict_[new_name] = state_dict[name]
return state_dict_
if new_name not in ["per_channel_statistics.channel", "per_channel_statistics.mean-of-stds", "per_channel_statistics.mean-of-stds_over_std-of-means"]:
state_dict_[new_name] = state_dict[name]
return state_dict_

View File

@@ -30,7 +30,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
)
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
height, width, num_frames = 512 * 2, 768 * 2, 121
height, width, num_frames = 512, 768, 121
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",

View File

@@ -0,0 +1,70 @@
import torch
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
from PIL import Image
from modelscope import dataset_snapshot_download
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = LTX2AudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-distilled.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-spatial-upscaler-x2-1.0.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
)
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
negative_prompt = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)
height, width, num_frames = 512 * 2, 768 * 2, 121
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"]
)
image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height))
# first frame
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
seed=43,
height=height,
width=width,
num_frames=num_frames,
tiled=True,
use_distilled_pipeline=True,
input_images=[image],
input_images_indexes=[0],
input_images_strength=1.0,
)
write_video_audio_ltx2(
video=video,
audio=audio,
output_path='ltx2.3_distilled_i2av_first.mp4',
fps=24,
audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,
)

View File

@@ -0,0 +1,56 @@
import torch
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
from PIL import Image
from modelscope import dataset_snapshot_download
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = LTX2AudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
)
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
height, width, num_frames = 512, 768, 121
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"]
)
image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height))
# first frame
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
seed=43,
height=height,
width=width,
num_frames=num_frames,
tiled=False,
input_images=[image],
input_images_indexes=[0],
input_images_strength=1.0,
num_inference_steps=40,
)
write_video_audio_ltx2(
video=video,
audio=audio,
output_path='ltx2.3_onestage_i2av_first.mp4',
fps=24,
audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,
)

View File

@@ -0,0 +1,71 @@
import torch
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
from PIL import Image
from modelscope import dataset_snapshot_download
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = LTX2AudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-spatial-upscaler-x2-1.0.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-distilled-lora-384.safetensors"),
)
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
negative_prompt = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)
height, width, num_frames = 512 * 2, 768 * 2, 121
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"]
)
image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height))
# first frame
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
seed=42,
height=height,
width=width,
num_frames=num_frames,
tiled=True,
use_two_stage_pipeline=True,
num_inference_steps=40,
input_images=[image],
input_images_indexes=[0],
input_images_strength=1.0,
)
write_video_audio_ltx2(
video=video,
audio=audio,
output_path='ltx2.3_twostage_i2av_first.mp4',
fps=24,
audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,
)

View File

@@ -0,0 +1,57 @@
import torch
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = LTX2AudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-distilled.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-spatial-upscaler-x2-1.0.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
)
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
negative_prompt = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)
height, width, num_frames = 512 * 2, 768 * 2, 121
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
seed=43,
height=height,
width=width,
num_frames=num_frames,
tiled=True,
use_distilled_pipeline=True,
)
write_video_audio_ltx2(
video=video,
audio=audio,
output_path='ltx2.3_distilled.mp4',
fps=24,
audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,
)

View File

@@ -0,0 +1,42 @@
import torch
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = LTX2AudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
)
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
height, width, num_frames = 512, 768, 121
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
seed=43,
height=height,
width=width,
num_frames=num_frames,
tiled=True,
)
write_video_audio_ltx2(
video=video,
audio=audio,
output_path='ltx2.3_onestage.mp4',
fps=24,
audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,
)

View File

@@ -0,0 +1,57 @@
import torch
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = LTX2AudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-spatial-upscaler-x2-1.0.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-distilled-lora-384.safetensors"),
)
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
negative_prompt = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)
height, width, num_frames = 512 * 2, 768 * 2, 121
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
seed=43,
height=height,
width=width,
num_frames=num_frames,
tiled=True,
use_two_stage_pipeline=True,
)
write_video_audio_ltx2(
video=video,
audio=audio,
output_path='ltx2.3_twostage.mp4',
fps=24,
audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,
)

View File

@@ -31,7 +31,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
)
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
height, width, num_frames = 512 * 2, 768 * 2, 121
height, width, num_frames = 512, 768, 121
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",