support ltx2.3 inference

This commit is contained in:
mi804
2026-03-06 16:07:17 +08:00
parent c5aaa1da41
commit 73b13f4c86
17 changed files with 1608 additions and 351 deletions

View File

@@ -225,6 +225,17 @@ class BatchedPerturbationConfig:
return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)])
ADALN_NUM_BASE_PARAMS = 6
# Cross-attention AdaLN adds 3 more (scale, shift, gate) for the CA norm.
ADALN_NUM_CROSS_ATTN_PARAMS = 3
def adaln_embedding_coefficient(cross_attention_adaln: bool) -> int:
"""Total number of AdaLN parameters per block."""
return ADALN_NUM_BASE_PARAMS + (ADALN_NUM_CROSS_ATTN_PARAMS if cross_attention_adaln else 0)
class AdaLayerNormSingle(torch.nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
@@ -460,6 +471,7 @@ class Attention(torch.nn.Module):
dim_head: int = 64,
norm_eps: float = 1e-6,
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
apply_gated_attention: bool = False,
) -> None:
super().__init__()
self.rope_type = rope_type
@@ -477,6 +489,12 @@ class Attention(torch.nn.Module):
self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True)
self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True)
# Optional per-head gating
if apply_gated_attention:
self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True)
else:
self.to_gate_logits = None
self.to_out = torch.nn.Sequential(torch.nn.Linear(inner_dim, query_dim, bias=True), torch.nn.Identity())
def forward(
@@ -486,6 +504,8 @@ class Attention(torch.nn.Module):
mask: torch.Tensor | None = None,
pe: torch.Tensor | None = None,
k_pe: torch.Tensor | None = None,
perturbation_mask: torch.Tensor | None = None,
all_perturbed: bool = False,
) -> torch.Tensor:
q = self.to_q(x)
context = x if context is None else context
@@ -517,6 +537,19 @@ class Attention(torch.nn.Module):
# Reshape back to original format
out = out.flatten(2, 3)
# Apply per-head gating if enabled
if self.to_gate_logits is not None:
gate_logits = self.to_gate_logits(x) # (B, T, H)
b, t, _ = out.shape
# Reshape to (B, T, H, D) for per-head gating
out = out.view(b, t, self.heads, self.dim_head)
# Apply gating: 2 * sigmoid(x) so that zero-init gives identity (2 * 0.5 = 1.0)
gates = 2.0 * torch.sigmoid(gate_logits) # (B, T, H)
out = out * gates.unsqueeze(-1) # (B, T, H, D) * (B, T, H, 1)
# Reshape back to (B, T, H*D)
out = out.view(b, t, self.heads * self.dim_head)
return self.to_out(out)
@@ -545,7 +578,6 @@ class PixArtAlphaTextProjection(torch.nn.Module):
hidden_states = self.linear_2(hidden_states)
return hidden_states
@dataclass(frozen=True)
class TransformerArgs:
x: torch.Tensor
@@ -558,7 +590,10 @@ class TransformerArgs:
cross_scale_shift_timestep: torch.Tensor | None
cross_gate_timestep: torch.Tensor | None
enabled: bool
prompt_timestep: torch.Tensor | None = None
self_attention_mask: torch.Tensor | None = (
None # Additive log-space self-attention bias (B, 1, T, T), None = full attention
)
class TransformerArgsPreprocessor:
@@ -566,7 +601,6 @@ class TransformerArgsPreprocessor:
self,
patchify_proj: torch.nn.Linear,
adaln: AdaLayerNormSingle,
caption_projection: PixArtAlphaTextProjection,
inner_dim: int,
max_pos: list[int],
num_attention_heads: int,
@@ -575,10 +609,11 @@ class TransformerArgsPreprocessor:
double_precision_rope: bool,
positional_embedding_theta: float,
rope_type: LTXRopeType,
caption_projection: torch.nn.Module | None = None,
prompt_adaln: AdaLayerNormSingle | None = None,
) -> None:
self.patchify_proj = patchify_proj
self.adaln = adaln
self.caption_projection = caption_projection
self.inner_dim = inner_dim
self.max_pos = max_pos
self.num_attention_heads = num_attention_heads
@@ -587,18 +622,18 @@ class TransformerArgsPreprocessor:
self.double_precision_rope = double_precision_rope
self.positional_embedding_theta = positional_embedding_theta
self.rope_type = rope_type
self.caption_projection = caption_projection
self.prompt_adaln = prompt_adaln
def _prepare_timestep(
self, timestep: torch.Tensor, batch_size: int, hidden_dtype: torch.dtype
self, timestep: torch.Tensor, adaln: AdaLayerNormSingle, batch_size: int, hidden_dtype: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
"""Prepare timestep embeddings."""
timestep = timestep * self.timestep_scale_multiplier
timestep, embedded_timestep = self.adaln(
timestep.flatten(),
timestep_scaled = timestep * self.timestep_scale_multiplier
timestep, embedded_timestep = adaln(
timestep_scaled.flatten(),
hidden_dtype=hidden_dtype,
)
# Second dimension is 1 or number of tokens (if timestep_per_token)
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
@@ -608,14 +643,12 @@ class TransformerArgsPreprocessor:
self,
context: torch.Tensor,
x: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
) -> torch.Tensor:
"""Prepare context for transformer blocks."""
if self.caption_projection is not None:
context = self.caption_projection(context)
batch_size = x.shape[0]
context = self.caption_projection(context)
context = context.view(batch_size, -1, x.shape[-1])
return context, attention_mask
return context.view(batch_size, -1, x.shape[-1])
def _prepare_attention_mask(self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype) -> torch.Tensor | None:
"""Prepare attention mask."""
@@ -626,6 +659,34 @@ class TransformerArgsPreprocessor:
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
) * torch.finfo(x_dtype).max
def _prepare_self_attention_mask(
self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype
) -> torch.Tensor | None:
"""Prepare self-attention mask by converting [0,1] values to additive log-space bias.
Input shape: (B, T, T) with values in [0, 1].
Output shape: (B, 1, T, T) with 0.0 for full attention and a large negative value
for masked positions.
Positions with attention_mask <= 0 are fully masked (mapped to the dtype's minimum
representable value). Strictly positive entries are converted via log-space for
smooth attenuation, with small values clamped for numerical stability.
Returns None if input is None (no masking).
"""
if attention_mask is None:
return None
# Convert [0, 1] attention mask to additive log-space bias:
# 1.0 -> log(1.0) = 0.0 (no bias, full attention)
# 0.0 -> finfo.min (fully masked)
finfo = torch.finfo(x_dtype)
eps = finfo.tiny
bias = torch.full_like(attention_mask, finfo.min, dtype=x_dtype)
positive = attention_mask > 0
if positive.any():
bias[positive] = torch.log(attention_mask[positive].clamp(min=eps)).to(x_dtype)
return bias.unsqueeze(1) # (B, 1, T, T) for head broadcast
def _prepare_positional_embeddings(
self,
positions: torch.Tensor,
@@ -653,11 +714,20 @@ class TransformerArgsPreprocessor:
def prepare(
self,
modality: Modality,
cross_modality: Modality | None = None, # noqa: ARG002
) -> TransformerArgs:
x = self.patchify_proj(modality.latent)
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], modality.latent.dtype)
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
batch_size = x.shape[0]
timestep, embedded_timestep = self._prepare_timestep(
modality.timesteps, self.adaln, batch_size, modality.latent.dtype
)
prompt_timestep = None
if self.prompt_adaln is not None:
prompt_timestep, _ = self._prepare_timestep(
modality.sigma, self.prompt_adaln, batch_size, modality.latent.dtype
)
context = self._prepare_context(modality.context, x)
attention_mask = self._prepare_attention_mask(modality.context_mask, modality.latent.dtype)
pe = self._prepare_positional_embeddings(
positions=modality.positions,
inner_dim=self.inner_dim,
@@ -666,6 +736,7 @@ class TransformerArgsPreprocessor:
num_attention_heads=self.num_attention_heads,
x_dtype=modality.latent.dtype,
)
self_attention_mask = self._prepare_self_attention_mask(modality.attention_mask, modality.latent.dtype)
return TransformerArgs(
x=x,
context=context,
@@ -677,6 +748,8 @@ class TransformerArgsPreprocessor:
cross_scale_shift_timestep=None,
cross_gate_timestep=None,
enabled=modality.enabled,
prompt_timestep=prompt_timestep,
self_attention_mask=self_attention_mask,
)
@@ -685,7 +758,6 @@ class MultiModalTransformerArgsPreprocessor:
self,
patchify_proj: torch.nn.Linear,
adaln: AdaLayerNormSingle,
caption_projection: PixArtAlphaTextProjection,
cross_scale_shift_adaln: AdaLayerNormSingle,
cross_gate_adaln: AdaLayerNormSingle,
inner_dim: int,
@@ -699,11 +771,12 @@ class MultiModalTransformerArgsPreprocessor:
positional_embedding_theta: float,
rope_type: LTXRopeType,
av_ca_timestep_scale_multiplier: int,
caption_projection: torch.nn.Module | None = None,
prompt_adaln: AdaLayerNormSingle | None = None,
) -> None:
self.simple_preprocessor = TransformerArgsPreprocessor(
patchify_proj=patchify_proj,
adaln=adaln,
caption_projection=caption_projection,
inner_dim=inner_dim,
max_pos=max_pos,
num_attention_heads=num_attention_heads,
@@ -712,6 +785,8 @@ class MultiModalTransformerArgsPreprocessor:
double_precision_rope=double_precision_rope,
positional_embedding_theta=positional_embedding_theta,
rope_type=rope_type,
caption_projection=caption_projection,
prompt_adaln=prompt_adaln,
)
self.cross_scale_shift_adaln = cross_scale_shift_adaln
self.cross_gate_adaln = cross_gate_adaln
@@ -722,8 +797,22 @@ class MultiModalTransformerArgsPreprocessor:
def prepare(
self,
modality: Modality,
cross_modality: Modality | None = None,
) -> TransformerArgs:
transformer_args = self.simple_preprocessor.prepare(modality)
if cross_modality is None:
return transformer_args
if cross_modality.sigma.numel() > 1:
if cross_modality.sigma.shape[0] != modality.timesteps.shape[0]:
raise ValueError("Cross modality sigma must have the same batch size as the modality")
if cross_modality.sigma.ndim != 1:
raise ValueError("Cross modality sigma must be a 1D tensor")
cross_timestep = cross_modality.sigma.view(
modality.timesteps.shape[0], 1, *[1] * len(modality.timesteps.shape[2:])
)
cross_pe = self.simple_preprocessor._prepare_positional_embeddings(
positions=modality.positions[:, 0:1, :],
inner_dim=self.audio_cross_attention_dim,
@@ -734,7 +823,7 @@ class MultiModalTransformerArgsPreprocessor:
)
cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep(
timestep=modality.timesteps,
timestep=cross_timestep,
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
batch_size=transformer_args.x.shape[0],
hidden_dtype=modality.latent.dtype,
@@ -749,7 +838,7 @@ class MultiModalTransformerArgsPreprocessor:
def _prepare_cross_attention_timestep(
self,
timestep: torch.Tensor,
timestep: torch.Tensor | None,
timestep_scale_multiplier: int,
batch_size: int,
hidden_dtype: torch.dtype,
@@ -779,6 +868,8 @@ class TransformerConfig:
heads: int
d_head: int
context_dim: int
apply_gated_attention: bool = False
cross_attention_adaln: bool = False
class BasicAVTransformerBlock(torch.nn.Module):
@@ -801,6 +892,7 @@ class BasicAVTransformerBlock(torch.nn.Module):
context_dim=None,
rope_type=rope_type,
norm_eps=norm_eps,
apply_gated_attention=video.apply_gated_attention,
)
self.attn2 = Attention(
query_dim=video.dim,
@@ -809,9 +901,11 @@ class BasicAVTransformerBlock(torch.nn.Module):
dim_head=video.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
apply_gated_attention=video.apply_gated_attention,
)
self.ff = FeedForward(video.dim, dim_out=video.dim)
self.scale_shift_table = torch.nn.Parameter(torch.empty(6, video.dim))
video_sst_size = adaln_embedding_coefficient(video.cross_attention_adaln)
self.scale_shift_table = torch.nn.Parameter(torch.empty(video_sst_size, video.dim))
if audio is not None:
self.audio_attn1 = Attention(
@@ -821,6 +915,7 @@ class BasicAVTransformerBlock(torch.nn.Module):
context_dim=None,
rope_type=rope_type,
norm_eps=norm_eps,
apply_gated_attention=audio.apply_gated_attention,
)
self.audio_attn2 = Attention(
query_dim=audio.dim,
@@ -829,9 +924,11 @@ class BasicAVTransformerBlock(torch.nn.Module):
dim_head=audio.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
apply_gated_attention=audio.apply_gated_attention,
)
self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim)
self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(6, audio.dim))
audio_sst_size = adaln_embedding_coefficient(audio.cross_attention_adaln)
self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(audio_sst_size, audio.dim))
if audio is not None and video is not None:
# Q: Video, K,V: Audio
@@ -842,6 +939,7 @@ class BasicAVTransformerBlock(torch.nn.Module):
dim_head=audio.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
apply_gated_attention=video.apply_gated_attention,
)
# Q: Audio, K,V: Video
@@ -852,11 +950,21 @@ class BasicAVTransformerBlock(torch.nn.Module):
dim_head=audio.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
apply_gated_attention=audio.apply_gated_attention,
)
self.scale_shift_table_a2v_ca_audio = torch.nn.Parameter(torch.empty(5, audio.dim))
self.scale_shift_table_a2v_ca_video = torch.nn.Parameter(torch.empty(5, video.dim))
self.cross_attention_adaln = (video is not None and video.cross_attention_adaln) or (
audio is not None and audio.cross_attention_adaln
)
if self.cross_attention_adaln and video is not None:
self.prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, video.dim))
if self.cross_attention_adaln and audio is not None:
self.audio_prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, audio.dim))
self.norm_eps = norm_eps
def get_ada_values(
@@ -876,19 +984,49 @@ class BasicAVTransformerBlock(torch.nn.Module):
batch_size: int,
scale_shift_timestep: torch.Tensor,
gate_timestep: torch.Tensor,
scale_shift_indices: slice,
num_scale_shift_values: int = 4,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
scale_shift_ada_values = self.get_ada_values(
scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, slice(None, None)
scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, scale_shift_indices
)
gate_ada_values = self.get_ada_values(
scale_shift_table[num_scale_shift_values:, :], batch_size, gate_timestep, slice(None, None)
)
scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values]
gate_ada_values = [t.squeeze(2) for t in gate_ada_values]
scale, shift = (t.squeeze(2) for t in scale_shift_ada_values)
(gate,) = (t.squeeze(2) for t in gate_ada_values)
return (*scale_shift_chunks, *gate_ada_values)
return scale, shift, gate
def _apply_text_cross_attention(
self,
x: torch.Tensor,
context: torch.Tensor,
attn: Attention,
scale_shift_table: torch.Tensor,
prompt_scale_shift_table: torch.Tensor | None,
timestep: torch.Tensor,
prompt_timestep: torch.Tensor | None,
context_mask: torch.Tensor | None,
cross_attention_adaln: bool = False,
) -> torch.Tensor:
"""Apply text cross-attention, with optional AdaLN modulation."""
if cross_attention_adaln:
shift_q, scale_q, gate = self.get_ada_values(scale_shift_table, x.shape[0], timestep, slice(6, 9))
return apply_cross_attention_adaln(
x,
context,
attn,
shift_q,
scale_q,
gate,
prompt_scale_shift_table,
prompt_timestep,
context_mask,
self.norm_eps,
)
return attn(rms_norm(x, eps=self.norm_eps), context=context, mask=context_mask)
def forward( # noqa: PLR0915
self,
@@ -896,7 +1034,11 @@ class BasicAVTransformerBlock(torch.nn.Module):
audio: TransformerArgs | None,
perturbations: BatchedPerturbationConfig | None = None,
) -> tuple[TransformerArgs | None, TransformerArgs | None]:
batch_size = video.x.shape[0]
if video is None and audio is None:
raise ValueError("At least one of video or audio must be provided")
batch_size = (video or audio).x.shape[0]
if perturbations is None:
perturbations = BatchedPerturbationConfig.empty(batch_size)
@@ -913,63 +1055,103 @@ class BasicAVTransformerBlock(torch.nn.Module):
vshift_msa, vscale_msa, vgate_msa = self.get_ada_values(
self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)
)
if not perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx):
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
v_mask = perturbations.mask_like(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx, vx)
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa * v_mask
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
del vshift_msa, vscale_msa
vx = vx + self.attn2(rms_norm(vx, eps=self.norm_eps), context=video.context, mask=video.context_mask)
del vshift_msa, vscale_msa, vgate_msa
all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx)
none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx)
v_mask = (
perturbations.mask_like(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx, vx)
if not all_perturbed and not none_perturbed
else None
)
vx = (
vx
+ self.attn1(
norm_vx,
pe=video.positional_embeddings,
mask=video.self_attention_mask,
perturbation_mask=v_mask,
all_perturbed=all_perturbed,
)
* vgate_msa
)
del vgate_msa, norm_vx, v_mask
vx = vx + self._apply_text_cross_attention(
vx,
video.context,
self.attn2,
self.scale_shift_table,
getattr(self, "prompt_scale_shift_table", None),
video.timesteps,
video.prompt_timestep,
video.context_mask,
cross_attention_adaln=self.cross_attention_adaln,
)
if run_ax:
ashift_msa, ascale_msa, agate_msa = self.get_ada_values(
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
)
if not perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx):
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
a_mask = perturbations.mask_like(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx, ax)
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa * a_mask
ax = ax + self.audio_attn2(rms_norm(ax, eps=self.norm_eps), context=audio.context, mask=audio.context_mask)
del ashift_msa, ascale_msa, agate_msa
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
del ashift_msa, ascale_msa
all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx)
none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx)
a_mask = (
perturbations.mask_like(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx, ax)
if not all_perturbed and not none_perturbed
else None
)
ax = (
ax
+ self.audio_attn1(
norm_ax,
pe=audio.positional_embeddings,
mask=audio.self_attention_mask,
perturbation_mask=a_mask,
all_perturbed=all_perturbed,
)
* agate_msa
)
del agate_msa, norm_ax, a_mask
ax = ax + self._apply_text_cross_attention(
ax,
audio.context,
self.audio_attn2,
self.audio_scale_shift_table,
getattr(self, "audio_prompt_scale_shift_table", None),
audio.timesteps,
audio.prompt_timestep,
audio.context_mask,
cross_attention_adaln=self.cross_attention_adaln,
)
# Audio - Video cross attention.
if run_a2v or run_v2a:
vx_norm3 = rms_norm(vx, eps=self.norm_eps)
ax_norm3 = rms_norm(ax, eps=self.norm_eps)
(
scale_ca_audio_hidden_states_a2v,
shift_ca_audio_hidden_states_a2v,
scale_ca_audio_hidden_states_v2a,
shift_ca_audio_hidden_states_v2a,
gate_out_v2a,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_audio,
ax.shape[0],
audio.cross_scale_shift_timestep,
audio.cross_gate_timestep,
)
if run_a2v and not perturbations.all_in_batch(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx):
scale_ca_video_a2v, shift_ca_video_a2v, gate_out_a2v = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_video,
vx.shape[0],
video.cross_scale_shift_timestep,
video.cross_gate_timestep,
slice(0, 2),
)
vx_scaled = vx_norm3 * (1 + scale_ca_video_a2v) + shift_ca_video_a2v
del scale_ca_video_a2v, shift_ca_video_a2v
(
scale_ca_video_hidden_states_a2v,
shift_ca_video_hidden_states_a2v,
scale_ca_video_hidden_states_v2a,
shift_ca_video_hidden_states_v2a,
gate_out_a2v,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_video,
vx.shape[0],
video.cross_scale_shift_timestep,
video.cross_gate_timestep,
)
if run_a2v:
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v) + shift_ca_video_hidden_states_a2v
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
scale_ca_audio_a2v, shift_ca_audio_a2v, _ = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_audio,
ax.shape[0],
audio.cross_scale_shift_timestep,
audio.cross_gate_timestep,
slice(0, 2),
)
ax_scaled = ax_norm3 * (1 + scale_ca_audio_a2v) + shift_ca_audio_a2v
del scale_ca_audio_a2v, shift_ca_audio_a2v
a2v_mask = perturbations.mask_like(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx, vx)
vx = vx + (
self.audio_to_video_attn(
@@ -981,10 +1163,27 @@ class BasicAVTransformerBlock(torch.nn.Module):
* gate_out_a2v
* a2v_mask
)
del gate_out_a2v, a2v_mask, vx_scaled, ax_scaled
if run_v2a:
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
if run_v2a and not perturbations.all_in_batch(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx):
scale_ca_audio_v2a, shift_ca_audio_v2a, gate_out_v2a = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_audio,
ax.shape[0],
audio.cross_scale_shift_timestep,
audio.cross_gate_timestep,
slice(2, 4),
)
ax_scaled = ax_norm3 * (1 + scale_ca_audio_v2a) + shift_ca_audio_v2a
del scale_ca_audio_v2a, shift_ca_audio_v2a
scale_ca_video_v2a, shift_ca_video_v2a, _ = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_video,
vx.shape[0],
video.cross_scale_shift_timestep,
video.cross_gate_timestep,
slice(2, 4),
)
vx_scaled = vx_norm3 * (1 + scale_ca_video_v2a) + shift_ca_video_v2a
del scale_ca_video_v2a, shift_ca_video_v2a
v2a_mask = perturbations.mask_like(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx, ax)
ax = ax + (
self.video_to_audio_attn(
@@ -996,40 +1195,53 @@ class BasicAVTransformerBlock(torch.nn.Module):
* gate_out_v2a
* v2a_mask
)
del gate_out_v2a, v2a_mask, ax_scaled, vx_scaled
del gate_out_a2v, gate_out_v2a
del (
scale_ca_video_hidden_states_a2v,
shift_ca_video_hidden_states_a2v,
scale_ca_audio_hidden_states_a2v,
shift_ca_audio_hidden_states_a2v,
scale_ca_video_hidden_states_v2a,
shift_ca_video_hidden_states_v2a,
scale_ca_audio_hidden_states_v2a,
shift_ca_audio_hidden_states_v2a,
)
del vx_norm3, ax_norm3
if run_vx:
vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, None)
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, 6)
)
vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
vx = vx + self.ff(vx_scaled) * vgate_mlp
del vshift_mlp, vscale_mlp, vgate_mlp
del vshift_mlp, vscale_mlp, vgate_mlp, vx_scaled
if run_ax:
ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, None)
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, 6)
)
ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
ax = ax + self.audio_ff(ax_scaled) * agate_mlp
del ashift_mlp, ascale_mlp, agate_mlp
del ashift_mlp, ascale_mlp, agate_mlp, ax_scaled
return replace(video, x=vx) if video is not None else None, replace(audio, x=ax) if audio is not None else None
def apply_cross_attention_adaln(
x: torch.Tensor,
context: torch.Tensor,
attn: Attention,
q_shift: torch.Tensor,
q_scale: torch.Tensor,
q_gate: torch.Tensor,
prompt_scale_shift_table: torch.Tensor,
prompt_timestep: torch.Tensor,
context_mask: torch.Tensor | None = None,
norm_eps: float = 1e-6,
) -> torch.Tensor:
batch_size = x.shape[0]
shift_kv, scale_kv = (
prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)
+ prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1)
).unbind(dim=2)
attn_input = rms_norm(x, eps=norm_eps) * (1 + q_scale) + q_shift
encoder_hidden_states = context * (1 + scale_kv) + shift_kv
return attn(attn_input, context=encoder_hidden_states, mask=context_mask) * q_gate
class GELUApprox(torch.nn.Module):
def __init__(self, dim_in: int, dim_out: int) -> None:
super().__init__()
@@ -1094,6 +1306,8 @@ class LTXModel(torch.nn.Module):
av_ca_timestep_scale_multiplier: int = 1000,
rope_type: LTXRopeType = LTXRopeType.SPLIT,
double_precision_rope: bool = True,
apply_gated_attention: bool = False,
cross_attention_adaln: bool = False,
):
super().__init__()
self._enable_gradient_checkpointing = False
@@ -1103,6 +1317,7 @@ class LTXModel(torch.nn.Module):
self.timestep_scale_multiplier = timestep_scale_multiplier
self.positional_embedding_theta = positional_embedding_theta
self.model_type = model_type
self.cross_attention_adaln = cross_attention_adaln
cross_pe_max_pos = None
if model_type.is_video_enabled():
if positional_embedding_max_pos is None:
@@ -1145,8 +1360,13 @@ class LTXModel(torch.nn.Module):
audio_attention_head_dim=audio_attention_head_dim if model_type.is_audio_enabled() else 0,
audio_cross_attention_dim=audio_cross_attention_dim,
norm_eps=norm_eps,
apply_gated_attention=apply_gated_attention,
)
@property
def _adaln_embedding_coefficient(self) -> int:
return adaln_embedding_coefficient(self.cross_attention_adaln)
def _init_video(
self,
in_channels: int,
@@ -1157,14 +1377,15 @@ class LTXModel(torch.nn.Module):
"""Initialize video-specific components."""
# Video input components
self.patchify_proj = torch.nn.Linear(in_channels, self.inner_dim, bias=True)
self.adaln_single = AdaLayerNormSingle(self.inner_dim)
self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=self._adaln_embedding_coefficient)
self.prompt_adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None
# Video caption projection
self.caption_projection = PixArtAlphaTextProjection(
in_features=caption_channels,
hidden_size=self.inner_dim,
)
if caption_channels is not None:
self.caption_projection = PixArtAlphaTextProjection(
in_features=caption_channels,
hidden_size=self.inner_dim,
)
# Video output components
self.scale_shift_table = torch.nn.Parameter(torch.empty(2, self.inner_dim))
@@ -1183,15 +1404,15 @@ class LTXModel(torch.nn.Module):
# Audio input components
self.audio_patchify_proj = torch.nn.Linear(in_channels, self.audio_inner_dim, bias=True)
self.audio_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
)
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=self._adaln_embedding_coefficient)
self.audio_prompt_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None
# Audio caption projection
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=caption_channels,
hidden_size=self.audio_inner_dim,
)
if caption_channels is not None:
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=caption_channels,
hidden_size=self.audio_inner_dim,
)
# Audio output components
self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(2, self.audio_inner_dim))
@@ -1233,7 +1454,6 @@ class LTXModel(torch.nn.Module):
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
patchify_proj=self.patchify_proj,
adaln=self.adaln_single,
caption_projection=self.caption_projection,
cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single,
cross_gate_adaln=self.av_ca_a2v_gate_adaln_single,
inner_dim=self.inner_dim,
@@ -1247,11 +1467,12 @@ class LTXModel(torch.nn.Module):
positional_embedding_theta=self.positional_embedding_theta,
rope_type=self.rope_type,
av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier,
caption_projection=getattr(self, "caption_projection", None),
prompt_adaln=getattr(self, "prompt_adaln_single", None),
)
self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor(
patchify_proj=self.audio_patchify_proj,
adaln=self.audio_adaln_single,
caption_projection=self.audio_caption_projection,
cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single,
cross_gate_adaln=self.av_ca_v2a_gate_adaln_single,
inner_dim=self.audio_inner_dim,
@@ -1265,12 +1486,13 @@ class LTXModel(torch.nn.Module):
positional_embedding_theta=self.positional_embedding_theta,
rope_type=self.rope_type,
av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier,
caption_projection=getattr(self, "audio_caption_projection", None),
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
)
elif self.model_type.is_video_enabled():
self.video_args_preprocessor = TransformerArgsPreprocessor(
patchify_proj=self.patchify_proj,
adaln=self.adaln_single,
caption_projection=self.caption_projection,
inner_dim=self.inner_dim,
max_pos=self.positional_embedding_max_pos,
num_attention_heads=self.num_attention_heads,
@@ -1279,12 +1501,13 @@ class LTXModel(torch.nn.Module):
double_precision_rope=self.double_precision_rope,
positional_embedding_theta=self.positional_embedding_theta,
rope_type=self.rope_type,
caption_projection=getattr(self, "caption_projection", None),
prompt_adaln=getattr(self, "prompt_adaln_single", None),
)
elif self.model_type.is_audio_enabled():
self.audio_args_preprocessor = TransformerArgsPreprocessor(
patchify_proj=self.audio_patchify_proj,
adaln=self.audio_adaln_single,
caption_projection=self.audio_caption_projection,
inner_dim=self.audio_inner_dim,
max_pos=self.audio_positional_embedding_max_pos,
num_attention_heads=self.audio_num_attention_heads,
@@ -1293,6 +1516,8 @@ class LTXModel(torch.nn.Module):
double_precision_rope=self.double_precision_rope,
positional_embedding_theta=self.positional_embedding_theta,
rope_type=self.rope_type,
caption_projection=getattr(self, "audio_caption_projection", None),
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
)
def _init_transformer_blocks(
@@ -1303,6 +1528,7 @@ class LTXModel(torch.nn.Module):
audio_attention_head_dim: int,
audio_cross_attention_dim: int,
norm_eps: float,
apply_gated_attention: bool,
) -> None:
"""Initialize transformer blocks for LTX."""
video_config = (
@@ -1311,6 +1537,8 @@ class LTXModel(torch.nn.Module):
heads=self.num_attention_heads,
d_head=attention_head_dim,
context_dim=cross_attention_dim,
apply_gated_attention=apply_gated_attention,
cross_attention_adaln=self.cross_attention_adaln,
)
if self.model_type.is_video_enabled()
else None
@@ -1321,6 +1549,8 @@ class LTXModel(torch.nn.Module):
heads=self.audio_num_attention_heads,
d_head=audio_attention_head_dim,
context_dim=audio_cross_attention_dim,
apply_gated_attention=apply_gated_attention,
cross_attention_adaln=self.cross_attention_adaln,
)
if self.model_type.is_audio_enabled()
else None
@@ -1409,8 +1639,8 @@ class LTXModel(torch.nn.Module):
if not self.model_type.is_audio_enabled() and audio is not None:
raise ValueError("Audio is not enabled for this model")
video_args = self.video_args_preprocessor.prepare(video) if video is not None else None
audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None
video_args = self.video_args_preprocessor.prepare(video, audio) if video is not None else None
audio_args = self.audio_args_preprocessor.prepare(audio, video) if audio is not None else None
# Process transformer blocks
video_out, audio_out = self._process_transformer_blocks(
video=video_args,
@@ -1441,12 +1671,12 @@ class LTXModel(torch.nn.Module):
)
return vx, ax
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False):
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps, sigma, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False):
cross_pe_max_pos = None
if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled():
cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0])
self._init_preprocessors(cross_pe_max_pos)
video = Modality(video_latents, video_timesteps, video_positions, video_context)
audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context) if audio_latents is not None else None
video = Modality(video_latents, sigma, video_timesteps, video_positions, video_context)
audio = Modality(audio_latents, sigma, audio_timesteps, audio_positions, audio_context) if audio_latents is not None else None
vx, ax = self._forward(video=video, audio=audio, perturbations=None, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload)
return vx, ax