import math import functools from dataclasses import dataclass, replace from enum import Enum from typing import Optional, Tuple, Callable import numpy as np import torch from einops import rearrange from .ltx2_common import rms_norm, Modality from ..core.attention.attention import attention_forward from ..core import gradient_checkpoint_forward def get_timestep_embedding( timesteps: torch.Tensor, embedding_dim: int, flip_sin_to_cos: bool = False, downscale_freq_shift: float = 1, scale: float = 1, max_period: int = 10000, ) -> torch.Tensor: """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. Args timesteps (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. embedding_dim (int): the dimension of the output. flip_sin_to_cos (bool): Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) downscale_freq_shift (float): Controls the delta between frequencies between dimensions scale (float): Scaling factor applied to the embeddings. max_period (int): Controls the maximum frequency of the embeddings Returns torch.Tensor: an [N x dim] Tensor of positional embeddings. """ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) emb = timesteps[:, None].float() * emb[None, :] # scale embeddings emb = scale * emb # concat sine and cosine embeddings emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # flip sine and cosine embeddings if flip_sin_to_cos: emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) # zero pad if embedding_dim % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb class TimestepEmbedding(torch.nn.Module): def __init__( self, in_channels: int, time_embed_dim: int, out_dim: int | None = None, post_act_fn: str | None = None, cond_proj_dim: int | None = None, sample_proj_bias: bool = True, ): super().__init__() self.linear_1 = torch.nn.Linear(in_channels, time_embed_dim, sample_proj_bias) if cond_proj_dim is not None: self.cond_proj = torch.nn.Linear(cond_proj_dim, in_channels, bias=False) else: self.cond_proj = None self.act = torch.nn.SiLU() time_embed_dim_out = out_dim if out_dim is not None else time_embed_dim self.linear_2 = torch.nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) if post_act_fn is None: self.post_act = None def forward(self, sample: torch.Tensor, condition: torch.Tensor | None = None) -> torch.Tensor: if condition is not None: sample = sample + self.cond_proj(condition) sample = self.linear_1(sample) if self.act is not None: sample = self.act(sample) sample = self.linear_2(sample) if self.post_act is not None: sample = self.post_act(sample) return sample class Timesteps(torch.nn.Module): def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): super().__init__() self.num_channels = num_channels self.flip_sin_to_cos = flip_sin_to_cos self.downscale_freq_shift = downscale_freq_shift self.scale = scale def forward(self, timesteps: torch.Tensor) -> torch.Tensor: t_emb = get_timestep_embedding( timesteps, self.num_channels, flip_sin_to_cos=self.flip_sin_to_cos, downscale_freq_shift=self.downscale_freq_shift, scale=self.scale, ) return t_emb class PixArtAlphaCombinedTimestepSizeEmbeddings(torch.nn.Module): """ For PixArt-Alpha. Reference: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 """ def __init__( self, embedding_dim: int, size_emb_dim: int, ): super().__init__() self.outdim = size_emb_dim self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) def forward( self, timestep: torch.Tensor, hidden_dtype: torch.dtype, ) -> torch.Tensor: timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) return timesteps_emb class PerturbationType(Enum): """Types of attention perturbations for STG (Spatio-Temporal Guidance).""" SKIP_A2V_CROSS_ATTN = "skip_a2v_cross_attn" SKIP_V2A_CROSS_ATTN = "skip_v2a_cross_attn" SKIP_VIDEO_SELF_ATTN = "skip_video_self_attn" SKIP_AUDIO_SELF_ATTN = "skip_audio_self_attn" @dataclass(frozen=True) class Perturbation: """A single perturbation specifying which attention type to skip and in which blocks.""" type: PerturbationType blocks: list[int] | None # None means all blocks def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool: if self.type != perturbation_type: return False if self.blocks is None: return True return block in self.blocks @dataclass(frozen=True) class PerturbationConfig: """Configuration holding a list of perturbations for a single sample.""" perturbations: list[Perturbation] | None def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool: if self.perturbations is None: return False return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations) @staticmethod def empty() -> "PerturbationConfig": return PerturbationConfig([]) @dataclass(frozen=True) class BatchedPerturbationConfig: """Perturbation configurations for a batch, with utilities for generating attention masks.""" perturbations: list[PerturbationConfig] def mask( self, perturbation_type: PerturbationType, block: int, device, dtype: torch.dtype ) -> torch.Tensor: mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype) for batch_idx, perturbation in enumerate(self.perturbations): if perturbation.is_perturbed(perturbation_type, block): mask[batch_idx] = 0 return mask def mask_like(self, perturbation_type: PerturbationType, block: int, values: torch.Tensor) -> torch.Tensor: mask = self.mask(perturbation_type, block, values.device, values.dtype) return mask.view(mask.numel(), *([1] * len(values.shape[1:]))) def any_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool: return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations) def all_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool: return all(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations) @staticmethod def empty(batch_size: int) -> "BatchedPerturbationConfig": return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)]) ADALN_NUM_BASE_PARAMS = 6 # Cross-attention AdaLN adds 3 more (scale, shift, gate) for the CA norm. ADALN_NUM_CROSS_ATTN_PARAMS = 3 def adaln_embedding_coefficient(cross_attention_adaln: bool) -> int: """Total number of AdaLN parameters per block.""" return ADALN_NUM_BASE_PARAMS + (ADALN_NUM_CROSS_ATTN_PARAMS if cross_attention_adaln else 0) class AdaLayerNormSingle(torch.nn.Module): r""" Norm layer adaptive layer norm single (adaLN-single). As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). Parameters: embedding_dim (`int`): The size of each embedding vector. use_additional_conditions (`bool`): To use additional conditions for normalization or not. """ def __init__(self, embedding_dim: int, embedding_coefficient: int = 6): super().__init__() self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( embedding_dim, size_emb_dim=embedding_dim // 3, ) self.silu = torch.nn.SiLU() self.linear = torch.nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True) def forward( self, timestep: torch.Tensor, hidden_dtype: Optional[torch.dtype] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: embedded_timestep = self.emb(timestep, hidden_dtype=hidden_dtype) return self.linear(self.silu(embedded_timestep)), embedded_timestep class LTXRopeType(Enum): INTERLEAVED = "interleaved" SPLIT = "split" def apply_rotary_emb( input_tensor: torch.Tensor, freqs_cis: Tuple[torch.Tensor, torch.Tensor], rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, ) -> torch.Tensor: if rope_type == LTXRopeType.INTERLEAVED: return apply_interleaved_rotary_emb(input_tensor, *freqs_cis) elif rope_type == LTXRopeType.SPLIT: return apply_split_rotary_emb(input_tensor, *freqs_cis) else: raise ValueError(f"Invalid rope type: {rope_type}") def apply_interleaved_rotary_emb( input_tensor: torch.Tensor, cos_freqs: torch.Tensor, sin_freqs: torch.Tensor ) -> torch.Tensor: t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) t1, t2 = t_dup.unbind(dim=-1) t_dup = torch.stack((-t2, t1), dim=-1) input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs return out def apply_split_rotary_emb( input_tensor: torch.Tensor, cos_freqs: torch.Tensor, sin_freqs: torch.Tensor ) -> torch.Tensor: needs_reshape = False if input_tensor.ndim != 4 and cos_freqs.ndim == 4: b, h, t, _ = cos_freqs.shape input_tensor = input_tensor.reshape(b, t, h, -1).swapaxes(1, 2) needs_reshape = True split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2) first_half_input = split_input[..., :1, :] second_half_input = split_input[..., 1:, :] output = split_input * cos_freqs.unsqueeze(-2) first_half_output = output[..., :1, :] second_half_output = output[..., 1:, :] first_half_output.addcmul_(-sin_freqs.unsqueeze(-2), second_half_input) second_half_output.addcmul_(sin_freqs.unsqueeze(-2), first_half_input) output = rearrange(output, "... d r -> ... (d r)") if needs_reshape: output = output.swapaxes(1, 2).reshape(b, t, -1) return output @functools.lru_cache(maxsize=5) def generate_freq_grid_np( positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int ) -> torch.Tensor: theta = positional_embedding_theta start = 1 end = theta n_elem = 2 * positional_embedding_max_pos_count pow_indices = np.power( theta, np.linspace( np.log(start) / np.log(theta), np.log(end) / np.log(theta), inner_dim // n_elem, dtype=np.float64, ), ) return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32) @functools.lru_cache(maxsize=5) def generate_freq_grid_pytorch( positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int ) -> torch.Tensor: theta = positional_embedding_theta start = 1 end = theta n_elem = 2 * positional_embedding_max_pos_count indices = theta ** ( torch.linspace( math.log(start, theta), math.log(end, theta), inner_dim // n_elem, dtype=torch.float32, ) ) indices = indices.to(dtype=torch.float32) indices = indices * math.pi / 2 return indices def get_fractional_positions(indices_grid: torch.Tensor, max_pos: list[int]) -> torch.Tensor: n_pos_dims = indices_grid.shape[1] assert n_pos_dims == len(max_pos), ( f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})" ) fractional_positions = torch.stack( [indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)], dim=-1, ) return fractional_positions def generate_freqs( indices: torch.Tensor, indices_grid: torch.Tensor, max_pos: list[int], use_middle_indices_grid: bool ) -> torch.Tensor: if use_middle_indices_grid: assert len(indices_grid.shape) == 4 assert indices_grid.shape[-1] == 2 indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1] indices_grid = (indices_grid_start + indices_grid_end) / 2.0 elif len(indices_grid.shape) == 4: indices_grid = indices_grid[..., 0] fractional_positions = get_fractional_positions(indices_grid, max_pos) indices = indices.to(device=fractional_positions.device) freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2) return freqs def split_freqs_cis(freqs: torch.Tensor, pad_size: int, num_attention_heads: int) -> tuple[torch.Tensor, torch.Tensor]: cos_freq = freqs.cos() sin_freq = freqs.sin() if pad_size != 0: cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) # Reshape freqs to be compatible with multi-head attention b = cos_freq.shape[0] t = cos_freq.shape[1] cos_freq = cos_freq.reshape(b, t, num_attention_heads, -1) sin_freq = sin_freq.reshape(b, t, num_attention_heads, -1) cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) return cos_freq, sin_freq def interleaved_freqs_cis(freqs: torch.Tensor, pad_size: int) -> tuple[torch.Tensor, torch.Tensor]: cos_freq = freqs.cos().repeat_interleave(2, dim=-1) sin_freq = freqs.sin().repeat_interleave(2, dim=-1) if pad_size != 0: cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) sin_padding = torch.zeros_like(cos_freq[:, :, :pad_size]) cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) return cos_freq, sin_freq def precompute_freqs_cis( indices_grid: torch.Tensor, dim: int, out_dtype: torch.dtype, theta: float = 10000.0, max_pos: list[int] | None = None, use_middle_indices_grid: bool = False, num_attention_heads: int = 32, rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, freq_grid_generator: Callable[[float, int, int, torch.device], torch.Tensor] = generate_freq_grid_pytorch, ) -> tuple[torch.Tensor, torch.Tensor]: if max_pos is None: max_pos = [20, 2048, 2048] indices = freq_grid_generator(theta, indices_grid.shape[1], dim) freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid) if rope_type == LTXRopeType.SPLIT: expected_freqs = dim // 2 current_freqs = freqs.shape[-1] pad_size = expected_freqs - current_freqs cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads) else: # 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only n_elem = 2 * indices_grid.shape[1] cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) return cos_freq.to(out_dtype), sin_freq.to(out_dtype) class Attention(torch.nn.Module): def __init__( self, query_dim: int, context_dim: int | None = None, heads: int = 8, dim_head: int = 64, norm_eps: float = 1e-6, rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, apply_gated_attention: bool = False, ) -> None: super().__init__() self.rope_type = rope_type inner_dim = dim_head * heads context_dim = query_dim if context_dim is None else context_dim self.heads = heads self.dim_head = dim_head self.q_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps) self.k_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps) self.to_q = torch.nn.Linear(query_dim, inner_dim, bias=True) self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True) self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True) # Optional per-head gating if apply_gated_attention: self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True) else: self.to_gate_logits = None self.to_out = torch.nn.Sequential(torch.nn.Linear(inner_dim, query_dim, bias=True), torch.nn.Identity()) def forward( self, x: torch.Tensor, context: torch.Tensor | None = None, mask: torch.Tensor | None = None, pe: torch.Tensor | None = None, k_pe: torch.Tensor | None = None, perturbation_mask: torch.Tensor | None = None, all_perturbed: bool = False, ) -> torch.Tensor: q = self.to_q(x) context = x if context is None else context k = self.to_k(context) v = self.to_v(context) q = self.q_norm(q) k = self.k_norm(k) if pe is not None: q = apply_rotary_emb(q, pe, self.rope_type) k = apply_rotary_emb(k, pe if k_pe is None else k_pe, self.rope_type) # Reshape for attention_forward using unflatten q = q.unflatten(-1, (self.heads, self.dim_head)) k = k.unflatten(-1, (self.heads, self.dim_head)) v = v.unflatten(-1, (self.heads, self.dim_head)) out = attention_forward( q=q, k=k, v=v, q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d", attn_mask=mask ) # Reshape back to original format out = out.flatten(2, 3) # Apply per-head gating if enabled if self.to_gate_logits is not None: gate_logits = self.to_gate_logits(x) # (B, T, H) b, t, _ = out.shape # Reshape to (B, T, H, D) for per-head gating out = out.view(b, t, self.heads, self.dim_head) # Apply gating: 2 * sigmoid(x) so that zero-init gives identity (2 * 0.5 = 1.0) gates = 2.0 * torch.sigmoid(gate_logits) # (B, T, H) out = out * gates.unsqueeze(-1) # (B, T, H, D) * (B, T, H, 1) # Reshape back to (B, T, H*D) out = out.view(b, t, self.heads * self.dim_head) return self.to_out(out) class PixArtAlphaTextProjection(torch.nn.Module): """ Projects caption embeddings. Also handles dropout for classifier-free guidance. Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py """ def __init__(self, in_features: int, hidden_size: int, out_features: int | None = None, act_fn: str = "gelu_tanh"): super().__init__() if out_features is None: out_features = hidden_size self.linear_1 = torch.nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) if act_fn == "gelu_tanh": self.act_1 = torch.nn.GELU(approximate="tanh") elif act_fn == "silu": self.act_1 = torch.nn.SiLU() else: raise ValueError(f"Unknown activation function: {act_fn}") self.linear_2 = torch.nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) def forward(self, caption: torch.Tensor) -> torch.Tensor: hidden_states = self.linear_1(caption) hidden_states = self.act_1(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states @dataclass(frozen=True) class TransformerArgs: x: torch.Tensor context: torch.Tensor context_mask: torch.Tensor timesteps: torch.Tensor embedded_timestep: torch.Tensor positional_embeddings: torch.Tensor cross_positional_embeddings: torch.Tensor | None cross_scale_shift_timestep: torch.Tensor | None cross_gate_timestep: torch.Tensor | None enabled: bool prompt_timestep: torch.Tensor | None = None self_attention_mask: torch.Tensor | None = ( None # Additive log-space self-attention bias (B, 1, T, T), None = full attention ) class TransformerArgsPreprocessor: def __init__( # noqa: PLR0913 self, patchify_proj: torch.nn.Linear, adaln: AdaLayerNormSingle, inner_dim: int, max_pos: list[int], num_attention_heads: int, use_middle_indices_grid: bool, timestep_scale_multiplier: int, double_precision_rope: bool, positional_embedding_theta: float, rope_type: LTXRopeType, caption_projection: torch.nn.Module | None = None, prompt_adaln: AdaLayerNormSingle | None = None, ) -> None: self.patchify_proj = patchify_proj self.adaln = adaln self.inner_dim = inner_dim self.max_pos = max_pos self.num_attention_heads = num_attention_heads self.use_middle_indices_grid = use_middle_indices_grid self.timestep_scale_multiplier = timestep_scale_multiplier self.double_precision_rope = double_precision_rope self.positional_embedding_theta = positional_embedding_theta self.rope_type = rope_type self.caption_projection = caption_projection self.prompt_adaln = prompt_adaln def _prepare_timestep( self, timestep: torch.Tensor, adaln: AdaLayerNormSingle, batch_size: int, hidden_dtype: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: """Prepare timestep embeddings.""" timestep_scaled = timestep * self.timestep_scale_multiplier timestep, embedded_timestep = adaln( timestep_scaled.flatten(), hidden_dtype=hidden_dtype, ) # Second dimension is 1 or number of tokens (if timestep_per_token) timestep = timestep.view(batch_size, -1, timestep.shape[-1]) embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) return timestep, embedded_timestep def _prepare_context( self, context: torch.Tensor, x: torch.Tensor, ) -> torch.Tensor: """Prepare context for transformer blocks.""" if self.caption_projection is not None: context = self.caption_projection(context) batch_size = x.shape[0] return context.view(batch_size, -1, x.shape[-1]) def _prepare_attention_mask(self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype) -> torch.Tensor | None: """Prepare attention mask.""" if attention_mask is None or torch.is_floating_point(attention_mask): return attention_mask return (attention_mask - 1).to(x_dtype).reshape( (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) ) * torch.finfo(x_dtype).max def _prepare_self_attention_mask( self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype ) -> torch.Tensor | None: """Prepare self-attention mask by converting [0,1] values to additive log-space bias. Input shape: (B, T, T) with values in [0, 1]. Output shape: (B, 1, T, T) with 0.0 for full attention and a large negative value for masked positions. Positions with attention_mask <= 0 are fully masked (mapped to the dtype's minimum representable value). Strictly positive entries are converted via log-space for smooth attenuation, with small values clamped for numerical stability. Returns None if input is None (no masking). """ if attention_mask is None: return None # Convert [0, 1] attention mask to additive log-space bias: # 1.0 -> log(1.0) = 0.0 (no bias, full attention) # 0.0 -> finfo.min (fully masked) finfo = torch.finfo(x_dtype) eps = finfo.tiny bias = torch.full_like(attention_mask, finfo.min, dtype=x_dtype) positive = attention_mask > 0 if positive.any(): bias[positive] = torch.log(attention_mask[positive].clamp(min=eps)).to(x_dtype) return bias.unsqueeze(1) # (B, 1, T, T) for head broadcast def _prepare_positional_embeddings( self, positions: torch.Tensor, inner_dim: int, max_pos: list[int], use_middle_indices_grid: bool, num_attention_heads: int, x_dtype: torch.dtype, ) -> torch.Tensor: """Prepare positional embeddings.""" freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch pe = precompute_freqs_cis( positions, dim=inner_dim, out_dtype=x_dtype, theta=self.positional_embedding_theta, max_pos=max_pos, use_middle_indices_grid=use_middle_indices_grid, num_attention_heads=num_attention_heads, rope_type=self.rope_type, freq_grid_generator=freq_grid_generator, ) return pe def prepare( self, modality: Modality, cross_modality: Modality | None = None, # noqa: ARG002 ) -> TransformerArgs: x = self.patchify_proj(modality.latent) batch_size = x.shape[0] timestep, embedded_timestep = self._prepare_timestep( modality.timesteps, self.adaln, batch_size, modality.latent.dtype ) prompt_timestep = None if self.prompt_adaln is not None: prompt_timestep, _ = self._prepare_timestep( modality.sigma, self.prompt_adaln, batch_size, modality.latent.dtype ) context = self._prepare_context(modality.context, x) attention_mask = self._prepare_attention_mask(modality.context_mask, modality.latent.dtype) pe = self._prepare_positional_embeddings( positions=modality.positions, inner_dim=self.inner_dim, max_pos=self.max_pos, use_middle_indices_grid=self.use_middle_indices_grid, num_attention_heads=self.num_attention_heads, x_dtype=modality.latent.dtype, ) self_attention_mask = self._prepare_self_attention_mask(modality.attention_mask, modality.latent.dtype) return TransformerArgs( x=x, context=context, context_mask=attention_mask, timesteps=timestep, embedded_timestep=embedded_timestep, positional_embeddings=pe, cross_positional_embeddings=None, cross_scale_shift_timestep=None, cross_gate_timestep=None, enabled=modality.enabled, prompt_timestep=prompt_timestep, self_attention_mask=self_attention_mask, ) class MultiModalTransformerArgsPreprocessor: def __init__( # noqa: PLR0913 self, patchify_proj: torch.nn.Linear, adaln: AdaLayerNormSingle, cross_scale_shift_adaln: AdaLayerNormSingle, cross_gate_adaln: AdaLayerNormSingle, inner_dim: int, max_pos: list[int], num_attention_heads: int, cross_pe_max_pos: int, use_middle_indices_grid: bool, audio_cross_attention_dim: int, timestep_scale_multiplier: int, double_precision_rope: bool, positional_embedding_theta: float, rope_type: LTXRopeType, av_ca_timestep_scale_multiplier: int, caption_projection: torch.nn.Module | None = None, prompt_adaln: AdaLayerNormSingle | None = None, ) -> None: self.simple_preprocessor = TransformerArgsPreprocessor( patchify_proj=patchify_proj, adaln=adaln, inner_dim=inner_dim, max_pos=max_pos, num_attention_heads=num_attention_heads, use_middle_indices_grid=use_middle_indices_grid, timestep_scale_multiplier=timestep_scale_multiplier, double_precision_rope=double_precision_rope, positional_embedding_theta=positional_embedding_theta, rope_type=rope_type, caption_projection=caption_projection, prompt_adaln=prompt_adaln, ) self.cross_scale_shift_adaln = cross_scale_shift_adaln self.cross_gate_adaln = cross_gate_adaln self.cross_pe_max_pos = cross_pe_max_pos self.audio_cross_attention_dim = audio_cross_attention_dim self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier def prepare( self, modality: Modality, cross_modality: Modality | None = None, ) -> TransformerArgs: transformer_args = self.simple_preprocessor.prepare(modality) if cross_modality is None: return transformer_args if cross_modality.sigma.numel() > 1: if cross_modality.sigma.shape[0] != modality.timesteps.shape[0]: raise ValueError("Cross modality sigma must have the same batch size as the modality") if cross_modality.sigma.ndim != 1: raise ValueError("Cross modality sigma must be a 1D tensor") cross_timestep = cross_modality.sigma.view( modality.timesteps.shape[0], 1, *[1] * len(modality.timesteps.shape[2:]) ) cross_pe = self.simple_preprocessor._prepare_positional_embeddings( positions=modality.positions[:, 0:1, :], inner_dim=self.audio_cross_attention_dim, max_pos=[self.cross_pe_max_pos], use_middle_indices_grid=True, num_attention_heads=self.simple_preprocessor.num_attention_heads, x_dtype=modality.latent.dtype, ) cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep( timestep=cross_timestep, timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier, batch_size=transformer_args.x.shape[0], hidden_dtype=modality.latent.dtype, ) return replace( transformer_args, cross_positional_embeddings=cross_pe, cross_scale_shift_timestep=cross_scale_shift_timestep, cross_gate_timestep=cross_gate_timestep, ) def _prepare_cross_attention_timestep( self, timestep: torch.Tensor | None, timestep_scale_multiplier: int, batch_size: int, hidden_dtype: torch.dtype, ) -> tuple[torch.Tensor, torch.Tensor]: """Prepare cross attention timestep embeddings.""" timestep = timestep * timestep_scale_multiplier av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier scale_shift_timestep, _ = self.cross_scale_shift_adaln( timestep.flatten(), hidden_dtype=hidden_dtype, ) scale_shift_timestep = scale_shift_timestep.view(batch_size, -1, scale_shift_timestep.shape[-1]) gate_noise_timestep, _ = self.cross_gate_adaln( timestep.flatten() * av_ca_factor, hidden_dtype=hidden_dtype, ) gate_noise_timestep = gate_noise_timestep.view(batch_size, -1, gate_noise_timestep.shape[-1]) return scale_shift_timestep, gate_noise_timestep @dataclass class TransformerConfig: dim: int heads: int d_head: int context_dim: int apply_gated_attention: bool = False cross_attention_adaln: bool = False class BasicAVTransformerBlock(torch.nn.Module): def __init__( self, idx: int, video: TransformerConfig | None = None, audio: TransformerConfig | None = None, rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, norm_eps: float = 1e-6, ): super().__init__() self.idx = idx if video is not None: self.attn1 = Attention( query_dim=video.dim, heads=video.heads, dim_head=video.d_head, context_dim=None, rope_type=rope_type, norm_eps=norm_eps, apply_gated_attention=video.apply_gated_attention, ) self.attn2 = Attention( query_dim=video.dim, context_dim=video.context_dim, heads=video.heads, dim_head=video.d_head, rope_type=rope_type, norm_eps=norm_eps, apply_gated_attention=video.apply_gated_attention, ) self.ff = FeedForward(video.dim, dim_out=video.dim) video_sst_size = adaln_embedding_coefficient(video.cross_attention_adaln) self.scale_shift_table = torch.nn.Parameter(torch.empty(video_sst_size, video.dim)) if audio is not None: self.audio_attn1 = Attention( query_dim=audio.dim, heads=audio.heads, dim_head=audio.d_head, context_dim=None, rope_type=rope_type, norm_eps=norm_eps, apply_gated_attention=audio.apply_gated_attention, ) self.audio_attn2 = Attention( query_dim=audio.dim, context_dim=audio.context_dim, heads=audio.heads, dim_head=audio.d_head, rope_type=rope_type, norm_eps=norm_eps, apply_gated_attention=audio.apply_gated_attention, ) self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim) audio_sst_size = adaln_embedding_coefficient(audio.cross_attention_adaln) self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(audio_sst_size, audio.dim)) if audio is not None and video is not None: # Q: Video, K,V: Audio self.audio_to_video_attn = Attention( query_dim=video.dim, context_dim=audio.dim, heads=audio.heads, dim_head=audio.d_head, rope_type=rope_type, norm_eps=norm_eps, apply_gated_attention=video.apply_gated_attention, ) # Q: Audio, K,V: Video self.video_to_audio_attn = Attention( query_dim=audio.dim, context_dim=video.dim, heads=audio.heads, dim_head=audio.d_head, rope_type=rope_type, norm_eps=norm_eps, apply_gated_attention=audio.apply_gated_attention, ) self.scale_shift_table_a2v_ca_audio = torch.nn.Parameter(torch.empty(5, audio.dim)) self.scale_shift_table_a2v_ca_video = torch.nn.Parameter(torch.empty(5, video.dim)) self.cross_attention_adaln = (video is not None and video.cross_attention_adaln) or ( audio is not None and audio.cross_attention_adaln ) if self.cross_attention_adaln and video is not None: self.prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, video.dim)) if self.cross_attention_adaln and audio is not None: self.audio_prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, audio.dim)) self.norm_eps = norm_eps def get_ada_values( self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice ) -> tuple[torch.Tensor, ...]: num_ada_params = scale_shift_table.shape[0] ada_values = ( scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype) + timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :] ).unbind(dim=2) return ada_values def get_av_ca_ada_values( self, scale_shift_table: torch.Tensor, batch_size: int, scale_shift_timestep: torch.Tensor, gate_timestep: torch.Tensor, scale_shift_indices: slice, num_scale_shift_values: int = 4, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: scale_shift_ada_values = self.get_ada_values( scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, scale_shift_indices ) gate_ada_values = self.get_ada_values( scale_shift_table[num_scale_shift_values:, :], batch_size, gate_timestep, slice(None, None) ) scale, shift = (t.squeeze(2) for t in scale_shift_ada_values) (gate,) = (t.squeeze(2) for t in gate_ada_values) return scale, shift, gate def _apply_text_cross_attention( self, x: torch.Tensor, context: torch.Tensor, attn: Attention, scale_shift_table: torch.Tensor, prompt_scale_shift_table: torch.Tensor | None, timestep: torch.Tensor, prompt_timestep: torch.Tensor | None, context_mask: torch.Tensor | None, cross_attention_adaln: bool = False, ) -> torch.Tensor: """Apply text cross-attention, with optional AdaLN modulation.""" if cross_attention_adaln: shift_q, scale_q, gate = self.get_ada_values(scale_shift_table, x.shape[0], timestep, slice(6, 9)) return apply_cross_attention_adaln( x, context, attn, shift_q, scale_q, gate, prompt_scale_shift_table, prompt_timestep, context_mask, self.norm_eps, ) return attn(rms_norm(x, eps=self.norm_eps), context=context, mask=context_mask) def forward( # noqa: PLR0915 self, video: TransformerArgs | None, audio: TransformerArgs | None, perturbations: BatchedPerturbationConfig | None = None, ) -> tuple[TransformerArgs | None, TransformerArgs | None]: if video is None and audio is None: raise ValueError("At least one of video or audio must be provided") batch_size = (video or audio).x.shape[0] if perturbations is None: perturbations = BatchedPerturbationConfig.empty(batch_size) vx = video.x if video is not None else None ax = audio.x if audio is not None else None run_vx = video is not None and video.enabled and vx.numel() > 0 run_ax = audio is not None and audio.enabled and ax.numel() > 0 run_a2v = run_vx and (audio is not None and ax.numel() > 0) run_v2a = run_ax and (video is not None and vx.numel() > 0) if run_vx: vshift_msa, vscale_msa, vgate_msa = self.get_ada_values( self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3) ) norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa del vshift_msa, vscale_msa all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx) none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx) v_mask = ( perturbations.mask_like(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx, vx) if not all_perturbed and not none_perturbed else None ) vx = ( vx + self.attn1( norm_vx, pe=video.positional_embeddings, mask=video.self_attention_mask, perturbation_mask=v_mask, all_perturbed=all_perturbed, ) * vgate_msa ) del vgate_msa, norm_vx, v_mask vx = vx + self._apply_text_cross_attention( vx, video.context, self.attn2, self.scale_shift_table, getattr(self, "prompt_scale_shift_table", None), video.timesteps, video.prompt_timestep, video.context_mask, cross_attention_adaln=self.cross_attention_adaln, ) if run_ax: ashift_msa, ascale_msa, agate_msa = self.get_ada_values( self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3) ) norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa del ashift_msa, ascale_msa all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx) none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx) a_mask = ( perturbations.mask_like(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx, ax) if not all_perturbed and not none_perturbed else None ) ax = ( ax + self.audio_attn1( norm_ax, pe=audio.positional_embeddings, mask=audio.self_attention_mask, perturbation_mask=a_mask, all_perturbed=all_perturbed, ) * agate_msa ) del agate_msa, norm_ax, a_mask ax = ax + self._apply_text_cross_attention( ax, audio.context, self.audio_attn2, self.audio_scale_shift_table, getattr(self, "audio_prompt_scale_shift_table", None), audio.timesteps, audio.prompt_timestep, audio.context_mask, cross_attention_adaln=self.cross_attention_adaln, ) # Audio - Video cross attention. if run_a2v or run_v2a: vx_norm3 = rms_norm(vx, eps=self.norm_eps) ax_norm3 = rms_norm(ax, eps=self.norm_eps) if run_a2v and not perturbations.all_in_batch(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx): scale_ca_video_a2v, shift_ca_video_a2v, gate_out_a2v = self.get_av_ca_ada_values( self.scale_shift_table_a2v_ca_video, vx.shape[0], video.cross_scale_shift_timestep, video.cross_gate_timestep, slice(0, 2), ) vx_scaled = vx_norm3 * (1 + scale_ca_video_a2v) + shift_ca_video_a2v del scale_ca_video_a2v, shift_ca_video_a2v scale_ca_audio_a2v, shift_ca_audio_a2v, _ = self.get_av_ca_ada_values( self.scale_shift_table_a2v_ca_audio, ax.shape[0], audio.cross_scale_shift_timestep, audio.cross_gate_timestep, slice(0, 2), ) ax_scaled = ax_norm3 * (1 + scale_ca_audio_a2v) + shift_ca_audio_a2v del scale_ca_audio_a2v, shift_ca_audio_a2v a2v_mask = perturbations.mask_like(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx, vx) vx = vx + ( self.audio_to_video_attn( vx_scaled, context=ax_scaled, pe=video.cross_positional_embeddings, k_pe=audio.cross_positional_embeddings, ) * gate_out_a2v * a2v_mask ) del gate_out_a2v, a2v_mask, vx_scaled, ax_scaled if run_v2a and not perturbations.all_in_batch(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx): scale_ca_audio_v2a, shift_ca_audio_v2a, gate_out_v2a = self.get_av_ca_ada_values( self.scale_shift_table_a2v_ca_audio, ax.shape[0], audio.cross_scale_shift_timestep, audio.cross_gate_timestep, slice(2, 4), ) ax_scaled = ax_norm3 * (1 + scale_ca_audio_v2a) + shift_ca_audio_v2a del scale_ca_audio_v2a, shift_ca_audio_v2a scale_ca_video_v2a, shift_ca_video_v2a, _ = self.get_av_ca_ada_values( self.scale_shift_table_a2v_ca_video, vx.shape[0], video.cross_scale_shift_timestep, video.cross_gate_timestep, slice(2, 4), ) vx_scaled = vx_norm3 * (1 + scale_ca_video_v2a) + shift_ca_video_v2a del scale_ca_video_v2a, shift_ca_video_v2a v2a_mask = perturbations.mask_like(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx, ax) ax = ax + ( self.video_to_audio_attn( ax_scaled, context=vx_scaled, pe=audio.cross_positional_embeddings, k_pe=video.cross_positional_embeddings, ) * gate_out_v2a * v2a_mask ) del gate_out_v2a, v2a_mask, ax_scaled, vx_scaled del vx_norm3, ax_norm3 if run_vx: vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values( self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, 6) ) vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp vx = vx + self.ff(vx_scaled) * vgate_mlp del vshift_mlp, vscale_mlp, vgate_mlp, vx_scaled if run_ax: ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values( self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, 6) ) ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp ax = ax + self.audio_ff(ax_scaled) * agate_mlp del ashift_mlp, ascale_mlp, agate_mlp, ax_scaled return replace(video, x=vx) if video is not None else None, replace(audio, x=ax) if audio is not None else None def apply_cross_attention_adaln( x: torch.Tensor, context: torch.Tensor, attn: Attention, q_shift: torch.Tensor, q_scale: torch.Tensor, q_gate: torch.Tensor, prompt_scale_shift_table: torch.Tensor, prompt_timestep: torch.Tensor, context_mask: torch.Tensor | None = None, norm_eps: float = 1e-6, ) -> torch.Tensor: batch_size = x.shape[0] shift_kv, scale_kv = ( prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1) ).unbind(dim=2) attn_input = rms_norm(x, eps=norm_eps) * (1 + q_scale) + q_shift encoder_hidden_states = context * (1 + scale_kv) + shift_kv return attn(attn_input, context=encoder_hidden_states, mask=context_mask) * q_gate class GELUApprox(torch.nn.Module): def __init__(self, dim_in: int, dim_out: int) -> None: super().__init__() self.proj = torch.nn.Linear(dim_in, dim_out) def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.gelu(self.proj(x), approximate="tanh") class FeedForward(torch.nn.Module): def __init__(self, dim: int, dim_out: int, mult: int = 4) -> None: super().__init__() inner_dim = int(dim * mult) project_in = GELUApprox(dim, inner_dim) self.net = torch.nn.Sequential(project_in, torch.nn.Identity(), torch.nn.Linear(inner_dim, dim_out)) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) class LTXModelType(Enum): AudioVideo = "ltx av model" VideoOnly = "ltx video only model" AudioOnly = "ltx audio only model" def is_video_enabled(self) -> bool: return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly) def is_audio_enabled(self) -> bool: return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly) class LTXModel(torch.nn.Module): """ LTX model transformer implementation. This class implements the transformer blocks for the LTX model. """ def __init__( # noqa: PLR0913 self, *, model_type: LTXModelType = LTXModelType.AudioVideo, num_attention_heads: int = 32, attention_head_dim: int = 128, in_channels: int = 128, out_channels: int = 128, num_layers: int = 48, cross_attention_dim: int = 4096, norm_eps: float = 1e-06, caption_channels: int = 3840, positional_embedding_theta: float = 10000.0, positional_embedding_max_pos: list[int] | None = [20, 2048, 2048], timestep_scale_multiplier: int = 1000, use_middle_indices_grid: bool = True, audio_num_attention_heads: int = 32, audio_attention_head_dim: int = 64, audio_in_channels: int = 128, audio_out_channels: int = 128, audio_cross_attention_dim: int = 2048, audio_positional_embedding_max_pos: list[int] | None = [20], av_ca_timestep_scale_multiplier: int = 1000, rope_type: LTXRopeType = LTXRopeType.SPLIT, double_precision_rope: bool = True, apply_gated_attention: bool = False, cross_attention_adaln: bool = False, ): super().__init__() self._enable_gradient_checkpointing = False self.use_middle_indices_grid = use_middle_indices_grid self.rope_type = rope_type self.double_precision_rope = double_precision_rope self.timestep_scale_multiplier = timestep_scale_multiplier self.positional_embedding_theta = positional_embedding_theta self.model_type = model_type self.cross_attention_adaln = cross_attention_adaln cross_pe_max_pos = None if model_type.is_video_enabled(): if positional_embedding_max_pos is None: positional_embedding_max_pos = [20, 2048, 2048] self.positional_embedding_max_pos = positional_embedding_max_pos self.num_attention_heads = num_attention_heads self.inner_dim = num_attention_heads * attention_head_dim self._init_video( in_channels=in_channels, out_channels=out_channels, caption_channels=caption_channels, norm_eps=norm_eps, ) if model_type.is_audio_enabled(): if audio_positional_embedding_max_pos is None: audio_positional_embedding_max_pos = [20] self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos self.audio_num_attention_heads = audio_num_attention_heads self.audio_inner_dim = self.audio_num_attention_heads * audio_attention_head_dim self._init_audio( in_channels=audio_in_channels, out_channels=audio_out_channels, caption_channels=caption_channels, norm_eps=norm_eps, ) if model_type.is_video_enabled() and model_type.is_audio_enabled(): cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]) self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier self.audio_cross_attention_dim = audio_cross_attention_dim self._init_audio_video(num_scale_shift_values=4) self._init_preprocessors(cross_pe_max_pos) # Initialize transformer blocks self._init_transformer_blocks( num_layers=num_layers, attention_head_dim=attention_head_dim if model_type.is_video_enabled() else 0, cross_attention_dim=cross_attention_dim, audio_attention_head_dim=audio_attention_head_dim if model_type.is_audio_enabled() else 0, audio_cross_attention_dim=audio_cross_attention_dim, norm_eps=norm_eps, apply_gated_attention=apply_gated_attention, ) @property def _adaln_embedding_coefficient(self) -> int: return adaln_embedding_coefficient(self.cross_attention_adaln) def _init_video( self, in_channels: int, out_channels: int, caption_channels: int, norm_eps: float, ) -> None: """Initialize video-specific components.""" # Video input components self.patchify_proj = torch.nn.Linear(in_channels, self.inner_dim, bias=True) self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=self._adaln_embedding_coefficient) self.prompt_adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None # Video caption projection if caption_channels is not None: self.caption_projection = PixArtAlphaTextProjection( in_features=caption_channels, hidden_size=self.inner_dim, ) # Video output components self.scale_shift_table = torch.nn.Parameter(torch.empty(2, self.inner_dim)) self.norm_out = torch.nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=norm_eps) self.proj_out = torch.nn.Linear(self.inner_dim, out_channels) def _init_audio( self, in_channels: int, out_channels: int, caption_channels: int, norm_eps: float, ) -> None: """Initialize audio-specific components.""" # Audio input components self.audio_patchify_proj = torch.nn.Linear(in_channels, self.audio_inner_dim, bias=True) self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=self._adaln_embedding_coefficient) self.audio_prompt_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None # Audio caption projection if caption_channels is not None: self.audio_caption_projection = PixArtAlphaTextProjection( in_features=caption_channels, hidden_size=self.audio_inner_dim, ) # Audio output components self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(2, self.audio_inner_dim)) self.audio_norm_out = torch.nn.LayerNorm(self.audio_inner_dim, elementwise_affine=False, eps=norm_eps) self.audio_proj_out = torch.nn.Linear(self.audio_inner_dim, out_channels) def _init_audio_video( self, num_scale_shift_values: int, ) -> None: """Initialize audio-video cross-attention components.""" self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle( self.inner_dim, embedding_coefficient=num_scale_shift_values, ) self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle( self.audio_inner_dim, embedding_coefficient=num_scale_shift_values, ) self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle( self.inner_dim, embedding_coefficient=1, ) self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle( self.audio_inner_dim, embedding_coefficient=1, ) def _init_preprocessors( self, cross_pe_max_pos: int | None = None, ) -> None: """Initialize preprocessors for LTX.""" if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled(): self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor( patchify_proj=self.patchify_proj, adaln=self.adaln_single, cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single, cross_gate_adaln=self.av_ca_a2v_gate_adaln_single, inner_dim=self.inner_dim, max_pos=self.positional_embedding_max_pos, num_attention_heads=self.num_attention_heads, cross_pe_max_pos=cross_pe_max_pos, use_middle_indices_grid=self.use_middle_indices_grid, audio_cross_attention_dim=self.audio_cross_attention_dim, timestep_scale_multiplier=self.timestep_scale_multiplier, double_precision_rope=self.double_precision_rope, positional_embedding_theta=self.positional_embedding_theta, rope_type=self.rope_type, av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier, caption_projection=getattr(self, "caption_projection", None), prompt_adaln=getattr(self, "prompt_adaln_single", None), ) self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor( patchify_proj=self.audio_patchify_proj, adaln=self.audio_adaln_single, cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single, cross_gate_adaln=self.av_ca_v2a_gate_adaln_single, inner_dim=self.audio_inner_dim, max_pos=self.audio_positional_embedding_max_pos, num_attention_heads=self.audio_num_attention_heads, cross_pe_max_pos=cross_pe_max_pos, use_middle_indices_grid=self.use_middle_indices_grid, audio_cross_attention_dim=self.audio_cross_attention_dim, timestep_scale_multiplier=self.timestep_scale_multiplier, double_precision_rope=self.double_precision_rope, positional_embedding_theta=self.positional_embedding_theta, rope_type=self.rope_type, av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier, caption_projection=getattr(self, "audio_caption_projection", None), prompt_adaln=getattr(self, "audio_prompt_adaln_single", None), ) elif self.model_type.is_video_enabled(): self.video_args_preprocessor = TransformerArgsPreprocessor( patchify_proj=self.patchify_proj, adaln=self.adaln_single, inner_dim=self.inner_dim, max_pos=self.positional_embedding_max_pos, num_attention_heads=self.num_attention_heads, use_middle_indices_grid=self.use_middle_indices_grid, timestep_scale_multiplier=self.timestep_scale_multiplier, double_precision_rope=self.double_precision_rope, positional_embedding_theta=self.positional_embedding_theta, rope_type=self.rope_type, caption_projection=getattr(self, "caption_projection", None), prompt_adaln=getattr(self, "prompt_adaln_single", None), ) elif self.model_type.is_audio_enabled(): self.audio_args_preprocessor = TransformerArgsPreprocessor( patchify_proj=self.audio_patchify_proj, adaln=self.audio_adaln_single, inner_dim=self.audio_inner_dim, max_pos=self.audio_positional_embedding_max_pos, num_attention_heads=self.audio_num_attention_heads, use_middle_indices_grid=self.use_middle_indices_grid, timestep_scale_multiplier=self.timestep_scale_multiplier, double_precision_rope=self.double_precision_rope, positional_embedding_theta=self.positional_embedding_theta, rope_type=self.rope_type, caption_projection=getattr(self, "audio_caption_projection", None), prompt_adaln=getattr(self, "audio_prompt_adaln_single", None), ) def _init_transformer_blocks( self, num_layers: int, attention_head_dim: int, cross_attention_dim: int, audio_attention_head_dim: int, audio_cross_attention_dim: int, norm_eps: float, apply_gated_attention: bool, ) -> None: """Initialize transformer blocks for LTX.""" video_config = ( TransformerConfig( dim=self.inner_dim, heads=self.num_attention_heads, d_head=attention_head_dim, context_dim=cross_attention_dim, apply_gated_attention=apply_gated_attention, cross_attention_adaln=self.cross_attention_adaln, ) if self.model_type.is_video_enabled() else None ) audio_config = ( TransformerConfig( dim=self.audio_inner_dim, heads=self.audio_num_attention_heads, d_head=audio_attention_head_dim, context_dim=audio_cross_attention_dim, apply_gated_attention=apply_gated_attention, cross_attention_adaln=self.cross_attention_adaln, ) if self.model_type.is_audio_enabled() else None ) self.transformer_blocks = torch.nn.ModuleList( [ BasicAVTransformerBlock( idx=idx, video=video_config, audio=audio_config, rope_type=self.rope_type, norm_eps=norm_eps, ) for idx in range(num_layers) ] ) def set_gradient_checkpointing(self, enable: bool) -> None: """Enable or disable gradient checkpointing for transformer blocks. Gradient checkpointing trades compute for memory by recomputing activations during the backward pass instead of storing them. This can significantly reduce memory usage at the cost of ~20-30% slower training. Args: enable: Whether to enable gradient checkpointing """ self._enable_gradient_checkpointing = enable def _process_transformer_blocks( self, video: TransformerArgs | None, audio: TransformerArgs | None, perturbations: BatchedPerturbationConfig, use_gradient_checkpointing: bool = False, use_gradient_checkpointing_offload: bool = False, ) -> tuple[TransformerArgs, TransformerArgs]: """Process transformer blocks for LTXAV.""" # Process transformer blocks for block in self.transformer_blocks: video, audio = gradient_checkpoint_forward( block, use_gradient_checkpointing, use_gradient_checkpointing_offload, video=video, audio=audio, perturbations=perturbations, ) return video, audio def _process_output( self, scale_shift_table: torch.Tensor, norm_out: torch.nn.LayerNorm, proj_out: torch.nn.Linear, x: torch.Tensor, embedded_timestep: torch.Tensor, ) -> torch.Tensor: """Process output for LTXV.""" # Apply scale-shift modulation scale_shift_values = ( scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] ) shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] x = norm_out(x) x = x * (1 + scale) + shift x = proj_out(x) return x def _forward( self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig, use_gradient_checkpointing: bool = False, use_gradient_checkpointing_offload: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ Forward pass for LTX models. Returns: Processed output tensors """ if not self.model_type.is_video_enabled() and video is not None: raise ValueError("Video is not enabled for this model") if not self.model_type.is_audio_enabled() and audio is not None: raise ValueError("Audio is not enabled for this model") video_args = self.video_args_preprocessor.prepare(video, audio) if video is not None else None audio_args = self.audio_args_preprocessor.prepare(audio, video) if audio is not None else None # Process transformer blocks video_out, audio_out = self._process_transformer_blocks( video=video_args, audio=audio_args, perturbations=perturbations, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, ) # Process output vx = ( self._process_output( self.scale_shift_table, self.norm_out, self.proj_out, video_out.x, video_out.embedded_timestep ) if video_out is not None else None ) ax = ( self._process_output( self.audio_scale_shift_table, self.audio_norm_out, self.audio_proj_out, audio_out.x, audio_out.embedded_timestep, ) if audio_out is not None else None ) return vx, ax def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps, sigma, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False): cross_pe_max_pos = None if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled(): cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]) self._init_preprocessors(cross_pe_max_pos) video = Modality(video_latents, sigma, video_timesteps, video_positions, video_context) audio = Modality(audio_latents, sigma, audio_timesteps, audio_positions, audio_context) if audio_latents is not None else None vx, ax = self._forward(video=video, audio=audio, perturbations=None, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload) return vx, ax