mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support ltx2.3 inference
This commit is contained in:
@@ -1279,9 +1279,268 @@ class LTX2AudioDecoder(torch.nn.Module):
|
||||
return torch.tanh(h) if self.tanh_out else h
|
||||
|
||||
|
||||
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2
|
||||
# Adopted from https://github.com/NVIDIA/BigVGAN
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _sinc(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.where(
|
||||
x == 0,
|
||||
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
||||
torch.sin(math.pi * x) / math.pi / x,
|
||||
)
|
||||
|
||||
|
||||
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor:
|
||||
even = kernel_size % 2 == 0
|
||||
half_size = kernel_size // 2
|
||||
delta_f = 4 * half_width
|
||||
amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||
if amplitude > 50.0:
|
||||
beta = 0.1102 * (amplitude - 8.7)
|
||||
elif amplitude >= 21.0:
|
||||
beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)
|
||||
else:
|
||||
beta = 0.0
|
||||
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||
time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size
|
||||
if cutoff == 0:
|
||||
filter_ = torch.zeros_like(time)
|
||||
else:
|
||||
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
|
||||
filter_ /= filter_.sum()
|
||||
return filter_.view(1, 1, kernel_size)
|
||||
|
||||
|
||||
class LowPassFilter1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cutoff: float = 0.5,
|
||||
half_width: float = 0.6,
|
||||
stride: int = 1,
|
||||
padding: bool = True,
|
||||
padding_mode: str = "replicate",
|
||||
kernel_size: int = 12,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if cutoff < -0.0:
|
||||
raise ValueError("Minimum cutoff must be larger than zero.")
|
||||
if cutoff > 0.5:
|
||||
raise ValueError("A cutoff above 0.5 does not make sense.")
|
||||
self.kernel_size = kernel_size
|
||||
self.even = kernel_size % 2 == 0
|
||||
self.pad_left = kernel_size // 2 - int(self.even)
|
||||
self.pad_right = kernel_size // 2
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.padding_mode = padding_mode
|
||||
self.register_buffer("filter", kaiser_sinc_filter1d(cutoff, half_width, kernel_size))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
_, n_channels, _ = x.shape
|
||||
if self.padding:
|
||||
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
||||
return F.conv1d(x, self.filter.expand(n_channels, -1, -1), stride=self.stride, groups=n_channels)
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ratio: int = 2,
|
||||
kernel_size: int | None = None,
|
||||
persistent: bool = True,
|
||||
window_type: str = "kaiser",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.stride = ratio
|
||||
|
||||
if window_type == "hann":
|
||||
# Hann-windowed sinc filter equivalent to torchaudio.functional.resample
|
||||
rolloff = 0.99
|
||||
lowpass_filter_width = 6
|
||||
width = math.ceil(lowpass_filter_width / rolloff)
|
||||
self.kernel_size = 2 * width * ratio + 1
|
||||
self.pad = width
|
||||
self.pad_left = 2 * width * ratio
|
||||
self.pad_right = self.kernel_size - ratio
|
||||
time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff
|
||||
time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width)
|
||||
window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2
|
||||
sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1)
|
||||
else:
|
||||
# Kaiser-windowed sinc filter (BigVGAN default).
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
sinc_filter = kaiser_sinc_filter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
kernel_size=self.kernel_size,
|
||||
)
|
||||
|
||||
self.register_buffer("filter", sinc_filter, persistent=persistent)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
_, n_channels, _ = x.shape
|
||||
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
||||
filt = self.filter.to(dtype=x.dtype, device=x.device).expand(n_channels, -1, -1)
|
||||
x = self.ratio * F.conv_transpose1d(x, filt, stride=self.stride, groups=n_channels)
|
||||
return x[..., self.pad_left : -self.pad_right]
|
||||
|
||||
|
||||
class DownSample1d(nn.Module):
|
||||
def __init__(self, ratio: int = 2, kernel_size: int | None = None) -> None:
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.lowpass = LowPassFilter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
stride=ratio,
|
||||
kernel_size=self.kernel_size,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.lowpass(x)
|
||||
|
||||
|
||||
class Activation1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
activation: nn.Module,
|
||||
up_ratio: int = 2,
|
||||
down_ratio: int = 2,
|
||||
up_kernel_size: int = 12,
|
||||
down_kernel_size: int = 12,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.act = activation
|
||||
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.upsample(x)
|
||||
x = self.act(x)
|
||||
return self.downsample(x)
|
||||
|
||||
|
||||
class Snake(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
alpha: float = 1.0,
|
||||
alpha_trainable: bool = True,
|
||||
alpha_logscale: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.eps = 1e-9
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
return x + (1.0 / (alpha + self.eps)) * torch.sin(x * alpha).pow(2)
|
||||
|
||||
|
||||
class SnakeBeta(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
alpha: float = 1.0,
|
||||
alpha_trainable: bool = True,
|
||||
alpha_logscale: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.beta = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
|
||||
self.beta.requires_grad = alpha_trainable
|
||||
self.eps = 1e-9
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
beta = torch.exp(beta)
|
||||
return x + (1.0 / (beta + self.eps)) * torch.sin(x * alpha).pow(2)
|
||||
|
||||
|
||||
class AMPBlock1(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int = 3,
|
||||
dilation: tuple[int, int, int] = (1, 3, 5),
|
||||
activation: str = "snake",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
),
|
||||
nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
),
|
||||
nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
|
||||
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
|
||||
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
|
||||
]
|
||||
)
|
||||
|
||||
self.acts1 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs1))])
|
||||
self.acts2 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs2))])
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2, strict=True):
|
||||
xt = a1(x)
|
||||
xt = c1(xt)
|
||||
xt = a2(xt)
|
||||
xt = c2(xt)
|
||||
x = x + xt
|
||||
return x
|
||||
|
||||
|
||||
class LTX2Vocoder(torch.nn.Module):
|
||||
"""
|
||||
Vocoder model for synthesizing audio from Mel spectrograms.
|
||||
LTX2Vocoder model for synthesizing audio from Mel spectrograms.
|
||||
Args:
|
||||
resblock_kernel_sizes: List of kernel sizes for the residual blocks.
|
||||
This value is read from the checkpoint at `config.vocoder.resblock_kernel_sizes`.
|
||||
@@ -1293,28 +1552,33 @@ class LTX2Vocoder(torch.nn.Module):
|
||||
This value is read from the checkpoint at `config.vocoder.resblock_dilation_sizes`.
|
||||
upsample_initial_channel: Initial number of channels for the upsampling layers.
|
||||
This value is read from the checkpoint at `config.vocoder.upsample_initial_channel`.
|
||||
stereo: Whether to use stereo output.
|
||||
This value is read from the checkpoint at `config.vocoder.stereo`.
|
||||
resblock: Type of residual block to use.
|
||||
resblock: Type of residual block to use ("1", "2", or "AMP1").
|
||||
This value is read from the checkpoint at `config.vocoder.resblock`.
|
||||
output_sample_rate: Waveform sample rate.
|
||||
This value is read from the checkpoint at `config.vocoder.output_sample_rate`.
|
||||
output_sampling_rate: Waveform sample rate.
|
||||
This value is read from the checkpoint at `config.vocoder.output_sampling_rate`.
|
||||
activation: Activation type for BigVGAN v2 ("snake" or "snakebeta"). Only used when resblock="AMP1".
|
||||
use_tanh_at_final: Apply tanh at the output (when apply_final_activation=True).
|
||||
apply_final_activation: Whether to apply the final tanh/clamp activation.
|
||||
use_bias_at_final: Whether to use bias in the final conv layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
def __init__( # noqa: PLR0913
|
||||
self,
|
||||
resblock_kernel_sizes: List[int] | None = [3, 7, 11],
|
||||
upsample_rates: List[int] | None = [6, 5, 2, 2, 2],
|
||||
upsample_kernel_sizes: List[int] | None = [16, 15, 8, 4, 4],
|
||||
resblock_dilation_sizes: List[List[int]] | None = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
upsample_initial_channel: int = 1024,
|
||||
stereo: bool = True,
|
||||
resblock: str = "1",
|
||||
output_sample_rate: int = 24000,
|
||||
):
|
||||
output_sampling_rate: int = 24000,
|
||||
activation: str = "snake",
|
||||
use_tanh_at_final: bool = True,
|
||||
apply_final_activation: bool = True,
|
||||
use_bias_at_final: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Initialize default values if not provided. Note that mutable default values are not supported.
|
||||
# Mutable default values are not supported as default arguments.
|
||||
if resblock_kernel_sizes is None:
|
||||
resblock_kernel_sizes = [3, 7, 11]
|
||||
if upsample_rates is None:
|
||||
@@ -1324,36 +1588,60 @@ class LTX2Vocoder(torch.nn.Module):
|
||||
if resblock_dilation_sizes is None:
|
||||
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
||||
|
||||
self.output_sample_rate = output_sample_rate
|
||||
self.output_sampling_rate = output_sampling_rate
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
in_channels = 128 if stereo else 64
|
||||
self.conv_pre = nn.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
||||
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
|
||||
self.use_tanh_at_final = use_tanh_at_final
|
||||
self.apply_final_activation = apply_final_activation
|
||||
self.is_amp = resblock == "AMP1"
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True)):
|
||||
self.ups.append(
|
||||
nn.ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=(kernel_size - stride) // 2,
|
||||
)
|
||||
# All production checkpoints are stereo: 128 input channels (2 stereo channels x 64 mel
|
||||
# bins each), 2 output channels.
|
||||
self.conv_pre = nn.Conv1d(
|
||||
in_channels=128,
|
||||
out_channels=upsample_initial_channel,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=3,
|
||||
)
|
||||
resblock_cls = ResBlock1 if resblock == "1" else AMPBlock1
|
||||
|
||||
self.ups = nn.ModuleList(
|
||||
nn.ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=(kernel_size - stride) // 2,
|
||||
)
|
||||
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True))
|
||||
)
|
||||
|
||||
final_channels = upsample_initial_channel // (2 ** len(upsample_rates))
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i, _ in enumerate(self.ups):
|
||||
|
||||
for i in range(len(upsample_rates)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes, strict=True):
|
||||
self.resblocks.append(resblock_class(ch, kernel_size, dilations))
|
||||
if self.is_amp:
|
||||
self.resblocks.append(resblock_cls(ch, kernel_size, dilations, activation=activation))
|
||||
else:
|
||||
self.resblocks.append(resblock_cls(ch, kernel_size, dilations))
|
||||
|
||||
out_channels = 2 if stereo else 1
|
||||
final_channels = upsample_initial_channel // (2**self.num_upsamples)
|
||||
self.conv_post = nn.Conv1d(final_channels, out_channels, 7, 1, padding=3)
|
||||
if self.is_amp:
|
||||
self.act_post: nn.Module = Activation1d(SnakeBeta(final_channels))
|
||||
else:
|
||||
self.act_post = nn.LeakyReLU()
|
||||
|
||||
self.upsample_factor = math.prod(layer.stride[0] for layer in self.ups)
|
||||
# All production checkpoints are stereo: this final conv maps `final_channels` to 2 output channels (stereo).
|
||||
self.conv_post = nn.Conv1d(
|
||||
in_channels=final_channels,
|
||||
out_channels=2,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=3,
|
||||
bias=use_bias_at_final,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
@@ -1374,7 +1662,8 @@ class LTX2Vocoder(torch.nn.Module):
|
||||
x = self.conv_pre(x)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
if not self.is_amp:
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
start = i * self.num_kernels
|
||||
end = start + self.num_kernels
|
||||
@@ -1386,23 +1675,198 @@ class LTX2Vocoder(torch.nn.Module):
|
||||
[self.resblocks[idx](x) for idx in range(start, end)],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
x = block_outputs.mean(dim=0)
|
||||
|
||||
x = self.conv_post(F.leaky_relu(x))
|
||||
return torch.tanh(x)
|
||||
x = self.act_post(x)
|
||||
x = self.conv_post(x)
|
||||
|
||||
if self.apply_final_activation:
|
||||
x = torch.tanh(x) if self.use_tanh_at_final else torch.clamp(x, -1, 1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def decode_audio(latent: torch.Tensor, audio_decoder: "LTX2AudioDecoder", vocoder: "LTX2Vocoder") -> torch.Tensor:
|
||||
class _STFTFn(nn.Module):
|
||||
"""Implements STFT as a convolution with precomputed DFT x Hann-window bases.
|
||||
The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal
|
||||
Hann window are stored as buffers and loaded from the checkpoint. Using the exact
|
||||
bfloat16 bases from training ensures the mel values fed to the BWE generator are
|
||||
bit-identical to what it was trained on.
|
||||
"""
|
||||
Decode an audio latent representation using the provided audio decoder and vocoder.
|
||||
Args:
|
||||
latent: Input audio latent tensor.
|
||||
audio_decoder: Model to decode the latent to waveform features.
|
||||
vocoder: Model to convert decoded features to audio waveform.
|
||||
Returns:
|
||||
Decoded audio as a float tensor.
|
||||
|
||||
def __init__(self, filter_length: int, hop_length: int, win_length: int) -> None:
|
||||
super().__init__()
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
n_freqs = filter_length // 2 + 1
|
||||
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
||||
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
||||
|
||||
def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute magnitude and phase spectrogram from a batch of waveforms.
|
||||
Applies causal (left-only) padding of win_length - hop_length samples so that
|
||||
each output frame depends only on past and present input — no lookahead.
|
||||
Args:
|
||||
y: Waveform tensor of shape (B, T).
|
||||
Returns:
|
||||
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
||||
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
||||
"""
|
||||
if y.dim() == 2:
|
||||
y = y.unsqueeze(1) # (B, 1, T)
|
||||
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
|
||||
y = F.pad(y, (left_pad, 0))
|
||||
spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
|
||||
n_freqs = spec.shape[1] // 2
|
||||
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
|
||||
magnitude = torch.sqrt(real**2 + imag**2)
|
||||
phase = torch.atan2(imag.float(), real.float()).to(real.dtype)
|
||||
return magnitude, phase
|
||||
|
||||
|
||||
class MelSTFT(nn.Module):
|
||||
"""Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.
|
||||
Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input
|
||||
waveform and projecting the linear magnitude spectrum onto the mel filterbank.
|
||||
The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint
|
||||
(mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).
|
||||
"""
|
||||
decoded_audio = audio_decoder(latent)
|
||||
decoded_audio = vocoder(decoded_audio).squeeze(0).float()
|
||||
return decoded_audio
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter_length: int,
|
||||
hop_length: int,
|
||||
win_length: int,
|
||||
n_mel_channels: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.stft_fn = _STFTFn(filter_length, hop_length, win_length)
|
||||
|
||||
# Initialized to zeros; load_state_dict overwrites with the checkpoint's
|
||||
# exact bfloat16 filterbank (vocoder.mel_stft.mel_basis, shape [n_mels, n_freqs]).
|
||||
n_freqs = filter_length // 2 + 1
|
||||
self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs))
|
||||
|
||||
def mel_spectrogram(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Compute log-mel spectrogram and auxiliary spectral quantities.
|
||||
Args:
|
||||
y: Waveform tensor of shape (B, T).
|
||||
Returns:
|
||||
log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).
|
||||
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
||||
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
||||
energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames).
|
||||
"""
|
||||
magnitude, phase = self.stft_fn(y)
|
||||
energy = torch.norm(magnitude, dim=1)
|
||||
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
|
||||
log_mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
return log_mel, magnitude, phase, energy
|
||||
|
||||
|
||||
class LTX2VocoderWithBWE(nn.Module):
|
||||
"""LTX2Vocoder with bandwidth extension (BWE) upsampling.
|
||||
Chains a mel-to-wav vocoder with a BWE module that upsamples the output
|
||||
to a higher sample rate. The BWE computes a mel spectrogram from the
|
||||
vocoder output, runs it through a second generator to predict a residual,
|
||||
and adds it to a sinc-resampled skip connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_sampling_rate: int = 16000,
|
||||
output_sampling_rate: int = 48000,
|
||||
hop_length: int = 80,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.vocoder = LTX2Vocoder(
|
||||
resblock_kernel_sizes=[3, 7, 11],
|
||||
upsample_rates=[5, 2, 2, 2, 2, 2],
|
||||
upsample_kernel_sizes=[11, 4, 4, 4, 4, 4],
|
||||
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
upsample_initial_channel=1536,
|
||||
resblock="AMP1",
|
||||
activation="snakebeta",
|
||||
use_tanh_at_final=False,
|
||||
apply_final_activation=True,
|
||||
use_bias_at_final=False,
|
||||
output_sampling_rate=input_sampling_rate,
|
||||
)
|
||||
self.bwe_generator = LTX2Vocoder(
|
||||
resblock_kernel_sizes=[3, 7, 11],
|
||||
upsample_rates=[6, 5, 2, 2, 2],
|
||||
upsample_kernel_sizes=[12, 11, 4, 4, 4],
|
||||
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
upsample_initial_channel=512,
|
||||
resblock="AMP1",
|
||||
activation="snakebeta",
|
||||
use_tanh_at_final=False,
|
||||
apply_final_activation=False,
|
||||
use_bias_at_final=False,
|
||||
output_sampling_rate=output_sampling_rate,
|
||||
)
|
||||
|
||||
self.mel_stft = MelSTFT(
|
||||
filter_length=512,
|
||||
hop_length=hop_length,
|
||||
win_length=512,
|
||||
n_mel_channels=64,
|
||||
)
|
||||
self.input_sampling_rate = input_sampling_rate
|
||||
self.output_sampling_rate = output_sampling_rate
|
||||
self.hop_length = hop_length
|
||||
# Compute the resampler on CPU so the sinc filter is materialized even when
|
||||
# the model is constructed on meta device (SingleGPUModelBuilder pattern).
|
||||
# The filter is not stored in the checkpoint (persistent=False).
|
||||
with torch.device("cpu"):
|
||||
self.resampler = UpSample1d(
|
||||
ratio=output_sampling_rate // input_sampling_rate, persistent=False, window_type="hann"
|
||||
)
|
||||
|
||||
@property
|
||||
def conv_pre(self) -> nn.Conv1d:
|
||||
return self.vocoder.conv_pre
|
||||
|
||||
@property
|
||||
def conv_post(self) -> nn.Conv1d:
|
||||
return self.vocoder.conv_post
|
||||
|
||||
def _compute_mel(self, audio: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute log-mel spectrogram from waveform using causal STFT bases.
|
||||
Args:
|
||||
audio: Waveform tensor of shape (B, C, T).
|
||||
Returns:
|
||||
mel: Log-mel spectrogram of shape (B, C, n_mels, T_frames).
|
||||
"""
|
||||
batch, n_channels, _ = audio.shape
|
||||
flat = audio.reshape(batch * n_channels, -1) # (B*C, T)
|
||||
mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
|
||||
return mel.reshape(batch, n_channels, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
|
||||
|
||||
def forward(self, mel_spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Run the full vocoder + BWE forward pass.
|
||||
Args:
|
||||
mel_spec: Mel spectrogram of shape (B, 2, T, mel_bins) for stereo
|
||||
or (B, T, mel_bins) for mono. Same format as LTX2Vocoder.forward.
|
||||
Returns:
|
||||
Waveform tensor of shape (B, out_channels, T_out) clipped to [-1, 1].
|
||||
"""
|
||||
x = self.vocoder(mel_spec)
|
||||
_, _, length_low_rate = x.shape
|
||||
output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate
|
||||
|
||||
# Pad to multiple of hop_length for exact mel frame count
|
||||
remainder = length_low_rate % self.hop_length
|
||||
if remainder != 0:
|
||||
x = F.pad(x, (0, self.hop_length - remainder))
|
||||
|
||||
# Compute mel spectrogram from vocoder output: (B, C, n_mels, T_frames)
|
||||
mel = self._compute_mel(x)
|
||||
|
||||
# LTX2Vocoder.forward expects (B, C, T, mel_bins) — transpose before calling bwe_generator
|
||||
mel_for_bwe = mel.transpose(2, 3) # (B, C, T_frames, mel_bins)
|
||||
residual = self.bwe_generator(mel_for_bwe)
|
||||
skip = self.resampler(x)
|
||||
assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
|
||||
|
||||
return torch.clamp(residual + skip, -1, 1)[..., :output_length]
|
||||
|
||||
@@ -251,11 +251,27 @@ class Modality:
|
||||
Input data for a single modality (video or audio) in the transformer.
|
||||
Bundles the latent tokens, timestep embeddings, positional information,
|
||||
and text conditioning context for processing by the diffusion transformer.
|
||||
Attributes:
|
||||
latent: Patchified latent tokens, shape ``(B, T, D)`` where *B* is
|
||||
the batch size, *T* is the total number of tokens (noisy +
|
||||
conditioning), and *D* is the input dimension.
|
||||
timesteps: Per-token timestep embeddings, shape ``(B, T)``.
|
||||
positions: Positional coordinates, shape ``(B, 3, T)`` for video
|
||||
(time, height, width) or ``(B, 1, T)`` for audio.
|
||||
context: Text conditioning embeddings from the prompt encoder.
|
||||
enabled: Whether this modality is active in the current forward pass.
|
||||
context_mask: Optional mask for the text context tokens.
|
||||
attention_mask: Optional 2-D self-attention mask, shape ``(B, T, T)``.
|
||||
Values in ``[0, 1]`` where ``1`` = full attention and ``0`` = no
|
||||
attention. ``None`` means unrestricted (full) attention between
|
||||
all tokens. Built incrementally by conditioning items; see
|
||||
:class:`~ltx_core.conditioning.types.attention_strength_wrapper.ConditioningItemAttentionStrengthWrapper`.
|
||||
"""
|
||||
|
||||
latent: (
|
||||
torch.Tensor
|
||||
) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
|
||||
sigma: torch.Tensor # Shape: (B,). Current sigma value, used for cross-attention timestep calculation.
|
||||
timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps
|
||||
positions: (
|
||||
torch.Tensor
|
||||
@@ -263,6 +279,7 @@ class Modality:
|
||||
context: torch.Tensor
|
||||
enabled: bool = True
|
||||
context_mask: torch.Tensor | None = None
|
||||
attention_mask: torch.Tensor | None = None
|
||||
|
||||
|
||||
def to_denoised(
|
||||
|
||||
@@ -225,6 +225,17 @@ class BatchedPerturbationConfig:
|
||||
return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)])
|
||||
|
||||
|
||||
|
||||
ADALN_NUM_BASE_PARAMS = 6
|
||||
# Cross-attention AdaLN adds 3 more (scale, shift, gate) for the CA norm.
|
||||
ADALN_NUM_CROSS_ATTN_PARAMS = 3
|
||||
|
||||
|
||||
def adaln_embedding_coefficient(cross_attention_adaln: bool) -> int:
|
||||
"""Total number of AdaLN parameters per block."""
|
||||
return ADALN_NUM_BASE_PARAMS + (ADALN_NUM_CROSS_ATTN_PARAMS if cross_attention_adaln else 0)
|
||||
|
||||
|
||||
class AdaLayerNormSingle(torch.nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm single (adaLN-single).
|
||||
@@ -460,6 +471,7 @@ class Attention(torch.nn.Module):
|
||||
dim_head: int = 64,
|
||||
norm_eps: float = 1e-6,
|
||||
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||
apply_gated_attention: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.rope_type = rope_type
|
||||
@@ -477,6 +489,12 @@ class Attention(torch.nn.Module):
|
||||
self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True)
|
||||
self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True)
|
||||
|
||||
# Optional per-head gating
|
||||
if apply_gated_attention:
|
||||
self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True)
|
||||
else:
|
||||
self.to_gate_logits = None
|
||||
|
||||
self.to_out = torch.nn.Sequential(torch.nn.Linear(inner_dim, query_dim, bias=True), torch.nn.Identity())
|
||||
|
||||
def forward(
|
||||
@@ -486,6 +504,8 @@ class Attention(torch.nn.Module):
|
||||
mask: torch.Tensor | None = None,
|
||||
pe: torch.Tensor | None = None,
|
||||
k_pe: torch.Tensor | None = None,
|
||||
perturbation_mask: torch.Tensor | None = None,
|
||||
all_perturbed: bool = False,
|
||||
) -> torch.Tensor:
|
||||
q = self.to_q(x)
|
||||
context = x if context is None else context
|
||||
@@ -517,6 +537,19 @@ class Attention(torch.nn.Module):
|
||||
|
||||
# Reshape back to original format
|
||||
out = out.flatten(2, 3)
|
||||
|
||||
# Apply per-head gating if enabled
|
||||
if self.to_gate_logits is not None:
|
||||
gate_logits = self.to_gate_logits(x) # (B, T, H)
|
||||
b, t, _ = out.shape
|
||||
# Reshape to (B, T, H, D) for per-head gating
|
||||
out = out.view(b, t, self.heads, self.dim_head)
|
||||
# Apply gating: 2 * sigmoid(x) so that zero-init gives identity (2 * 0.5 = 1.0)
|
||||
gates = 2.0 * torch.sigmoid(gate_logits) # (B, T, H)
|
||||
out = out * gates.unsqueeze(-1) # (B, T, H, D) * (B, T, H, 1)
|
||||
# Reshape back to (B, T, H*D)
|
||||
out = out.view(b, t, self.heads * self.dim_head)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
@@ -545,7 +578,6 @@ class PixArtAlphaTextProjection(torch.nn.Module):
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TransformerArgs:
|
||||
x: torch.Tensor
|
||||
@@ -558,7 +590,10 @@ class TransformerArgs:
|
||||
cross_scale_shift_timestep: torch.Tensor | None
|
||||
cross_gate_timestep: torch.Tensor | None
|
||||
enabled: bool
|
||||
|
||||
prompt_timestep: torch.Tensor | None = None
|
||||
self_attention_mask: torch.Tensor | None = (
|
||||
None # Additive log-space self-attention bias (B, 1, T, T), None = full attention
|
||||
)
|
||||
|
||||
|
||||
class TransformerArgsPreprocessor:
|
||||
@@ -566,7 +601,6 @@ class TransformerArgsPreprocessor:
|
||||
self,
|
||||
patchify_proj: torch.nn.Linear,
|
||||
adaln: AdaLayerNormSingle,
|
||||
caption_projection: PixArtAlphaTextProjection,
|
||||
inner_dim: int,
|
||||
max_pos: list[int],
|
||||
num_attention_heads: int,
|
||||
@@ -575,10 +609,11 @@ class TransformerArgsPreprocessor:
|
||||
double_precision_rope: bool,
|
||||
positional_embedding_theta: float,
|
||||
rope_type: LTXRopeType,
|
||||
caption_projection: torch.nn.Module | None = None,
|
||||
prompt_adaln: AdaLayerNormSingle | None = None,
|
||||
) -> None:
|
||||
self.patchify_proj = patchify_proj
|
||||
self.adaln = adaln
|
||||
self.caption_projection = caption_projection
|
||||
self.inner_dim = inner_dim
|
||||
self.max_pos = max_pos
|
||||
self.num_attention_heads = num_attention_heads
|
||||
@@ -587,18 +622,18 @@ class TransformerArgsPreprocessor:
|
||||
self.double_precision_rope = double_precision_rope
|
||||
self.positional_embedding_theta = positional_embedding_theta
|
||||
self.rope_type = rope_type
|
||||
self.caption_projection = caption_projection
|
||||
self.prompt_adaln = prompt_adaln
|
||||
|
||||
def _prepare_timestep(
|
||||
self, timestep: torch.Tensor, batch_size: int, hidden_dtype: torch.dtype
|
||||
self, timestep: torch.Tensor, adaln: AdaLayerNormSingle, batch_size: int, hidden_dtype: torch.dtype
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Prepare timestep embeddings."""
|
||||
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep, embedded_timestep = self.adaln(
|
||||
timestep.flatten(),
|
||||
timestep_scaled = timestep * self.timestep_scale_multiplier
|
||||
timestep, embedded_timestep = adaln(
|
||||
timestep_scaled.flatten(),
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
|
||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
||||
@@ -608,14 +643,12 @@ class TransformerArgsPreprocessor:
|
||||
self,
|
||||
context: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
) -> torch.Tensor:
|
||||
"""Prepare context for transformer blocks."""
|
||||
if self.caption_projection is not None:
|
||||
context = self.caption_projection(context)
|
||||
batch_size = x.shape[0]
|
||||
context = self.caption_projection(context)
|
||||
context = context.view(batch_size, -1, x.shape[-1])
|
||||
|
||||
return context, attention_mask
|
||||
return context.view(batch_size, -1, x.shape[-1])
|
||||
|
||||
def _prepare_attention_mask(self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype) -> torch.Tensor | None:
|
||||
"""Prepare attention mask."""
|
||||
@@ -626,6 +659,34 @@ class TransformerArgsPreprocessor:
|
||||
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||
) * torch.finfo(x_dtype).max
|
||||
|
||||
def _prepare_self_attention_mask(
|
||||
self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype
|
||||
) -> torch.Tensor | None:
|
||||
"""Prepare self-attention mask by converting [0,1] values to additive log-space bias.
|
||||
Input shape: (B, T, T) with values in [0, 1].
|
||||
Output shape: (B, 1, T, T) with 0.0 for full attention and a large negative value
|
||||
for masked positions.
|
||||
Positions with attention_mask <= 0 are fully masked (mapped to the dtype's minimum
|
||||
representable value). Strictly positive entries are converted via log-space for
|
||||
smooth attenuation, with small values clamped for numerical stability.
|
||||
Returns None if input is None (no masking).
|
||||
"""
|
||||
if attention_mask is None:
|
||||
return None
|
||||
|
||||
# Convert [0, 1] attention mask to additive log-space bias:
|
||||
# 1.0 -> log(1.0) = 0.0 (no bias, full attention)
|
||||
# 0.0 -> finfo.min (fully masked)
|
||||
finfo = torch.finfo(x_dtype)
|
||||
eps = finfo.tiny
|
||||
|
||||
bias = torch.full_like(attention_mask, finfo.min, dtype=x_dtype)
|
||||
positive = attention_mask > 0
|
||||
if positive.any():
|
||||
bias[positive] = torch.log(attention_mask[positive].clamp(min=eps)).to(x_dtype)
|
||||
|
||||
return bias.unsqueeze(1) # (B, 1, T, T) for head broadcast
|
||||
|
||||
def _prepare_positional_embeddings(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -653,11 +714,20 @@ class TransformerArgsPreprocessor:
|
||||
def prepare(
|
||||
self,
|
||||
modality: Modality,
|
||||
cross_modality: Modality | None = None, # noqa: ARG002
|
||||
) -> TransformerArgs:
|
||||
x = self.patchify_proj(modality.latent)
|
||||
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], modality.latent.dtype)
|
||||
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
|
||||
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
|
||||
batch_size = x.shape[0]
|
||||
timestep, embedded_timestep = self._prepare_timestep(
|
||||
modality.timesteps, self.adaln, batch_size, modality.latent.dtype
|
||||
)
|
||||
prompt_timestep = None
|
||||
if self.prompt_adaln is not None:
|
||||
prompt_timestep, _ = self._prepare_timestep(
|
||||
modality.sigma, self.prompt_adaln, batch_size, modality.latent.dtype
|
||||
)
|
||||
context = self._prepare_context(modality.context, x)
|
||||
attention_mask = self._prepare_attention_mask(modality.context_mask, modality.latent.dtype)
|
||||
pe = self._prepare_positional_embeddings(
|
||||
positions=modality.positions,
|
||||
inner_dim=self.inner_dim,
|
||||
@@ -666,6 +736,7 @@ class TransformerArgsPreprocessor:
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
x_dtype=modality.latent.dtype,
|
||||
)
|
||||
self_attention_mask = self._prepare_self_attention_mask(modality.attention_mask, modality.latent.dtype)
|
||||
return TransformerArgs(
|
||||
x=x,
|
||||
context=context,
|
||||
@@ -677,6 +748,8 @@ class TransformerArgsPreprocessor:
|
||||
cross_scale_shift_timestep=None,
|
||||
cross_gate_timestep=None,
|
||||
enabled=modality.enabled,
|
||||
prompt_timestep=prompt_timestep,
|
||||
self_attention_mask=self_attention_mask,
|
||||
)
|
||||
|
||||
|
||||
@@ -685,7 +758,6 @@ class MultiModalTransformerArgsPreprocessor:
|
||||
self,
|
||||
patchify_proj: torch.nn.Linear,
|
||||
adaln: AdaLayerNormSingle,
|
||||
caption_projection: PixArtAlphaTextProjection,
|
||||
cross_scale_shift_adaln: AdaLayerNormSingle,
|
||||
cross_gate_adaln: AdaLayerNormSingle,
|
||||
inner_dim: int,
|
||||
@@ -699,11 +771,12 @@ class MultiModalTransformerArgsPreprocessor:
|
||||
positional_embedding_theta: float,
|
||||
rope_type: LTXRopeType,
|
||||
av_ca_timestep_scale_multiplier: int,
|
||||
caption_projection: torch.nn.Module | None = None,
|
||||
prompt_adaln: AdaLayerNormSingle | None = None,
|
||||
) -> None:
|
||||
self.simple_preprocessor = TransformerArgsPreprocessor(
|
||||
patchify_proj=patchify_proj,
|
||||
adaln=adaln,
|
||||
caption_projection=caption_projection,
|
||||
inner_dim=inner_dim,
|
||||
max_pos=max_pos,
|
||||
num_attention_heads=num_attention_heads,
|
||||
@@ -712,6 +785,8 @@ class MultiModalTransformerArgsPreprocessor:
|
||||
double_precision_rope=double_precision_rope,
|
||||
positional_embedding_theta=positional_embedding_theta,
|
||||
rope_type=rope_type,
|
||||
caption_projection=caption_projection,
|
||||
prompt_adaln=prompt_adaln,
|
||||
)
|
||||
self.cross_scale_shift_adaln = cross_scale_shift_adaln
|
||||
self.cross_gate_adaln = cross_gate_adaln
|
||||
@@ -722,8 +797,22 @@ class MultiModalTransformerArgsPreprocessor:
|
||||
def prepare(
|
||||
self,
|
||||
modality: Modality,
|
||||
cross_modality: Modality | None = None,
|
||||
) -> TransformerArgs:
|
||||
transformer_args = self.simple_preprocessor.prepare(modality)
|
||||
if cross_modality is None:
|
||||
return transformer_args
|
||||
|
||||
if cross_modality.sigma.numel() > 1:
|
||||
if cross_modality.sigma.shape[0] != modality.timesteps.shape[0]:
|
||||
raise ValueError("Cross modality sigma must have the same batch size as the modality")
|
||||
if cross_modality.sigma.ndim != 1:
|
||||
raise ValueError("Cross modality sigma must be a 1D tensor")
|
||||
|
||||
cross_timestep = cross_modality.sigma.view(
|
||||
modality.timesteps.shape[0], 1, *[1] * len(modality.timesteps.shape[2:])
|
||||
)
|
||||
|
||||
cross_pe = self.simple_preprocessor._prepare_positional_embeddings(
|
||||
positions=modality.positions[:, 0:1, :],
|
||||
inner_dim=self.audio_cross_attention_dim,
|
||||
@@ -734,7 +823,7 @@ class MultiModalTransformerArgsPreprocessor:
|
||||
)
|
||||
|
||||
cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep(
|
||||
timestep=modality.timesteps,
|
||||
timestep=cross_timestep,
|
||||
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
|
||||
batch_size=transformer_args.x.shape[0],
|
||||
hidden_dtype=modality.latent.dtype,
|
||||
@@ -749,7 +838,7 @@ class MultiModalTransformerArgsPreprocessor:
|
||||
|
||||
def _prepare_cross_attention_timestep(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
timestep: torch.Tensor | None,
|
||||
timestep_scale_multiplier: int,
|
||||
batch_size: int,
|
||||
hidden_dtype: torch.dtype,
|
||||
@@ -779,6 +868,8 @@ class TransformerConfig:
|
||||
heads: int
|
||||
d_head: int
|
||||
context_dim: int
|
||||
apply_gated_attention: bool = False
|
||||
cross_attention_adaln: bool = False
|
||||
|
||||
|
||||
class BasicAVTransformerBlock(torch.nn.Module):
|
||||
@@ -801,6 +892,7 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
||||
context_dim=None,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
apply_gated_attention=video.apply_gated_attention,
|
||||
)
|
||||
self.attn2 = Attention(
|
||||
query_dim=video.dim,
|
||||
@@ -809,9 +901,11 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
||||
dim_head=video.d_head,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
apply_gated_attention=video.apply_gated_attention,
|
||||
)
|
||||
self.ff = FeedForward(video.dim, dim_out=video.dim)
|
||||
self.scale_shift_table = torch.nn.Parameter(torch.empty(6, video.dim))
|
||||
video_sst_size = adaln_embedding_coefficient(video.cross_attention_adaln)
|
||||
self.scale_shift_table = torch.nn.Parameter(torch.empty(video_sst_size, video.dim))
|
||||
|
||||
if audio is not None:
|
||||
self.audio_attn1 = Attention(
|
||||
@@ -821,6 +915,7 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
||||
context_dim=None,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
apply_gated_attention=audio.apply_gated_attention,
|
||||
)
|
||||
self.audio_attn2 = Attention(
|
||||
query_dim=audio.dim,
|
||||
@@ -829,9 +924,11 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
||||
dim_head=audio.d_head,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
apply_gated_attention=audio.apply_gated_attention,
|
||||
)
|
||||
self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim)
|
||||
self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(6, audio.dim))
|
||||
audio_sst_size = adaln_embedding_coefficient(audio.cross_attention_adaln)
|
||||
self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(audio_sst_size, audio.dim))
|
||||
|
||||
if audio is not None and video is not None:
|
||||
# Q: Video, K,V: Audio
|
||||
@@ -842,6 +939,7 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
||||
dim_head=audio.d_head,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
apply_gated_attention=video.apply_gated_attention,
|
||||
)
|
||||
|
||||
# Q: Audio, K,V: Video
|
||||
@@ -852,11 +950,21 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
||||
dim_head=audio.d_head,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
apply_gated_attention=audio.apply_gated_attention,
|
||||
)
|
||||
|
||||
self.scale_shift_table_a2v_ca_audio = torch.nn.Parameter(torch.empty(5, audio.dim))
|
||||
self.scale_shift_table_a2v_ca_video = torch.nn.Parameter(torch.empty(5, video.dim))
|
||||
|
||||
self.cross_attention_adaln = (video is not None and video.cross_attention_adaln) or (
|
||||
audio is not None and audio.cross_attention_adaln
|
||||
)
|
||||
|
||||
if self.cross_attention_adaln and video is not None:
|
||||
self.prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, video.dim))
|
||||
if self.cross_attention_adaln and audio is not None:
|
||||
self.audio_prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, audio.dim))
|
||||
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
def get_ada_values(
|
||||
@@ -876,19 +984,49 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
||||
batch_size: int,
|
||||
scale_shift_timestep: torch.Tensor,
|
||||
gate_timestep: torch.Tensor,
|
||||
scale_shift_indices: slice,
|
||||
num_scale_shift_values: int = 4,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
scale_shift_ada_values = self.get_ada_values(
|
||||
scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, slice(None, None)
|
||||
scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, scale_shift_indices
|
||||
)
|
||||
gate_ada_values = self.get_ada_values(
|
||||
scale_shift_table[num_scale_shift_values:, :], batch_size, gate_timestep, slice(None, None)
|
||||
)
|
||||
|
||||
scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values]
|
||||
gate_ada_values = [t.squeeze(2) for t in gate_ada_values]
|
||||
scale, shift = (t.squeeze(2) for t in scale_shift_ada_values)
|
||||
(gate,) = (t.squeeze(2) for t in gate_ada_values)
|
||||
|
||||
return (*scale_shift_chunks, *gate_ada_values)
|
||||
return scale, shift, gate
|
||||
|
||||
def _apply_text_cross_attention(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
attn: Attention,
|
||||
scale_shift_table: torch.Tensor,
|
||||
prompt_scale_shift_table: torch.Tensor | None,
|
||||
timestep: torch.Tensor,
|
||||
prompt_timestep: torch.Tensor | None,
|
||||
context_mask: torch.Tensor | None,
|
||||
cross_attention_adaln: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Apply text cross-attention, with optional AdaLN modulation."""
|
||||
if cross_attention_adaln:
|
||||
shift_q, scale_q, gate = self.get_ada_values(scale_shift_table, x.shape[0], timestep, slice(6, 9))
|
||||
return apply_cross_attention_adaln(
|
||||
x,
|
||||
context,
|
||||
attn,
|
||||
shift_q,
|
||||
scale_q,
|
||||
gate,
|
||||
prompt_scale_shift_table,
|
||||
prompt_timestep,
|
||||
context_mask,
|
||||
self.norm_eps,
|
||||
)
|
||||
return attn(rms_norm(x, eps=self.norm_eps), context=context, mask=context_mask)
|
||||
|
||||
def forward( # noqa: PLR0915
|
||||
self,
|
||||
@@ -896,7 +1034,11 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
||||
audio: TransformerArgs | None,
|
||||
perturbations: BatchedPerturbationConfig | None = None,
|
||||
) -> tuple[TransformerArgs | None, TransformerArgs | None]:
|
||||
batch_size = video.x.shape[0]
|
||||
if video is None and audio is None:
|
||||
raise ValueError("At least one of video or audio must be provided")
|
||||
|
||||
batch_size = (video or audio).x.shape[0]
|
||||
|
||||
if perturbations is None:
|
||||
perturbations = BatchedPerturbationConfig.empty(batch_size)
|
||||
|
||||
@@ -913,63 +1055,103 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
||||
vshift_msa, vscale_msa, vgate_msa = self.get_ada_values(
|
||||
self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)
|
||||
)
|
||||
if not perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx):
|
||||
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
|
||||
v_mask = perturbations.mask_like(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx, vx)
|
||||
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa * v_mask
|
||||
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
|
||||
del vshift_msa, vscale_msa
|
||||
|
||||
vx = vx + self.attn2(rms_norm(vx, eps=self.norm_eps), context=video.context, mask=video.context_mask)
|
||||
|
||||
del vshift_msa, vscale_msa, vgate_msa
|
||||
all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx)
|
||||
none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx)
|
||||
v_mask = (
|
||||
perturbations.mask_like(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx, vx)
|
||||
if not all_perturbed and not none_perturbed
|
||||
else None
|
||||
)
|
||||
vx = (
|
||||
vx
|
||||
+ self.attn1(
|
||||
norm_vx,
|
||||
pe=video.positional_embeddings,
|
||||
mask=video.self_attention_mask,
|
||||
perturbation_mask=v_mask,
|
||||
all_perturbed=all_perturbed,
|
||||
)
|
||||
* vgate_msa
|
||||
)
|
||||
del vgate_msa, norm_vx, v_mask
|
||||
vx = vx + self._apply_text_cross_attention(
|
||||
vx,
|
||||
video.context,
|
||||
self.attn2,
|
||||
self.scale_shift_table,
|
||||
getattr(self, "prompt_scale_shift_table", None),
|
||||
video.timesteps,
|
||||
video.prompt_timestep,
|
||||
video.context_mask,
|
||||
cross_attention_adaln=self.cross_attention_adaln,
|
||||
)
|
||||
|
||||
if run_ax:
|
||||
ashift_msa, ascale_msa, agate_msa = self.get_ada_values(
|
||||
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
|
||||
)
|
||||
|
||||
if not perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx):
|
||||
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
|
||||
a_mask = perturbations.mask_like(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx, ax)
|
||||
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa * a_mask
|
||||
|
||||
ax = ax + self.audio_attn2(rms_norm(ax, eps=self.norm_eps), context=audio.context, mask=audio.context_mask)
|
||||
|
||||
del ashift_msa, ascale_msa, agate_msa
|
||||
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
|
||||
del ashift_msa, ascale_msa
|
||||
all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx)
|
||||
none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx)
|
||||
a_mask = (
|
||||
perturbations.mask_like(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx, ax)
|
||||
if not all_perturbed and not none_perturbed
|
||||
else None
|
||||
)
|
||||
ax = (
|
||||
ax
|
||||
+ self.audio_attn1(
|
||||
norm_ax,
|
||||
pe=audio.positional_embeddings,
|
||||
mask=audio.self_attention_mask,
|
||||
perturbation_mask=a_mask,
|
||||
all_perturbed=all_perturbed,
|
||||
)
|
||||
* agate_msa
|
||||
)
|
||||
del agate_msa, norm_ax, a_mask
|
||||
ax = ax + self._apply_text_cross_attention(
|
||||
ax,
|
||||
audio.context,
|
||||
self.audio_attn2,
|
||||
self.audio_scale_shift_table,
|
||||
getattr(self, "audio_prompt_scale_shift_table", None),
|
||||
audio.timesteps,
|
||||
audio.prompt_timestep,
|
||||
audio.context_mask,
|
||||
cross_attention_adaln=self.cross_attention_adaln,
|
||||
)
|
||||
|
||||
# Audio - Video cross attention.
|
||||
if run_a2v or run_v2a:
|
||||
vx_norm3 = rms_norm(vx, eps=self.norm_eps)
|
||||
ax_norm3 = rms_norm(ax, eps=self.norm_eps)
|
||||
|
||||
(
|
||||
scale_ca_audio_hidden_states_a2v,
|
||||
shift_ca_audio_hidden_states_a2v,
|
||||
scale_ca_audio_hidden_states_v2a,
|
||||
shift_ca_audio_hidden_states_v2a,
|
||||
gate_out_v2a,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio,
|
||||
ax.shape[0],
|
||||
audio.cross_scale_shift_timestep,
|
||||
audio.cross_gate_timestep,
|
||||
)
|
||||
if run_a2v and not perturbations.all_in_batch(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx):
|
||||
scale_ca_video_a2v, shift_ca_video_a2v, gate_out_a2v = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video,
|
||||
vx.shape[0],
|
||||
video.cross_scale_shift_timestep,
|
||||
video.cross_gate_timestep,
|
||||
slice(0, 2),
|
||||
)
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_a2v) + shift_ca_video_a2v
|
||||
del scale_ca_video_a2v, shift_ca_video_a2v
|
||||
|
||||
(
|
||||
scale_ca_video_hidden_states_a2v,
|
||||
shift_ca_video_hidden_states_a2v,
|
||||
scale_ca_video_hidden_states_v2a,
|
||||
shift_ca_video_hidden_states_v2a,
|
||||
gate_out_a2v,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video,
|
||||
vx.shape[0],
|
||||
video.cross_scale_shift_timestep,
|
||||
video.cross_gate_timestep,
|
||||
)
|
||||
|
||||
if run_a2v:
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v) + shift_ca_video_hidden_states_a2v
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
|
||||
scale_ca_audio_a2v, shift_ca_audio_a2v, _ = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio,
|
||||
ax.shape[0],
|
||||
audio.cross_scale_shift_timestep,
|
||||
audio.cross_gate_timestep,
|
||||
slice(0, 2),
|
||||
)
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_a2v) + shift_ca_audio_a2v
|
||||
del scale_ca_audio_a2v, shift_ca_audio_a2v
|
||||
a2v_mask = perturbations.mask_like(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx, vx)
|
||||
vx = vx + (
|
||||
self.audio_to_video_attn(
|
||||
@@ -981,10 +1163,27 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
||||
* gate_out_a2v
|
||||
* a2v_mask
|
||||
)
|
||||
del gate_out_a2v, a2v_mask, vx_scaled, ax_scaled
|
||||
|
||||
if run_v2a:
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
|
||||
if run_v2a and not perturbations.all_in_batch(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx):
|
||||
scale_ca_audio_v2a, shift_ca_audio_v2a, gate_out_v2a = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio,
|
||||
ax.shape[0],
|
||||
audio.cross_scale_shift_timestep,
|
||||
audio.cross_gate_timestep,
|
||||
slice(2, 4),
|
||||
)
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_v2a) + shift_ca_audio_v2a
|
||||
del scale_ca_audio_v2a, shift_ca_audio_v2a
|
||||
scale_ca_video_v2a, shift_ca_video_v2a, _ = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video,
|
||||
vx.shape[0],
|
||||
video.cross_scale_shift_timestep,
|
||||
video.cross_gate_timestep,
|
||||
slice(2, 4),
|
||||
)
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_v2a) + shift_ca_video_v2a
|
||||
del scale_ca_video_v2a, shift_ca_video_v2a
|
||||
v2a_mask = perturbations.mask_like(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx, ax)
|
||||
ax = ax + (
|
||||
self.video_to_audio_attn(
|
||||
@@ -996,40 +1195,53 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
||||
* gate_out_v2a
|
||||
* v2a_mask
|
||||
)
|
||||
del gate_out_v2a, v2a_mask, ax_scaled, vx_scaled
|
||||
|
||||
del gate_out_a2v, gate_out_v2a
|
||||
del (
|
||||
scale_ca_video_hidden_states_a2v,
|
||||
shift_ca_video_hidden_states_a2v,
|
||||
scale_ca_audio_hidden_states_a2v,
|
||||
shift_ca_audio_hidden_states_a2v,
|
||||
scale_ca_video_hidden_states_v2a,
|
||||
shift_ca_video_hidden_states_v2a,
|
||||
scale_ca_audio_hidden_states_v2a,
|
||||
shift_ca_audio_hidden_states_v2a,
|
||||
)
|
||||
del vx_norm3, ax_norm3
|
||||
|
||||
if run_vx:
|
||||
vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(
|
||||
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, None)
|
||||
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, 6)
|
||||
)
|
||||
vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
|
||||
vx = vx + self.ff(vx_scaled) * vgate_mlp
|
||||
|
||||
del vshift_mlp, vscale_mlp, vgate_mlp
|
||||
del vshift_mlp, vscale_mlp, vgate_mlp, vx_scaled
|
||||
|
||||
if run_ax:
|
||||
ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(
|
||||
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, None)
|
||||
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, 6)
|
||||
)
|
||||
ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
|
||||
ax = ax + self.audio_ff(ax_scaled) * agate_mlp
|
||||
|
||||
del ashift_mlp, ascale_mlp, agate_mlp
|
||||
del ashift_mlp, ascale_mlp, agate_mlp, ax_scaled
|
||||
|
||||
return replace(video, x=vx) if video is not None else None, replace(audio, x=ax) if audio is not None else None
|
||||
|
||||
|
||||
def apply_cross_attention_adaln(
|
||||
x: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
attn: Attention,
|
||||
q_shift: torch.Tensor,
|
||||
q_scale: torch.Tensor,
|
||||
q_gate: torch.Tensor,
|
||||
prompt_scale_shift_table: torch.Tensor,
|
||||
prompt_timestep: torch.Tensor,
|
||||
context_mask: torch.Tensor | None = None,
|
||||
norm_eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
batch_size = x.shape[0]
|
||||
shift_kv, scale_kv = (
|
||||
prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)
|
||||
+ prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1)
|
||||
).unbind(dim=2)
|
||||
attn_input = rms_norm(x, eps=norm_eps) * (1 + q_scale) + q_shift
|
||||
encoder_hidden_states = context * (1 + scale_kv) + shift_kv
|
||||
return attn(attn_input, context=encoder_hidden_states, mask=context_mask) * q_gate
|
||||
|
||||
|
||||
class GELUApprox(torch.nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int) -> None:
|
||||
super().__init__()
|
||||
@@ -1094,6 +1306,8 @@ class LTXModel(torch.nn.Module):
|
||||
av_ca_timestep_scale_multiplier: int = 1000,
|
||||
rope_type: LTXRopeType = LTXRopeType.SPLIT,
|
||||
double_precision_rope: bool = True,
|
||||
apply_gated_attention: bool = False,
|
||||
cross_attention_adaln: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self._enable_gradient_checkpointing = False
|
||||
@@ -1103,6 +1317,7 @@ class LTXModel(torch.nn.Module):
|
||||
self.timestep_scale_multiplier = timestep_scale_multiplier
|
||||
self.positional_embedding_theta = positional_embedding_theta
|
||||
self.model_type = model_type
|
||||
self.cross_attention_adaln = cross_attention_adaln
|
||||
cross_pe_max_pos = None
|
||||
if model_type.is_video_enabled():
|
||||
if positional_embedding_max_pos is None:
|
||||
@@ -1145,8 +1360,13 @@ class LTXModel(torch.nn.Module):
|
||||
audio_attention_head_dim=audio_attention_head_dim if model_type.is_audio_enabled() else 0,
|
||||
audio_cross_attention_dim=audio_cross_attention_dim,
|
||||
norm_eps=norm_eps,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
)
|
||||
|
||||
@property
|
||||
def _adaln_embedding_coefficient(self) -> int:
|
||||
return adaln_embedding_coefficient(self.cross_attention_adaln)
|
||||
|
||||
def _init_video(
|
||||
self,
|
||||
in_channels: int,
|
||||
@@ -1157,14 +1377,15 @@ class LTXModel(torch.nn.Module):
|
||||
"""Initialize video-specific components."""
|
||||
# Video input components
|
||||
self.patchify_proj = torch.nn.Linear(in_channels, self.inner_dim, bias=True)
|
||||
|
||||
self.adaln_single = AdaLayerNormSingle(self.inner_dim)
|
||||
self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=self._adaln_embedding_coefficient)
|
||||
self.prompt_adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None
|
||||
|
||||
# Video caption projection
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
)
|
||||
if caption_channels is not None:
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
)
|
||||
|
||||
# Video output components
|
||||
self.scale_shift_table = torch.nn.Parameter(torch.empty(2, self.inner_dim))
|
||||
@@ -1183,15 +1404,15 @@ class LTXModel(torch.nn.Module):
|
||||
# Audio input components
|
||||
self.audio_patchify_proj = torch.nn.Linear(in_channels, self.audio_inner_dim, bias=True)
|
||||
|
||||
self.audio_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
)
|
||||
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=self._adaln_embedding_coefficient)
|
||||
self.audio_prompt_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None
|
||||
|
||||
# Audio caption projection
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
)
|
||||
if caption_channels is not None:
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
)
|
||||
|
||||
# Audio output components
|
||||
self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(2, self.audio_inner_dim))
|
||||
@@ -1233,7 +1454,6 @@ class LTXModel(torch.nn.Module):
|
||||
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
||||
patchify_proj=self.patchify_proj,
|
||||
adaln=self.adaln_single,
|
||||
caption_projection=self.caption_projection,
|
||||
cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single,
|
||||
cross_gate_adaln=self.av_ca_a2v_gate_adaln_single,
|
||||
inner_dim=self.inner_dim,
|
||||
@@ -1247,11 +1467,12 @@ class LTXModel(torch.nn.Module):
|
||||
positional_embedding_theta=self.positional_embedding_theta,
|
||||
rope_type=self.rope_type,
|
||||
av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier,
|
||||
caption_projection=getattr(self, "caption_projection", None),
|
||||
prompt_adaln=getattr(self, "prompt_adaln_single", None),
|
||||
)
|
||||
self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
||||
patchify_proj=self.audio_patchify_proj,
|
||||
adaln=self.audio_adaln_single,
|
||||
caption_projection=self.audio_caption_projection,
|
||||
cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single,
|
||||
cross_gate_adaln=self.av_ca_v2a_gate_adaln_single,
|
||||
inner_dim=self.audio_inner_dim,
|
||||
@@ -1265,12 +1486,13 @@ class LTXModel(torch.nn.Module):
|
||||
positional_embedding_theta=self.positional_embedding_theta,
|
||||
rope_type=self.rope_type,
|
||||
av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier,
|
||||
caption_projection=getattr(self, "audio_caption_projection", None),
|
||||
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
|
||||
)
|
||||
elif self.model_type.is_video_enabled():
|
||||
self.video_args_preprocessor = TransformerArgsPreprocessor(
|
||||
patchify_proj=self.patchify_proj,
|
||||
adaln=self.adaln_single,
|
||||
caption_projection=self.caption_projection,
|
||||
inner_dim=self.inner_dim,
|
||||
max_pos=self.positional_embedding_max_pos,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
@@ -1279,12 +1501,13 @@ class LTXModel(torch.nn.Module):
|
||||
double_precision_rope=self.double_precision_rope,
|
||||
positional_embedding_theta=self.positional_embedding_theta,
|
||||
rope_type=self.rope_type,
|
||||
caption_projection=getattr(self, "caption_projection", None),
|
||||
prompt_adaln=getattr(self, "prompt_adaln_single", None),
|
||||
)
|
||||
elif self.model_type.is_audio_enabled():
|
||||
self.audio_args_preprocessor = TransformerArgsPreprocessor(
|
||||
patchify_proj=self.audio_patchify_proj,
|
||||
adaln=self.audio_adaln_single,
|
||||
caption_projection=self.audio_caption_projection,
|
||||
inner_dim=self.audio_inner_dim,
|
||||
max_pos=self.audio_positional_embedding_max_pos,
|
||||
num_attention_heads=self.audio_num_attention_heads,
|
||||
@@ -1293,6 +1516,8 @@ class LTXModel(torch.nn.Module):
|
||||
double_precision_rope=self.double_precision_rope,
|
||||
positional_embedding_theta=self.positional_embedding_theta,
|
||||
rope_type=self.rope_type,
|
||||
caption_projection=getattr(self, "audio_caption_projection", None),
|
||||
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
|
||||
)
|
||||
|
||||
def _init_transformer_blocks(
|
||||
@@ -1303,6 +1528,7 @@ class LTXModel(torch.nn.Module):
|
||||
audio_attention_head_dim: int,
|
||||
audio_cross_attention_dim: int,
|
||||
norm_eps: float,
|
||||
apply_gated_attention: bool,
|
||||
) -> None:
|
||||
"""Initialize transformer blocks for LTX."""
|
||||
video_config = (
|
||||
@@ -1311,6 +1537,8 @@ class LTXModel(torch.nn.Module):
|
||||
heads=self.num_attention_heads,
|
||||
d_head=attention_head_dim,
|
||||
context_dim=cross_attention_dim,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
cross_attention_adaln=self.cross_attention_adaln,
|
||||
)
|
||||
if self.model_type.is_video_enabled()
|
||||
else None
|
||||
@@ -1321,6 +1549,8 @@ class LTXModel(torch.nn.Module):
|
||||
heads=self.audio_num_attention_heads,
|
||||
d_head=audio_attention_head_dim,
|
||||
context_dim=audio_cross_attention_dim,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
cross_attention_adaln=self.cross_attention_adaln,
|
||||
)
|
||||
if self.model_type.is_audio_enabled()
|
||||
else None
|
||||
@@ -1409,8 +1639,8 @@ class LTXModel(torch.nn.Module):
|
||||
if not self.model_type.is_audio_enabled() and audio is not None:
|
||||
raise ValueError("Audio is not enabled for this model")
|
||||
|
||||
video_args = self.video_args_preprocessor.prepare(video) if video is not None else None
|
||||
audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None
|
||||
video_args = self.video_args_preprocessor.prepare(video, audio) if video is not None else None
|
||||
audio_args = self.audio_args_preprocessor.prepare(audio, video) if audio is not None else None
|
||||
# Process transformer blocks
|
||||
video_out, audio_out = self._process_transformer_blocks(
|
||||
video=video_args,
|
||||
@@ -1441,12 +1671,12 @@ class LTXModel(torch.nn.Module):
|
||||
)
|
||||
return vx, ax
|
||||
|
||||
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False):
|
||||
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps, sigma, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False):
|
||||
cross_pe_max_pos = None
|
||||
if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled():
|
||||
cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0])
|
||||
self._init_preprocessors(cross_pe_max_pos)
|
||||
video = Modality(video_latents, video_timesteps, video_positions, video_context)
|
||||
audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context) if audio_latents is not None else None
|
||||
video = Modality(video_latents, sigma, video_timesteps, video_positions, video_context)
|
||||
audio = Modality(audio_latents, sigma, audio_timesteps, audio_positions, audio_context) if audio_latents is not None else None
|
||||
vx, ax = self._forward(video=video, audio=audio, perturbations=None, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload)
|
||||
return vx, ax
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from transformers import Gemma3ForConditionalGeneration, Gemma3Config, AutoTokenizer
|
||||
from .ltx2_dit import (LTXRopeType, generate_freq_grid_np, generate_freq_grid_pytorch, precompute_freqs_cis, Attention,
|
||||
FeedForward)
|
||||
@@ -147,14 +150,14 @@ class LTXVGemmaTokenizer:
|
||||
return out
|
||||
|
||||
|
||||
class GemmaFeaturesExtractorProjLinear(torch.nn.Module):
|
||||
class GemmaFeaturesExtractorProjLinear(nn.Module):
|
||||
"""
|
||||
Feature extractor module for Gemma models.
|
||||
This module applies a single linear projection to the input tensor.
|
||||
It expects a flattened feature tensor of shape (batch_size, 3840*49).
|
||||
The linear layer maps this to a (batch_size, 3840) embedding.
|
||||
Attributes:
|
||||
aggregate_embed (torch.nn.Linear): Linear projection layer.
|
||||
aggregate_embed (nn.Linear): Linear projection layer.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
@@ -163,26 +166,65 @@ class GemmaFeaturesExtractorProjLinear(torch.nn.Module):
|
||||
The input dimension is expected to be 3840 * 49, and the output is 3840.
|
||||
"""
|
||||
super().__init__()
|
||||
self.aggregate_embed = torch.nn.Linear(3840 * 49, 3840, bias=False)
|
||||
self.aggregate_embed = nn.Linear(3840 * 49, 3840, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the feature extractor.
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor of shape (batch_size, 3840 * 49).
|
||||
Returns:
|
||||
torch.Tensor: Output tensor of shape (batch_size, 3840).
|
||||
"""
|
||||
return self.aggregate_embed(x)
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
padding_side: str = "left",
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
encoded = torch.stack(hidden_states, dim=-1) if isinstance(hidden_states, (list, tuple)) else hidden_states
|
||||
dtype = encoded.dtype
|
||||
sequence_lengths = attention_mask.sum(dim=-1)
|
||||
normed = _norm_and_concat_padded_batch(encoded, sequence_lengths, padding_side)
|
||||
features = self.aggregate_embed(normed.to(dtype))
|
||||
return features, features
|
||||
|
||||
|
||||
class _BasicTransformerBlock1D(torch.nn.Module):
|
||||
class GemmaSeperatedFeaturesExtractorProjLinear(nn.Module):
|
||||
"""22B: per-token RMS norm → rescale → dual aggregate embeds"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int,
|
||||
embedding_dim: int,
|
||||
video_inner_dim: int,
|
||||
audio_inner_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
in_dim = embedding_dim * num_layers
|
||||
self.video_aggregate_embed = torch.nn.Linear(in_dim, video_inner_dim, bias=True)
|
||||
self.audio_aggregate_embed = torch.nn.Linear(in_dim, audio_inner_dim, bias=True)
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
padding_side: str = "left", # noqa: ARG002
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
encoded = torch.stack(hidden_states, dim=-1) if isinstance(hidden_states, (list, tuple)) else hidden_states
|
||||
normed = norm_and_concat_per_token_rms(encoded, attention_mask)
|
||||
normed = normed.to(encoded.dtype)
|
||||
v_dim = self.video_aggregate_embed.out_features
|
||||
video = self.video_aggregate_embed(_rescale_norm(normed, v_dim, self.embedding_dim))
|
||||
audio = None
|
||||
if self.audio_aggregate_embed is not None:
|
||||
a_dim = self.audio_aggregate_embed.out_features
|
||||
audio = self.audio_aggregate_embed(_rescale_norm(normed, a_dim, self.embedding_dim))
|
||||
return video, audio
|
||||
|
||||
|
||||
|
||||
class _BasicTransformerBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
heads: int,
|
||||
dim_head: int,
|
||||
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||
apply_gated_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -191,6 +233,7 @@ class _BasicTransformerBlock1D(torch.nn.Module):
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
)
|
||||
|
||||
self.ff = FeedForward(
|
||||
@@ -231,7 +274,7 @@ class _BasicTransformerBlock1D(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Embeddings1DConnector(torch.nn.Module):
|
||||
class Embeddings1DConnector(nn.Module):
|
||||
"""
|
||||
Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or
|
||||
other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can
|
||||
@@ -263,6 +306,7 @@ class Embeddings1DConnector(torch.nn.Module):
|
||||
num_learnable_registers: int | None = 128,
|
||||
rope_type: LTXRopeType = LTXRopeType.SPLIT,
|
||||
double_precision_rope: bool = True,
|
||||
apply_gated_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
@@ -274,13 +318,14 @@ class Embeddings1DConnector(torch.nn.Module):
|
||||
)
|
||||
self.rope_type = rope_type
|
||||
self.double_precision_rope = double_precision_rope
|
||||
self.transformer_1d_blocks = torch.nn.ModuleList(
|
||||
self.transformer_1d_blocks = nn.ModuleList(
|
||||
[
|
||||
_BasicTransformerBlock1D(
|
||||
dim=self.inner_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
@@ -288,7 +333,7 @@ class Embeddings1DConnector(torch.nn.Module):
|
||||
|
||||
self.num_learnable_registers = num_learnable_registers
|
||||
if self.num_learnable_registers:
|
||||
self.learnable_registers = torch.nn.Parameter(
|
||||
self.learnable_registers = nn.Parameter(
|
||||
torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0
|
||||
)
|
||||
|
||||
@@ -307,7 +352,7 @@ class Embeddings1DConnector(torch.nn.Module):
|
||||
non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :]
|
||||
non_zero_nums = non_zero_hidden_states.shape[1]
|
||||
pad_length = hidden_states.shape[1] - non_zero_nums
|
||||
adjusted_hidden_states = torch.nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0)
|
||||
adjusted_hidden_states = nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0)
|
||||
flipped_mask = torch.flip(attention_mask_binary, dims=[1])
|
||||
hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers
|
||||
|
||||
@@ -358,9 +403,147 @@ class Embeddings1DConnector(torch.nn.Module):
|
||||
return hidden_states, attention_mask
|
||||
|
||||
|
||||
class LTX2TextEncoderPostModules(torch.nn.Module):
|
||||
def __init__(self,):
|
||||
class LTX2TextEncoderPostModules(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
seperated_audio_video: bool = False,
|
||||
embedding_dim_gemma: int = 3840,
|
||||
num_layers_gemma: int = 49,
|
||||
video_attetion_heads: int = 32,
|
||||
video_attention_head_dim: int = 128,
|
||||
audio_attention_heads: int = 32,
|
||||
audio_attention_head_dim: int = 64,
|
||||
num_connetor_layers: int = 2,
|
||||
apply_gated_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear()
|
||||
self.embeddings_connector = Embeddings1DConnector()
|
||||
self.audio_embeddings_connector = Embeddings1DConnector()
|
||||
if not seperated_audio_video:
|
||||
self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear()
|
||||
self.embeddings_connector = Embeddings1DConnector()
|
||||
self.audio_embeddings_connector = Embeddings1DConnector()
|
||||
else:
|
||||
# LTX-2.3
|
||||
self.feature_extractor_linear = GemmaSeperatedFeaturesExtractorProjLinear(
|
||||
num_layers_gemma, embedding_dim_gemma, video_attetion_heads * video_attention_head_dim,
|
||||
audio_attention_heads * audio_attention_head_dim)
|
||||
self.embeddings_connector = Embeddings1DConnector(
|
||||
attention_head_dim=video_attention_head_dim,
|
||||
num_attention_heads=video_attetion_heads,
|
||||
num_layers=num_connetor_layers,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
)
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
attention_head_dim=audio_attention_head_dim,
|
||||
num_attention_heads=audio_attention_heads,
|
||||
num_layers=num_connetor_layers,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
)
|
||||
|
||||
def create_embeddings(
|
||||
self,
|
||||
video_features: torch.Tensor,
|
||||
audio_features: torch.Tensor | None,
|
||||
additive_attention_mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]:
|
||||
video_encoded, video_mask = self.embeddings_connector(video_features, additive_attention_mask)
|
||||
video_encoded, binary_mask = _to_binary_mask(video_encoded, video_mask)
|
||||
audio_encoded, _ = self.audio_embeddings_connector(audio_features, additive_attention_mask)
|
||||
|
||||
return video_encoded, audio_encoded, binary_mask
|
||||
|
||||
def process_hidden_states(
|
||||
self,
|
||||
hidden_states: tuple[torch.Tensor, ...],
|
||||
attention_mask: torch.Tensor,
|
||||
padding_side: str = "left",
|
||||
):
|
||||
video_feats, audio_feats = self.feature_extractor_linear(hidden_states, attention_mask, padding_side)
|
||||
additive_mask = _convert_to_additive_mask(attention_mask, video_feats.dtype)
|
||||
video_enc, audio_enc, binary_mask = self.create_embeddings(video_feats, audio_feats, additive_mask)
|
||||
return video_enc, audio_enc, binary_mask
|
||||
|
||||
|
||||
def _norm_and_concat_padded_batch(
|
||||
encoded_text: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
padding_side: str = "right",
|
||||
) -> torch.Tensor:
|
||||
"""Normalize and flatten multi-layer hidden states, respecting padding.
|
||||
Performs per-batch, per-layer normalization using masked mean and range,
|
||||
then concatenates across the layer dimension.
|
||||
Args:
|
||||
encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].
|
||||
sequence_lengths: Number of valid (non-padded) tokens per batch item.
|
||||
padding_side: Whether padding is on "left" or "right".
|
||||
Returns:
|
||||
Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],
|
||||
with padded positions zeroed out.
|
||||
"""
|
||||
b, t, d, l = encoded_text.shape # noqa: E741
|
||||
device = encoded_text.device
|
||||
# Build mask: [B, T, 1, 1]
|
||||
token_indices = torch.arange(t, device=device)[None, :] # [1, T]
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [B, T]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = t - sequence_lengths[:, None] # [B, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = rearrange(mask, "b t -> b t 1 1")
|
||||
eps = 1e-6
|
||||
# Compute masked mean: [B, 1, 1, L]
|
||||
masked = encoded_text.masked_fill(~mask, 0.0)
|
||||
denom = (sequence_lengths * d).view(b, 1, 1, 1)
|
||||
mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps)
|
||||
# Compute masked min/max: [B, 1, 1, L]
|
||||
x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
range_ = x_max - x_min
|
||||
# Normalize only the valid tokens
|
||||
normed = 8 * (encoded_text - mean) / (range_ + eps)
|
||||
# concat to be [Batch, T, D * L] - this preserves the original structure
|
||||
normed = normed.reshape(b, t, -1) # [B, T, D * L]
|
||||
# Apply mask to preserve original padding (set padded positions to 0)
|
||||
mask_flattened = rearrange(mask, "b t 1 1 -> b t 1").expand(-1, -1, d * l)
|
||||
normed = normed.masked_fill(~mask_flattened, 0.0)
|
||||
|
||||
return normed
|
||||
|
||||
|
||||
def _convert_to_additive_mask(attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
return (attention_mask - 1).to(dtype).reshape(
|
||||
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(dtype).max
|
||||
|
||||
def _to_binary_mask(encoded: torch.Tensor, encoded_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert connector output mask to binary mask and apply to encoded tensor."""
|
||||
binary_mask = (encoded_mask < 0.000001).to(torch.int64)
|
||||
binary_mask = binary_mask.reshape([encoded.shape[0], encoded.shape[1], 1])
|
||||
encoded = encoded * binary_mask
|
||||
return encoded, binary_mask
|
||||
|
||||
|
||||
def norm_and_concat_per_token_rms(
|
||||
encoded_text: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Per-token RMSNorm normalization for V2 models.
|
||||
Args:
|
||||
encoded_text: [B, T, D, L]
|
||||
attention_mask: [B, T] binary mask
|
||||
Returns:
|
||||
[B, T, D*L] normalized tensor with padding zeroed out.
|
||||
"""
|
||||
B, T, D, L = encoded_text.shape # noqa: N806
|
||||
variance = torch.mean(encoded_text**2, dim=2, keepdim=True) # [B,T,1,L]
|
||||
normed = encoded_text * torch.rsqrt(variance + 1e-6)
|
||||
normed = normed.reshape(B, T, D * L)
|
||||
mask_3d = attention_mask.bool().unsqueeze(-1) # [B, T, 1]
|
||||
return torch.where(mask_3d, normed, torch.zeros_like(normed))
|
||||
|
||||
|
||||
def _rescale_norm(x: torch.Tensor, target_dim: int, source_dim: int) -> torch.Tensor:
|
||||
"""Rescale normalization: x * sqrt(target_dim / source_dim)."""
|
||||
return x * math.sqrt(target_dim / source_dim)
|
||||
|
||||
@@ -555,9 +555,6 @@ class PerChannelStatistics(nn.Module):
|
||||
super().__init__()
|
||||
self.register_buffer("std-of-means", torch.empty(latent_channels))
|
||||
self.register_buffer("mean-of-means", torch.empty(latent_channels))
|
||||
self.register_buffer("mean-of-stds", torch.empty(latent_channels))
|
||||
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(latent_channels))
|
||||
self.register_buffer("channel", torch.empty(latent_channels))
|
||||
|
||||
def un_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(
|
||||
@@ -1335,27 +1332,49 @@ class LTX2VideoEncoder(nn.Module):
|
||||
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
||||
latent_log_var: LogVarianceType = LogVarianceType.UNIFORM,
|
||||
encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
encoder_version: str = "ltx-2",
|
||||
):
|
||||
super().__init__()
|
||||
encoder_blocks = [['res_x', {
|
||||
'num_layers': 4
|
||||
}], ['compress_space_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 6
|
||||
}], ['compress_time_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 6
|
||||
}], ['compress_all_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 2
|
||||
}], ['compress_all_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 2
|
||||
}]]
|
||||
if encoder_version == "ltx-2":
|
||||
encoder_blocks = [['res_x', {
|
||||
'num_layers': 4
|
||||
}], ['compress_space_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 6
|
||||
}], ['compress_time_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 6
|
||||
}], ['compress_all_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 2
|
||||
}], ['compress_all_res', {
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 2
|
||||
}]]
|
||||
else:
|
||||
encoder_blocks = [["res_x", {
|
||||
"num_layers": 4
|
||||
}], ["compress_space_res", {
|
||||
"multiplier": 2
|
||||
}], ["res_x", {
|
||||
"num_layers": 6
|
||||
}], ["compress_time_res", {
|
||||
"multiplier": 2
|
||||
}], ["res_x", {
|
||||
"num_layers": 4
|
||||
}], ["compress_all_res", {
|
||||
"multiplier": 2
|
||||
}], ["res_x", {
|
||||
"num_layers": 2
|
||||
}], ["compress_all_res", {
|
||||
"multiplier": 1
|
||||
}], ["res_x", {
|
||||
"num_layers": 2
|
||||
}]]
|
||||
self.patch_size = patch_size
|
||||
self.norm_layer = norm_layer
|
||||
self.latent_channels = out_channels
|
||||
@@ -1435,8 +1454,8 @@ class LTX2VideoEncoder(nn.Module):
|
||||
# Validate frame count
|
||||
frames_count = sample.shape[2]
|
||||
if ((frames_count - 1) % 8) != 0:
|
||||
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames "
|
||||
"(e.g., 1, 9, 17, ...). Please check your input.")
|
||||
frames_to_crop = (frames_count - 1) % 8
|
||||
sample = sample[:, :, :-frames_to_crop, ...]
|
||||
|
||||
# Initial spatial compression: trade spatial resolution for channel depth
|
||||
# This reduces H,W by patch_size and increases channels, making convolutions more efficient
|
||||
@@ -1712,17 +1731,21 @@ def _make_decoder_block(
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
out_channels = in_channels // block_config.get("multiplier", 1)
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
stride=(2, 1, 1),
|
||||
out_channels_reduction_factor=block_config.get("multiplier", 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
out_channels = in_channels // block_config.get("multiplier", 1)
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
stride=(1, 2, 2),
|
||||
out_channels_reduction_factor=block_config.get("multiplier", 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
@@ -1782,6 +1805,8 @@ class LTX2VideoDecoder(nn.Module):
|
||||
causal: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||
decoder_version: str = "ltx-2",
|
||||
base_channels: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -1790,28 +1815,49 @@ class LTX2VideoDecoder(nn.Module):
|
||||
# video inputs by a factor of 8 in the temporal dimension and 32 in
|
||||
# each spatial dimension (height and width). This parameter determines how
|
||||
# many video frames and pixels correspond to a single latent cell.
|
||||
decoder_blocks = [['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}], ['compress_all', {
|
||||
'residual': True,
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}], ['compress_all', {
|
||||
'residual': True,
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}], ['compress_all', {
|
||||
'residual': True,
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}]]
|
||||
if decoder_version == "ltx-2":
|
||||
decoder_blocks = [['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}], ['compress_all', {
|
||||
'residual': True,
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}], ['compress_all', {
|
||||
'residual': True,
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}], ['compress_all', {
|
||||
'residual': True,
|
||||
'multiplier': 2
|
||||
}], ['res_x', {
|
||||
'num_layers': 5,
|
||||
'inject_noise': False
|
||||
}]]
|
||||
else:
|
||||
decoder_blocks = [["res_x", {
|
||||
"num_layers": 4
|
||||
}], ["compress_space", {
|
||||
"multiplier": 2
|
||||
}], ["res_x", {
|
||||
"num_layers": 6
|
||||
}], ["compress_time", {
|
||||
"multiplier": 2
|
||||
}], ["res_x", {
|
||||
"num_layers": 4
|
||||
}], ["compress_all", {
|
||||
"multiplier": 1
|
||||
}], ["res_x", {
|
||||
"num_layers": 2
|
||||
}], ["compress_all", {
|
||||
"multiplier": 2
|
||||
}], ["res_x", {
|
||||
"num_layers": 2
|
||||
}]]
|
||||
self.video_downscale_factors = SpatioTemporalScaleFactors(
|
||||
time=8,
|
||||
width=32,
|
||||
@@ -1833,13 +1879,14 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
# Compute initial feature_channels by going through blocks in reverse
|
||||
# This determines the channel width at the start of the decoder
|
||||
feature_channels = in_channels
|
||||
for block_name, block_params in list(reversed(decoder_blocks)):
|
||||
block_config = block_params if isinstance(block_params, dict) else {}
|
||||
if block_name == "res_x_y":
|
||||
feature_channels = feature_channels * block_config.get("multiplier", 2)
|
||||
if block_name == "compress_all":
|
||||
feature_channels = feature_channels * block_config.get("multiplier", 1)
|
||||
# feature_channels = in_channels
|
||||
# for block_name, block_params in list(reversed(decoder_blocks)):
|
||||
# block_config = block_params if isinstance(block_params, dict) else {}
|
||||
# if block_name == "res_x_y":
|
||||
# feature_channels = feature_channels * block_config.get("multiplier", 2)
|
||||
# if block_name == "compress_all":
|
||||
# feature_channels = feature_channels * block_config.get("multiplier", 1)
|
||||
feature_channels = base_channels * 8
|
||||
|
||||
self.conv_in = make_conv_nd(
|
||||
dims=convolution_dimensions,
|
||||
|
||||
Reference in New Issue
Block a user