mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:39:43 +00:00
1683 lines
66 KiB
Python
1683 lines
66 KiB
Python
import math
|
|
import functools
|
|
from dataclasses import dataclass, replace
|
|
from enum import Enum
|
|
from typing import Optional, Tuple, Callable
|
|
import numpy as np
|
|
import torch
|
|
from einops import rearrange
|
|
from .ltx2_common import rms_norm, Modality
|
|
from ..core.attention.attention import attention_forward
|
|
from ..core import gradient_checkpoint_forward
|
|
|
|
|
|
def get_timestep_embedding(
|
|
timesteps: torch.Tensor,
|
|
embedding_dim: int,
|
|
flip_sin_to_cos: bool = False,
|
|
downscale_freq_shift: float = 1,
|
|
scale: float = 1,
|
|
max_period: int = 10000,
|
|
) -> torch.Tensor:
|
|
"""
|
|
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
|
Args
|
|
timesteps (torch.Tensor):
|
|
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
|
embedding_dim (int):
|
|
the dimension of the output.
|
|
flip_sin_to_cos (bool):
|
|
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
|
downscale_freq_shift (float):
|
|
Controls the delta between frequencies between dimensions
|
|
scale (float):
|
|
Scaling factor applied to the embeddings.
|
|
max_period (int):
|
|
Controls the maximum frequency of the embeddings
|
|
Returns
|
|
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
|
"""
|
|
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
|
|
|
half_dim = embedding_dim // 2
|
|
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
|
exponent = exponent / (half_dim - downscale_freq_shift)
|
|
|
|
emb = torch.exp(exponent)
|
|
emb = timesteps[:, None].float() * emb[None, :]
|
|
|
|
# scale embeddings
|
|
emb = scale * emb
|
|
|
|
# concat sine and cosine embeddings
|
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
|
|
|
# flip sine and cosine embeddings
|
|
if flip_sin_to_cos:
|
|
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
|
|
|
# zero pad
|
|
if embedding_dim % 2 == 1:
|
|
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
|
return emb
|
|
|
|
|
|
class TimestepEmbedding(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
time_embed_dim: int,
|
|
out_dim: int | None = None,
|
|
post_act_fn: str | None = None,
|
|
cond_proj_dim: int | None = None,
|
|
sample_proj_bias: bool = True,
|
|
):
|
|
super().__init__()
|
|
|
|
self.linear_1 = torch.nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
|
|
|
|
if cond_proj_dim is not None:
|
|
self.cond_proj = torch.nn.Linear(cond_proj_dim, in_channels, bias=False)
|
|
else:
|
|
self.cond_proj = None
|
|
|
|
self.act = torch.nn.SiLU()
|
|
time_embed_dim_out = out_dim if out_dim is not None else time_embed_dim
|
|
|
|
self.linear_2 = torch.nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
|
|
|
if post_act_fn is None:
|
|
self.post_act = None
|
|
|
|
def forward(self, sample: torch.Tensor, condition: torch.Tensor | None = None) -> torch.Tensor:
|
|
if condition is not None:
|
|
sample = sample + self.cond_proj(condition)
|
|
sample = self.linear_1(sample)
|
|
|
|
if self.act is not None:
|
|
sample = self.act(sample)
|
|
|
|
sample = self.linear_2(sample)
|
|
|
|
if self.post_act is not None:
|
|
sample = self.post_act(sample)
|
|
return sample
|
|
|
|
|
|
class Timesteps(torch.nn.Module):
|
|
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
|
|
super().__init__()
|
|
self.num_channels = num_channels
|
|
self.flip_sin_to_cos = flip_sin_to_cos
|
|
self.downscale_freq_shift = downscale_freq_shift
|
|
self.scale = scale
|
|
|
|
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
|
t_emb = get_timestep_embedding(
|
|
timesteps,
|
|
self.num_channels,
|
|
flip_sin_to_cos=self.flip_sin_to_cos,
|
|
downscale_freq_shift=self.downscale_freq_shift,
|
|
scale=self.scale,
|
|
)
|
|
return t_emb
|
|
|
|
|
|
class PixArtAlphaCombinedTimestepSizeEmbeddings(torch.nn.Module):
|
|
"""
|
|
For PixArt-Alpha.
|
|
Reference:
|
|
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embedding_dim: int,
|
|
size_emb_dim: int,
|
|
):
|
|
super().__init__()
|
|
|
|
self.outdim = size_emb_dim
|
|
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
|
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
|
|
|
def forward(
|
|
self,
|
|
timestep: torch.Tensor,
|
|
hidden_dtype: torch.dtype,
|
|
) -> torch.Tensor:
|
|
timesteps_proj = self.time_proj(timestep)
|
|
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
|
return timesteps_emb
|
|
|
|
|
|
class PerturbationType(Enum):
|
|
"""Types of attention perturbations for STG (Spatio-Temporal Guidance)."""
|
|
|
|
SKIP_A2V_CROSS_ATTN = "skip_a2v_cross_attn"
|
|
SKIP_V2A_CROSS_ATTN = "skip_v2a_cross_attn"
|
|
SKIP_VIDEO_SELF_ATTN = "skip_video_self_attn"
|
|
SKIP_AUDIO_SELF_ATTN = "skip_audio_self_attn"
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Perturbation:
|
|
"""A single perturbation specifying which attention type to skip and in which blocks."""
|
|
|
|
type: PerturbationType
|
|
blocks: list[int] | None # None means all blocks
|
|
|
|
def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:
|
|
if self.type != perturbation_type:
|
|
return False
|
|
|
|
if self.blocks is None:
|
|
return True
|
|
|
|
return block in self.blocks
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class PerturbationConfig:
|
|
"""Configuration holding a list of perturbations for a single sample."""
|
|
|
|
perturbations: list[Perturbation] | None
|
|
|
|
def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:
|
|
if self.perturbations is None:
|
|
return False
|
|
|
|
return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
|
|
|
|
@staticmethod
|
|
def empty() -> "PerturbationConfig":
|
|
return PerturbationConfig([])
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class BatchedPerturbationConfig:
|
|
"""Perturbation configurations for a batch, with utilities for generating attention masks."""
|
|
|
|
perturbations: list[PerturbationConfig]
|
|
|
|
def mask(
|
|
self, perturbation_type: PerturbationType, block: int, device, dtype: torch.dtype
|
|
) -> torch.Tensor:
|
|
mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype)
|
|
for batch_idx, perturbation in enumerate(self.perturbations):
|
|
if perturbation.is_perturbed(perturbation_type, block):
|
|
mask[batch_idx] = 0
|
|
|
|
return mask
|
|
|
|
def mask_like(self, perturbation_type: PerturbationType, block: int, values: torch.Tensor) -> torch.Tensor:
|
|
mask = self.mask(perturbation_type, block, values.device, values.dtype)
|
|
return mask.view(mask.numel(), *([1] * len(values.shape[1:])))
|
|
|
|
def any_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool:
|
|
return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
|
|
|
|
def all_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool:
|
|
return all(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
|
|
|
|
@staticmethod
|
|
def empty(batch_size: int) -> "BatchedPerturbationConfig":
|
|
return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)])
|
|
|
|
|
|
|
|
ADALN_NUM_BASE_PARAMS = 6
|
|
# Cross-attention AdaLN adds 3 more (scale, shift, gate) for the CA norm.
|
|
ADALN_NUM_CROSS_ATTN_PARAMS = 3
|
|
|
|
|
|
def adaln_embedding_coefficient(cross_attention_adaln: bool) -> int:
|
|
"""Total number of AdaLN parameters per block."""
|
|
return ADALN_NUM_BASE_PARAMS + (ADALN_NUM_CROSS_ATTN_PARAMS if cross_attention_adaln else 0)
|
|
|
|
|
|
class AdaLayerNormSingle(torch.nn.Module):
|
|
r"""
|
|
Norm layer adaptive layer norm single (adaLN-single).
|
|
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
|
Parameters:
|
|
embedding_dim (`int`): The size of each embedding vector.
|
|
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
|
"""
|
|
|
|
def __init__(self, embedding_dim: int, embedding_coefficient: int = 6):
|
|
super().__init__()
|
|
|
|
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
|
embedding_dim,
|
|
size_emb_dim=embedding_dim // 3,
|
|
)
|
|
|
|
self.silu = torch.nn.SiLU()
|
|
self.linear = torch.nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True)
|
|
|
|
def forward(
|
|
self,
|
|
timestep: torch.Tensor,
|
|
hidden_dtype: Optional[torch.dtype] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
embedded_timestep = self.emb(timestep, hidden_dtype=hidden_dtype)
|
|
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
|
|
|
|
|
class LTXRopeType(Enum):
|
|
INTERLEAVED = "interleaved"
|
|
SPLIT = "split"
|
|
|
|
|
|
def apply_rotary_emb(
|
|
input_tensor: torch.Tensor,
|
|
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
|
|
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
|
) -> torch.Tensor:
|
|
if rope_type == LTXRopeType.INTERLEAVED:
|
|
return apply_interleaved_rotary_emb(input_tensor, *freqs_cis)
|
|
elif rope_type == LTXRopeType.SPLIT:
|
|
return apply_split_rotary_emb(input_tensor, *freqs_cis)
|
|
else:
|
|
raise ValueError(f"Invalid rope type: {rope_type}")
|
|
|
|
|
|
|
|
def apply_interleaved_rotary_emb(
|
|
input_tensor: torch.Tensor, cos_freqs: torch.Tensor, sin_freqs: torch.Tensor
|
|
) -> torch.Tensor:
|
|
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
|
t1, t2 = t_dup.unbind(dim=-1)
|
|
t_dup = torch.stack((-t2, t1), dim=-1)
|
|
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
|
|
|
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
|
|
|
return out
|
|
|
|
|
|
def apply_split_rotary_emb(
|
|
input_tensor: torch.Tensor, cos_freqs: torch.Tensor, sin_freqs: torch.Tensor
|
|
) -> torch.Tensor:
|
|
needs_reshape = False
|
|
if input_tensor.ndim != 4 and cos_freqs.ndim == 4:
|
|
b, h, t, _ = cos_freqs.shape
|
|
input_tensor = input_tensor.reshape(b, t, h, -1).swapaxes(1, 2)
|
|
needs_reshape = True
|
|
|
|
split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2)
|
|
first_half_input = split_input[..., :1, :]
|
|
second_half_input = split_input[..., 1:, :]
|
|
|
|
output = split_input * cos_freqs.unsqueeze(-2)
|
|
first_half_output = output[..., :1, :]
|
|
second_half_output = output[..., 1:, :]
|
|
|
|
first_half_output.addcmul_(-sin_freqs.unsqueeze(-2), second_half_input)
|
|
second_half_output.addcmul_(sin_freqs.unsqueeze(-2), first_half_input)
|
|
|
|
output = rearrange(output, "... d r -> ... (d r)")
|
|
if needs_reshape:
|
|
output = output.swapaxes(1, 2).reshape(b, t, -1)
|
|
|
|
return output
|
|
|
|
|
|
@functools.lru_cache(maxsize=5)
|
|
def generate_freq_grid_np(
|
|
positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int
|
|
) -> torch.Tensor:
|
|
theta = positional_embedding_theta
|
|
start = 1
|
|
end = theta
|
|
|
|
n_elem = 2 * positional_embedding_max_pos_count
|
|
pow_indices = np.power(
|
|
theta,
|
|
np.linspace(
|
|
np.log(start) / np.log(theta),
|
|
np.log(end) / np.log(theta),
|
|
inner_dim // n_elem,
|
|
dtype=np.float64,
|
|
),
|
|
)
|
|
return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32)
|
|
|
|
|
|
@functools.lru_cache(maxsize=5)
|
|
def generate_freq_grid_pytorch(
|
|
positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int
|
|
) -> torch.Tensor:
|
|
theta = positional_embedding_theta
|
|
start = 1
|
|
end = theta
|
|
n_elem = 2 * positional_embedding_max_pos_count
|
|
|
|
indices = theta ** (
|
|
torch.linspace(
|
|
math.log(start, theta),
|
|
math.log(end, theta),
|
|
inner_dim // n_elem,
|
|
dtype=torch.float32,
|
|
)
|
|
)
|
|
indices = indices.to(dtype=torch.float32)
|
|
|
|
indices = indices * math.pi / 2
|
|
|
|
return indices
|
|
|
|
|
|
def get_fractional_positions(indices_grid: torch.Tensor, max_pos: list[int]) -> torch.Tensor:
|
|
n_pos_dims = indices_grid.shape[1]
|
|
assert n_pos_dims == len(max_pos), (
|
|
f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
|
|
)
|
|
fractional_positions = torch.stack(
|
|
[indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)],
|
|
dim=-1,
|
|
)
|
|
return fractional_positions
|
|
|
|
|
|
def generate_freqs(
|
|
indices: torch.Tensor, indices_grid: torch.Tensor, max_pos: list[int], use_middle_indices_grid: bool
|
|
) -> torch.Tensor:
|
|
if use_middle_indices_grid:
|
|
assert len(indices_grid.shape) == 4
|
|
assert indices_grid.shape[-1] == 2
|
|
indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1]
|
|
indices_grid = (indices_grid_start + indices_grid_end) / 2.0
|
|
elif len(indices_grid.shape) == 4:
|
|
indices_grid = indices_grid[..., 0]
|
|
|
|
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
|
indices = indices.to(device=fractional_positions.device)
|
|
|
|
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
|
|
return freqs
|
|
|
|
|
|
def split_freqs_cis(freqs: torch.Tensor, pad_size: int, num_attention_heads: int) -> tuple[torch.Tensor, torch.Tensor]:
|
|
cos_freq = freqs.cos()
|
|
sin_freq = freqs.sin()
|
|
|
|
if pad_size != 0:
|
|
cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
|
|
sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
|
|
|
|
cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
|
|
sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
|
|
|
|
# Reshape freqs to be compatible with multi-head attention
|
|
b = cos_freq.shape[0]
|
|
t = cos_freq.shape[1]
|
|
|
|
cos_freq = cos_freq.reshape(b, t, num_attention_heads, -1)
|
|
sin_freq = sin_freq.reshape(b, t, num_attention_heads, -1)
|
|
|
|
cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
|
|
sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
|
|
return cos_freq, sin_freq
|
|
|
|
|
|
def interleaved_freqs_cis(freqs: torch.Tensor, pad_size: int) -> tuple[torch.Tensor, torch.Tensor]:
|
|
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
|
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
|
if pad_size != 0:
|
|
cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
|
|
sin_padding = torch.zeros_like(cos_freq[:, :, :pad_size])
|
|
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
|
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
|
return cos_freq, sin_freq
|
|
|
|
|
|
def precompute_freqs_cis(
|
|
indices_grid: torch.Tensor,
|
|
dim: int,
|
|
out_dtype: torch.dtype,
|
|
theta: float = 10000.0,
|
|
max_pos: list[int] | None = None,
|
|
use_middle_indices_grid: bool = False,
|
|
num_attention_heads: int = 32,
|
|
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
|
freq_grid_generator: Callable[[float, int, int, torch.device], torch.Tensor] = generate_freq_grid_pytorch,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
if max_pos is None:
|
|
max_pos = [20, 2048, 2048]
|
|
|
|
indices = freq_grid_generator(theta, indices_grid.shape[1], dim)
|
|
freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)
|
|
|
|
if rope_type == LTXRopeType.SPLIT:
|
|
expected_freqs = dim // 2
|
|
current_freqs = freqs.shape[-1]
|
|
pad_size = expected_freqs - current_freqs
|
|
cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)
|
|
else:
|
|
# 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only
|
|
n_elem = 2 * indices_grid.shape[1]
|
|
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
|
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
|
|
|
|
|
|
class Attention(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
query_dim: int,
|
|
context_dim: int | None = None,
|
|
heads: int = 8,
|
|
dim_head: int = 64,
|
|
norm_eps: float = 1e-6,
|
|
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
|
apply_gated_attention: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
self.rope_type = rope_type
|
|
|
|
inner_dim = dim_head * heads
|
|
context_dim = query_dim if context_dim is None else context_dim
|
|
|
|
self.heads = heads
|
|
self.dim_head = dim_head
|
|
|
|
self.q_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps)
|
|
self.k_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps)
|
|
|
|
self.to_q = torch.nn.Linear(query_dim, inner_dim, bias=True)
|
|
self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True)
|
|
self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True)
|
|
|
|
# Optional per-head gating
|
|
if apply_gated_attention:
|
|
self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True)
|
|
else:
|
|
self.to_gate_logits = None
|
|
|
|
self.to_out = torch.nn.Sequential(torch.nn.Linear(inner_dim, query_dim, bias=True), torch.nn.Identity())
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
context: torch.Tensor | None = None,
|
|
mask: torch.Tensor | None = None,
|
|
pe: torch.Tensor | None = None,
|
|
k_pe: torch.Tensor | None = None,
|
|
perturbation_mask: torch.Tensor | None = None,
|
|
all_perturbed: bool = False,
|
|
) -> torch.Tensor:
|
|
q = self.to_q(x)
|
|
context = x if context is None else context
|
|
k = self.to_k(context)
|
|
v = self.to_v(context)
|
|
|
|
q = self.q_norm(q)
|
|
k = self.k_norm(k)
|
|
|
|
if pe is not None:
|
|
q = apply_rotary_emb(q, pe, self.rope_type)
|
|
k = apply_rotary_emb(k, pe if k_pe is None else k_pe, self.rope_type)
|
|
|
|
# Reshape for attention_forward using unflatten
|
|
q = q.unflatten(-1, (self.heads, self.dim_head))
|
|
k = k.unflatten(-1, (self.heads, self.dim_head))
|
|
v = v.unflatten(-1, (self.heads, self.dim_head))
|
|
|
|
out = attention_forward(
|
|
q=q,
|
|
k=k,
|
|
v=v,
|
|
q_pattern="b s n d",
|
|
k_pattern="b s n d",
|
|
v_pattern="b s n d",
|
|
out_pattern="b s n d",
|
|
attn_mask=mask
|
|
)
|
|
|
|
# Reshape back to original format
|
|
out = out.flatten(2, 3)
|
|
|
|
# Apply per-head gating if enabled
|
|
if self.to_gate_logits is not None:
|
|
gate_logits = self.to_gate_logits(x) # (B, T, H)
|
|
b, t, _ = out.shape
|
|
# Reshape to (B, T, H, D) for per-head gating
|
|
out = out.view(b, t, self.heads, self.dim_head)
|
|
# Apply gating: 2 * sigmoid(x) so that zero-init gives identity (2 * 0.5 = 1.0)
|
|
gates = 2.0 * torch.sigmoid(gate_logits) # (B, T, H)
|
|
out = out * gates.unsqueeze(-1) # (B, T, H, D) * (B, T, H, 1)
|
|
# Reshape back to (B, T, H*D)
|
|
out = out.view(b, t, self.heads * self.dim_head)
|
|
|
|
return self.to_out(out)
|
|
|
|
|
|
class PixArtAlphaTextProjection(torch.nn.Module):
|
|
"""
|
|
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
|
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
|
"""
|
|
|
|
def __init__(self, in_features: int, hidden_size: int, out_features: int | None = None, act_fn: str = "gelu_tanh"):
|
|
super().__init__()
|
|
if out_features is None:
|
|
out_features = hidden_size
|
|
self.linear_1 = torch.nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
|
|
if act_fn == "gelu_tanh":
|
|
self.act_1 = torch.nn.GELU(approximate="tanh")
|
|
elif act_fn == "silu":
|
|
self.act_1 = torch.nn.SiLU()
|
|
else:
|
|
raise ValueError(f"Unknown activation function: {act_fn}")
|
|
self.linear_2 = torch.nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
|
|
|
|
def forward(self, caption: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.linear_1(caption)
|
|
hidden_states = self.act_1(hidden_states)
|
|
hidden_states = self.linear_2(hidden_states)
|
|
return hidden_states
|
|
|
|
@dataclass(frozen=True)
|
|
class TransformerArgs:
|
|
x: torch.Tensor
|
|
context: torch.Tensor
|
|
context_mask: torch.Tensor
|
|
timesteps: torch.Tensor
|
|
embedded_timestep: torch.Tensor
|
|
positional_embeddings: torch.Tensor
|
|
cross_positional_embeddings: torch.Tensor | None
|
|
cross_scale_shift_timestep: torch.Tensor | None
|
|
cross_gate_timestep: torch.Tensor | None
|
|
enabled: bool
|
|
prompt_timestep: torch.Tensor | None = None
|
|
self_attention_mask: torch.Tensor | None = (
|
|
None # Additive log-space self-attention bias (B, 1, T, T), None = full attention
|
|
)
|
|
|
|
|
|
class TransformerArgsPreprocessor:
|
|
def __init__( # noqa: PLR0913
|
|
self,
|
|
patchify_proj: torch.nn.Linear,
|
|
adaln: AdaLayerNormSingle,
|
|
inner_dim: int,
|
|
max_pos: list[int],
|
|
num_attention_heads: int,
|
|
use_middle_indices_grid: bool,
|
|
timestep_scale_multiplier: int,
|
|
double_precision_rope: bool,
|
|
positional_embedding_theta: float,
|
|
rope_type: LTXRopeType,
|
|
caption_projection: torch.nn.Module | None = None,
|
|
prompt_adaln: AdaLayerNormSingle | None = None,
|
|
) -> None:
|
|
self.patchify_proj = patchify_proj
|
|
self.adaln = adaln
|
|
self.inner_dim = inner_dim
|
|
self.max_pos = max_pos
|
|
self.num_attention_heads = num_attention_heads
|
|
self.use_middle_indices_grid = use_middle_indices_grid
|
|
self.timestep_scale_multiplier = timestep_scale_multiplier
|
|
self.double_precision_rope = double_precision_rope
|
|
self.positional_embedding_theta = positional_embedding_theta
|
|
self.rope_type = rope_type
|
|
self.caption_projection = caption_projection
|
|
self.prompt_adaln = prompt_adaln
|
|
|
|
def _prepare_timestep(
|
|
self, timestep: torch.Tensor, adaln: AdaLayerNormSingle, batch_size: int, hidden_dtype: torch.dtype
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Prepare timestep embeddings."""
|
|
timestep_scaled = timestep * self.timestep_scale_multiplier
|
|
timestep, embedded_timestep = adaln(
|
|
timestep_scaled.flatten(),
|
|
hidden_dtype=hidden_dtype,
|
|
)
|
|
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
|
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
|
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
|
return timestep, embedded_timestep
|
|
|
|
def _prepare_context(
|
|
self,
|
|
context: torch.Tensor,
|
|
x: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Prepare context for transformer blocks."""
|
|
if self.caption_projection is not None:
|
|
context = self.caption_projection(context)
|
|
batch_size = x.shape[0]
|
|
return context.view(batch_size, -1, x.shape[-1])
|
|
|
|
def _prepare_attention_mask(self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype) -> torch.Tensor | None:
|
|
"""Prepare attention mask."""
|
|
if attention_mask is None or torch.is_floating_point(attention_mask):
|
|
return attention_mask
|
|
|
|
return (attention_mask - 1).to(x_dtype).reshape(
|
|
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
|
) * torch.finfo(x_dtype).max
|
|
|
|
def _prepare_self_attention_mask(
|
|
self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype
|
|
) -> torch.Tensor | None:
|
|
"""Prepare self-attention mask by converting [0,1] values to additive log-space bias.
|
|
Input shape: (B, T, T) with values in [0, 1].
|
|
Output shape: (B, 1, T, T) with 0.0 for full attention and a large negative value
|
|
for masked positions.
|
|
Positions with attention_mask <= 0 are fully masked (mapped to the dtype's minimum
|
|
representable value). Strictly positive entries are converted via log-space for
|
|
smooth attenuation, with small values clamped for numerical stability.
|
|
Returns None if input is None (no masking).
|
|
"""
|
|
if attention_mask is None:
|
|
return None
|
|
|
|
# Convert [0, 1] attention mask to additive log-space bias:
|
|
# 1.0 -> log(1.0) = 0.0 (no bias, full attention)
|
|
# 0.0 -> finfo.min (fully masked)
|
|
finfo = torch.finfo(x_dtype)
|
|
eps = finfo.tiny
|
|
|
|
bias = torch.full_like(attention_mask, finfo.min, dtype=x_dtype)
|
|
positive = attention_mask > 0
|
|
if positive.any():
|
|
bias[positive] = torch.log(attention_mask[positive].clamp(min=eps)).to(x_dtype)
|
|
|
|
return bias.unsqueeze(1) # (B, 1, T, T) for head broadcast
|
|
|
|
def _prepare_positional_embeddings(
|
|
self,
|
|
positions: torch.Tensor,
|
|
inner_dim: int,
|
|
max_pos: list[int],
|
|
use_middle_indices_grid: bool,
|
|
num_attention_heads: int,
|
|
x_dtype: torch.dtype,
|
|
) -> torch.Tensor:
|
|
"""Prepare positional embeddings."""
|
|
freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch
|
|
pe = precompute_freqs_cis(
|
|
positions,
|
|
dim=inner_dim,
|
|
out_dtype=x_dtype,
|
|
theta=self.positional_embedding_theta,
|
|
max_pos=max_pos,
|
|
use_middle_indices_grid=use_middle_indices_grid,
|
|
num_attention_heads=num_attention_heads,
|
|
rope_type=self.rope_type,
|
|
freq_grid_generator=freq_grid_generator,
|
|
)
|
|
return pe
|
|
|
|
def prepare(
|
|
self,
|
|
modality: Modality,
|
|
cross_modality: Modality | None = None, # noqa: ARG002
|
|
) -> TransformerArgs:
|
|
x = self.patchify_proj(modality.latent)
|
|
batch_size = x.shape[0]
|
|
timestep, embedded_timestep = self._prepare_timestep(
|
|
modality.timesteps, self.adaln, batch_size, modality.latent.dtype
|
|
)
|
|
prompt_timestep = None
|
|
if self.prompt_adaln is not None:
|
|
prompt_timestep, _ = self._prepare_timestep(
|
|
modality.sigma, self.prompt_adaln, batch_size, modality.latent.dtype
|
|
)
|
|
context = self._prepare_context(modality.context, x)
|
|
attention_mask = self._prepare_attention_mask(modality.context_mask, modality.latent.dtype)
|
|
pe = self._prepare_positional_embeddings(
|
|
positions=modality.positions,
|
|
inner_dim=self.inner_dim,
|
|
max_pos=self.max_pos,
|
|
use_middle_indices_grid=self.use_middle_indices_grid,
|
|
num_attention_heads=self.num_attention_heads,
|
|
x_dtype=modality.latent.dtype,
|
|
)
|
|
self_attention_mask = self._prepare_self_attention_mask(modality.attention_mask, modality.latent.dtype)
|
|
return TransformerArgs(
|
|
x=x,
|
|
context=context,
|
|
context_mask=attention_mask,
|
|
timesteps=timestep,
|
|
embedded_timestep=embedded_timestep,
|
|
positional_embeddings=pe,
|
|
cross_positional_embeddings=None,
|
|
cross_scale_shift_timestep=None,
|
|
cross_gate_timestep=None,
|
|
enabled=modality.enabled,
|
|
prompt_timestep=prompt_timestep,
|
|
self_attention_mask=self_attention_mask,
|
|
)
|
|
|
|
|
|
class MultiModalTransformerArgsPreprocessor:
|
|
def __init__( # noqa: PLR0913
|
|
self,
|
|
patchify_proj: torch.nn.Linear,
|
|
adaln: AdaLayerNormSingle,
|
|
cross_scale_shift_adaln: AdaLayerNormSingle,
|
|
cross_gate_adaln: AdaLayerNormSingle,
|
|
inner_dim: int,
|
|
max_pos: list[int],
|
|
num_attention_heads: int,
|
|
cross_pe_max_pos: int,
|
|
use_middle_indices_grid: bool,
|
|
audio_cross_attention_dim: int,
|
|
timestep_scale_multiplier: int,
|
|
double_precision_rope: bool,
|
|
positional_embedding_theta: float,
|
|
rope_type: LTXRopeType,
|
|
av_ca_timestep_scale_multiplier: int,
|
|
caption_projection: torch.nn.Module | None = None,
|
|
prompt_adaln: AdaLayerNormSingle | None = None,
|
|
) -> None:
|
|
self.simple_preprocessor = TransformerArgsPreprocessor(
|
|
patchify_proj=patchify_proj,
|
|
adaln=adaln,
|
|
inner_dim=inner_dim,
|
|
max_pos=max_pos,
|
|
num_attention_heads=num_attention_heads,
|
|
use_middle_indices_grid=use_middle_indices_grid,
|
|
timestep_scale_multiplier=timestep_scale_multiplier,
|
|
double_precision_rope=double_precision_rope,
|
|
positional_embedding_theta=positional_embedding_theta,
|
|
rope_type=rope_type,
|
|
caption_projection=caption_projection,
|
|
prompt_adaln=prompt_adaln,
|
|
)
|
|
self.cross_scale_shift_adaln = cross_scale_shift_adaln
|
|
self.cross_gate_adaln = cross_gate_adaln
|
|
self.cross_pe_max_pos = cross_pe_max_pos
|
|
self.audio_cross_attention_dim = audio_cross_attention_dim
|
|
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
|
|
|
|
def prepare(
|
|
self,
|
|
modality: Modality,
|
|
cross_modality: Modality | None = None,
|
|
) -> TransformerArgs:
|
|
transformer_args = self.simple_preprocessor.prepare(modality)
|
|
if cross_modality is None:
|
|
return transformer_args
|
|
|
|
if cross_modality.sigma.numel() > 1:
|
|
if cross_modality.sigma.shape[0] != modality.timesteps.shape[0]:
|
|
raise ValueError("Cross modality sigma must have the same batch size as the modality")
|
|
if cross_modality.sigma.ndim != 1:
|
|
raise ValueError("Cross modality sigma must be a 1D tensor")
|
|
|
|
cross_timestep = cross_modality.sigma.view(
|
|
modality.timesteps.shape[0], 1, *[1] * len(modality.timesteps.shape[2:])
|
|
)
|
|
|
|
cross_pe = self.simple_preprocessor._prepare_positional_embeddings(
|
|
positions=modality.positions[:, 0:1, :],
|
|
inner_dim=self.audio_cross_attention_dim,
|
|
max_pos=[self.cross_pe_max_pos],
|
|
use_middle_indices_grid=True,
|
|
num_attention_heads=self.simple_preprocessor.num_attention_heads,
|
|
x_dtype=modality.latent.dtype,
|
|
)
|
|
|
|
cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep(
|
|
timestep=cross_timestep,
|
|
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
|
|
batch_size=transformer_args.x.shape[0],
|
|
hidden_dtype=modality.latent.dtype,
|
|
)
|
|
|
|
return replace(
|
|
transformer_args,
|
|
cross_positional_embeddings=cross_pe,
|
|
cross_scale_shift_timestep=cross_scale_shift_timestep,
|
|
cross_gate_timestep=cross_gate_timestep,
|
|
)
|
|
|
|
def _prepare_cross_attention_timestep(
|
|
self,
|
|
timestep: torch.Tensor | None,
|
|
timestep_scale_multiplier: int,
|
|
batch_size: int,
|
|
hidden_dtype: torch.dtype,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Prepare cross attention timestep embeddings."""
|
|
timestep = timestep * timestep_scale_multiplier
|
|
|
|
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
|
|
|
|
scale_shift_timestep, _ = self.cross_scale_shift_adaln(
|
|
timestep.flatten(),
|
|
hidden_dtype=hidden_dtype,
|
|
)
|
|
scale_shift_timestep = scale_shift_timestep.view(batch_size, -1, scale_shift_timestep.shape[-1])
|
|
gate_noise_timestep, _ = self.cross_gate_adaln(
|
|
timestep.flatten() * av_ca_factor,
|
|
hidden_dtype=hidden_dtype,
|
|
)
|
|
gate_noise_timestep = gate_noise_timestep.view(batch_size, -1, gate_noise_timestep.shape[-1])
|
|
|
|
return scale_shift_timestep, gate_noise_timestep
|
|
|
|
|
|
@dataclass
|
|
class TransformerConfig:
|
|
dim: int
|
|
heads: int
|
|
d_head: int
|
|
context_dim: int
|
|
apply_gated_attention: bool = False
|
|
cross_attention_adaln: bool = False
|
|
|
|
|
|
class BasicAVTransformerBlock(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
idx: int,
|
|
video: TransformerConfig | None = None,
|
|
audio: TransformerConfig | None = None,
|
|
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
|
norm_eps: float = 1e-6,
|
|
):
|
|
super().__init__()
|
|
|
|
self.idx = idx
|
|
if video is not None:
|
|
self.attn1 = Attention(
|
|
query_dim=video.dim,
|
|
heads=video.heads,
|
|
dim_head=video.d_head,
|
|
context_dim=None,
|
|
rope_type=rope_type,
|
|
norm_eps=norm_eps,
|
|
apply_gated_attention=video.apply_gated_attention,
|
|
)
|
|
self.attn2 = Attention(
|
|
query_dim=video.dim,
|
|
context_dim=video.context_dim,
|
|
heads=video.heads,
|
|
dim_head=video.d_head,
|
|
rope_type=rope_type,
|
|
norm_eps=norm_eps,
|
|
apply_gated_attention=video.apply_gated_attention,
|
|
)
|
|
self.ff = FeedForward(video.dim, dim_out=video.dim)
|
|
video_sst_size = adaln_embedding_coefficient(video.cross_attention_adaln)
|
|
self.scale_shift_table = torch.nn.Parameter(torch.empty(video_sst_size, video.dim))
|
|
|
|
if audio is not None:
|
|
self.audio_attn1 = Attention(
|
|
query_dim=audio.dim,
|
|
heads=audio.heads,
|
|
dim_head=audio.d_head,
|
|
context_dim=None,
|
|
rope_type=rope_type,
|
|
norm_eps=norm_eps,
|
|
apply_gated_attention=audio.apply_gated_attention,
|
|
)
|
|
self.audio_attn2 = Attention(
|
|
query_dim=audio.dim,
|
|
context_dim=audio.context_dim,
|
|
heads=audio.heads,
|
|
dim_head=audio.d_head,
|
|
rope_type=rope_type,
|
|
norm_eps=norm_eps,
|
|
apply_gated_attention=audio.apply_gated_attention,
|
|
)
|
|
self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim)
|
|
audio_sst_size = adaln_embedding_coefficient(audio.cross_attention_adaln)
|
|
self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(audio_sst_size, audio.dim))
|
|
|
|
if audio is not None and video is not None:
|
|
# Q: Video, K,V: Audio
|
|
self.audio_to_video_attn = Attention(
|
|
query_dim=video.dim,
|
|
context_dim=audio.dim,
|
|
heads=audio.heads,
|
|
dim_head=audio.d_head,
|
|
rope_type=rope_type,
|
|
norm_eps=norm_eps,
|
|
apply_gated_attention=video.apply_gated_attention,
|
|
)
|
|
|
|
# Q: Audio, K,V: Video
|
|
self.video_to_audio_attn = Attention(
|
|
query_dim=audio.dim,
|
|
context_dim=video.dim,
|
|
heads=audio.heads,
|
|
dim_head=audio.d_head,
|
|
rope_type=rope_type,
|
|
norm_eps=norm_eps,
|
|
apply_gated_attention=audio.apply_gated_attention,
|
|
)
|
|
|
|
self.scale_shift_table_a2v_ca_audio = torch.nn.Parameter(torch.empty(5, audio.dim))
|
|
self.scale_shift_table_a2v_ca_video = torch.nn.Parameter(torch.empty(5, video.dim))
|
|
|
|
self.cross_attention_adaln = (video is not None and video.cross_attention_adaln) or (
|
|
audio is not None and audio.cross_attention_adaln
|
|
)
|
|
|
|
if self.cross_attention_adaln and video is not None:
|
|
self.prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, video.dim))
|
|
if self.cross_attention_adaln and audio is not None:
|
|
self.audio_prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, audio.dim))
|
|
|
|
self.norm_eps = norm_eps
|
|
|
|
def get_ada_values(
|
|
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice
|
|
) -> tuple[torch.Tensor, ...]:
|
|
num_ada_params = scale_shift_table.shape[0]
|
|
|
|
ada_values = (
|
|
scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype)
|
|
+ timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :]
|
|
).unbind(dim=2)
|
|
return ada_values
|
|
|
|
def get_av_ca_ada_values(
|
|
self,
|
|
scale_shift_table: torch.Tensor,
|
|
batch_size: int,
|
|
scale_shift_timestep: torch.Tensor,
|
|
gate_timestep: torch.Tensor,
|
|
scale_shift_indices: slice,
|
|
num_scale_shift_values: int = 4,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
scale_shift_ada_values = self.get_ada_values(
|
|
scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, scale_shift_indices
|
|
)
|
|
gate_ada_values = self.get_ada_values(
|
|
scale_shift_table[num_scale_shift_values:, :], batch_size, gate_timestep, slice(None, None)
|
|
)
|
|
|
|
scale, shift = (t.squeeze(2) for t in scale_shift_ada_values)
|
|
(gate,) = (t.squeeze(2) for t in gate_ada_values)
|
|
|
|
return scale, shift, gate
|
|
|
|
def _apply_text_cross_attention(
|
|
self,
|
|
x: torch.Tensor,
|
|
context: torch.Tensor,
|
|
attn: Attention,
|
|
scale_shift_table: torch.Tensor,
|
|
prompt_scale_shift_table: torch.Tensor | None,
|
|
timestep: torch.Tensor,
|
|
prompt_timestep: torch.Tensor | None,
|
|
context_mask: torch.Tensor | None,
|
|
cross_attention_adaln: bool = False,
|
|
) -> torch.Tensor:
|
|
"""Apply text cross-attention, with optional AdaLN modulation."""
|
|
if cross_attention_adaln:
|
|
shift_q, scale_q, gate = self.get_ada_values(scale_shift_table, x.shape[0], timestep, slice(6, 9))
|
|
return apply_cross_attention_adaln(
|
|
x,
|
|
context,
|
|
attn,
|
|
shift_q,
|
|
scale_q,
|
|
gate,
|
|
prompt_scale_shift_table,
|
|
prompt_timestep,
|
|
context_mask,
|
|
self.norm_eps,
|
|
)
|
|
return attn(rms_norm(x, eps=self.norm_eps), context=context, mask=context_mask)
|
|
|
|
def forward( # noqa: PLR0915
|
|
self,
|
|
video: TransformerArgs | None,
|
|
audio: TransformerArgs | None,
|
|
perturbations: BatchedPerturbationConfig | None = None,
|
|
) -> tuple[TransformerArgs | None, TransformerArgs | None]:
|
|
if video is None and audio is None:
|
|
raise ValueError("At least one of video or audio must be provided")
|
|
|
|
batch_size = (video or audio).x.shape[0]
|
|
|
|
if perturbations is None:
|
|
perturbations = BatchedPerturbationConfig.empty(batch_size)
|
|
|
|
vx = video.x if video is not None else None
|
|
ax = audio.x if audio is not None else None
|
|
|
|
run_vx = video is not None and video.enabled and vx.numel() > 0
|
|
run_ax = audio is not None and audio.enabled and ax.numel() > 0
|
|
|
|
run_a2v = run_vx and (audio is not None and ax.numel() > 0)
|
|
run_v2a = run_ax and (video is not None and vx.numel() > 0)
|
|
|
|
if run_vx:
|
|
vshift_msa, vscale_msa, vgate_msa = self.get_ada_values(
|
|
self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)
|
|
)
|
|
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
|
|
del vshift_msa, vscale_msa
|
|
|
|
all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx)
|
|
none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx)
|
|
v_mask = (
|
|
perturbations.mask_like(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx, vx)
|
|
if not all_perturbed and not none_perturbed
|
|
else None
|
|
)
|
|
vx = (
|
|
vx
|
|
+ self.attn1(
|
|
norm_vx,
|
|
pe=video.positional_embeddings,
|
|
mask=video.self_attention_mask,
|
|
perturbation_mask=v_mask,
|
|
all_perturbed=all_perturbed,
|
|
)
|
|
* vgate_msa
|
|
)
|
|
del vgate_msa, norm_vx, v_mask
|
|
vx = vx + self._apply_text_cross_attention(
|
|
vx,
|
|
video.context,
|
|
self.attn2,
|
|
self.scale_shift_table,
|
|
getattr(self, "prompt_scale_shift_table", None),
|
|
video.timesteps,
|
|
video.prompt_timestep,
|
|
video.context_mask,
|
|
cross_attention_adaln=self.cross_attention_adaln,
|
|
)
|
|
|
|
if run_ax:
|
|
ashift_msa, ascale_msa, agate_msa = self.get_ada_values(
|
|
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
|
|
)
|
|
|
|
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
|
|
del ashift_msa, ascale_msa
|
|
all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx)
|
|
none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx)
|
|
a_mask = (
|
|
perturbations.mask_like(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx, ax)
|
|
if not all_perturbed and not none_perturbed
|
|
else None
|
|
)
|
|
ax = (
|
|
ax
|
|
+ self.audio_attn1(
|
|
norm_ax,
|
|
pe=audio.positional_embeddings,
|
|
mask=audio.self_attention_mask,
|
|
perturbation_mask=a_mask,
|
|
all_perturbed=all_perturbed,
|
|
)
|
|
* agate_msa
|
|
)
|
|
del agate_msa, norm_ax, a_mask
|
|
ax = ax + self._apply_text_cross_attention(
|
|
ax,
|
|
audio.context,
|
|
self.audio_attn2,
|
|
self.audio_scale_shift_table,
|
|
getattr(self, "audio_prompt_scale_shift_table", None),
|
|
audio.timesteps,
|
|
audio.prompt_timestep,
|
|
audio.context_mask,
|
|
cross_attention_adaln=self.cross_attention_adaln,
|
|
)
|
|
|
|
# Audio - Video cross attention.
|
|
if run_a2v or run_v2a:
|
|
vx_norm3 = rms_norm(vx, eps=self.norm_eps)
|
|
ax_norm3 = rms_norm(ax, eps=self.norm_eps)
|
|
|
|
if run_a2v and not perturbations.all_in_batch(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx):
|
|
scale_ca_video_a2v, shift_ca_video_a2v, gate_out_a2v = self.get_av_ca_ada_values(
|
|
self.scale_shift_table_a2v_ca_video,
|
|
vx.shape[0],
|
|
video.cross_scale_shift_timestep,
|
|
video.cross_gate_timestep,
|
|
slice(0, 2),
|
|
)
|
|
vx_scaled = vx_norm3 * (1 + scale_ca_video_a2v) + shift_ca_video_a2v
|
|
del scale_ca_video_a2v, shift_ca_video_a2v
|
|
|
|
scale_ca_audio_a2v, shift_ca_audio_a2v, _ = self.get_av_ca_ada_values(
|
|
self.scale_shift_table_a2v_ca_audio,
|
|
ax.shape[0],
|
|
audio.cross_scale_shift_timestep,
|
|
audio.cross_gate_timestep,
|
|
slice(0, 2),
|
|
)
|
|
ax_scaled = ax_norm3 * (1 + scale_ca_audio_a2v) + shift_ca_audio_a2v
|
|
del scale_ca_audio_a2v, shift_ca_audio_a2v
|
|
a2v_mask = perturbations.mask_like(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx, vx)
|
|
vx = vx + (
|
|
self.audio_to_video_attn(
|
|
vx_scaled,
|
|
context=ax_scaled,
|
|
pe=video.cross_positional_embeddings,
|
|
k_pe=audio.cross_positional_embeddings,
|
|
)
|
|
* gate_out_a2v
|
|
* a2v_mask
|
|
)
|
|
del gate_out_a2v, a2v_mask, vx_scaled, ax_scaled
|
|
|
|
if run_v2a and not perturbations.all_in_batch(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx):
|
|
scale_ca_audio_v2a, shift_ca_audio_v2a, gate_out_v2a = self.get_av_ca_ada_values(
|
|
self.scale_shift_table_a2v_ca_audio,
|
|
ax.shape[0],
|
|
audio.cross_scale_shift_timestep,
|
|
audio.cross_gate_timestep,
|
|
slice(2, 4),
|
|
)
|
|
ax_scaled = ax_norm3 * (1 + scale_ca_audio_v2a) + shift_ca_audio_v2a
|
|
del scale_ca_audio_v2a, shift_ca_audio_v2a
|
|
scale_ca_video_v2a, shift_ca_video_v2a, _ = self.get_av_ca_ada_values(
|
|
self.scale_shift_table_a2v_ca_video,
|
|
vx.shape[0],
|
|
video.cross_scale_shift_timestep,
|
|
video.cross_gate_timestep,
|
|
slice(2, 4),
|
|
)
|
|
vx_scaled = vx_norm3 * (1 + scale_ca_video_v2a) + shift_ca_video_v2a
|
|
del scale_ca_video_v2a, shift_ca_video_v2a
|
|
v2a_mask = perturbations.mask_like(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx, ax)
|
|
ax = ax + (
|
|
self.video_to_audio_attn(
|
|
ax_scaled,
|
|
context=vx_scaled,
|
|
pe=audio.cross_positional_embeddings,
|
|
k_pe=video.cross_positional_embeddings,
|
|
)
|
|
* gate_out_v2a
|
|
* v2a_mask
|
|
)
|
|
del gate_out_v2a, v2a_mask, ax_scaled, vx_scaled
|
|
|
|
del vx_norm3, ax_norm3
|
|
|
|
if run_vx:
|
|
vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(
|
|
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, 6)
|
|
)
|
|
vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
|
|
vx = vx + self.ff(vx_scaled) * vgate_mlp
|
|
|
|
del vshift_mlp, vscale_mlp, vgate_mlp, vx_scaled
|
|
|
|
if run_ax:
|
|
ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(
|
|
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, 6)
|
|
)
|
|
ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
|
|
ax = ax + self.audio_ff(ax_scaled) * agate_mlp
|
|
|
|
del ashift_mlp, ascale_mlp, agate_mlp, ax_scaled
|
|
|
|
return replace(video, x=vx) if video is not None else None, replace(audio, x=ax) if audio is not None else None
|
|
|
|
|
|
def apply_cross_attention_adaln(
|
|
x: torch.Tensor,
|
|
context: torch.Tensor,
|
|
attn: Attention,
|
|
q_shift: torch.Tensor,
|
|
q_scale: torch.Tensor,
|
|
q_gate: torch.Tensor,
|
|
prompt_scale_shift_table: torch.Tensor,
|
|
prompt_timestep: torch.Tensor,
|
|
context_mask: torch.Tensor | None = None,
|
|
norm_eps: float = 1e-6,
|
|
) -> torch.Tensor:
|
|
batch_size = x.shape[0]
|
|
shift_kv, scale_kv = (
|
|
prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)
|
|
+ prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1)
|
|
).unbind(dim=2)
|
|
attn_input = rms_norm(x, eps=norm_eps) * (1 + q_scale) + q_shift
|
|
encoder_hidden_states = context * (1 + scale_kv) + shift_kv
|
|
return attn(attn_input, context=encoder_hidden_states, mask=context_mask) * q_gate
|
|
|
|
|
|
class GELUApprox(torch.nn.Module):
|
|
def __init__(self, dim_in: int, dim_out: int) -> None:
|
|
super().__init__()
|
|
self.proj = torch.nn.Linear(dim_in, dim_out)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return torch.nn.functional.gelu(self.proj(x), approximate="tanh")
|
|
|
|
|
|
class FeedForward(torch.nn.Module):
|
|
def __init__(self, dim: int, dim_out: int, mult: int = 4) -> None:
|
|
super().__init__()
|
|
inner_dim = int(dim * mult)
|
|
project_in = GELUApprox(dim, inner_dim)
|
|
|
|
self.net = torch.nn.Sequential(project_in, torch.nn.Identity(), torch.nn.Linear(inner_dim, dim_out))
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.net(x)
|
|
|
|
|
|
class LTXModelType(Enum):
|
|
AudioVideo = "ltx av model"
|
|
VideoOnly = "ltx video only model"
|
|
AudioOnly = "ltx audio only model"
|
|
|
|
def is_video_enabled(self) -> bool:
|
|
return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly)
|
|
|
|
def is_audio_enabled(self) -> bool:
|
|
return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly)
|
|
|
|
|
|
class LTXModel(torch.nn.Module):
|
|
"""
|
|
LTX model transformer implementation.
|
|
This class implements the transformer blocks for the LTX model.
|
|
"""
|
|
|
|
def __init__( # noqa: PLR0913
|
|
self,
|
|
*,
|
|
model_type: LTXModelType = LTXModelType.AudioVideo,
|
|
num_attention_heads: int = 32,
|
|
attention_head_dim: int = 128,
|
|
in_channels: int = 128,
|
|
out_channels: int = 128,
|
|
num_layers: int = 48,
|
|
cross_attention_dim: int = 4096,
|
|
norm_eps: float = 1e-06,
|
|
caption_channels: int = 3840,
|
|
positional_embedding_theta: float = 10000.0,
|
|
positional_embedding_max_pos: list[int] | None = [20, 2048, 2048],
|
|
timestep_scale_multiplier: int = 1000,
|
|
use_middle_indices_grid: bool = True,
|
|
audio_num_attention_heads: int = 32,
|
|
audio_attention_head_dim: int = 64,
|
|
audio_in_channels: int = 128,
|
|
audio_out_channels: int = 128,
|
|
audio_cross_attention_dim: int = 2048,
|
|
audio_positional_embedding_max_pos: list[int] | None = [20],
|
|
av_ca_timestep_scale_multiplier: int = 1000,
|
|
rope_type: LTXRopeType = LTXRopeType.SPLIT,
|
|
double_precision_rope: bool = True,
|
|
apply_gated_attention: bool = False,
|
|
cross_attention_adaln: bool = False,
|
|
):
|
|
super().__init__()
|
|
self._enable_gradient_checkpointing = False
|
|
self.use_middle_indices_grid = use_middle_indices_grid
|
|
self.rope_type = rope_type
|
|
self.double_precision_rope = double_precision_rope
|
|
self.timestep_scale_multiplier = timestep_scale_multiplier
|
|
self.positional_embedding_theta = positional_embedding_theta
|
|
self.model_type = model_type
|
|
self.cross_attention_adaln = cross_attention_adaln
|
|
cross_pe_max_pos = None
|
|
if model_type.is_video_enabled():
|
|
if positional_embedding_max_pos is None:
|
|
positional_embedding_max_pos = [20, 2048, 2048]
|
|
self.positional_embedding_max_pos = positional_embedding_max_pos
|
|
self.num_attention_heads = num_attention_heads
|
|
self.inner_dim = num_attention_heads * attention_head_dim
|
|
self._init_video(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
caption_channels=caption_channels,
|
|
norm_eps=norm_eps,
|
|
)
|
|
|
|
if model_type.is_audio_enabled():
|
|
if audio_positional_embedding_max_pos is None:
|
|
audio_positional_embedding_max_pos = [20]
|
|
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
|
|
self.audio_num_attention_heads = audio_num_attention_heads
|
|
self.audio_inner_dim = self.audio_num_attention_heads * audio_attention_head_dim
|
|
self._init_audio(
|
|
in_channels=audio_in_channels,
|
|
out_channels=audio_out_channels,
|
|
caption_channels=caption_channels,
|
|
norm_eps=norm_eps,
|
|
)
|
|
|
|
if model_type.is_video_enabled() and model_type.is_audio_enabled():
|
|
cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0])
|
|
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
|
|
self.audio_cross_attention_dim = audio_cross_attention_dim
|
|
self._init_audio_video(num_scale_shift_values=4)
|
|
|
|
self._init_preprocessors(cross_pe_max_pos)
|
|
# Initialize transformer blocks
|
|
self._init_transformer_blocks(
|
|
num_layers=num_layers,
|
|
attention_head_dim=attention_head_dim if model_type.is_video_enabled() else 0,
|
|
cross_attention_dim=cross_attention_dim,
|
|
audio_attention_head_dim=audio_attention_head_dim if model_type.is_audio_enabled() else 0,
|
|
audio_cross_attention_dim=audio_cross_attention_dim,
|
|
norm_eps=norm_eps,
|
|
apply_gated_attention=apply_gated_attention,
|
|
)
|
|
|
|
@property
|
|
def _adaln_embedding_coefficient(self) -> int:
|
|
return adaln_embedding_coefficient(self.cross_attention_adaln)
|
|
|
|
def _init_video(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
caption_channels: int,
|
|
norm_eps: float,
|
|
) -> None:
|
|
"""Initialize video-specific components."""
|
|
# Video input components
|
|
self.patchify_proj = torch.nn.Linear(in_channels, self.inner_dim, bias=True)
|
|
self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=self._adaln_embedding_coefficient)
|
|
self.prompt_adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None
|
|
|
|
# Video caption projection
|
|
if caption_channels is not None:
|
|
self.caption_projection = PixArtAlphaTextProjection(
|
|
in_features=caption_channels,
|
|
hidden_size=self.inner_dim,
|
|
)
|
|
|
|
# Video output components
|
|
self.scale_shift_table = torch.nn.Parameter(torch.empty(2, self.inner_dim))
|
|
self.norm_out = torch.nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=norm_eps)
|
|
self.proj_out = torch.nn.Linear(self.inner_dim, out_channels)
|
|
|
|
def _init_audio(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
caption_channels: int,
|
|
norm_eps: float,
|
|
) -> None:
|
|
"""Initialize audio-specific components."""
|
|
|
|
# Audio input components
|
|
self.audio_patchify_proj = torch.nn.Linear(in_channels, self.audio_inner_dim, bias=True)
|
|
|
|
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=self._adaln_embedding_coefficient)
|
|
self.audio_prompt_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None
|
|
|
|
# Audio caption projection
|
|
if caption_channels is not None:
|
|
self.audio_caption_projection = PixArtAlphaTextProjection(
|
|
in_features=caption_channels,
|
|
hidden_size=self.audio_inner_dim,
|
|
)
|
|
|
|
# Audio output components
|
|
self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(2, self.audio_inner_dim))
|
|
self.audio_norm_out = torch.nn.LayerNorm(self.audio_inner_dim, elementwise_affine=False, eps=norm_eps)
|
|
self.audio_proj_out = torch.nn.Linear(self.audio_inner_dim, out_channels)
|
|
|
|
def _init_audio_video(
|
|
self,
|
|
num_scale_shift_values: int,
|
|
) -> None:
|
|
"""Initialize audio-video cross-attention components."""
|
|
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
|
self.inner_dim,
|
|
embedding_coefficient=num_scale_shift_values,
|
|
)
|
|
|
|
self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
|
|
self.audio_inner_dim,
|
|
embedding_coefficient=num_scale_shift_values,
|
|
)
|
|
|
|
self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
|
|
self.inner_dim,
|
|
embedding_coefficient=1,
|
|
)
|
|
|
|
self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
|
|
self.audio_inner_dim,
|
|
embedding_coefficient=1,
|
|
)
|
|
|
|
def _init_preprocessors(
|
|
self,
|
|
cross_pe_max_pos: int | None = None,
|
|
) -> None:
|
|
"""Initialize preprocessors for LTX."""
|
|
|
|
if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled():
|
|
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
|
patchify_proj=self.patchify_proj,
|
|
adaln=self.adaln_single,
|
|
cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single,
|
|
cross_gate_adaln=self.av_ca_a2v_gate_adaln_single,
|
|
inner_dim=self.inner_dim,
|
|
max_pos=self.positional_embedding_max_pos,
|
|
num_attention_heads=self.num_attention_heads,
|
|
cross_pe_max_pos=cross_pe_max_pos,
|
|
use_middle_indices_grid=self.use_middle_indices_grid,
|
|
audio_cross_attention_dim=self.audio_cross_attention_dim,
|
|
timestep_scale_multiplier=self.timestep_scale_multiplier,
|
|
double_precision_rope=self.double_precision_rope,
|
|
positional_embedding_theta=self.positional_embedding_theta,
|
|
rope_type=self.rope_type,
|
|
av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier,
|
|
caption_projection=getattr(self, "caption_projection", None),
|
|
prompt_adaln=getattr(self, "prompt_adaln_single", None),
|
|
)
|
|
self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
|
patchify_proj=self.audio_patchify_proj,
|
|
adaln=self.audio_adaln_single,
|
|
cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single,
|
|
cross_gate_adaln=self.av_ca_v2a_gate_adaln_single,
|
|
inner_dim=self.audio_inner_dim,
|
|
max_pos=self.audio_positional_embedding_max_pos,
|
|
num_attention_heads=self.audio_num_attention_heads,
|
|
cross_pe_max_pos=cross_pe_max_pos,
|
|
use_middle_indices_grid=self.use_middle_indices_grid,
|
|
audio_cross_attention_dim=self.audio_cross_attention_dim,
|
|
timestep_scale_multiplier=self.timestep_scale_multiplier,
|
|
double_precision_rope=self.double_precision_rope,
|
|
positional_embedding_theta=self.positional_embedding_theta,
|
|
rope_type=self.rope_type,
|
|
av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier,
|
|
caption_projection=getattr(self, "audio_caption_projection", None),
|
|
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
|
|
)
|
|
elif self.model_type.is_video_enabled():
|
|
self.video_args_preprocessor = TransformerArgsPreprocessor(
|
|
patchify_proj=self.patchify_proj,
|
|
adaln=self.adaln_single,
|
|
inner_dim=self.inner_dim,
|
|
max_pos=self.positional_embedding_max_pos,
|
|
num_attention_heads=self.num_attention_heads,
|
|
use_middle_indices_grid=self.use_middle_indices_grid,
|
|
timestep_scale_multiplier=self.timestep_scale_multiplier,
|
|
double_precision_rope=self.double_precision_rope,
|
|
positional_embedding_theta=self.positional_embedding_theta,
|
|
rope_type=self.rope_type,
|
|
caption_projection=getattr(self, "caption_projection", None),
|
|
prompt_adaln=getattr(self, "prompt_adaln_single", None),
|
|
)
|
|
elif self.model_type.is_audio_enabled():
|
|
self.audio_args_preprocessor = TransformerArgsPreprocessor(
|
|
patchify_proj=self.audio_patchify_proj,
|
|
adaln=self.audio_adaln_single,
|
|
inner_dim=self.audio_inner_dim,
|
|
max_pos=self.audio_positional_embedding_max_pos,
|
|
num_attention_heads=self.audio_num_attention_heads,
|
|
use_middle_indices_grid=self.use_middle_indices_grid,
|
|
timestep_scale_multiplier=self.timestep_scale_multiplier,
|
|
double_precision_rope=self.double_precision_rope,
|
|
positional_embedding_theta=self.positional_embedding_theta,
|
|
rope_type=self.rope_type,
|
|
caption_projection=getattr(self, "audio_caption_projection", None),
|
|
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
|
|
)
|
|
|
|
def _init_transformer_blocks(
|
|
self,
|
|
num_layers: int,
|
|
attention_head_dim: int,
|
|
cross_attention_dim: int,
|
|
audio_attention_head_dim: int,
|
|
audio_cross_attention_dim: int,
|
|
norm_eps: float,
|
|
apply_gated_attention: bool,
|
|
) -> None:
|
|
"""Initialize transformer blocks for LTX."""
|
|
video_config = (
|
|
TransformerConfig(
|
|
dim=self.inner_dim,
|
|
heads=self.num_attention_heads,
|
|
d_head=attention_head_dim,
|
|
context_dim=cross_attention_dim,
|
|
apply_gated_attention=apply_gated_attention,
|
|
cross_attention_adaln=self.cross_attention_adaln,
|
|
)
|
|
if self.model_type.is_video_enabled()
|
|
else None
|
|
)
|
|
audio_config = (
|
|
TransformerConfig(
|
|
dim=self.audio_inner_dim,
|
|
heads=self.audio_num_attention_heads,
|
|
d_head=audio_attention_head_dim,
|
|
context_dim=audio_cross_attention_dim,
|
|
apply_gated_attention=apply_gated_attention,
|
|
cross_attention_adaln=self.cross_attention_adaln,
|
|
)
|
|
if self.model_type.is_audio_enabled()
|
|
else None
|
|
)
|
|
self.transformer_blocks = torch.nn.ModuleList(
|
|
[
|
|
BasicAVTransformerBlock(
|
|
idx=idx,
|
|
video=video_config,
|
|
audio=audio_config,
|
|
rope_type=self.rope_type,
|
|
norm_eps=norm_eps,
|
|
)
|
|
for idx in range(num_layers)
|
|
]
|
|
)
|
|
|
|
def set_gradient_checkpointing(self, enable: bool) -> None:
|
|
"""Enable or disable gradient checkpointing for transformer blocks.
|
|
Gradient checkpointing trades compute for memory by recomputing activations
|
|
during the backward pass instead of storing them. This can significantly
|
|
reduce memory usage at the cost of ~20-30% slower training.
|
|
Args:
|
|
enable: Whether to enable gradient checkpointing
|
|
"""
|
|
self._enable_gradient_checkpointing = enable
|
|
|
|
def _process_transformer_blocks(
|
|
self,
|
|
video: TransformerArgs | None,
|
|
audio: TransformerArgs | None,
|
|
perturbations: BatchedPerturbationConfig,
|
|
use_gradient_checkpointing: bool = False,
|
|
use_gradient_checkpointing_offload: bool = False,
|
|
) -> tuple[TransformerArgs, TransformerArgs]:
|
|
"""Process transformer blocks for LTXAV."""
|
|
|
|
# Process transformer blocks
|
|
for block in self.transformer_blocks:
|
|
video, audio = gradient_checkpoint_forward(
|
|
block,
|
|
use_gradient_checkpointing,
|
|
use_gradient_checkpointing_offload,
|
|
video=video,
|
|
audio=audio,
|
|
perturbations=perturbations,
|
|
)
|
|
|
|
return video, audio
|
|
|
|
def _process_output(
|
|
self,
|
|
scale_shift_table: torch.Tensor,
|
|
norm_out: torch.nn.LayerNorm,
|
|
proj_out: torch.nn.Linear,
|
|
x: torch.Tensor,
|
|
embedded_timestep: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Process output for LTXV."""
|
|
# Apply scale-shift modulation
|
|
scale_shift_values = (
|
|
scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
|
)
|
|
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
|
|
|
x = norm_out(x)
|
|
x = x * (1 + scale) + shift
|
|
x = proj_out(x)
|
|
return x
|
|
|
|
def _forward(
|
|
self,
|
|
video: Modality | None,
|
|
audio: Modality | None,
|
|
perturbations: BatchedPerturbationConfig,
|
|
use_gradient_checkpointing: bool = False,
|
|
use_gradient_checkpointing_offload: bool = False,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Forward pass for LTX models.
|
|
Returns:
|
|
Processed output tensors
|
|
"""
|
|
if not self.model_type.is_video_enabled() and video is not None:
|
|
raise ValueError("Video is not enabled for this model")
|
|
if not self.model_type.is_audio_enabled() and audio is not None:
|
|
raise ValueError("Audio is not enabled for this model")
|
|
|
|
video_args = self.video_args_preprocessor.prepare(video, audio) if video is not None else None
|
|
audio_args = self.audio_args_preprocessor.prepare(audio, video) if audio is not None else None
|
|
# Process transformer blocks
|
|
video_out, audio_out = self._process_transformer_blocks(
|
|
video=video_args,
|
|
audio=audio_args,
|
|
perturbations=perturbations,
|
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
|
)
|
|
|
|
# Process output
|
|
vx = (
|
|
self._process_output(
|
|
self.scale_shift_table, self.norm_out, self.proj_out, video_out.x, video_out.embedded_timestep
|
|
)
|
|
if video_out is not None
|
|
else None
|
|
)
|
|
ax = (
|
|
self._process_output(
|
|
self.audio_scale_shift_table,
|
|
self.audio_norm_out,
|
|
self.audio_proj_out,
|
|
audio_out.x,
|
|
audio_out.embedded_timestep,
|
|
)
|
|
if audio_out is not None
|
|
else None
|
|
)
|
|
return vx, ax
|
|
|
|
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps, sigma, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False):
|
|
cross_pe_max_pos = None
|
|
if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled():
|
|
cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0])
|
|
self._init_preprocessors(cross_pe_max_pos)
|
|
video = Modality(video_latents, sigma, video_timesteps, video_positions, video_context)
|
|
audio = Modality(audio_latents, sigma, audio_timesteps, audio_positions, audio_context) if audio_latents is not None else None
|
|
vx, ax = self._forward(video=video, audio=audio, perturbations=None, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload)
|
|
return vx, ax
|