From 00da4b6c4f8db0aa6d639366aae98370f20b2250 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Tue, 27 Jan 2026 19:34:09 +0800 Subject: [PATCH] add video_vae and dit for ltx-2 --- diffsynth/configs/model_configs.py | 26 +- diffsynth/core/attention/attention.py | 2 +- diffsynth/models/ltx2_common.py | 253 +++ diffsynth/models/ltx2_dit.py | 1442 ++++++++++++ diffsynth/models/ltx2_video_vae.py | 1969 +++++++++++++++++ .../utils/state_dict_converters/ltx2_dit.py | 9 + .../state_dict_converters/ltx2_video_vae.py | 22 + diffsynth/utils/test/load_model.py | 22 + 8 files changed, 3743 insertions(+), 2 deletions(-) create mode 100644 diffsynth/models/ltx2_common.py create mode 100644 diffsynth/models/ltx2_dit.py create mode 100644 diffsynth/models/ltx2_video_vae.py create mode 100644 diffsynth/utils/state_dict_converters/ltx2_dit.py create mode 100644 diffsynth/utils/state_dict_converters/ltx2_video_vae.py create mode 100644 diffsynth/utils/test/load_model.py diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index c93f5e9..7809ee2 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -591,4 +591,28 @@ z_image_series = [ }, ] -MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series +ltx2_series = [ + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "aca7b0bbf8415e9c98360750268915fc", + "model_name": "ltx2_dit", + "model_class": "diffsynth.models.ltx2_dit.LTXModel", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "aca7b0bbf8415e9c98360750268915fc", + "model_name": "ltx2_video_vae_encoder", + "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "aca7b0bbf8415e9c98360750268915fc", + "model_name": "ltx2_video_vae_decoder", + "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter", + }, +] + +MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series diff --git a/diffsynth/core/attention/attention.py b/diffsynth/core/attention/attention.py index 15b55a4..630d375 100644 --- a/diffsynth/core/attention/attention.py +++ b/diffsynth/core/attention/attention.py @@ -52,7 +52,7 @@ def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern=" if k_pattern != required_in_pattern: k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims) if v_pattern != required_in_pattern: - v = rearrange(v, f"{q_pattern} -> {required_in_pattern}", **dims) + v = rearrange(v, f"{v_pattern} -> {required_in_pattern}", **dims) return q, k, v diff --git a/diffsynth/models/ltx2_common.py b/diffsynth/models/ltx2_common.py new file mode 100644 index 0000000..fbcae4b --- /dev/null +++ b/diffsynth/models/ltx2_common.py @@ -0,0 +1,253 @@ +from dataclasses import dataclass +from typing import NamedTuple +import torch +from torch import nn + +class VideoPixelShape(NamedTuple): + """ + Shape of the tensor representing the video pixel array. Assumes BGR channel format. + """ + + batch: int + frames: int + height: int + width: int + fps: float + + +class SpatioTemporalScaleFactors(NamedTuple): + """ + Describes the spatiotemporal downscaling between decoded video space and + the corresponding VAE latent grid. + """ + + time: int + width: int + height: int + + @classmethod + def default(cls) -> "SpatioTemporalScaleFactors": + return cls(time=8, width=32, height=32) + + +VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default() + + +class VideoLatentShape(NamedTuple): + """ + Shape of the tensor representing video in VAE latent space. + The latent representation is a 5D tensor with dimensions ordered as + (batch, channels, frames, height, width). Spatial and temporal dimensions + are downscaled relative to pixel space according to the VAE's scale factors. + """ + + batch: int + channels: int + frames: int + height: int + width: int + + def to_torch_shape(self) -> torch.Size: + return torch.Size([self.batch, self.channels, self.frames, self.height, self.width]) + + @staticmethod + def from_torch_shape(shape: torch.Size) -> "VideoLatentShape": + return VideoLatentShape( + batch=shape[0], + channels=shape[1], + frames=shape[2], + height=shape[3], + width=shape[4], + ) + + def mask_shape(self) -> "VideoLatentShape": + return self._replace(channels=1) + + @staticmethod + def from_pixel_shape( + shape: VideoPixelShape, + latent_channels: int = 128, + scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS, + ) -> "VideoLatentShape": + frames = (shape.frames - 1) // scale_factors[0] + 1 + height = shape.height // scale_factors[1] + width = shape.width // scale_factors[2] + + return VideoLatentShape( + batch=shape.batch, + channels=latent_channels, + frames=frames, + height=height, + width=width, + ) + + def upscale(self, scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS) -> "VideoLatentShape": + return self._replace( + channels=3, + frames=(self.frames - 1) * scale_factors.time + 1, + height=self.height * scale_factors.height, + width=self.width * scale_factors.width, + ) + + +class AudioLatentShape(NamedTuple): + """ + Shape of audio in VAE latent space: (batch, channels, frames, mel_bins). + mel_bins is the number of frequency bins from the mel-spectrogram encoding. + """ + + batch: int + channels: int + frames: int + mel_bins: int + + def to_torch_shape(self) -> torch.Size: + return torch.Size([self.batch, self.channels, self.frames, self.mel_bins]) + + def mask_shape(self) -> "AudioLatentShape": + return self._replace(channels=1, mel_bins=1) + + @staticmethod + def from_torch_shape(shape: torch.Size) -> "AudioLatentShape": + return AudioLatentShape( + batch=shape[0], + channels=shape[1], + frames=shape[2], + mel_bins=shape[3], + ) + + @staticmethod + def from_duration( + batch: int, + duration: float, + channels: int = 8, + mel_bins: int = 16, + sample_rate: int = 16000, + hop_length: int = 160, + audio_latent_downsample_factor: int = 4, + ) -> "AudioLatentShape": + latents_per_second = float(sample_rate) / float(hop_length) / float(audio_latent_downsample_factor) + + return AudioLatentShape( + batch=batch, + channels=channels, + frames=round(duration * latents_per_second), + mel_bins=mel_bins, + ) + + @staticmethod + def from_video_pixel_shape( + shape: VideoPixelShape, + channels: int = 8, + mel_bins: int = 16, + sample_rate: int = 16000, + hop_length: int = 160, + audio_latent_downsample_factor: int = 4, + ) -> "AudioLatentShape": + return AudioLatentShape.from_duration( + batch=shape.batch, + duration=float(shape.frames) / float(shape.fps), + channels=channels, + mel_bins=mel_bins, + sample_rate=sample_rate, + hop_length=hop_length, + audio_latent_downsample_factor=audio_latent_downsample_factor, + ) + + +@dataclass(frozen=True) +class LatentState: + """ + State of latents during the diffusion denoising process. + Attributes: + latent: The current noisy latent tensor being denoised. + denoise_mask: Mask encoding the denoising strength for each token (1 = full denoising, 0 = no denoising). + positions: Positional indices for each latent element, used for positional embeddings. + clean_latent: Initial state of the latent before denoising, may include conditioning latents. + """ + + latent: torch.Tensor + denoise_mask: torch.Tensor + positions: torch.Tensor + clean_latent: torch.Tensor + + def clone(self) -> "LatentState": + return LatentState( + latent=self.latent.clone(), + denoise_mask=self.denoise_mask.clone(), + positions=self.positions.clone(), + clean_latent=self.clean_latent.clone(), + ) + + +class PixelNorm(nn.Module): + """ + Per-pixel (per-location) RMS normalization layer. + For each element along the chosen dimension, this layer normalizes the tensor + by the root-mean-square of its values across that dimension: + y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps) + """ + + def __init__(self, dim: int = 1, eps: float = 1e-8) -> None: + """ + Args: + dim: Dimension along which to compute the RMS (typically channels). + eps: Small constant added for numerical stability. + """ + super().__init__() + self.dim = dim + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply RMS normalization along the configured dimension. + """ + # Compute mean of squared values along `dim`, keep dimensions for broadcasting. + mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True) + # Normalize by the root-mean-square (RMS). + rms = torch.sqrt(mean_sq + self.eps) + return x / rms + + +def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor: + """Root-mean-square (RMS) normalize `x` over its last dimension. + Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized + shape and forwards `weight` and `eps`. + """ + return torch.nn.functional.rms_norm(x, (x.shape[-1],), weight=weight, eps=eps) + + +@dataclass(frozen=True) +class Modality: + """ + Input data for a single modality (video or audio) in the transformer. + Bundles the latent tokens, timestep embeddings, positional information, + and text conditioning context for processing by the diffusion transformer. + """ + + latent: ( + torch.Tensor + ) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension + timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps + positions: ( + torch.Tensor + ) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens + context: torch.Tensor + enabled: bool = True + context_mask: torch.Tensor | None = None + + +def to_denoised( + sample: torch.Tensor, + velocity: torch.Tensor, + sigma: float | torch.Tensor, + calc_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Convert the sample and its denoising velocity to denoised sample. + Returns: + Denoised sample + """ + if isinstance(sigma, torch.Tensor): + sigma = sigma.to(calc_dtype) + return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype) diff --git a/diffsynth/models/ltx2_dit.py b/diffsynth/models/ltx2_dit.py new file mode 100644 index 0000000..a4cfd5c --- /dev/null +++ b/diffsynth/models/ltx2_dit.py @@ -0,0 +1,1442 @@ +import math +import functools +from dataclasses import dataclass, replace +from enum import Enum +from typing import Optional, Tuple, Callable +import numpy as np +import torch +from torch._prims_common import DeviceLikeType +from einops import rearrange +from .ltx2_common import rms_norm, Modality +from ..core.attention.attention import attention_forward + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> torch.Tensor: + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class TimestepEmbedding(torch.nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + out_dim: int | None = None, + post_act_fn: str | None = None, + cond_proj_dim: int | None = None, + sample_proj_bias: bool = True, + ): + super().__init__() + + self.linear_1 = torch.nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + + if cond_proj_dim is not None: + self.cond_proj = torch.nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = torch.nn.SiLU() + time_embed_dim_out = out_dim if out_dim is not None else time_embed_dim + + self.linear_2 = torch.nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + + if post_act_fn is None: + self.post_act = None + + def forward(self, sample: torch.Tensor, condition: torch.Tensor | None = None) -> torch.Tensor: + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Timesteps(torch.nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class PixArtAlphaCombinedTimestepSizeEmbeddings(torch.nn.Module): + """ + For PixArt-Alpha. + Reference: + https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 + """ + + def __init__( + self, + embedding_dim: int, + size_emb_dim: int, + ): + super().__init__() + + self.outdim = size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward( + self, + timestep: torch.Tensor, + hidden_dtype: torch.dtype, + ) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + return timesteps_emb + + +class PerturbationType(Enum): + """Types of attention perturbations for STG (Spatio-Temporal Guidance).""" + + SKIP_A2V_CROSS_ATTN = "skip_a2v_cross_attn" + SKIP_V2A_CROSS_ATTN = "skip_v2a_cross_attn" + SKIP_VIDEO_SELF_ATTN = "skip_video_self_attn" + SKIP_AUDIO_SELF_ATTN = "skip_audio_self_attn" + + +@dataclass(frozen=True) +class Perturbation: + """A single perturbation specifying which attention type to skip and in which blocks.""" + + type: PerturbationType + blocks: list[int] | None # None means all blocks + + def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool: + if self.type != perturbation_type: + return False + + if self.blocks is None: + return True + + return block in self.blocks + + +@dataclass(frozen=True) +class PerturbationConfig: + """Configuration holding a list of perturbations for a single sample.""" + + perturbations: list[Perturbation] | None + + def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool: + if self.perturbations is None: + return False + + return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations) + + @staticmethod + def empty() -> "PerturbationConfig": + return PerturbationConfig([]) + + +@dataclass(frozen=True) +class BatchedPerturbationConfig: + """Perturbation configurations for a batch, with utilities for generating attention masks.""" + + perturbations: list[PerturbationConfig] + + def mask( + self, perturbation_type: PerturbationType, block: int, device: DeviceLikeType, dtype: torch.dtype + ) -> torch.Tensor: + mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype) + for batch_idx, perturbation in enumerate(self.perturbations): + if perturbation.is_perturbed(perturbation_type, block): + mask[batch_idx] = 0 + + return mask + + def mask_like(self, perturbation_type: PerturbationType, block: int, values: torch.Tensor) -> torch.Tensor: + mask = self.mask(perturbation_type, block, values.device, values.dtype) + return mask.view(mask.numel(), *([1] * len(values.shape[1:]))) + + def any_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool: + return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations) + + def all_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool: + return all(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations) + + @staticmethod + def empty(batch_size: int) -> "BatchedPerturbationConfig": + return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)]) + + +class AdaLayerNormSingle(torch.nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, embedding_coefficient: int = 6): + super().__init__() + + self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim, + size_emb_dim=embedding_dim // 3, + ) + + self.silu = torch.nn.SiLU() + self.linear = torch.nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + embedded_timestep = self.emb(timestep, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + +class LTXRopeType(Enum): + INTERLEAVED = "interleaved" + SPLIT = "split" + + +def apply_rotary_emb( + input_tensor: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor], + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, +) -> torch.Tensor: + if rope_type == LTXRopeType.INTERLEAVED: + return apply_interleaved_rotary_emb(input_tensor, *freqs_cis) + elif rope_type == LTXRopeType.SPLIT: + return apply_split_rotary_emb(input_tensor, *freqs_cis) + else: + raise ValueError(f"Invalid rope type: {rope_type}") + + + +def apply_interleaved_rotary_emb( + input_tensor: torch.Tensor, cos_freqs: torch.Tensor, sin_freqs: torch.Tensor +) -> torch.Tensor: + t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) + t1, t2 = t_dup.unbind(dim=-1) + t_dup = torch.stack((-t2, t1), dim=-1) + input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") + + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out + + +def apply_split_rotary_emb( + input_tensor: torch.Tensor, cos_freqs: torch.Tensor, sin_freqs: torch.Tensor +) -> torch.Tensor: + needs_reshape = False + if input_tensor.ndim != 4 and cos_freqs.ndim == 4: + b, h, t, _ = cos_freqs.shape + input_tensor = input_tensor.reshape(b, t, h, -1).swapaxes(1, 2) + needs_reshape = True + + split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2) + first_half_input = split_input[..., :1, :] + second_half_input = split_input[..., 1:, :] + + output = split_input * cos_freqs.unsqueeze(-2) + first_half_output = output[..., :1, :] + second_half_output = output[..., 1:, :] + + first_half_output.addcmul_(-sin_freqs.unsqueeze(-2), second_half_input) + second_half_output.addcmul_(sin_freqs.unsqueeze(-2), first_half_input) + + output = rearrange(output, "... d r -> ... (d r)") + if needs_reshape: + output = output.swapaxes(1, 2).reshape(b, t, -1) + + return output + + +@functools.lru_cache(maxsize=5) +def generate_freq_grid_np( + positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int +) -> torch.Tensor: + theta = positional_embedding_theta + start = 1 + end = theta + + n_elem = 2 * positional_embedding_max_pos_count + pow_indices = np.power( + theta, + np.linspace( + np.log(start) / np.log(theta), + np.log(end) / np.log(theta), + inner_dim // n_elem, + dtype=np.float64, + ), + ) + return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32) + + +@functools.lru_cache(maxsize=5) +def generate_freq_grid_pytorch( + positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int +) -> torch.Tensor: + theta = positional_embedding_theta + start = 1 + end = theta + n_elem = 2 * positional_embedding_max_pos_count + + indices = theta ** ( + torch.linspace( + math.log(start, theta), + math.log(end, theta), + inner_dim // n_elem, + dtype=torch.float32, + ) + ) + indices = indices.to(dtype=torch.float32) + + indices = indices * math.pi / 2 + + return indices + + +def get_fractional_positions(indices_grid: torch.Tensor, max_pos: list[int]) -> torch.Tensor: + n_pos_dims = indices_grid.shape[1] + assert n_pos_dims == len(max_pos), ( + f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})" + ) + fractional_positions = torch.stack( + [indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)], + dim=-1, + ) + return fractional_positions + + +def generate_freqs( + indices: torch.Tensor, indices_grid: torch.Tensor, max_pos: list[int], use_middle_indices_grid: bool +) -> torch.Tensor: + if use_middle_indices_grid: + assert len(indices_grid.shape) == 4 + assert indices_grid.shape[-1] == 2 + indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1] + indices_grid = (indices_grid_start + indices_grid_end) / 2.0 + elif len(indices_grid.shape) == 4: + indices_grid = indices_grid[..., 0] + + fractional_positions = get_fractional_positions(indices_grid, max_pos) + indices = indices.to(device=fractional_positions.device) + + freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2) + return freqs + + +def split_freqs_cis(freqs: torch.Tensor, pad_size: int, num_attention_heads: int) -> tuple[torch.Tensor, torch.Tensor]: + cos_freq = freqs.cos() + sin_freq = freqs.sin() + + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) + + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + + # Reshape freqs to be compatible with multi-head attention + b = cos_freq.shape[0] + t = cos_freq.shape[1] + + cos_freq = cos_freq.reshape(b, t, num_attention_heads, -1) + sin_freq = sin_freq.reshape(b, t, num_attention_heads, -1) + + cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) + sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) + return cos_freq, sin_freq + + +def interleaved_freqs_cis(freqs: torch.Tensor, pad_size: int) -> tuple[torch.Tensor, torch.Tensor]: + cos_freq = freqs.cos().repeat_interleave(2, dim=-1) + sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(cos_freq[:, :, :pad_size]) + cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) + sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) + return cos_freq, sin_freq + + +def precompute_freqs_cis( + indices_grid: torch.Tensor, + dim: int, + out_dtype: torch.dtype, + theta: float = 10000.0, + max_pos: list[int] | None = None, + use_middle_indices_grid: bool = False, + num_attention_heads: int = 32, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + freq_grid_generator: Callable[[float, int, int, torch.device], torch.Tensor] = generate_freq_grid_pytorch, +) -> tuple[torch.Tensor, torch.Tensor]: + if max_pos is None: + max_pos = [20, 2048, 2048] + + indices = freq_grid_generator(theta, indices_grid.shape[1], dim) + freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid) + + if rope_type == LTXRopeType.SPLIT: + expected_freqs = dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads) + else: + # 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only + n_elem = 2 * indices_grid.shape[1] + cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) + return cos_freq.to(out_dtype), sin_freq.to(out_dtype) + + +class Attention(torch.nn.Module): + def __init__( + self, + query_dim: int, + context_dim: int | None = None, + heads: int = 8, + dim_head: int = 64, + norm_eps: float = 1e-6, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + ) -> None: + super().__init__() + self.rope_type = rope_type + + inner_dim = dim_head * heads + context_dim = query_dim if context_dim is None else context_dim + + self.heads = heads + self.dim_head = dim_head + + self.q_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps) + self.k_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps) + + self.to_q = torch.nn.Linear(query_dim, inner_dim, bias=True) + self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True) + self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True) + + self.to_out = torch.nn.Sequential(torch.nn.Linear(inner_dim, query_dim, bias=True), torch.nn.Identity()) + + def forward( + self, + x: torch.Tensor, + context: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + pe: torch.Tensor | None = None, + k_pe: torch.Tensor | None = None, + ) -> torch.Tensor: + q = self.to_q(x) + context = x if context is None else context + k = self.to_k(context) + v = self.to_v(context) + + q = self.q_norm(q) + k = self.k_norm(k) + + if pe is not None: + q = apply_rotary_emb(q, pe, self.rope_type) + k = apply_rotary_emb(k, pe if k_pe is None else k_pe, self.rope_type) + + # Reshape for attention_forward using unflatten + q = q.unflatten(-1, (self.heads, self.dim_head)) + k = k.unflatten(-1, (self.heads, self.dim_head)) + v = v.unflatten(-1, (self.heads, self.dim_head)) + + out = attention_forward( + q=q, + k=k, + v=v, + q_pattern="b s n d", + k_pattern="b s n d", + v_pattern="b s n d", + out_pattern="b s n d", + attn_mask=mask + ) + + # Reshape back to original format + out = out.flatten(2, 3) + return self.to_out(out) + + +class PixArtAlphaTextProjection(torch.nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_features: int, hidden_size: int, out_features: int | None = None, act_fn: str = "gelu_tanh"): + super().__init__() + if out_features is None: + out_features = hidden_size + self.linear_1 = torch.nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) + if act_fn == "gelu_tanh": + self.act_1 = torch.nn.GELU(approximate="tanh") + elif act_fn == "silu": + self.act_1 = torch.nn.SiLU() + else: + raise ValueError(f"Unknown activation function: {act_fn}") + self.linear_2 = torch.nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) + + def forward(self, caption: torch.Tensor) -> torch.Tensor: + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +@dataclass(frozen=True) +class TransformerArgs: + x: torch.Tensor + context: torch.Tensor + context_mask: torch.Tensor + timesteps: torch.Tensor + embedded_timestep: torch.Tensor + positional_embeddings: torch.Tensor + cross_positional_embeddings: torch.Tensor | None + cross_scale_shift_timestep: torch.Tensor | None + cross_gate_timestep: torch.Tensor | None + enabled: bool + + + +class TransformerArgsPreprocessor: + def __init__( # noqa: PLR0913 + self, + patchify_proj: torch.nn.Linear, + adaln: AdaLayerNormSingle, + caption_projection: PixArtAlphaTextProjection, + inner_dim: int, + max_pos: list[int], + num_attention_heads: int, + use_middle_indices_grid: bool, + timestep_scale_multiplier: int, + double_precision_rope: bool, + positional_embedding_theta: float, + rope_type: LTXRopeType, + ) -> None: + self.patchify_proj = patchify_proj + self.adaln = adaln + self.caption_projection = caption_projection + self.inner_dim = inner_dim + self.max_pos = max_pos + self.num_attention_heads = num_attention_heads + self.use_middle_indices_grid = use_middle_indices_grid + self.timestep_scale_multiplier = timestep_scale_multiplier + self.double_precision_rope = double_precision_rope + self.positional_embedding_theta = positional_embedding_theta + self.rope_type = rope_type + + def _prepare_timestep( + self, timestep: torch.Tensor, batch_size: int, hidden_dtype: torch.dtype + ) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare timestep embeddings.""" + + timestep = timestep * self.timestep_scale_multiplier + timestep, embedded_timestep = self.adaln( + timestep.flatten(), + hidden_dtype=hidden_dtype, + ) + + # Second dimension is 1 or number of tokens (if timestep_per_token) + timestep = timestep.view(batch_size, -1, timestep.shape[-1]) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) + return timestep, embedded_timestep + + def _prepare_context( + self, + context: torch.Tensor, + x: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Prepare context for transformer blocks.""" + batch_size = x.shape[0] + context = self.caption_projection(context) + context = context.view(batch_size, -1, x.shape[-1]) + + return context, attention_mask + + def _prepare_attention_mask(self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype) -> torch.Tensor | None: + """Prepare attention mask.""" + if attention_mask is None or torch.is_floating_point(attention_mask): + return attention_mask + + return (attention_mask - 1).to(x_dtype).reshape( + (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + ) * torch.finfo(x_dtype).max + + def _prepare_positional_embeddings( + self, + positions: torch.Tensor, + inner_dim: int, + max_pos: list[int], + use_middle_indices_grid: bool, + num_attention_heads: int, + x_dtype: torch.dtype, + ) -> torch.Tensor: + """Prepare positional embeddings.""" + freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch + pe = precompute_freqs_cis( + positions, + dim=inner_dim, + out_dtype=x_dtype, + theta=self.positional_embedding_theta, + max_pos=max_pos, + use_middle_indices_grid=use_middle_indices_grid, + num_attention_heads=num_attention_heads, + rope_type=self.rope_type, + freq_grid_generator=freq_grid_generator, + ) + return pe + + def prepare( + self, + modality: Modality, + ) -> TransformerArgs: + x = self.patchify_proj(modality.latent) + timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], modality.latent.dtype) + context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask) + attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype) + pe = self._prepare_positional_embeddings( + positions=modality.positions, + inner_dim=self.inner_dim, + max_pos=self.max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + num_attention_heads=self.num_attention_heads, + x_dtype=modality.latent.dtype, + ) + return TransformerArgs( + x=x, + context=context, + context_mask=attention_mask, + timesteps=timestep, + embedded_timestep=embedded_timestep, + positional_embeddings=pe, + cross_positional_embeddings=None, + cross_scale_shift_timestep=None, + cross_gate_timestep=None, + enabled=modality.enabled, + ) + + +class MultiModalTransformerArgsPreprocessor: + def __init__( # noqa: PLR0913 + self, + patchify_proj: torch.nn.Linear, + adaln: AdaLayerNormSingle, + caption_projection: PixArtAlphaTextProjection, + cross_scale_shift_adaln: AdaLayerNormSingle, + cross_gate_adaln: AdaLayerNormSingle, + inner_dim: int, + max_pos: list[int], + num_attention_heads: int, + cross_pe_max_pos: int, + use_middle_indices_grid: bool, + audio_cross_attention_dim: int, + timestep_scale_multiplier: int, + double_precision_rope: bool, + positional_embedding_theta: float, + rope_type: LTXRopeType, + av_ca_timestep_scale_multiplier: int, + ) -> None: + self.simple_preprocessor = TransformerArgsPreprocessor( + patchify_proj=patchify_proj, + adaln=adaln, + caption_projection=caption_projection, + inner_dim=inner_dim, + max_pos=max_pos, + num_attention_heads=num_attention_heads, + use_middle_indices_grid=use_middle_indices_grid, + timestep_scale_multiplier=timestep_scale_multiplier, + double_precision_rope=double_precision_rope, + positional_embedding_theta=positional_embedding_theta, + rope_type=rope_type, + ) + self.cross_scale_shift_adaln = cross_scale_shift_adaln + self.cross_gate_adaln = cross_gate_adaln + self.cross_pe_max_pos = cross_pe_max_pos + self.audio_cross_attention_dim = audio_cross_attention_dim + self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier + + def prepare( + self, + modality: Modality, + ) -> TransformerArgs: + transformer_args = self.simple_preprocessor.prepare(modality) + cross_pe = self.simple_preprocessor._prepare_positional_embeddings( + positions=modality.positions[:, 0:1, :], + inner_dim=self.audio_cross_attention_dim, + max_pos=[self.cross_pe_max_pos], + use_middle_indices_grid=True, + num_attention_heads=self.simple_preprocessor.num_attention_heads, + x_dtype=modality.latent.dtype, + ) + + cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep( + timestep=modality.timesteps, + timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier, + batch_size=transformer_args.x.shape[0], + hidden_dtype=modality.latent.dtype, + ) + + return replace( + transformer_args, + cross_positional_embeddings=cross_pe, + cross_scale_shift_timestep=cross_scale_shift_timestep, + cross_gate_timestep=cross_gate_timestep, + ) + + def _prepare_cross_attention_timestep( + self, + timestep: torch.Tensor, + timestep_scale_multiplier: int, + batch_size: int, + hidden_dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare cross attention timestep embeddings.""" + timestep = timestep * timestep_scale_multiplier + + av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier + + scale_shift_timestep, _ = self.cross_scale_shift_adaln( + timestep.flatten(), + hidden_dtype=hidden_dtype, + ) + scale_shift_timestep = scale_shift_timestep.view(batch_size, -1, scale_shift_timestep.shape[-1]) + gate_noise_timestep, _ = self.cross_gate_adaln( + timestep.flatten() * av_ca_factor, + hidden_dtype=hidden_dtype, + ) + gate_noise_timestep = gate_noise_timestep.view(batch_size, -1, gate_noise_timestep.shape[-1]) + + return scale_shift_timestep, gate_noise_timestep + + +@dataclass +class TransformerConfig: + dim: int + heads: int + d_head: int + context_dim: int + + +class BasicAVTransformerBlock(torch.nn.Module): + def __init__( + self, + idx: int, + video: TransformerConfig | None = None, + audio: TransformerConfig | None = None, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + norm_eps: float = 1e-6, + ): + super().__init__() + + self.idx = idx + if video is not None: + self.attn1 = Attention( + query_dim=video.dim, + heads=video.heads, + dim_head=video.d_head, + context_dim=None, + rope_type=rope_type, + norm_eps=norm_eps, + ) + self.attn2 = Attention( + query_dim=video.dim, + context_dim=video.context_dim, + heads=video.heads, + dim_head=video.d_head, + rope_type=rope_type, + norm_eps=norm_eps, + ) + self.ff = FeedForward(video.dim, dim_out=video.dim) + self.scale_shift_table = torch.nn.Parameter(torch.empty(6, video.dim)) + + if audio is not None: + self.audio_attn1 = Attention( + query_dim=audio.dim, + heads=audio.heads, + dim_head=audio.d_head, + context_dim=None, + rope_type=rope_type, + norm_eps=norm_eps, + ) + self.audio_attn2 = Attention( + query_dim=audio.dim, + context_dim=audio.context_dim, + heads=audio.heads, + dim_head=audio.d_head, + rope_type=rope_type, + norm_eps=norm_eps, + ) + self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim) + self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(6, audio.dim)) + + if audio is not None and video is not None: + # Q: Video, K,V: Audio + self.audio_to_video_attn = Attention( + query_dim=video.dim, + context_dim=audio.dim, + heads=audio.heads, + dim_head=audio.d_head, + rope_type=rope_type, + norm_eps=norm_eps, + ) + + # Q: Audio, K,V: Video + self.video_to_audio_attn = Attention( + query_dim=audio.dim, + context_dim=video.dim, + heads=audio.heads, + dim_head=audio.d_head, + rope_type=rope_type, + norm_eps=norm_eps, + ) + + self.scale_shift_table_a2v_ca_audio = torch.nn.Parameter(torch.empty(5, audio.dim)) + self.scale_shift_table_a2v_ca_video = torch.nn.Parameter(torch.empty(5, video.dim)) + + self.norm_eps = norm_eps + + def get_ada_values( + self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice + ) -> tuple[torch.Tensor, ...]: + num_ada_params = scale_shift_table.shape[0] + + ada_values = ( + scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype) + + timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :] + ).unbind(dim=2) + return ada_values + + def get_av_ca_ada_values( + self, + scale_shift_table: torch.Tensor, + batch_size: int, + scale_shift_timestep: torch.Tensor, + gate_timestep: torch.Tensor, + num_scale_shift_values: int = 4, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + scale_shift_ada_values = self.get_ada_values( + scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, slice(None, None) + ) + gate_ada_values = self.get_ada_values( + scale_shift_table[num_scale_shift_values:, :], batch_size, gate_timestep, slice(None, None) + ) + + scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values] + gate_ada_values = [t.squeeze(2) for t in gate_ada_values] + + return (*scale_shift_chunks, *gate_ada_values) + + def forward( # noqa: PLR0915 + self, + video: TransformerArgs | None, + audio: TransformerArgs | None, + perturbations: BatchedPerturbationConfig | None = None, + ) -> tuple[TransformerArgs | None, TransformerArgs | None]: + batch_size = video.x.shape[0] + if perturbations is None: + perturbations = BatchedPerturbationConfig.empty(batch_size) + + vx = video.x if video is not None else None + ax = audio.x if audio is not None else None + + run_vx = video is not None and video.enabled and vx.numel() > 0 + run_ax = audio is not None and audio.enabled and ax.numel() > 0 + + run_a2v = run_vx and (audio is not None and ax.numel() > 0) + run_v2a = run_ax and (video is not None and vx.numel() > 0) + + if run_vx: + vshift_msa, vscale_msa, vgate_msa = self.get_ada_values( + self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3) + ) + if not perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx): + norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa + v_mask = perturbations.mask_like(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx, vx) + vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa * v_mask + + vx = vx + self.attn2(rms_norm(vx, eps=self.norm_eps), context=video.context, mask=video.context_mask) + + del vshift_msa, vscale_msa, vgate_msa + + if run_ax: + ashift_msa, ascale_msa, agate_msa = self.get_ada_values( + self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3) + ) + + if not perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx): + norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa + a_mask = perturbations.mask_like(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx, ax) + ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa * a_mask + + ax = ax + self.audio_attn2(rms_norm(ax, eps=self.norm_eps), context=audio.context, mask=audio.context_mask) + + del ashift_msa, ascale_msa, agate_msa + + # Audio - Video cross attention. + if run_a2v or run_v2a: + vx_norm3 = rms_norm(vx, eps=self.norm_eps) + ax_norm3 = rms_norm(ax, eps=self.norm_eps) + + ( + scale_ca_audio_hidden_states_a2v, + shift_ca_audio_hidden_states_a2v, + scale_ca_audio_hidden_states_v2a, + shift_ca_audio_hidden_states_v2a, + gate_out_v2a, + ) = self.get_av_ca_ada_values( + self.scale_shift_table_a2v_ca_audio, + ax.shape[0], + audio.cross_scale_shift_timestep, + audio.cross_gate_timestep, + ) + + ( + scale_ca_video_hidden_states_a2v, + shift_ca_video_hidden_states_a2v, + scale_ca_video_hidden_states_v2a, + shift_ca_video_hidden_states_v2a, + gate_out_a2v, + ) = self.get_av_ca_ada_values( + self.scale_shift_table_a2v_ca_video, + vx.shape[0], + video.cross_scale_shift_timestep, + video.cross_gate_timestep, + ) + + if run_a2v: + vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v) + shift_ca_video_hidden_states_a2v + ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v + a2v_mask = perturbations.mask_like(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx, vx) + vx = vx + ( + self.audio_to_video_attn( + vx_scaled, + context=ax_scaled, + pe=video.cross_positional_embeddings, + k_pe=audio.cross_positional_embeddings, + ) + * gate_out_a2v + * a2v_mask + ) + + if run_v2a: + ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a + vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a + v2a_mask = perturbations.mask_like(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx, ax) + ax = ax + ( + self.video_to_audio_attn( + ax_scaled, + context=vx_scaled, + pe=audio.cross_positional_embeddings, + k_pe=video.cross_positional_embeddings, + ) + * gate_out_v2a + * v2a_mask + ) + + del gate_out_a2v, gate_out_v2a + del ( + scale_ca_video_hidden_states_a2v, + shift_ca_video_hidden_states_a2v, + scale_ca_audio_hidden_states_a2v, + shift_ca_audio_hidden_states_a2v, + scale_ca_video_hidden_states_v2a, + shift_ca_video_hidden_states_v2a, + scale_ca_audio_hidden_states_v2a, + shift_ca_audio_hidden_states_v2a, + ) + + if run_vx: + vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values( + self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, None) + ) + vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp + vx = vx + self.ff(vx_scaled) * vgate_mlp + + del vshift_mlp, vscale_mlp, vgate_mlp + + if run_ax: + ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values( + self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, None) + ) + ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp + ax = ax + self.audio_ff(ax_scaled) * agate_mlp + + del ashift_mlp, ascale_mlp, agate_mlp + + return replace(video, x=vx) if video is not None else None, replace(audio, x=ax) if audio is not None else None + + +class GELUApprox(torch.nn.Module): + def __init__(self, dim_in: int, dim_out: int) -> None: + super().__init__() + self.proj = torch.nn.Linear(dim_in, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(self.proj(x), approximate="tanh") + + +class FeedForward(torch.nn.Module): + def __init__(self, dim: int, dim_out: int, mult: int = 4) -> None: + super().__init__() + inner_dim = int(dim * mult) + project_in = GELUApprox(dim, inner_dim) + + self.net = torch.nn.Sequential(project_in, torch.nn.Identity(), torch.nn.Linear(inner_dim, dim_out)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class LTXModelType(Enum): + AudioVideo = "ltx av model" + VideoOnly = "ltx video only model" + AudioOnly = "ltx audio only model" + + def is_video_enabled(self) -> bool: + return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly) + + def is_audio_enabled(self) -> bool: + return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly) + + +class LTXModel(torch.nn.Module): + """ + LTX model transformer implementation. + This class implements the transformer blocks for the LTX model. + """ + + def __init__( # noqa: PLR0913 + self, + *, + model_type: LTXModelType = LTXModelType.AudioVideo, + num_attention_heads: int = 32, + attention_head_dim: int = 128, + in_channels: int = 128, + out_channels: int = 128, + num_layers: int = 48, + cross_attention_dim: int = 4096, + norm_eps: float = 1e-06, + caption_channels: int = 3840, + positional_embedding_theta: float = 10000.0, + positional_embedding_max_pos: list[int] | None = [20, 2048, 2048], + timestep_scale_multiplier: int = 1000, + use_middle_indices_grid: bool = True, + audio_num_attention_heads: int = 32, + audio_attention_head_dim: int = 64, + audio_in_channels: int = 128, + audio_out_channels: int = 128, + audio_cross_attention_dim: int = 2048, + audio_positional_embedding_max_pos: list[int] | None = [20], + av_ca_timestep_scale_multiplier: int = 1000, + rope_type: LTXRopeType = LTXRopeType.SPLIT, + double_precision_rope: bool = True, + ): + super().__init__() + self._enable_gradient_checkpointing = False + self.use_middle_indices_grid = use_middle_indices_grid + self.rope_type = rope_type + self.double_precision_rope = double_precision_rope + self.timestep_scale_multiplier = timestep_scale_multiplier + self.positional_embedding_theta = positional_embedding_theta + self.model_type = model_type + cross_pe_max_pos = None + if model_type.is_video_enabled(): + if positional_embedding_max_pos is None: + positional_embedding_max_pos = [20, 2048, 2048] + self.positional_embedding_max_pos = positional_embedding_max_pos + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self._init_video( + in_channels=in_channels, + out_channels=out_channels, + caption_channels=caption_channels, + norm_eps=norm_eps, + ) + + if model_type.is_audio_enabled(): + if audio_positional_embedding_max_pos is None: + audio_positional_embedding_max_pos = [20] + self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos + self.audio_num_attention_heads = audio_num_attention_heads + self.audio_inner_dim = self.audio_num_attention_heads * audio_attention_head_dim + self._init_audio( + in_channels=audio_in_channels, + out_channels=audio_out_channels, + caption_channels=caption_channels, + norm_eps=norm_eps, + ) + + if model_type.is_video_enabled() and model_type.is_audio_enabled(): + cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]) + self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier + self.audio_cross_attention_dim = audio_cross_attention_dim + self._init_audio_video(num_scale_shift_values=4) + + self._init_preprocessors(cross_pe_max_pos) + # Initialize transformer blocks + self._init_transformer_blocks( + num_layers=num_layers, + attention_head_dim=attention_head_dim if model_type.is_video_enabled() else 0, + cross_attention_dim=cross_attention_dim, + audio_attention_head_dim=audio_attention_head_dim if model_type.is_audio_enabled() else 0, + audio_cross_attention_dim=audio_cross_attention_dim, + norm_eps=norm_eps, + ) + + def _init_video( + self, + in_channels: int, + out_channels: int, + caption_channels: int, + norm_eps: float, + ) -> None: + """Initialize video-specific components.""" + # Video input components + self.patchify_proj = torch.nn.Linear(in_channels, self.inner_dim, bias=True) + + self.adaln_single = AdaLayerNormSingle(self.inner_dim) + + # Video caption projection + self.caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, + hidden_size=self.inner_dim, + ) + + # Video output components + self.scale_shift_table = torch.nn.Parameter(torch.empty(2, self.inner_dim)) + self.norm_out = torch.nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=norm_eps) + self.proj_out = torch.nn.Linear(self.inner_dim, out_channels) + + def _init_audio( + self, + in_channels: int, + out_channels: int, + caption_channels: int, + norm_eps: float, + ) -> None: + """Initialize audio-specific components.""" + + # Audio input components + self.audio_patchify_proj = torch.nn.Linear(in_channels, self.audio_inner_dim, bias=True) + + self.audio_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + ) + + # Audio caption projection + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, + hidden_size=self.audio_inner_dim, + ) + + # Audio output components + self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(2, self.audio_inner_dim)) + self.audio_norm_out = torch.nn.LayerNorm(self.audio_inner_dim, elementwise_affine=False, eps=norm_eps) + self.audio_proj_out = torch.nn.Linear(self.audio_inner_dim, out_channels) + + def _init_audio_video( + self, + num_scale_shift_values: int, + ) -> None: + """Initialize audio-video cross-attention components.""" + self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=num_scale_shift_values, + ) + + self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + embedding_coefficient=num_scale_shift_values, + ) + + self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=1, + ) + + self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + embedding_coefficient=1, + ) + + def _init_preprocessors( + self, + cross_pe_max_pos: int | None = None, + ) -> None: + """Initialize preprocessors for LTX.""" + + if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled(): + self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor( + patchify_proj=self.patchify_proj, + adaln=self.adaln_single, + caption_projection=self.caption_projection, + cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single, + cross_gate_adaln=self.av_ca_a2v_gate_adaln_single, + inner_dim=self.inner_dim, + max_pos=self.positional_embedding_max_pos, + num_attention_heads=self.num_attention_heads, + cross_pe_max_pos=cross_pe_max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + audio_cross_attention_dim=self.audio_cross_attention_dim, + timestep_scale_multiplier=self.timestep_scale_multiplier, + double_precision_rope=self.double_precision_rope, + positional_embedding_theta=self.positional_embedding_theta, + rope_type=self.rope_type, + av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier, + ) + self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor( + patchify_proj=self.audio_patchify_proj, + adaln=self.audio_adaln_single, + caption_projection=self.audio_caption_projection, + cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single, + cross_gate_adaln=self.av_ca_v2a_gate_adaln_single, + inner_dim=self.audio_inner_dim, + max_pos=self.audio_positional_embedding_max_pos, + num_attention_heads=self.audio_num_attention_heads, + cross_pe_max_pos=cross_pe_max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + audio_cross_attention_dim=self.audio_cross_attention_dim, + timestep_scale_multiplier=self.timestep_scale_multiplier, + double_precision_rope=self.double_precision_rope, + positional_embedding_theta=self.positional_embedding_theta, + rope_type=self.rope_type, + av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier, + ) + elif self.model_type.is_video_enabled(): + self.video_args_preprocessor = TransformerArgsPreprocessor( + patchify_proj=self.patchify_proj, + adaln=self.adaln_single, + caption_projection=self.caption_projection, + inner_dim=self.inner_dim, + max_pos=self.positional_embedding_max_pos, + num_attention_heads=self.num_attention_heads, + use_middle_indices_grid=self.use_middle_indices_grid, + timestep_scale_multiplier=self.timestep_scale_multiplier, + double_precision_rope=self.double_precision_rope, + positional_embedding_theta=self.positional_embedding_theta, + rope_type=self.rope_type, + ) + elif self.model_type.is_audio_enabled(): + self.audio_args_preprocessor = TransformerArgsPreprocessor( + patchify_proj=self.audio_patchify_proj, + adaln=self.audio_adaln_single, + caption_projection=self.audio_caption_projection, + inner_dim=self.audio_inner_dim, + max_pos=self.audio_positional_embedding_max_pos, + num_attention_heads=self.audio_num_attention_heads, + use_middle_indices_grid=self.use_middle_indices_grid, + timestep_scale_multiplier=self.timestep_scale_multiplier, + double_precision_rope=self.double_precision_rope, + positional_embedding_theta=self.positional_embedding_theta, + rope_type=self.rope_type, + ) + + def _init_transformer_blocks( + self, + num_layers: int, + attention_head_dim: int, + cross_attention_dim: int, + audio_attention_head_dim: int, + audio_cross_attention_dim: int, + norm_eps: float, + ) -> None: + """Initialize transformer blocks for LTX.""" + video_config = ( + TransformerConfig( + dim=self.inner_dim, + heads=self.num_attention_heads, + d_head=attention_head_dim, + context_dim=cross_attention_dim, + ) + if self.model_type.is_video_enabled() + else None + ) + audio_config = ( + TransformerConfig( + dim=self.audio_inner_dim, + heads=self.audio_num_attention_heads, + d_head=audio_attention_head_dim, + context_dim=audio_cross_attention_dim, + ) + if self.model_type.is_audio_enabled() + else None + ) + self.transformer_blocks = torch.nn.ModuleList( + [ + BasicAVTransformerBlock( + idx=idx, + video=video_config, + audio=audio_config, + rope_type=self.rope_type, + norm_eps=norm_eps, + ) + for idx in range(num_layers) + ] + ) + + def set_gradient_checkpointing(self, enable: bool) -> None: + """Enable or disable gradient checkpointing for transformer blocks. + Gradient checkpointing trades compute for memory by recomputing activations + during the backward pass instead of storing them. This can significantly + reduce memory usage at the cost of ~20-30% slower training. + Args: + enable: Whether to enable gradient checkpointing + """ + self._enable_gradient_checkpointing = enable + + def _process_transformer_blocks( + self, + video: TransformerArgs | None, + audio: TransformerArgs | None, + perturbations: BatchedPerturbationConfig, + ) -> tuple[TransformerArgs, TransformerArgs]: + """Process transformer blocks for LTXAV.""" + + # Process transformer blocks + for block in self.transformer_blocks: + if self._enable_gradient_checkpointing and self.training: + # Use gradient checkpointing to save memory during training. + # With use_reentrant=False, we can pass dataclasses directly - + # PyTorch will track all tensor leaves in the computation graph. + video, audio = torch.utils.checkpoint.checkpoint( + block, + video, + audio, + perturbations, + use_reentrant=False, + ) + else: + video, audio = block( + video=video, + audio=audio, + perturbations=perturbations, + ) + + return video, audio + + def _process_output( + self, + scale_shift_table: torch.Tensor, + norm_out: torch.nn.LayerNorm, + proj_out: torch.nn.Linear, + x: torch.Tensor, + embedded_timestep: torch.Tensor, + ) -> torch.Tensor: + """Process output for LTXV.""" + # Apply scale-shift modulation + scale_shift_values = ( + scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + + x = norm_out(x) + x = x * (1 + scale) + shift + x = proj_out(x) + return x + + def forward( + self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for LTX models. + Returns: + Processed output tensors + """ + if not self.model_type.is_video_enabled() and video is not None: + raise ValueError("Video is not enabled for this model") + if not self.model_type.is_audio_enabled() and audio is not None: + raise ValueError("Audio is not enabled for this model") + + video_args = self.video_args_preprocessor.prepare(video) if video is not None else None + audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None + # Process transformer blocks + video_out, audio_out = self._process_transformer_blocks( + video=video_args, + audio=audio_args, + perturbations=perturbations, + ) + + # Process output + vx = ( + self._process_output( + self.scale_shift_table, self.norm_out, self.proj_out, video_out.x, video_out.embedded_timestep + ) + if video_out is not None + else None + ) + ax = ( + self._process_output( + self.audio_scale_shift_table, + self.audio_norm_out, + self.audio_proj_out, + audio_out.x, + audio_out.embedded_timestep, + ) + if audio_out is not None + else None + ) + return vx, ax diff --git a/diffsynth/models/ltx2_video_vae.py b/diffsynth/models/ltx2_video_vae.py new file mode 100644 index 0000000..34db4fe --- /dev/null +++ b/diffsynth/models/ltx2_video_vae.py @@ -0,0 +1,1969 @@ +import itertools +import math +from dataclasses import replace, dataclass +from typing import Any, Callable, Iterator, List, NamedTuple, Tuple, Union, Optional +import torch +from einops import rearrange +from torch import nn +from torch.nn import functional as F +from enum import Enum +from .ltx2_common import PixelNorm, SpatioTemporalScaleFactors, VideoLatentShape +from .ltx2_dit import PixArtAlphaCombinedTimestepSizeEmbeddings + + +class NormLayerType(Enum): + GROUP_NORM = "group_norm" + PIXEL_NORM = "pixel_norm" + + +class LogVarianceType(Enum): + PER_CHANNEL = "per_channel" + UNIFORM = "uniform" + CONSTANT = "constant" + NONE = "none" + + +class PaddingModeType(Enum): + ZEROS = "zeros" + REFLECT = "reflect" + REPLICATE = "replicate" + CIRCULAR = "circular" + + +class DualConv3d(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + ) -> None: + super(DualConv3d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.padding_mode = padding_mode + # Ensure kernel_size, stride, padding, and dilation are tuples of length 3 + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + if kernel_size == (1, 1, 1): + raise ValueError("kernel_size must be greater than 1. Use make_linear_nd instead.") + if isinstance(stride, int): + stride = (stride, stride, stride) + if isinstance(padding, int): + padding = (padding, padding, padding) + if isinstance(dilation, int): + dilation = (dilation, dilation, dilation) + + # Set parameters for convolutions + self.groups = groups + self.bias = bias + + # Define the size of the channels after the first convolution + intermediate_channels = out_channels if in_channels < out_channels else in_channels + + # Define parameters for the first convolution + self.weight1 = nn.Parameter( + torch.Tensor( + intermediate_channels, + in_channels // groups, + 1, + kernel_size[1], + kernel_size[2], + )) + self.stride1 = (1, stride[1], stride[2]) + self.padding1 = (0, padding[1], padding[2]) + self.dilation1 = (1, dilation[1], dilation[2]) + if bias: + self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels)) + else: + self.register_parameter("bias1", None) + + # Define parameters for the second convolution + self.weight2 = nn.Parameter(torch.Tensor(out_channels, intermediate_channels // groups, kernel_size[0], 1, 1)) + self.stride2 = (stride[0], 1, 1) + self.padding2 = (padding[0], 0, 0) + self.dilation2 = (dilation[0], 1, 1) + if bias: + self.bias2 = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter("bias2", None) + + # Initialize weights and biases + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.kaiming_uniform_(self.weight1, a=torch.sqrt(5)) + nn.init.kaiming_uniform_(self.weight2, a=torch.sqrt(5)) + if self.bias: + fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1) + bound1 = 1 / torch.sqrt(fan_in1) + nn.init.uniform_(self.bias1, -bound1, bound1) + fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2) + bound2 = 1 / torch.sqrt(fan_in2) + nn.init.uniform_(self.bias2, -bound2, bound2) + + def forward( + self, + x: torch.Tensor, + use_conv3d: bool = False, + skip_time_conv: bool = False, + ) -> torch.Tensor: + if use_conv3d: + return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv) + else: + return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv) + + def forward_with_3d(self, x: torch.Tensor, skip_time_conv: bool = False) -> torch.Tensor: + # First convolution + x = F.conv3d( + x, + self.weight1, + self.bias1, + self.stride1, + self.padding1, + self.dilation1, + self.groups, + padding_mode=self.padding_mode, + ) + + if skip_time_conv: + return x + + # Second convolution + x = F.conv3d( + x, + self.weight2, + self.bias2, + self.stride2, + self.padding2, + self.dilation2, + self.groups, + padding_mode=self.padding_mode, + ) + + return x + + def forward_with_2d(self, x: torch.Tensor, skip_time_conv: bool = False) -> torch.Tensor: + b, _, _, h, w = x.shape + + # First 2D convolution + x = rearrange(x, "b c d h w -> (b d) c h w") + # Squeeze the depth dimension out of weight1 since it's 1 + weight1 = self.weight1.squeeze(2) + # Select stride, padding, and dilation for the 2D convolution + stride1 = (self.stride1[1], self.stride1[2]) + padding1 = (self.padding1[1], self.padding1[2]) + dilation1 = (self.dilation1[1], self.dilation1[2]) + x = F.conv2d( + x, + weight1, + self.bias1, + stride1, + padding1, + dilation1, + self.groups, + padding_mode=self.padding_mode, + ) + + _, _, h, w = x.shape + + if skip_time_conv: + x = rearrange(x, "(b d) c h w -> b c d h w", b=b) + return x + + # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension + x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b) + + # Reshape weight2 to match the expected dimensions for conv1d + weight2 = self.weight2.squeeze(-1).squeeze(-1) + # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution + stride2 = self.stride2[0] + padding2 = self.padding2[0] + dilation2 = self.dilation2[0] + x = F.conv1d( + x, + weight2, + self.bias2, + stride2, + padding2, + dilation2, + self.groups, + padding_mode=self.padding_mode, + ) + x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) + + return x + + @property + def weight(self) -> torch.Tensor: + return self.weight2 + + +class CausalConv3d(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: Union[int, Tuple[int]] = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ) -> None: + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + kernel_size = (kernel_size, kernel_size, kernel_size) + self.time_kernel_size = kernel_size[0] + + dilation = (dilation, 1, 1) + + height_pad = kernel_size[1] // 2 + width_pad = kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + padding_mode=spatial_padding_mode.value, + groups=groups, + bias=bias, + ) + + def forward(self, x: torch.Tensor, causal: bool = True) -> torch.Tensor: + if causal: + first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) + x = torch.concatenate((first_frame_pad, x), dim=2) + else: + first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)) + last_frame_pad = x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)) + x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) + x = self.conv(x) + return x + + @property + def weight(self) -> torch.Tensor: + return self.conv.weight + + +def make_conv_nd( # noqa: PLR0913 + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causal: bool = False, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + temporal_padding_mode: PaddingModeType = PaddingModeType.ZEROS, +) -> nn.Module: + if not (spatial_padding_mode == temporal_padding_mode or causal): + raise NotImplementedError("spatial and temporal padding modes must be equal") + if dims == 2: + return nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=spatial_padding_mode.value, + ) + elif dims == 3: + if causal: + return CausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + bias=bias, + spatial_padding_mode=spatial_padding_mode, + ) + return nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=spatial_padding_mode.value, + ) + elif dims == (2, 1): + return DualConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + padding_mode=spatial_padding_mode.value, + ) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +def make_linear_nd( + dims: int, + in_channels: int, + out_channels: int, + bias: bool = True, +) -> nn.Module: + if dims == 2: + return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias) + elif dims in (3, (2, 1)): + return nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +def patchify(x: torch.Tensor, patch_size_hw: int, patch_size_t: int = 1) -> torch.Tensor: + """ + Rearrange spatial dimensions into channels. Divides image into patch_size x patch_size blocks + and moves pixels from each block into separate channels (space-to-depth). + Args: + x: Input tensor (4D or 5D) + patch_size_hw: Spatial patch size for height and width. With patch_size_hw=4, divides HxW into 4x4 blocks. + patch_size_t: Temporal patch size for frames. Default=1 (no temporal patching). + For 5D: (B, C, F, H, W) -> (B, Cx(patch_size_hw^2)x(patch_size_t), F/patch_size_t, H/patch_size_hw, W/patch_size_hw) + Example: (B, 3, 33, 512, 512) with patch_size_hw=4, patch_size_t=1 -> (B, 48, 33, 128, 128) + """ + if patch_size_hw == 1 and patch_size_t == 1: + return x + if x.dim() == 4: + x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw) + elif x.dim() == 5: + x = rearrange( + x, + "b c (f p) (h q) (w r) -> b (c p r q) f h w", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x: torch.Tensor, patch_size_hw: int, patch_size_t: int = 1) -> torch.Tensor: + """ + Rearrange channels back into spatial dimensions. Inverse of patchify - moves pixels from + channels back into patch_size x patch_size blocks (depth-to-space). + Args: + x: Input tensor (4D or 5D) + patch_size_hw: Spatial patch size for height and width. With patch_size_hw=4, expands HxW by 4x. + patch_size_t: Temporal patch size for frames. Default=1 (no temporal expansion). + For 5D: (B, Cx(patch_size_hw^2)x(patch_size_t), F, H, W) -> (B, C, Fxpatch_size_t, Hxpatch_size_hw, Wxpatch_size_hw) + Example: (B, 48, 33, 128, 128) with patch_size_hw=4, patch_size_t=1 -> (B, 3, 33, 512, 512) + """ + if patch_size_hw == 1 and patch_size_t == 1: + return x + + if x.dim() == 4: + x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw) + elif x.dim() == 5: + x = rearrange( + x, + "b (c p r q) f h w -> b c (f p) (h q) (w r)", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + + return x + + +class PerChannelStatistics(nn.Module): + """ + Per-channel statistics for normalizing and denormalizing the latent representation. + This statics is computed over the entire dataset and stored in model's checkpoint under VAE state_dict. + """ + + def __init__(self, latent_channels: int = 128): + super().__init__() + self.register_buffer("std-of-means", torch.empty(latent_channels)) + self.register_buffer("mean-of-means", torch.empty(latent_channels)) + self.register_buffer("mean-of-stds", torch.empty(latent_channels)) + self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(latent_channels)) + self.register_buffer("channel", torch.empty(latent_channels)) + + def un_normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view( + 1, -1, 1, 1, 1).to(x) + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view( + 1, -1, 1, 1, 1).to(x) + + +class ResnetBlock3D(nn.Module): + r""" + A Resnet block. + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + norm_layer: NormLayerType = NormLayerType.PIXEL_NORM, + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.inject_noise = inject_noise + + if norm_layer == NormLayerType.GROUP_NORM: + self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + elif norm_layer == NormLayerType.PIXEL_NORM: + self.norm1 = PixelNorm() + + self.non_linearity = nn.SiLU() + + self.conv1 = make_conv_nd( + dims, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + if norm_layer == NormLayerType.GROUP_NORM: + self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) + elif norm_layer == NormLayerType.PIXEL_NORM: + self.norm2 = PixelNorm() + + self.dropout = torch.nn.Dropout(dropout) + + self.conv2 = make_conv_nd( + dims, + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + if inject_noise: + self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + self.conv_shortcut = (make_linear_nd(dims=dims, in_channels=in_channels, out_channels=out_channels) + if in_channels != out_channels else nn.Identity()) + + # Using GroupNorm with 1 group is equivalent to LayerNorm but works with (B, C, ...) layout + # avoiding the need for dimension rearrangement used in standard nn.LayerNorm + self.norm3 = (nn.GroupNorm(num_groups=1, num_channels=in_channels, eps=eps, affine=True) + if in_channels != out_channels else nn.Identity()) + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.scale_shift_table = nn.Parameter(torch.zeros(4, in_channels)) + + def _feed_spatial_noise( + self, + hidden_states: torch.Tensor, + per_channel_scale: torch.Tensor, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + spatial_shape = hidden_states.shape[-2:] + device = hidden_states.device + dtype = hidden_states.dtype + + # similar to the "explicit noise inputs" method in style-gan + spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype, generator=generator)[None] + scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...] + hidden_states = hidden_states + scaled_noise + + return hidden_states + + def forward( + self, + input_tensor: torch.Tensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + hidden_states = input_tensor + batch_size = hidden_states.shape[0] + + hidden_states = self.norm1(hidden_states) + if self.timestep_conditioning: + if timestep is None: + raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True") + ada_values = self.scale_shift_table[None, ..., None, None, None].to( + device=hidden_states.device, dtype=hidden_states.dtype) + timestep.reshape( + batch_size, + 4, + -1, + timestep.shape[-3], + timestep.shape[-2], + timestep.shape[-1], + ) + shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1) + + hidden_states = hidden_states * (1 + scale1) + shift1 + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.conv1(hidden_states, causal=causal) + + if self.inject_noise: + hidden_states = self._feed_spatial_noise( + hidden_states, + self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype), + generator=generator, + ) + + hidden_states = self.norm2(hidden_states) + + if self.timestep_conditioning: + hidden_states = hidden_states * (1 + scale2) + shift2 + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + + hidden_states = self.conv2(hidden_states, causal=causal) + + if self.inject_noise: + hidden_states = self._feed_spatial_noise( + hidden_states, + self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype), + generator=generator, + ) + + input_tensor = self.norm3(input_tensor) + + batch_size = input_tensor.shape[0] + + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +class UNetMidBlock3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + Args: + in_channels (`int`): The number of input channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + inject_noise (`bool`, *optional*, defaults to `False`): + Whether to inject noise into the hidden states. + timestep_conditioning (`bool`, *optional*, defaults to `False`): + Whether to condition the hidden states on the timestep. + Returns: + `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + norm_layer: NormLayerType = NormLayerType.GROUP_NORM, + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(embedding_dim=in_channels * 4, + size_emb_dim=0) + + self.res_blocks = nn.ModuleList([ + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) for _ in range(num_layers) + ]) + + def forward( + self, + hidden_states: torch.Tensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + timestep_embed = None + if self.timestep_conditioning: + if timestep is None: + raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True") + batch_size = hidden_states.shape[0] + timestep_embed = self.time_embedder( + timestep=timestep.flatten(), + hidden_dtype=hidden_states.dtype, + ) + timestep_embed = timestep_embed.view(batch_size, timestep_embed.shape[-1], 1, 1, 1) + + for resnet in self.res_blocks: + hidden_states = resnet( + hidden_states, + causal=causal, + timestep=timestep_embed, + generator=generator, + ) + + return hidden_states + + +class SpaceToDepthDownsample(nn.Module): + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + stride: Tuple[int, int, int], + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + self.stride = stride + self.group_size = in_channels * math.prod(stride) // out_channels + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=out_channels // math.prod(stride), + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + def forward( + self, + x: torch.Tensor, + causal: bool = True, + ) -> torch.Tensor: + if self.stride[0] == 2: + x = torch.cat([x[:, :, :1, :, :], x], dim=2) # duplicate first frames for padding + + # skip connection + x_in = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size) + x_in = x_in.mean(dim=2) + + # conv + x = self.conv(x, causal=causal) + x = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + + x = x + x_in + + return x + + +class DepthToSpaceUpsample(nn.Module): + + def __init__( + self, + dims: int | Tuple[int, int], + in_channels: int, + stride: Tuple[int, int, int], + residual: bool = False, + out_channels_reduction_factor: int = 1, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + self.stride = stride + self.out_channels = math.prod(stride) * in_channels // out_channels_reduction_factor + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + self.residual = residual + self.out_channels_reduction_factor = out_channels_reduction_factor + + def forward( + self, + x: torch.Tensor, + causal: bool = True, + ) -> torch.Tensor: + if self.residual: + # Reshape and duplicate the input to match the output shape + x_in = rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor + x_in = x_in.repeat(1, num_repeat, 1, 1, 1) + if self.stride[0] == 2: + x_in = x_in[:, :, 1:, :, :] + x = self.conv(x, causal=causal) + x = rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + if self.stride[0] == 2: + x = x[:, :, 1:, :, :] + if self.residual: + x = x + x_in + return x + + +def compute_trapezoidal_mask_1d( + length: int, + ramp_left: int, + ramp_right: int, + left_starts_from_0: bool = False, +) -> torch.Tensor: + """ + Generate a 1D trapezoidal blending mask with linear ramps. + Args: + length: Output length of the mask. + ramp_left: Fade-in length on the left. + ramp_right: Fade-out length on the right. + left_starts_from_0: Whether the ramp starts from 0 or first non-zero value. + Useful for temporal tiles where the first tile is causal. + Returns: + A 1D tensor of shape `(length,)` with values in [0, 1]. + """ + if length <= 0: + raise ValueError("Mask length must be positive.") + + ramp_left = max(0, min(ramp_left, length)) + ramp_right = max(0, min(ramp_right, length)) + + mask = torch.ones(length) + + if ramp_left > 0: + interval_length = ramp_left + 1 if left_starts_from_0 else ramp_left + 2 + fade_in = torch.linspace(0.0, 1.0, interval_length)[:-1] + if not left_starts_from_0: + fade_in = fade_in[1:] + mask[:ramp_left] *= fade_in + + if ramp_right > 0: + fade_out = torch.linspace(1.0, 0.0, steps=ramp_right + 2)[1:-1] + mask[-ramp_right:] *= fade_out + + return mask.clamp_(0, 1) + + +@dataclass(frozen=True) +class SpatialTilingConfig: + """Configuration for dividing each frame into spatial tiles with optional overlap. + Args: + tile_size_in_pixels (int): Size of each tile in pixels. Must be at least 64 and divisible by 32. + tile_overlap_in_pixels (int, optional): Overlap between tiles in pixels. Must be divisible by 32. Defaults to 0. + """ + + tile_size_in_pixels: int + tile_overlap_in_pixels: int = 0 + + def __post_init__(self) -> None: + if self.tile_size_in_pixels < 64: + raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}") + if self.tile_size_in_pixels % 32 != 0: + raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}") + if self.tile_overlap_in_pixels % 32 != 0: + raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}") + if self.tile_overlap_in_pixels >= self.tile_size_in_pixels: + raise ValueError( + f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}" + ) + + +@dataclass(frozen=True) +class TemporalTilingConfig: + """Configuration for dividing a video into temporal tiles (chunks of frames) with optional overlap. + Args: + tile_size_in_frames (int): Number of frames in each tile. Must be at least 16 and divisible by 8. + tile_overlap_in_frames (int, optional): Number of overlapping frames between consecutive tiles. + Must be divisible by 8. Defaults to 0. + """ + + tile_size_in_frames: int + tile_overlap_in_frames: int = 0 + + def __post_init__(self) -> None: + if self.tile_size_in_frames < 16: + raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}") + if self.tile_size_in_frames % 8 != 0: + raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}") + if self.tile_overlap_in_frames % 8 != 0: + raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}") + if self.tile_overlap_in_frames >= self.tile_size_in_frames: + raise ValueError( + f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}" + ) + + +@dataclass(frozen=True) +class TilingConfig: + """Configuration for splitting video into tiles with optional overlap. + Attributes: + spatial_config: Configuration for splitting spatial dimensions into tiles. + temporal_config: Configuration for splitting temporal dimension into tiles. + """ + + spatial_config: SpatialTilingConfig | None = None + temporal_config: TemporalTilingConfig | None = None + + @classmethod + def default(cls) -> "TilingConfig": + return cls( + spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64), + temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24), + ) + + +@dataclass(frozen=True) +class DimensionIntervals: + """Intervals which a single dimension of the latent space is split into. + Each interval is defined by its start, end, left ramp, and right ramp. + The start and end are the indices of the first and last element (exclusive) in the interval. + Ramps are regions of the interval where the value of the mask tensor is + interpolated between 0 and 1 for blending with neighboring intervals. + The left ramp and right ramp values are the lengths of the left and right ramps. + """ + + starts: List[int] + ends: List[int] + left_ramps: List[int] + right_ramps: List[int] + + +@dataclass(frozen=True) +class LatentIntervals: + """Intervals which the latent tensor of given shape is split into. + Each dimension of the latent space is split into intervals based on the length along said dimension. + """ + + original_shape: torch.Size + dimension_intervals: Tuple[DimensionIntervals, ...] + + +# Operation to split a single dimension of the tensor into intervals based on the length along the dimension. +SplitOperation = Callable[[int], DimensionIntervals] +# Operation to map the intervals in input dimension to slices and masks along a corresponding output dimension. +MappingOperation = Callable[[DimensionIntervals], tuple[list[slice], list[torch.Tensor | None]]] + + +def default_split_operation(length: int) -> DimensionIntervals: + return DimensionIntervals(starts=[0], ends=[length], left_ramps=[0], right_ramps=[0]) + + +DEFAULT_SPLIT_OPERATION: SplitOperation = default_split_operation + + +def default_mapping_operation(_intervals: DimensionIntervals,) -> tuple[list[slice], list[torch.Tensor | None]]: + return [slice(0, None)], [None] + + +DEFAULT_MAPPING_OPERATION: MappingOperation = default_mapping_operation + + +class Tile(NamedTuple): + """ + Represents a single tile. + Attributes: + in_coords: + Tuple of slices specifying where to cut the tile from the INPUT tensor. + out_coords: + Tuple of slices specifying where this tile's OUTPUT should be placed in the reconstructed OUTPUT tensor. + masks_1d: + Per-dimension masks in OUTPUT units. + These are used to create all-dimensional blending mask. + Methods: + blend_mask: + Create a single N-D mask from the per-dimension masks. + """ + + in_coords: Tuple[slice, ...] + out_coords: Tuple[slice, ...] + masks_1d: Tuple[Tuple[torch.Tensor, ...]] + + @property + def blend_mask(self) -> torch.Tensor: + num_dims = len(self.out_coords) + per_dimension_masks: List[torch.Tensor] = [] + + for dim_idx in range(num_dims): + mask_1d = self.masks_1d[dim_idx] + view_shape = [1] * num_dims + if mask_1d is None: + # Broadcast mask along this dimension (length 1). + one = torch.ones(1) + + view_shape[dim_idx] = 1 + per_dimension_masks.append(one.view(*view_shape)) + continue + + # Reshape (L,) -> (1, ..., L, ..., 1) so masks across dimensions broadcast-multiply. + view_shape[dim_idx] = mask_1d.shape[0] + per_dimension_masks.append(mask_1d.view(*view_shape)) + + # Multiply per-dimension masks to form the full N-D mask (separable blending window). + combined_mask = per_dimension_masks[0] + for mask in per_dimension_masks[1:]: + combined_mask = combined_mask * mask + + return combined_mask + + +def create_tiles_from_intervals_and_mappers( + intervals: LatentIntervals, + mappers: List[MappingOperation], +) -> List[Tile]: + full_dim_input_slices = [] + full_dim_output_slices = [] + full_dim_masks_1d = [] + for axis_index in range(len(intervals.original_shape)): + dimension_intervals = intervals.dimension_intervals[axis_index] + starts = dimension_intervals.starts + ends = dimension_intervals.ends + input_slices = [slice(s, e) for s, e in zip(starts, ends, strict=True)] + output_slices, masks_1d = mappers[axis_index](dimension_intervals) + full_dim_input_slices.append(input_slices) + full_dim_output_slices.append(output_slices) + full_dim_masks_1d.append(masks_1d) + + tiles = [] + tile_in_coords = list(itertools.product(*full_dim_input_slices)) + tile_out_coords = list(itertools.product(*full_dim_output_slices)) + tile_mask_1ds = list(itertools.product(*full_dim_masks_1d)) + for in_coord, out_coord, mask_1d in zip(tile_in_coords, tile_out_coords, tile_mask_1ds, strict=True): + tiles.append(Tile( + in_coords=in_coord, + out_coords=out_coord, + masks_1d=mask_1d, + )) + return tiles + + +def create_tiles( + latent_shape: torch.Size, + splitters: List[SplitOperation], + mappers: List[MappingOperation], +) -> List[Tile]: + if len(splitters) != len(latent_shape): + raise ValueError(f"Number of splitters must be equal to number of dimensions in latent shape, " + f"got {len(splitters)} and {len(latent_shape)}") + if len(mappers) != len(latent_shape): + raise ValueError(f"Number of mappers must be equal to number of dimensions in latent shape, " + f"got {len(mappers)} and {len(latent_shape)}") + intervals = [splitter(length) for splitter, length in zip(splitters, latent_shape, strict=True)] + latent_intervals = LatentIntervals(original_shape=latent_shape, dimension_intervals=tuple(intervals)) + return create_tiles_from_intervals_and_mappers(latent_intervals, mappers) + + +def _make_encoder_block( + block_name: str, + block_config: dict[str, Any], + in_channels: int, + convolution_dimensions: int, + norm_layer: NormLayerType, + norm_num_groups: int, + spatial_padding_mode: PaddingModeType, +) -> Tuple[nn.Module, int]: + out_channels = in_channels + + if block_name == "res_x": + block = UNetMidBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + num_layers=block_config["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "res_x_y": + out_channels = in_channels * block_config.get("multiplier", 2) + block = ResnetBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time": + block = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(2, 1, 1), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space": + block = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(1, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all": + block = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all_x_y": + out_channels = in_channels * block_config.get("multiplier", 2) + block = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all_res": + out_channels = in_channels * block_config.get("multiplier", 2) + block = SpaceToDepthDownsample( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space_res": + out_channels = in_channels * block_config.get("multiplier", 2) + block = SpaceToDepthDownsample( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time_res": + out_channels = in_channels * block_config.get("multiplier", 2) + block = SpaceToDepthDownsample( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unknown block: {block_name}") + + return block, out_channels + + +class LTX2VideoEncoder(nn.Module): + _DEFAULT_NORM_NUM_GROUPS = 32 + """ + Variational Autoencoder Encoder. Encodes video frames into a latent representation. + The encoder compresses the input video through a series of downsampling operations controlled by + patch_size and encoder_blocks. The output is a normalized latent tensor with shape (B, 128, F', H', W'). + Compression Behavior: + The total compression is determined by: + 1. Initial spatial compression via patchify: H -> H/4, W -> W/4 (patch_size=4) + 2. Sequential compression through encoder_blocks based on their stride patterns + Compression blocks apply 2x compression in specified dimensions: + - "compress_time" / "compress_time_res": temporal only + - "compress_space" / "compress_space_res": spatial only (H and W) + - "compress_all" / "compress_all_res": all dimensions (F, H, W) + - "res_x" / "res_x_y": no compression + Standard LTX Video configuration: + - patch_size=4 + - encoder_blocks: 1x compress_space_res, 1x compress_time_res, 2x compress_all_res + - Final dimensions: F' = 1 + (F-1)/8, H' = H/32, W' = W/32 + - Example: (B, 3, 33, 512, 512) -> (B, 128, 5, 16, 16) + - Note: Input must have 1 + 8*k frames (e.g., 1, 9, 17, 25, 33...) + Args: + convolution_dimensions: The number of dimensions to use in convolutions (2D or 3D). + in_channels: The number of input channels. For RGB images, this is 3. + out_channels: The number of output channels (latent channels). For latent channels, this is 128. + encoder_blocks: The list of blocks to construct the encoder. Each block is a tuple of (block_name, params) + where params is either an int (num_layers) or a dict with configuration. + patch_size: The patch size for initial spatial compression. Should be a power of 2. + norm_layer: The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + latent_log_var: The log variance mode. Can be either `per_channel`, `uniform`, `constant` or `none`. + """ + + def __init__( + self, + convolution_dimensions: int = 3, + in_channels: int = 3, + out_channels: int = 128, + patch_size: int = 4, + norm_layer: NormLayerType = NormLayerType.PIXEL_NORM, + latent_log_var: LogVarianceType = LogVarianceType.UNIFORM, + encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + encoder_blocks = [['res_x', { + 'num_layers': 4 + }], ['compress_space_res', { + 'multiplier': 2 + }], ['res_x', { + 'num_layers': 6 + }], ['compress_time_res', { + 'multiplier': 2 + }], ['res_x', { + 'num_layers': 6 + }], ['compress_all_res', { + 'multiplier': 2 + }], ['res_x', { + 'num_layers': 2 + }], ['compress_all_res', { + 'multiplier': 2 + }], ['res_x', { + 'num_layers': 2 + }]] + self.patch_size = patch_size + self.norm_layer = norm_layer + self.latent_channels = out_channels + self.latent_log_var = latent_log_var + self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS + + # Per-channel statistics for normalizing latents + self.per_channel_statistics = PerChannelStatistics(latent_channels=out_channels) + + in_channels = in_channels * patch_size**2 + feature_channels = out_channels + + self.conv_in = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=feature_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=encoder_spatial_padding_mode, + ) + + self.down_blocks = nn.ModuleList([]) + + for block_name, block_params in encoder_blocks: + # Convert int to dict format for uniform handling + block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params + + block, feature_channels = _make_encoder_block( + block_name=block_name, + block_config=block_config, + in_channels=feature_channels, + convolution_dimensions=convolution_dimensions, + norm_layer=norm_layer, + norm_num_groups=self._norm_num_groups, + spatial_padding_mode=encoder_spatial_padding_mode, + ) + + self.down_blocks.append(block) + + # out + if norm_layer == NormLayerType.GROUP_NORM: + self.conv_norm_out = nn.GroupNorm(num_channels=feature_channels, num_groups=self._norm_num_groups, eps=1e-6) + elif norm_layer == NormLayerType.PIXEL_NORM: + self.conv_norm_out = PixelNorm() + + self.conv_act = nn.SiLU() + + conv_out_channels = out_channels + if latent_log_var == LogVarianceType.PER_CHANNEL: + conv_out_channels *= 2 + elif latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}: + conv_out_channels += 1 + elif latent_log_var != LogVarianceType.NONE: + raise ValueError(f"Invalid latent_log_var: {latent_log_var}") + + self.conv_out = make_conv_nd( + dims=convolution_dimensions, + in_channels=feature_channels, + out_channels=conv_out_channels, + kernel_size=3, + padding=1, + causal=True, + spatial_padding_mode=encoder_spatial_padding_mode, + ) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + r""" + Encode video frames into normalized latent representation. + Args: + sample: Input video (B, C, F, H, W). F must be 1 + 8*k (e.g., 1, 9, 17, 25, 33...). + Returns: + Normalized latent means (B, 128, F', H', W') where F' = 1+(F-1)/8, H' = H/32, W' = W/32. + Example: (B, 3, 33, 512, 512) -> (B, 128, 5, 16, 16). + """ + # Validate frame count + frames_count = sample.shape[2] + if ((frames_count - 1) % 8) != 0: + raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames " + "(e.g., 1, 9, 17, ...). Please check your input.") + + # Initial spatial compression: trade spatial resolution for channel depth + # This reduces H,W by patch_size and increases channels, making convolutions more efficient + # Example: (B, 3, F, 512, 512) -> (B, 48, F, 128, 128) with patch_size=4 + sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + sample = self.conv_in(sample) + + for down_block in self.down_blocks: + sample = down_block(sample) + + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if self.latent_log_var == LogVarianceType.UNIFORM: + # Uniform Variance: model outputs N means and 1 shared log-variance channel. + # We need to expand the single logvar to match the number of means channels + # to create a format compatible with PER_CHANNEL (means + logvar, each with N channels). + # Sample shape: (B, N+1, ...) where N = latent_channels (e.g., 128 means + 1 logvar = 129) + # Target shape: (B, 2*N, ...) where first N are means, last N are logvar + + if sample.shape[1] < 2: + raise ValueError(f"Invalid channel count for UNIFORM mode: expected at least 2 channels " + f"(N means + 1 logvar), got {sample.shape[1]}") + + # Extract means (first N channels) and logvar (last 1 channel) + means = sample[:, :-1, ...] # (B, N, ...) + logvar = sample[:, -1:, ...] # (B, 1, ...) + + # Repeat logvar N times to match means channels + # Use expand/repeat pattern that works for both 4D and 5D tensors + num_channels = means.shape[1] + repeat_shape = [1, num_channels] + [1] * (sample.ndim - 2) + repeated_logvar = logvar.repeat(*repeat_shape) # (B, N, ...) + + # Concatenate to create (B, 2*N, ...) format: [means, repeated_logvar] + sample = torch.cat([means, repeated_logvar], dim=1) + elif self.latent_log_var == LogVarianceType.CONSTANT: + sample = sample[:, :-1, ...] + approx_ln_0 = -30 # this is the minimal clamp value in DiagonalGaussianDistribution objects + sample = torch.cat( + [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0], + dim=1, + ) + + # Split into means and logvar, then normalize means + means, _ = torch.chunk(sample, 2, dim=1) + return self.per_channel_statistics.normalize(means) + + +def _make_decoder_block( + block_name: str, + block_config: dict[str, Any], + in_channels: int, + convolution_dimensions: int, + norm_layer: NormLayerType, + timestep_conditioning: bool, + norm_num_groups: int, + spatial_padding_mode: PaddingModeType, +) -> Tuple[nn.Module, int]: + out_channels = in_channels + if block_name == "res_x": + block = UNetMidBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + num_layers=block_config["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_config.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "attn_res_x": + block = UNetMidBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + num_layers=block_config["num_layers"], + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_config.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + attention_head_dim=block_config["attention_head_dim"], + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "res_x_y": + out_channels = in_channels // block_config.get("multiplier", 2) + block = ResnetBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_config.get("inject_noise", False), + timestep_conditioning=False, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time": + block = DepthToSpaceUpsample( + dims=convolution_dimensions, + in_channels=in_channels, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space": + block = DepthToSpaceUpsample( + dims=convolution_dimensions, + in_channels=in_channels, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all": + out_channels = in_channels // block_config.get("multiplier", 1) + block = DepthToSpaceUpsample( + dims=convolution_dimensions, + in_channels=in_channels, + stride=(2, 2, 2), + residual=block_config.get("residual", False), + out_channels_reduction_factor=block_config.get("multiplier", 1), + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unknown layer: {block_name}") + + return block, out_channels + + +class LTX2VideoDecoder(nn.Module): + _DEFAULT_NORM_NUM_GROUPS = 32 + """ + Variational Autoencoder Decoder. Decodes latent representation into video frames. + The decoder upsamples latents through a series of upsampling operations (inverse of encoder). + Output dimensions: F = 8x(F'-1) + 1, H = 32xH', W = 32xW' for standard LTX Video configuration. + Upsampling blocks expand dimensions by 2x in specified dimensions: + - "compress_time": temporal only + - "compress_space": spatial only (H and W) + - "compress_all": all dimensions (F, H, W) + - "res_x" / "res_x_y" / "attn_res_x": no upsampling + Causal Mode: + causal=False (standard): Symmetric padding, allows future frame dependencies. + causal=True: Causal padding, each frame depends only on past/current frames. + First frame removed after temporal upsampling in both modes. Output shape unchanged. + Example: (B, 128, 5, 16, 16) -> (B, 3, 33, 512, 512) for both modes. + Args: + convolution_dimensions: The number of dimensions to use in convolutions (2D or 3D). + in_channels: The number of input channels (latent channels). Default is 128. + out_channels: The number of output channels. For RGB images, this is 3. + decoder_blocks: The list of blocks to construct the decoder. Each block is a tuple of (block_name, params) + where params is either an int (num_layers) or a dict with configuration. + patch_size: Final spatial expansion factor. For standard LTX Video, use 4 for 4x spatial expansion: + H -> Hx4, W -> Wx4. Should be a power of 2. + norm_layer: The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + causal: Whether to use causal convolutions. For standard LTX Video, use False for symmetric padding. + When True, uses causal padding (past/current frames only). + timestep_conditioning: Whether to condition the decoder on timestep for denoising. + """ + + def __init__( + self, + convolution_dimensions: int = 3, + in_channels: int = 128, + out_channels: int = 3, + decoder_blocks: List[Tuple[str, int | dict]] = [], # noqa: B006 + patch_size: int = 4, + norm_layer: NormLayerType = NormLayerType.PIXEL_NORM, + causal: bool = False, + timestep_conditioning: bool = False, + decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT, + ): + super().__init__() + + # Spatiotemporal downscaling between decoded video space and VAE latents. + # According to the LTXV paper, the standard configuration downsamples + # video inputs by a factor of 8 in the temporal dimension and 32 in + # each spatial dimension (height and width). This parameter determines how + # many video frames and pixels correspond to a single latent cell. + decoder_blocks = [['res_x', { + 'num_layers': 5, + 'inject_noise': False + }], ['compress_all', { + 'residual': True, + 'multiplier': 2 + }], ['res_x', { + 'num_layers': 5, + 'inject_noise': False + }], ['compress_all', { + 'residual': True, + 'multiplier': 2 + }], ['res_x', { + 'num_layers': 5, + 'inject_noise': False + }], ['compress_all', { + 'residual': True, + 'multiplier': 2 + }], ['res_x', { + 'num_layers': 5, + 'inject_noise': False + }]] + self.video_downscale_factors = SpatioTemporalScaleFactors( + time=8, + width=32, + height=32, + ) + + self.patch_size = patch_size + out_channels = out_channels * patch_size**2 + self.causal = causal + self.timestep_conditioning = timestep_conditioning + self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS + + # Per-channel statistics for denormalizing latents + self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels) + + # Noise and timestep parameters for decoder conditioning + self.decode_noise_scale = 0.025 + self.decode_timestep = 0.05 + + # Compute initial feature_channels by going through blocks in reverse + # This determines the channel width at the start of the decoder + feature_channels = in_channels + for block_name, block_params in list(reversed(decoder_blocks)): + block_config = block_params if isinstance(block_params, dict) else {} + if block_name == "res_x_y": + feature_channels = feature_channels * block_config.get("multiplier", 2) + if block_name == "compress_all": + feature_channels = feature_channels * block_config.get("multiplier", 1) + + self.conv_in = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=feature_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=decoder_spatial_padding_mode, + ) + + self.up_blocks = nn.ModuleList([]) + + for block_name, block_params in list(reversed(decoder_blocks)): + # Convert int to dict format for uniform handling + block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params + + block, feature_channels = _make_decoder_block( + block_name=block_name, + block_config=block_config, + in_channels=feature_channels, + convolution_dimensions=convolution_dimensions, + norm_layer=norm_layer, + timestep_conditioning=timestep_conditioning, + norm_num_groups=self._norm_num_groups, + spatial_padding_mode=decoder_spatial_padding_mode, + ) + + self.up_blocks.append(block) + + if norm_layer == NormLayerType.GROUP_NORM: + self.conv_norm_out = nn.GroupNorm(num_channels=feature_channels, num_groups=self._norm_num_groups, eps=1e-6) + elif norm_layer == NormLayerType.PIXEL_NORM: + self.conv_norm_out = PixelNorm() + + self.conv_act = nn.SiLU() + self.conv_out = make_conv_nd( + dims=convolution_dimensions, + in_channels=feature_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + causal=True, + spatial_padding_mode=decoder_spatial_padding_mode, + ) + + if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0)) + self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(embedding_dim=feature_channels * 2, + size_emb_dim=0) + self.last_scale_shift_table = nn.Parameter(torch.empty(2, feature_channels)) + + def forward( + self, + sample: torch.Tensor, + timestep: torch.Tensor | None = None, + generator: torch.Generator | None = None, + ) -> torch.Tensor: + r""" + Decode latent representation into video frames. + Args: + sample: Latent tensor (B, 128, F', H', W'). + timestep: Timestep for conditioning (if timestep_conditioning=True). Uses default 0.05 if None. + generator: Random generator for deterministic noise injection (if inject_noise=True in blocks). + Returns: + Decoded video (B, 3, F, H, W) where F = 8x(F'-1) + 1, H = 32xH', W = 32xW'. + Example: (B, 128, 5, 16, 16) -> (B, 3, 33, 512, 512). + Note: First frame is removed after temporal upsampling regardless of causal mode. + When causal=False, allows future frame dependencies in convolutions but maintains same output shape. + """ + batch_size = sample.shape[0] + + # Add noise if timestep conditioning is enabled + if self.timestep_conditioning: + noise = (torch.randn( + sample.size(), + generator=generator, + dtype=sample.dtype, + device=sample.device, + ) * self.decode_noise_scale) + + sample = noise + (1.0 - self.decode_noise_scale) * sample + + # Denormalize latents + sample = self.per_channel_statistics.un_normalize(sample) + + # Use default decode_timestep if timestep not provided + if timestep is None and self.timestep_conditioning: + timestep = torch.full((batch_size,), self.decode_timestep, device=sample.device, dtype=sample.dtype) + + sample = self.conv_in(sample, causal=self.causal) + + scaled_timestep = None + if self.timestep_conditioning: + if timestep is None: + raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True") + scaled_timestep = timestep * self.timestep_scale_multiplier.to(sample) + + for up_block in self.up_blocks: + if isinstance(up_block, UNetMidBlock3D): + block_kwargs = { + "causal": self.causal, + "timestep": scaled_timestep if self.timestep_conditioning else None, + "generator": generator, + } + sample = up_block(sample, **block_kwargs) + elif isinstance(up_block, ResnetBlock3D): + sample = up_block(sample, causal=self.causal, generator=generator) + else: + sample = up_block(sample, causal=self.causal) + + sample = self.conv_norm_out(sample) + + if self.timestep_conditioning: + embedded_timestep = self.last_time_embedder( + timestep=scaled_timestep.flatten(), + hidden_dtype=sample.dtype, + ) + embedded_timestep = embedded_timestep.view(batch_size, embedded_timestep.shape[-1], 1, 1, 1) + ada_values = self.last_scale_shift_table[None, ..., None, None, None].to( + device=sample.device, dtype=sample.dtype) + embedded_timestep.reshape( + batch_size, + 2, + -1, + embedded_timestep.shape[-3], + embedded_timestep.shape[-2], + embedded_timestep.shape[-1], + ) + shift, scale = ada_values.unbind(dim=1) + sample = sample * (1 + scale) + shift + + sample = self.conv_act(sample) + sample = self.conv_out(sample, causal=self.causal) + + # Final spatial expansion: reverse the initial patchify from encoder + # Moves pixels from channels back to spatial dimensions + # Example: (B, 48, F, 128, 128) -> (B, 3, F, 512, 512) with patch_size=4 + sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + + return sample + + def _prepare_tiles( + self, + latent: torch.Tensor, + tiling_config: TilingConfig | None = None, + ) -> List[Tile]: + splitters = [DEFAULT_SPLIT_OPERATION] * len(latent.shape) + mappers = [DEFAULT_MAPPING_OPERATION] * len(latent.shape) + if tiling_config is not None and tiling_config.spatial_config is not None: + cfg = tiling_config.spatial_config + long_side = max(latent.shape[3], latent.shape[4]) + + def enable_on_axis(axis_idx: int, factor: int) -> None: + size = cfg.tile_size_in_pixels // factor + overlap = cfg.tile_overlap_in_pixels // factor + axis_length = latent.shape[axis_idx] + lower_threshold = max(2, overlap + 1) + tile_size = max(lower_threshold, round(size * axis_length / long_side)) + splitters[axis_idx] = split_in_spatial(tile_size, overlap) + mappers[axis_idx] = to_mapping_operation(map_spatial_slice, factor) + + enable_on_axis(3, self.video_downscale_factors.height) + enable_on_axis(4, self.video_downscale_factors.width) + + if tiling_config is not None and tiling_config.temporal_config is not None: + cfg = tiling_config.temporal_config + tile_size = cfg.tile_size_in_frames // self.video_downscale_factors.time + overlap = cfg.tile_overlap_in_frames // self.video_downscale_factors.time + splitters[2] = split_in_temporal(tile_size, overlap) + mappers[2] = to_mapping_operation(map_temporal_slice, self.video_downscale_factors.time) + + return create_tiles(latent.shape, splitters, mappers) + + def tiled_decode( + self, + latent: torch.Tensor, + tiling_config: TilingConfig | None = None, + timestep: torch.Tensor | None = None, + generator: torch.Generator | None = None, + ) -> Iterator[torch.Tensor]: + """ + Decode a latent tensor into video frames using tiled processing. + Splits the latent tensor into tiles, decodes each tile individually, + and yields video chunks as they become available. + Args: + latent: Input latent tensor (B, C, F', H', W'). + tiling_config: Tiling configuration for the latent tensor. + timestep: Optional timestep for decoder conditioning. + generator: Optional random generator for deterministic decoding. + Yields: + Video chunks (B, C, T, H, W) by temporal slices; + """ + + # Calculate full video shape from latent shape to get spatial dimensions + full_video_shape = VideoLatentShape.from_torch_shape(latent.shape).upscale(self.video_downscale_factors) + tiles = self._prepare_tiles(latent, tiling_config) + + temporal_groups = self._group_tiles_by_temporal_slice(tiles) + + # State for temporal overlap handling + previous_chunk = None + previous_weights = None + previous_temporal_slice = None + + for temporal_group_tiles in temporal_groups: + curr_temporal_slice = temporal_group_tiles[0].out_coords[2] + + # Calculate the shape of the temporal buffer for this group of tiles. + # The temporal length depends on whether this is the first tile (starts at 0) or not. + # - First tile: (frames - 1) * scale + 1 + # - Subsequent tiles: frames * scale + # This logic is handled by TemporalAxisMapping and reflected in out_coords. + temporal_tile_buffer_shape = full_video_shape._replace(frames=curr_temporal_slice.stop - + curr_temporal_slice.start,) + + buffer = torch.zeros( + temporal_tile_buffer_shape.to_torch_shape(), + device=latent.device, + dtype=latent.dtype, + ) + + curr_weights = self._accumulate_temporal_group_into_buffer( + group_tiles=temporal_group_tiles, + buffer=buffer, + latent=latent, + timestep=timestep, + generator=generator, + ) + + # Blend with previous temporal chunk if it exists + if previous_chunk is not None: + # Check if current temporal slice overlaps with previous temporal slice + if previous_temporal_slice.stop > curr_temporal_slice.start: + overlap_len = previous_temporal_slice.stop - curr_temporal_slice.start + temporal_overlap_slice = slice(curr_temporal_slice.start - previous_temporal_slice.start, None) + + # The overlap is already masked before it reaches this step. Each tile is accumulated into buffer + # with its trapezoidal mask, and curr_weights accumulates the same mask. In the overlap blend we add + # the masked values (buffer[...]) and the corresponding weights (curr_weights[...]) into the + # previous buffers, then later normalize by weights. + previous_chunk[:, :, temporal_overlap_slice, :, :] += buffer[:, :, slice(0, overlap_len), :, :] + previous_weights[:, :, temporal_overlap_slice, :, :] += curr_weights[:, :, + slice(0, overlap_len), :, :] + + buffer[:, :, slice(0, overlap_len), :, :] = previous_chunk[:, :, temporal_overlap_slice, :, :] + curr_weights[:, :, slice(0, overlap_len), :, :] = previous_weights[:, :, + temporal_overlap_slice, :, :] + + # Yield the non-overlapping part of the previous chunk + previous_weights = previous_weights.clamp(min=1e-8) + yield_len = curr_temporal_slice.start - previous_temporal_slice.start + yield (previous_chunk / previous_weights)[:, :, :yield_len, :, :] + + # Update state for next iteration + previous_chunk = buffer + previous_weights = curr_weights + previous_temporal_slice = curr_temporal_slice + + # Yield any remaining chunk + if previous_chunk is not None: + previous_weights = previous_weights.clamp(min=1e-8) + yield previous_chunk / previous_weights + + def _group_tiles_by_temporal_slice(self, tiles: List[Tile]) -> List[List[Tile]]: + """Group tiles by their temporal output slice.""" + if not tiles: + return [] + + groups = [] + current_slice = tiles[0].out_coords[2] + current_group = [] + + for tile in tiles: + tile_slice = tile.out_coords[2] + if tile_slice == current_slice: + current_group.append(tile) + else: + groups.append(current_group) + current_slice = tile_slice + current_group = [tile] + + # Add the final group + if current_group: + groups.append(current_group) + + return groups + + def _accumulate_temporal_group_into_buffer( + self, + group_tiles: List[Tile], + buffer: torch.Tensor, + latent: torch.Tensor, + timestep: torch.Tensor | None, + generator: torch.Generator | None, + ) -> torch.Tensor: + """ + Decode and accumulate all tiles of a temporal group into a local buffer. + The buffer is local to the group and always starts at time 0; temporal coordinates + are rebased by subtracting temporal_slice.start. + """ + temporal_slice = group_tiles[0].out_coords[2] + + weights = torch.zeros_like(buffer) + + for tile in group_tiles: + decoded_tile = self.forward(latent[tile.in_coords], timestep, generator) + mask = tile.blend_mask.to(device=buffer.device, dtype=buffer.dtype) + temporal_offset = tile.out_coords[2].start - temporal_slice.start + # Use the tile's output coordinate length, not the decoded tile's length, + # as the decoder may produce a different number of frames than expected + expected_temporal_len = tile.out_coords[2].stop - tile.out_coords[2].start + decoded_temporal_len = decoded_tile.shape[2] + + # Ensure we don't exceed the buffer or decoded tile bounds + actual_temporal_len = min(expected_temporal_len, decoded_temporal_len, buffer.shape[2] - temporal_offset) + + chunk_coords = ( + slice(None), # batch + slice(None), # channels + slice(temporal_offset, temporal_offset + actual_temporal_len), + tile.out_coords[3], # height + tile.out_coords[4], # width + ) + + # Slice decoded_tile and mask to match the actual length we're writing + decoded_slice = decoded_tile[:, :, :actual_temporal_len, :, :] + mask_slice = mask[:, :, :actual_temporal_len, :, :] if mask.shape[2] > 1 else mask + + buffer[chunk_coords] += decoded_slice * mask_slice + weights[chunk_coords] += mask_slice + + return weights + + +def decode_video( + latent: torch.Tensor, + video_decoder: LTX2VideoDecoder, + tiling_config: TilingConfig | None = None, + generator: torch.Generator | None = None, +) -> Iterator[torch.Tensor]: + """ + Decode a video latent tensor with the given decoder. + Args: + latent: Tensor [c, f, h, w] + video_decoder: Decoder module. + tiling_config: Optional tiling settings. + generator: Optional random generator for deterministic decoding. + Yields: + Decoded chunk [f, h, w, c], uint8 in [0, 255]. + """ + + def convert_to_uint8(frames: torch.Tensor) -> torch.Tensor: + frames = (((frames + 1.0) / 2.0).clamp(0.0, 1.0) * 255.0).to(torch.uint8) + frames = rearrange(frames[0], "c f h w -> f h w c") + return frames + + if tiling_config is not None: + for frames in video_decoder.tiled_decode(latent, tiling_config, generator=generator): + yield convert_to_uint8(frames) + else: + decoded_video = video_decoder(latent, generator=generator) + yield convert_to_uint8(decoded_video) + + +def get_video_chunks_number(num_frames: int, tiling_config: TilingConfig | None = None) -> int: + """ + Get the number of video chunks for a given number of frames and tiling configuration. + Args: + num_frames: Number of frames in the video. + tiling_config: Tiling configuration. + Returns: + Number of video chunks. + """ + if not tiling_config or not tiling_config.temporal_config: + return 1 + cfg = tiling_config.temporal_config + frame_stride = cfg.tile_size_in_frames - cfg.tile_overlap_in_frames + return (num_frames - 1 + frame_stride - 1) // frame_stride + + +def split_in_spatial(size: int, overlap: int) -> SplitOperation: + + def split(dimension_size: int) -> DimensionIntervals: + if dimension_size <= size: + return DEFAULT_SPLIT_OPERATION(dimension_size) + amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap) + starts = [i * (size - overlap) for i in range(amount)] + ends = [start + size for start in starts] + ends[-1] = dimension_size + left_ramps = [0] + [overlap] * (amount - 1) + right_ramps = [overlap] * (amount - 1) + [0] + return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps) + + return split + + +def split_in_temporal(size: int, overlap: int) -> SplitOperation: + non_causal_split = split_in_spatial(size, overlap) + + def split(dimension_size: int) -> DimensionIntervals: + if dimension_size <= size: + return DEFAULT_SPLIT_OPERATION(dimension_size) + intervals = non_causal_split(dimension_size) + starts = intervals.starts + starts[1:] = [s - 1 for s in starts[1:]] + left_ramps = intervals.left_ramps + left_ramps[1:] = [r + 1 for r in left_ramps[1:]] + return replace(intervals, starts=starts, left_ramps=left_ramps) + + return split + + +def to_mapping_operation( + map_func: Callable[[int, int, int, int, int], Tuple[slice, torch.Tensor]], + scale: int, +) -> MappingOperation: + + def map_op(intervals: DimensionIntervals) -> tuple[list[slice], list[torch.Tensor | None]]: + output_slices: list[slice] = [] + masks_1d: list[torch.Tensor | None] = [] + number_of_slices = len(intervals.starts) + for i in range(number_of_slices): + start = intervals.starts[i] + end = intervals.ends[i] + left_ramp = intervals.left_ramps[i] + right_ramp = intervals.right_ramps[i] + output_slice, mask_1d = map_func(start, end, left_ramp, right_ramp, scale) + output_slices.append(output_slice) + masks_1d.append(mask_1d) + return output_slices, masks_1d + + return map_op + + +def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, torch.Tensor]: + start = begin * scale + stop = 1 + (end - 1) * scale + left_ramp = 1 + (left_ramp - 1) * scale + right_ramp = right_ramp * scale + + return slice(start, stop), compute_trapezoidal_mask_1d(stop - start, left_ramp, right_ramp, True) + + +def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, torch.Tensor]: + start = begin * scale + stop = end * scale + left_ramp = left_ramp * scale + right_ramp = right_ramp * scale + + return slice(start, stop), compute_trapezoidal_mask_1d(stop - start, left_ramp, right_ramp, False) diff --git a/diffsynth/utils/state_dict_converters/ltx2_dit.py b/diffsynth/utils/state_dict_converters/ltx2_dit.py new file mode 100644 index 0000000..baffb9a --- /dev/null +++ b/diffsynth/utils/state_dict_converters/ltx2_dit.py @@ -0,0 +1,9 @@ +def LTXModelStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("model.diffusion_model."): + new_name = name.replace("model.diffusion_model.", "") + if new_name.startswith("audio_embeddings_connector.") or new_name.startswith("video_embeddings_connector."): + continue + state_dict_[new_name] = state_dict[name] + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/ltx2_video_vae.py b/diffsynth/utils/state_dict_converters/ltx2_video_vae.py new file mode 100644 index 0000000..132897d --- /dev/null +++ b/diffsynth/utils/state_dict_converters/ltx2_video_vae.py @@ -0,0 +1,22 @@ +def LTX2VideoEncoderStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("vae.encoder."): + new_name = name.replace("vae.encoder.", "") + state_dict_[new_name] = state_dict[name] + elif name.startswith("vae.per_channel_statistics."): + new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.") + state_dict_[new_name] = state_dict[name] + return state_dict_ + + +def LTX2VideoDecoderStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("vae.decoder."): + new_name = name.replace("vae.decoder.", "") + state_dict_[new_name] = state_dict[name] + elif name.startswith("vae.per_channel_statistics."): + new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.") + state_dict_[new_name] = state_dict[name] + return state_dict_ \ No newline at end of file diff --git a/diffsynth/utils/test/load_model.py b/diffsynth/utils/test/load_model.py new file mode 100644 index 0000000..f21fa60 --- /dev/null +++ b/diffsynth/utils/test/load_model.py @@ -0,0 +1,22 @@ +import torch +from diffsynth.models.model_loader import ModelPool +from diffsynth.core.loader import ModelConfig + + +def test_model_loading(model_name, + model_config: ModelConfig, + vram_limit: float = None, + device="cpu", + torch_dtype=torch.bfloat16): + model_pool = ModelPool() + model_config.download_if_necessary() + vram_config = model_config.vram_config() + vram_config["computation_dtype"] = torch_dtype + vram_config["computation_device"] = device + model_pool.auto_load_model( + model_config.path, + vram_config=vram_config, + vram_limit=vram_limit, + clear_parameters=model_config.clear_parameters, + ) + return model_pool.fetch_model(model_name)