Files
DiffSynth-Studio/diffsynth/models/ltx2_video_vae.py
2026-01-29 16:30:15 +08:00

2303 lines
90 KiB
Python

import itertools
import math
import einops
from dataclasses import replace, dataclass
from typing import Any, Callable, Iterator, List, NamedTuple, Tuple, Union, Optional
import torch
from einops import rearrange
from torch import nn
from torch.nn import functional as F
from enum import Enum
from .ltx2_common import PixelNorm, SpatioTemporalScaleFactors, VideoLatentShape, Patchifier, AudioLatentShape
from .ltx2_dit import PixArtAlphaCombinedTimestepSizeEmbeddings
VAE_SPATIAL_FACTOR = 32
VAE_TEMPORAL_FACTOR = 8
class VideoLatentPatchifier(Patchifier):
def __init__(self, patch_size: int):
# Patch sizes for video latents.
self._patch_size = (
1, # temporal dimension
patch_size, # height dimension
patch_size, # width dimension
)
@property
def patch_size(self) -> Tuple[int, int, int]:
return self._patch_size
def get_token_count(self, tgt_shape: VideoLatentShape) -> int:
return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size)
def patchify(
self,
latents: torch.Tensor,
) -> torch.Tensor:
latents = einops.rearrange(
latents,
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
p1=self._patch_size[0],
p2=self._patch_size[1],
p3=self._patch_size[2],
)
return latents
def unpatchify(
self,
latents: torch.Tensor,
output_shape: VideoLatentShape,
) -> torch.Tensor:
assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier"
patch_grid_frames = output_shape.frames // self._patch_size[0]
patch_grid_height = output_shape.height // self._patch_size[1]
patch_grid_width = output_shape.width // self._patch_size[2]
latents = einops.rearrange(
latents,
"b (f h w) (c p q) -> b c f (h p) (w q)",
f=patch_grid_frames,
h=patch_grid_height,
w=patch_grid_width,
p=self._patch_size[1],
q=self._patch_size[2],
)
return latents
def get_patch_grid_bounds(
self,
output_shape: AudioLatentShape | VideoLatentShape,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""
Return the per-dimension bounds [inclusive start, exclusive end) for every
patch produced by `patchify`. The bounds are expressed in the original
video grid coordinates: frame/time, height, and width.
The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where:
- axis 1 (size 3) enumerates (frame/time, height, width) dimensions
- axis 3 (size 2) stores `[start, end)` indices within each dimension
Args:
output_shape: Video grid description containing frames, height, and width.
device: Device of the latent tensor.
"""
if not isinstance(output_shape, VideoLatentShape):
raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates")
frames = output_shape.frames
height = output_shape.height
width = output_shape.width
batch_size = output_shape.batch
# Validate inputs to ensure positive dimensions
assert frames > 0, f"frames must be positive, got {frames}"
assert height > 0, f"height must be positive, got {height}"
assert width > 0, f"width must be positive, got {width}"
assert batch_size > 0, f"batch_size must be positive, got {batch_size}"
# Generate grid coordinates for each dimension (frame, height, width)
# We use torch.arange to create the starting coordinates for each patch.
# indexing='ij' ensures the dimensions are in the order (frame, height, width).
grid_coords = torch.meshgrid(
torch.arange(start=0, end=frames, step=self._patch_size[0], device=device),
torch.arange(start=0, end=height, step=self._patch_size[1], device=device),
torch.arange(start=0, end=width, step=self._patch_size[2], device=device),
indexing="ij",
)
# Stack the grid coordinates to create the start coordinates tensor.
# Shape becomes (3, grid_f, grid_h, grid_w)
patch_starts = torch.stack(grid_coords, dim=0)
# Create a tensor containing the size of a single patch:
# (frame_patch_size, height_patch_size, width_patch_size).
# Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates.
patch_size_delta = torch.tensor(
self._patch_size,
device=patch_starts.device,
dtype=patch_starts.dtype,
).view(3, 1, 1, 1)
# Calculate end coordinates: start + patch_size
# Shape becomes (3, grid_f, grid_h, grid_w)
patch_ends = patch_starts + patch_size_delta
# Stack start and end coordinates together along the last dimension
# Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end]
latent_coords = torch.stack((patch_starts, patch_ends), dim=-1)
# Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence.
# Final Shape: (batch_size, 3, num_patches, 2)
latent_coords = einops.repeat(
latent_coords,
"c f h w bounds -> b c (f h w) bounds",
b=batch_size,
bounds=2,
)
return latent_coords
class NormLayerType(Enum):
GROUP_NORM = "group_norm"
PIXEL_NORM = "pixel_norm"
class LogVarianceType(Enum):
PER_CHANNEL = "per_channel"
UNIFORM = "uniform"
CONSTANT = "constant"
NONE = "none"
class PaddingModeType(Enum):
ZEROS = "zeros"
REFLECT = "reflect"
REPLICATE = "replicate"
CIRCULAR = "circular"
class DualConv3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
) -> None:
super(DualConv3d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.padding_mode = padding_mode
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if kernel_size == (1, 1, 1):
raise ValueError("kernel_size must be greater than 1. Use make_linear_nd instead.")
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
# Set parameters for convolutions
self.groups = groups
self.bias = bias
# Define the size of the channels after the first convolution
intermediate_channels = out_channels if in_channels < out_channels else in_channels
# Define parameters for the first convolution
self.weight1 = nn.Parameter(
torch.Tensor(
intermediate_channels,
in_channels // groups,
1,
kernel_size[1],
kernel_size[2],
))
self.stride1 = (1, stride[1], stride[2])
self.padding1 = (0, padding[1], padding[2])
self.dilation1 = (1, dilation[1], dilation[2])
if bias:
self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
else:
self.register_parameter("bias1", None)
# Define parameters for the second convolution
self.weight2 = nn.Parameter(torch.Tensor(out_channels, intermediate_channels // groups, kernel_size[0], 1, 1))
self.stride2 = (stride[0], 1, 1)
self.padding2 = (padding[0], 0, 0)
self.dilation2 = (dilation[0], 1, 1)
if bias:
self.bias2 = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter("bias2", None)
# Initialize weights and biases
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.kaiming_uniform_(self.weight1, a=torch.sqrt(5))
nn.init.kaiming_uniform_(self.weight2, a=torch.sqrt(5))
if self.bias:
fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
bound1 = 1 / torch.sqrt(fan_in1)
nn.init.uniform_(self.bias1, -bound1, bound1)
fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
bound2 = 1 / torch.sqrt(fan_in2)
nn.init.uniform_(self.bias2, -bound2, bound2)
def forward(
self,
x: torch.Tensor,
use_conv3d: bool = False,
skip_time_conv: bool = False,
) -> torch.Tensor:
if use_conv3d:
return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
else:
return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
def forward_with_3d(self, x: torch.Tensor, skip_time_conv: bool = False) -> torch.Tensor:
# First convolution
x = F.conv3d(
x,
self.weight1,
self.bias1,
self.stride1,
self.padding1,
self.dilation1,
self.groups,
padding_mode=self.padding_mode,
)
if skip_time_conv:
return x
# Second convolution
x = F.conv3d(
x,
self.weight2,
self.bias2,
self.stride2,
self.padding2,
self.dilation2,
self.groups,
padding_mode=self.padding_mode,
)
return x
def forward_with_2d(self, x: torch.Tensor, skip_time_conv: bool = False) -> torch.Tensor:
b, _, _, h, w = x.shape
# First 2D convolution
x = rearrange(x, "b c d h w -> (b d) c h w")
# Squeeze the depth dimension out of weight1 since it's 1
weight1 = self.weight1.squeeze(2)
# Select stride, padding, and dilation for the 2D convolution
stride1 = (self.stride1[1], self.stride1[2])
padding1 = (self.padding1[1], self.padding1[2])
dilation1 = (self.dilation1[1], self.dilation1[2])
x = F.conv2d(
x,
weight1,
self.bias1,
stride1,
padding1,
dilation1,
self.groups,
padding_mode=self.padding_mode,
)
_, _, h, w = x.shape
if skip_time_conv:
x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
return x
# Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
# Reshape weight2 to match the expected dimensions for conv1d
weight2 = self.weight2.squeeze(-1).squeeze(-1)
# Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
stride2 = self.stride2[0]
padding2 = self.padding2[0]
dilation2 = self.dilation2[0]
x = F.conv1d(
x,
weight2,
self.bias2,
stride2,
padding2,
dilation2,
self.groups,
padding_mode=self.padding_mode,
)
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
return x
@property
def weight(self) -> torch.Tensor:
return self.weight2
class CausalConv3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: Union[int, Tuple[int]] = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
kernel_size = (kernel_size, kernel_size, kernel_size)
self.time_kernel_size = kernel_size[0]
dilation = (dilation, 1, 1)
height_pad = kernel_size[1] // 2
width_pad = kernel_size[2] // 2
padding = (0, height_pad, width_pad)
self.conv = nn.Conv3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
padding_mode=spatial_padding_mode.value,
groups=groups,
bias=bias,
)
def forward(self, x: torch.Tensor, causal: bool = True) -> torch.Tensor:
if causal:
first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1))
x = torch.concatenate((first_frame_pad, x), dim=2)
else:
first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))
last_frame_pad = x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
x = self.conv(x)
return x
@property
def weight(self) -> torch.Tensor:
return self.conv.weight
def make_conv_nd( # noqa: PLR0913
dims: Union[int, Tuple[int, int]],
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
causal: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
temporal_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
) -> nn.Module:
if not (spatial_padding_mode == temporal_padding_mode or causal):
raise NotImplementedError("spatial and temporal padding modes must be equal")
if dims == 2:
return nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=spatial_padding_mode.value,
)
elif dims == 3:
if causal:
return CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
groups=groups,
bias=bias,
spatial_padding_mode=spatial_padding_mode,
)
return nn.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=spatial_padding_mode.value,
)
elif dims == (2, 1):
return DualConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
padding_mode=spatial_padding_mode.value,
)
else:
raise ValueError(f"unsupported dimensions: {dims}")
def make_linear_nd(
dims: int,
in_channels: int,
out_channels: int,
bias: bool = True,
) -> nn.Module:
if dims == 2:
return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias)
elif dims in (3, (2, 1)):
return nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias)
else:
raise ValueError(f"unsupported dimensions: {dims}")
def patchify(x: torch.Tensor, patch_size_hw: int, patch_size_t: int = 1) -> torch.Tensor:
"""
Rearrange spatial dimensions into channels. Divides image into patch_size x patch_size blocks
and moves pixels from each block into separate channels (space-to-depth).
Args:
x: Input tensor (4D or 5D)
patch_size_hw: Spatial patch size for height and width. With patch_size_hw=4, divides HxW into 4x4 blocks.
patch_size_t: Temporal patch size for frames. Default=1 (no temporal patching).
For 5D: (B, C, F, H, W) -> (B, Cx(patch_size_hw^2)x(patch_size_t), F/patch_size_t, H/patch_size_hw, W/patch_size_hw)
Example: (B, 3, 33, 512, 512) with patch_size_hw=4, patch_size_t=1 -> (B, 48, 33, 128, 128)
"""
if patch_size_hw == 1 and patch_size_t == 1:
return x
if x.dim() == 4:
x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw)
elif x.dim() == 5:
x = rearrange(
x,
"b c (f p) (h q) (w r) -> b (c p r q) f h w",
p=patch_size_t,
q=patch_size_hw,
r=patch_size_hw,
)
else:
raise ValueError(f"Invalid input shape: {x.shape}")
return x
def unpatchify(x: torch.Tensor, patch_size_hw: int, patch_size_t: int = 1) -> torch.Tensor:
"""
Rearrange channels back into spatial dimensions. Inverse of patchify - moves pixels from
channels back into patch_size x patch_size blocks (depth-to-space).
Args:
x: Input tensor (4D or 5D)
patch_size_hw: Spatial patch size for height and width. With patch_size_hw=4, expands HxW by 4x.
patch_size_t: Temporal patch size for frames. Default=1 (no temporal expansion).
For 5D: (B, Cx(patch_size_hw^2)x(patch_size_t), F, H, W) -> (B, C, Fxpatch_size_t, Hxpatch_size_hw, Wxpatch_size_hw)
Example: (B, 48, 33, 128, 128) with patch_size_hw=4, patch_size_t=1 -> (B, 3, 33, 512, 512)
"""
if patch_size_hw == 1 and patch_size_t == 1:
return x
if x.dim() == 4:
x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw)
elif x.dim() == 5:
x = rearrange(
x,
"b (c p r q) f h w -> b c (f p) (h q) (w r)",
p=patch_size_t,
q=patch_size_hw,
r=patch_size_hw,
)
return x
class PerChannelStatistics(nn.Module):
"""
Per-channel statistics for normalizing and denormalizing the latent representation.
This statics is computed over the entire dataset and stored in model's checkpoint under VAE state_dict.
"""
def __init__(self, latent_channels: int = 128):
super().__init__()
self.register_buffer("std-of-means", torch.empty(latent_channels))
self.register_buffer("mean-of-means", torch.empty(latent_channels))
self.register_buffer("mean-of-stds", torch.empty(latent_channels))
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(latent_channels))
self.register_buffer("channel", torch.empty(latent_channels))
def un_normalize(self, x: torch.Tensor) -> torch.Tensor:
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(
1, -1, 1, 1, 1).to(x)
def normalize(self, x: torch.Tensor) -> torch.Tensor:
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(
1, -1, 1, 1, 1).to(x)
class ResnetBlock3D(nn.Module):
r"""
A Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv layer. If None, same as `in_channels`.
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]],
in_channels: int,
out_channels: Optional[int] = None,
dropout: float = 0.0,
groups: int = 32,
eps: float = 1e-6,
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
inject_noise: bool = False,
timestep_conditioning: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.inject_noise = inject_noise
if norm_layer == NormLayerType.GROUP_NORM:
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
elif norm_layer == NormLayerType.PIXEL_NORM:
self.norm1 = PixelNorm()
self.non_linearity = nn.SiLU()
self.conv1 = make_conv_nd(
dims,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
if inject_noise:
self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
if norm_layer == NormLayerType.GROUP_NORM:
self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
elif norm_layer == NormLayerType.PIXEL_NORM:
self.norm2 = PixelNorm()
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = make_conv_nd(
dims,
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
if inject_noise:
self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
self.conv_shortcut = (make_linear_nd(dims=dims, in_channels=in_channels, out_channels=out_channels)
if in_channels != out_channels else nn.Identity())
# Using GroupNorm with 1 group is equivalent to LayerNorm but works with (B, C, ...) layout
# avoiding the need for dimension rearrangement used in standard nn.LayerNorm
self.norm3 = (nn.GroupNorm(num_groups=1, num_channels=in_channels, eps=eps, affine=True)
if in_channels != out_channels else nn.Identity())
self.timestep_conditioning = timestep_conditioning
if timestep_conditioning:
self.scale_shift_table = nn.Parameter(torch.zeros(4, in_channels))
def _feed_spatial_noise(
self,
hidden_states: torch.Tensor,
per_channel_scale: torch.Tensor,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
spatial_shape = hidden_states.shape[-2:]
device = hidden_states.device
dtype = hidden_states.dtype
# similar to the "explicit noise inputs" method in style-gan
spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype, generator=generator)[None]
scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
hidden_states = hidden_states + scaled_noise
return hidden_states
def forward(
self,
input_tensor: torch.Tensor,
causal: bool = True,
timestep: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
hidden_states = input_tensor
batch_size = hidden_states.shape[0]
hidden_states = self.norm1(hidden_states)
if self.timestep_conditioning:
if timestep is None:
raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True")
ada_values = self.scale_shift_table[None, ..., None, None, None].to(
device=hidden_states.device, dtype=hidden_states.dtype) + timestep.reshape(
batch_size,
4,
-1,
timestep.shape[-3],
timestep.shape[-2],
timestep.shape[-1],
)
shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)
hidden_states = hidden_states * (1 + scale1) + shift1
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.conv1(hidden_states, causal=causal)
if self.inject_noise:
hidden_states = self._feed_spatial_noise(
hidden_states,
self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype),
generator=generator,
)
hidden_states = self.norm2(hidden_states)
if self.timestep_conditioning:
hidden_states = hidden_states * (1 + scale2) + shift2
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, causal=causal)
if self.inject_noise:
hidden_states = self._feed_spatial_noise(
hidden_states,
self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype),
generator=generator,
)
input_tensor = self.norm3(input_tensor)
batch_size = input_tensor.shape[0]
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = input_tensor + hidden_states
return output_tensor
class UNetMidBlock3D(nn.Module):
"""
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
Args:
in_channels (`int`): The number of input channels.
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
resnet_groups (`int`, *optional*, defaults to 32):
The number of groups to use in the group normalization layers of the resnet blocks.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
inject_noise (`bool`, *optional*, defaults to `False`):
Whether to inject noise into the hidden states.
timestep_conditioning (`bool`, *optional*, defaults to `False`):
Whether to condition the hidden states on the timestep.
Returns:
`torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
in_channels, height, width)`.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]],
in_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_groups: int = 32,
norm_layer: NormLayerType = NormLayerType.GROUP_NORM,
inject_noise: bool = False,
timestep_conditioning: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.timestep_conditioning = timestep_conditioning
if timestep_conditioning:
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(embedding_dim=in_channels * 4,
size_emb_dim=0)
self.res_blocks = nn.ModuleList([
ResnetBlock3D(
dims=dims,
in_channels=in_channels,
out_channels=in_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
norm_layer=norm_layer,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
) for _ in range(num_layers)
])
def forward(
self,
hidden_states: torch.Tensor,
causal: bool = True,
timestep: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
timestep_embed = None
if self.timestep_conditioning:
if timestep is None:
raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True")
batch_size = hidden_states.shape[0]
timestep_embed = self.time_embedder(
timestep=timestep.flatten(),
hidden_dtype=hidden_states.dtype,
)
timestep_embed = timestep_embed.view(batch_size, timestep_embed.shape[-1], 1, 1, 1)
for resnet in self.res_blocks:
hidden_states = resnet(
hidden_states,
causal=causal,
timestep=timestep_embed,
generator=generator,
)
return hidden_states
class SpaceToDepthDownsample(nn.Module):
def __init__(
self,
dims: Union[int, Tuple[int, int]],
in_channels: int,
out_channels: int,
stride: Tuple[int, int, int],
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
self.stride = stride
self.group_size = in_channels * math.prod(stride) // out_channels
self.conv = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=out_channels // math.prod(stride),
kernel_size=3,
stride=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
def forward(
self,
x: torch.Tensor,
causal: bool = True,
) -> torch.Tensor:
if self.stride[0] == 2:
x = torch.cat([x[:, :, :1, :, :], x], dim=2) # duplicate first frames for padding
# skip connection
x_in = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size)
x_in = x_in.mean(dim=2)
# conv
x = self.conv(x, causal=causal)
x = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
x = x + x_in
return x
class DepthToSpaceUpsample(nn.Module):
def __init__(
self,
dims: int | Tuple[int, int],
in_channels: int,
stride: Tuple[int, int, int],
residual: bool = False,
out_channels_reduction_factor: int = 1,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
self.stride = stride
self.out_channels = math.prod(stride) * in_channels // out_channels_reduction_factor
self.conv = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.residual = residual
self.out_channels_reduction_factor = out_channels_reduction_factor
def forward(
self,
x: torch.Tensor,
causal: bool = True,
) -> torch.Tensor:
if self.residual:
# Reshape and duplicate the input to match the output shape
x_in = rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
if self.stride[0] == 2:
x_in = x_in[:, :, 1:, :, :]
x = self.conv(x, causal=causal)
x = rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
if self.stride[0] == 2:
x = x[:, :, 1:, :, :]
if self.residual:
x = x + x_in
return x
def compute_trapezoidal_mask_1d(
length: int,
ramp_left: int,
ramp_right: int,
left_starts_from_0: bool = False,
) -> torch.Tensor:
"""
Generate a 1D trapezoidal blending mask with linear ramps.
Args:
length: Output length of the mask.
ramp_left: Fade-in length on the left.
ramp_right: Fade-out length on the right.
left_starts_from_0: Whether the ramp starts from 0 or first non-zero value.
Useful for temporal tiles where the first tile is causal.
Returns:
A 1D tensor of shape `(length,)` with values in [0, 1].
"""
if length <= 0:
raise ValueError("Mask length must be positive.")
ramp_left = max(0, min(ramp_left, length))
ramp_right = max(0, min(ramp_right, length))
mask = torch.ones(length)
if ramp_left > 0:
interval_length = ramp_left + 1 if left_starts_from_0 else ramp_left + 2
fade_in = torch.linspace(0.0, 1.0, interval_length)[:-1]
if not left_starts_from_0:
fade_in = fade_in[1:]
mask[:ramp_left] *= fade_in
if ramp_right > 0:
fade_out = torch.linspace(1.0, 0.0, steps=ramp_right + 2)[1:-1]
mask[-ramp_right:] *= fade_out
return mask.clamp_(0, 1)
@dataclass(frozen=True)
class SpatialTilingConfig:
"""Configuration for dividing each frame into spatial tiles with optional overlap.
Args:
tile_size_in_pixels (int): Size of each tile in pixels. Must be at least 64 and divisible by 32.
tile_overlap_in_pixels (int, optional): Overlap between tiles in pixels. Must be divisible by 32. Defaults to 0.
"""
tile_size_in_pixels: int
tile_overlap_in_pixels: int = 0
def __post_init__(self) -> None:
if self.tile_size_in_pixels < 64:
raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}")
if self.tile_size_in_pixels % 32 != 0:
raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}")
if self.tile_overlap_in_pixels % 32 != 0:
raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}")
if self.tile_overlap_in_pixels >= self.tile_size_in_pixels:
raise ValueError(
f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}"
)
@dataclass(frozen=True)
class TemporalTilingConfig:
"""Configuration for dividing a video into temporal tiles (chunks of frames) with optional overlap.
Args:
tile_size_in_frames (int): Number of frames in each tile. Must be at least 16 and divisible by 8.
tile_overlap_in_frames (int, optional): Number of overlapping frames between consecutive tiles.
Must be divisible by 8. Defaults to 0.
"""
tile_size_in_frames: int
tile_overlap_in_frames: int = 0
def __post_init__(self) -> None:
if self.tile_size_in_frames < 16:
raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}")
if self.tile_size_in_frames % 8 != 0:
raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}")
if self.tile_overlap_in_frames % 8 != 0:
raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}")
if self.tile_overlap_in_frames >= self.tile_size_in_frames:
raise ValueError(
f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}"
)
@dataclass(frozen=True)
class TilingConfig:
"""Configuration for splitting video into tiles with optional overlap.
Attributes:
spatial_config: Configuration for splitting spatial dimensions into tiles.
temporal_config: Configuration for splitting temporal dimension into tiles.
"""
spatial_config: SpatialTilingConfig | None = None
temporal_config: TemporalTilingConfig | None = None
@classmethod
def default(cls) -> "TilingConfig":
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24),
)
@dataclass(frozen=True)
class DimensionIntervals:
"""Intervals which a single dimension of the latent space is split into.
Each interval is defined by its start, end, left ramp, and right ramp.
The start and end are the indices of the first and last element (exclusive) in the interval.
Ramps are regions of the interval where the value of the mask tensor is
interpolated between 0 and 1 for blending with neighboring intervals.
The left ramp and right ramp values are the lengths of the left and right ramps.
"""
starts: List[int]
ends: List[int]
left_ramps: List[int]
right_ramps: List[int]
@dataclass(frozen=True)
class LatentIntervals:
"""Intervals which the latent tensor of given shape is split into.
Each dimension of the latent space is split into intervals based on the length along said dimension.
"""
original_shape: torch.Size
dimension_intervals: Tuple[DimensionIntervals, ...]
# Operation to split a single dimension of the tensor into intervals based on the length along the dimension.
SplitOperation = Callable[[int], DimensionIntervals]
# Operation to map the intervals in input dimension to slices and masks along a corresponding output dimension.
MappingOperation = Callable[[DimensionIntervals], tuple[list[slice], list[torch.Tensor | None]]]
def default_split_operation(length: int) -> DimensionIntervals:
return DimensionIntervals(starts=[0], ends=[length], left_ramps=[0], right_ramps=[0])
DEFAULT_SPLIT_OPERATION: SplitOperation = default_split_operation
def default_mapping_operation(_intervals: DimensionIntervals,) -> tuple[list[slice], list[torch.Tensor | None]]:
return [slice(0, None)], [None]
DEFAULT_MAPPING_OPERATION: MappingOperation = default_mapping_operation
class Tile(NamedTuple):
"""
Represents a single tile.
Attributes:
in_coords:
Tuple of slices specifying where to cut the tile from the INPUT tensor.
out_coords:
Tuple of slices specifying where this tile's OUTPUT should be placed in the reconstructed OUTPUT tensor.
masks_1d:
Per-dimension masks in OUTPUT units.
These are used to create all-dimensional blending mask.
Methods:
blend_mask:
Create a single N-D mask from the per-dimension masks.
"""
in_coords: Tuple[slice, ...]
out_coords: Tuple[slice, ...]
masks_1d: Tuple[Tuple[torch.Tensor, ...]]
@property
def blend_mask(self) -> torch.Tensor:
num_dims = len(self.out_coords)
per_dimension_masks: List[torch.Tensor] = []
for dim_idx in range(num_dims):
mask_1d = self.masks_1d[dim_idx]
view_shape = [1] * num_dims
if mask_1d is None:
# Broadcast mask along this dimension (length 1).
one = torch.ones(1)
view_shape[dim_idx] = 1
per_dimension_masks.append(one.view(*view_shape))
continue
# Reshape (L,) -> (1, ..., L, ..., 1) so masks across dimensions broadcast-multiply.
view_shape[dim_idx] = mask_1d.shape[0]
per_dimension_masks.append(mask_1d.view(*view_shape))
# Multiply per-dimension masks to form the full N-D mask (separable blending window).
combined_mask = per_dimension_masks[0]
for mask in per_dimension_masks[1:]:
combined_mask = combined_mask * mask
return combined_mask
def create_tiles_from_intervals_and_mappers(
intervals: LatentIntervals,
mappers: List[MappingOperation],
) -> List[Tile]:
full_dim_input_slices = []
full_dim_output_slices = []
full_dim_masks_1d = []
for axis_index in range(len(intervals.original_shape)):
dimension_intervals = intervals.dimension_intervals[axis_index]
starts = dimension_intervals.starts
ends = dimension_intervals.ends
input_slices = [slice(s, e) for s, e in zip(starts, ends, strict=True)]
output_slices, masks_1d = mappers[axis_index](dimension_intervals)
full_dim_input_slices.append(input_slices)
full_dim_output_slices.append(output_slices)
full_dim_masks_1d.append(masks_1d)
tiles = []
tile_in_coords = list(itertools.product(*full_dim_input_slices))
tile_out_coords = list(itertools.product(*full_dim_output_slices))
tile_mask_1ds = list(itertools.product(*full_dim_masks_1d))
for in_coord, out_coord, mask_1d in zip(tile_in_coords, tile_out_coords, tile_mask_1ds, strict=True):
tiles.append(Tile(
in_coords=in_coord,
out_coords=out_coord,
masks_1d=mask_1d,
))
return tiles
def create_tiles(
latent_shape: torch.Size,
splitters: List[SplitOperation],
mappers: List[MappingOperation],
) -> List[Tile]:
if len(splitters) != len(latent_shape):
raise ValueError(f"Number of splitters must be equal to number of dimensions in latent shape, "
f"got {len(splitters)} and {len(latent_shape)}")
if len(mappers) != len(latent_shape):
raise ValueError(f"Number of mappers must be equal to number of dimensions in latent shape, "
f"got {len(mappers)} and {len(latent_shape)}")
intervals = [splitter(length) for splitter, length in zip(splitters, latent_shape, strict=True)]
latent_intervals = LatentIntervals(original_shape=latent_shape, dimension_intervals=tuple(intervals))
return create_tiles_from_intervals_and_mappers(latent_intervals, mappers)
def _make_encoder_block(
block_name: str,
block_config: dict[str, Any],
in_channels: int,
convolution_dimensions: int,
norm_layer: NormLayerType,
norm_num_groups: int,
spatial_padding_mode: PaddingModeType,
) -> Tuple[nn.Module, int]:
out_channels = in_channels
if block_name == "res_x":
block = UNetMidBlock3D(
dims=convolution_dimensions,
in_channels=in_channels,
num_layers=block_config["num_layers"],
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
out_channels = in_channels * block_config.get("multiplier", 2)
block = ResnetBlock3D(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
block = make_conv_nd(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=(2, 1, 1),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
block = make_conv_nd(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=(1, 2, 2),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
block = make_conv_nd(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=(2, 2, 2),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all_x_y":
out_channels = in_channels * block_config.get("multiplier", 2)
block = make_conv_nd(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=(2, 2, 2),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all_res":
out_channels = in_channels * block_config.get("multiplier", 2)
block = SpaceToDepthDownsample(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
stride=(2, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space_res":
out_channels = in_channels * block_config.get("multiplier", 2)
block = SpaceToDepthDownsample(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time_res":
out_channels = in_channels * block_config.get("multiplier", 2)
block = SpaceToDepthDownsample(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"unknown block: {block_name}")
return block, out_channels
class LTX2VideoEncoder(nn.Module):
_DEFAULT_NORM_NUM_GROUPS = 32
"""
Variational Autoencoder Encoder. Encodes video frames into a latent representation.
The encoder compresses the input video through a series of downsampling operations controlled by
patch_size and encoder_blocks. The output is a normalized latent tensor with shape (B, 128, F', H', W').
Compression Behavior:
The total compression is determined by:
1. Initial spatial compression via patchify: H -> H/4, W -> W/4 (patch_size=4)
2. Sequential compression through encoder_blocks based on their stride patterns
Compression blocks apply 2x compression in specified dimensions:
- "compress_time" / "compress_time_res": temporal only
- "compress_space" / "compress_space_res": spatial only (H and W)
- "compress_all" / "compress_all_res": all dimensions (F, H, W)
- "res_x" / "res_x_y": no compression
Standard LTX Video configuration:
- patch_size=4
- encoder_blocks: 1x compress_space_res, 1x compress_time_res, 2x compress_all_res
- Final dimensions: F' = 1 + (F-1)/8, H' = H/32, W' = W/32
- Example: (B, 3, 33, 512, 512) -> (B, 128, 5, 16, 16)
- Note: Input must have 1 + 8*k frames (e.g., 1, 9, 17, 25, 33...)
Args:
convolution_dimensions: The number of dimensions to use in convolutions (2D or 3D).
in_channels: The number of input channels. For RGB images, this is 3.
out_channels: The number of output channels (latent channels). For latent channels, this is 128.
encoder_blocks: The list of blocks to construct the encoder. Each block is a tuple of (block_name, params)
where params is either an int (num_layers) or a dict with configuration.
patch_size: The patch size for initial spatial compression. Should be a power of 2.
norm_layer: The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
latent_log_var: The log variance mode. Can be either `per_channel`, `uniform`, `constant` or `none`.
"""
def __init__(
self,
convolution_dimensions: int = 3,
in_channels: int = 3,
out_channels: int = 128,
patch_size: int = 4,
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
latent_log_var: LogVarianceType = LogVarianceType.UNIFORM,
encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
encoder_blocks = [['res_x', {
'num_layers': 4
}], ['compress_space_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 6
}], ['compress_time_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 6
}], ['compress_all_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 2
}], ['compress_all_res', {
'multiplier': 2
}], ['res_x', {
'num_layers': 2
}]]
self.patch_size = patch_size
self.norm_layer = norm_layer
self.latent_channels = out_channels
self.latent_log_var = latent_log_var
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
# Per-channel statistics for normalizing latents
self.per_channel_statistics = PerChannelStatistics(latent_channels=out_channels)
in_channels = in_channels * patch_size**2
feature_channels = out_channels
self.conv_in = make_conv_nd(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=feature_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=encoder_spatial_padding_mode,
)
self.down_blocks = nn.ModuleList([])
for block_name, block_params in encoder_blocks:
# Convert int to dict format for uniform handling
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block, feature_channels = _make_encoder_block(
block_name=block_name,
block_config=block_config,
in_channels=feature_channels,
convolution_dimensions=convolution_dimensions,
norm_layer=norm_layer,
norm_num_groups=self._norm_num_groups,
spatial_padding_mode=encoder_spatial_padding_mode,
)
self.down_blocks.append(block)
# out
if norm_layer == NormLayerType.GROUP_NORM:
self.conv_norm_out = nn.GroupNorm(num_channels=feature_channels, num_groups=self._norm_num_groups, eps=1e-6)
elif norm_layer == NormLayerType.PIXEL_NORM:
self.conv_norm_out = PixelNorm()
self.conv_act = nn.SiLU()
conv_out_channels = out_channels
if latent_log_var == LogVarianceType.PER_CHANNEL:
conv_out_channels *= 2
elif latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
conv_out_channels += 1
elif latent_log_var != LogVarianceType.NONE:
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
self.conv_out = make_conv_nd(
dims=convolution_dimensions,
in_channels=feature_channels,
out_channels=conv_out_channels,
kernel_size=3,
padding=1,
causal=True,
spatial_padding_mode=encoder_spatial_padding_mode,
)
def forward(self, sample: torch.Tensor) -> torch.Tensor:
r"""
Encode video frames into normalized latent representation.
Args:
sample: Input video (B, C, F, H, W). F must be 1 + 8*k (e.g., 1, 9, 17, 25, 33...).
Returns:
Normalized latent means (B, 128, F', H', W') where F' = 1+(F-1)/8, H' = H/32, W' = W/32.
Example: (B, 3, 33, 512, 512) -> (B, 128, 5, 16, 16).
"""
# Validate frame count
frames_count = sample.shape[2]
if ((frames_count - 1) % 8) != 0:
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames "
"(e.g., 1, 9, 17, ...). Please check your input.")
# Initial spatial compression: trade spatial resolution for channel depth
# This reduces H,W by patch_size and increases channels, making convolutions more efficient
# Example: (B, 3, F, 512, 512) -> (B, 48, F, 128, 128) with patch_size=4
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
sample = self.conv_in(sample)
for down_block in self.down_blocks:
sample = down_block(sample)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if self.latent_log_var == LogVarianceType.UNIFORM:
# Uniform Variance: model outputs N means and 1 shared log-variance channel.
# We need to expand the single logvar to match the number of means channels
# to create a format compatible with PER_CHANNEL (means + logvar, each with N channels).
# Sample shape: (B, N+1, ...) where N = latent_channels (e.g., 128 means + 1 logvar = 129)
# Target shape: (B, 2*N, ...) where first N are means, last N are logvar
if sample.shape[1] < 2:
raise ValueError(f"Invalid channel count for UNIFORM mode: expected at least 2 channels "
f"(N means + 1 logvar), got {sample.shape[1]}")
# Extract means (first N channels) and logvar (last 1 channel)
means = sample[:, :-1, ...] # (B, N, ...)
logvar = sample[:, -1:, ...] # (B, 1, ...)
# Repeat logvar N times to match means channels
# Use expand/repeat pattern that works for both 4D and 5D tensors
num_channels = means.shape[1]
repeat_shape = [1, num_channels] + [1] * (sample.ndim - 2)
repeated_logvar = logvar.repeat(*repeat_shape) # (B, N, ...)
# Concatenate to create (B, 2*N, ...) format: [means, repeated_logvar]
sample = torch.cat([means, repeated_logvar], dim=1)
elif self.latent_log_var == LogVarianceType.CONSTANT:
sample = sample[:, :-1, ...]
approx_ln_0 = -30 # this is the minimal clamp value in DiagonalGaussianDistribution objects
sample = torch.cat(
[sample, torch.ones_like(sample, device=sample.device) * approx_ln_0],
dim=1,
)
# Split into means and logvar, then normalize means
means, _ = torch.chunk(sample, 2, dim=1)
return self.per_channel_statistics.normalize(means)
def tiled_encode_video(
self,
video: torch.Tensor,
tile_size: int = 512,
tile_overlap: int = 128,
) -> torch.Tensor:
"""Encode video using spatial tiling for memory efficiency.
Splits the video into overlapping spatial tiles, encodes each tile separately,
and blends the results using linear feathering in the overlap regions.
Args:
video: Input tensor of shape [B, C, F, H, W]
tile_size: Tile size in pixels (must be divisible by 32)
tile_overlap: Overlap between tiles in pixels (must be divisible by 32)
Returns:
Encoded latent tensor [B, C_latent, F_latent, H_latent, W_latent]
"""
batch, _channels, frames, height, width = video.shape
device = video.device
dtype = video.dtype
# Validate tile parameters
if tile_size % VAE_SPATIAL_FACTOR != 0:
raise ValueError(f"tile_size must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_size}")
if tile_overlap % VAE_SPATIAL_FACTOR != 0:
raise ValueError(f"tile_overlap must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_overlap}")
if tile_overlap >= tile_size:
raise ValueError(f"tile_overlap ({tile_overlap}) must be less than tile_size ({tile_size})")
# If video fits in a single tile, use regular encoding
if height <= tile_size and width <= tile_size:
return self.forward(video)
# Calculate output dimensions
# VAE compresses: H -> H/32, W -> W/32, F -> 1 + (F-1)/8
output_height = height // VAE_SPATIAL_FACTOR
output_width = width // VAE_SPATIAL_FACTOR
output_frames = 1 + (frames - 1) // VAE_TEMPORAL_FACTOR
# Latent channels (128 for LTX-2)
# Get from a small test encode or assume 128
latent_channels = 128
# Initialize output and weight tensors
output = torch.zeros(
(batch, latent_channels, output_frames, output_height, output_width),
device=device,
dtype=dtype,
)
weights = torch.zeros(
(batch, 1, output_frames, output_height, output_width),
device=device,
dtype=dtype,
)
# Calculate tile positions with overlap
# Step size is tile_size - tile_overlap
step_h = tile_size - tile_overlap
step_w = tile_size - tile_overlap
h_positions = list(range(0, max(1, height - tile_overlap), step_h))
w_positions = list(range(0, max(1, width - tile_overlap), step_w))
# Ensure last tile covers the edge
if h_positions[-1] + tile_size < height:
h_positions.append(height - tile_size)
if w_positions[-1] + tile_size < width:
w_positions.append(width - tile_size)
# Remove duplicates and sort
h_positions = sorted(set(h_positions))
w_positions = sorted(set(w_positions))
# Overlap in latent space
overlap_out_h = tile_overlap // VAE_SPATIAL_FACTOR
overlap_out_w = tile_overlap // VAE_SPATIAL_FACTOR
# Process each tile
for h_pos in h_positions:
for w_pos in w_positions:
# Calculate tile boundaries in input space
h_start = max(0, h_pos)
w_start = max(0, w_pos)
h_end = min(h_start + tile_size, height)
w_end = min(w_start + tile_size, width)
# Ensure tile dimensions are divisible by VAE_SPATIAL_FACTOR
tile_h = ((h_end - h_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR
tile_w = ((w_end - w_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR
if tile_h < VAE_SPATIAL_FACTOR or tile_w < VAE_SPATIAL_FACTOR:
continue
# Adjust end positions
h_end = h_start + tile_h
w_end = w_start + tile_w
# Extract tile
tile = video[:, :, :, h_start:h_end, w_start:w_end]
# Encode tile
encoded_tile = self.forward(tile)
# Get actual encoded dimensions
_, _, tile_out_frames, tile_out_height, tile_out_width = encoded_tile.shape
# Calculate output positions
out_h_start = h_start // VAE_SPATIAL_FACTOR
out_w_start = w_start // VAE_SPATIAL_FACTOR
out_h_end = min(out_h_start + tile_out_height, output_height)
out_w_end = min(out_w_start + tile_out_width, output_width)
# Trim encoded tile if necessary
actual_tile_h = out_h_end - out_h_start
actual_tile_w = out_w_end - out_w_start
encoded_tile = encoded_tile[:, :, :, :actual_tile_h, :actual_tile_w]
# Create blending mask with linear feathering at edges
mask = torch.ones(
(1, 1, tile_out_frames, actual_tile_h, actual_tile_w),
device=device,
dtype=dtype,
)
# Apply feathering at edges (linear blend in overlap regions)
# Left edge
if h_pos > 0 and overlap_out_h > 0 and overlap_out_h < actual_tile_h:
fade_in = torch.linspace(0.0, 1.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1]
mask[:, :, :, :overlap_out_h, :] *= fade_in.view(1, 1, 1, -1, 1)
# Right edge (bottom in height dimension)
if h_end < height and overlap_out_h > 0 and overlap_out_h < actual_tile_h:
fade_out = torch.linspace(1.0, 0.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1]
mask[:, :, :, -overlap_out_h:, :] *= fade_out.view(1, 1, 1, -1, 1)
# Top edge (left in width dimension)
if w_pos > 0 and overlap_out_w > 0 and overlap_out_w < actual_tile_w:
fade_in = torch.linspace(0.0, 1.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1]
mask[:, :, :, :, :overlap_out_w] *= fade_in.view(1, 1, 1, 1, -1)
# Bottom edge (right in width dimension)
if w_end < width and overlap_out_w > 0 and overlap_out_w < actual_tile_w:
fade_out = torch.linspace(1.0, 0.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1]
mask[:, :, :, :, -overlap_out_w:] *= fade_out.view(1, 1, 1, 1, -1)
# Accumulate weighted results
output[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += encoded_tile * mask
weights[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += mask
# Normalize by weights (avoid division by zero)
output = output / (weights + 1e-8)
return output
def encode(
self,
video: torch.Tensor,
tiled=False,
tile_size_in_pixels: Optional[int] = 512,
tile_overlap_in_pixels: Optional[int] = 128,
**kwargs,
) -> torch.Tensor:
device = next(self.parameters()).device
vae_dtype = next(self.parameters()).dtype
if video.ndim == 4:
video = video.unsqueeze(0) # [C, F, H, W] -> [B, C, F, H, W]
video = video.to(device=device, dtype=vae_dtype)
# Choose encoding method based on tiling flag
if tiled:
latents = self.tiled_encode_video(
video=video,
tile_size=tile_size_in_pixels,
tile_overlap=tile_overlap_in_pixels,
)
else:
# Encode video - VAE expects [B, C, F, H, W], returns [B, C, F', H', W']
latents = self.forward(video)
return latents
def _make_decoder_block(
block_name: str,
block_config: dict[str, Any],
in_channels: int,
convolution_dimensions: int,
norm_layer: NormLayerType,
timestep_conditioning: bool,
norm_num_groups: int,
spatial_padding_mode: PaddingModeType,
) -> Tuple[nn.Module, int]:
out_channels = in_channels
if block_name == "res_x":
block = UNetMidBlock3D(
dims=convolution_dimensions,
in_channels=in_channels,
num_layers=block_config["num_layers"],
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_config.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "attn_res_x":
block = UNetMidBlock3D(
dims=convolution_dimensions,
in_channels=in_channels,
num_layers=block_config["num_layers"],
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_config.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
attention_head_dim=block_config["attention_head_dim"],
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
out_channels = in_channels // block_config.get("multiplier", 2)
block = ResnetBlock3D(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_config.get("inject_noise", False),
timestep_conditioning=False,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
block = DepthToSpaceUpsample(
dims=convolution_dimensions,
in_channels=in_channels,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
block = DepthToSpaceUpsample(
dims=convolution_dimensions,
in_channels=in_channels,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
out_channels = in_channels // block_config.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=convolution_dimensions,
in_channels=in_channels,
stride=(2, 2, 2),
residual=block_config.get("residual", False),
out_channels_reduction_factor=block_config.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"unknown layer: {block_name}")
return block, out_channels
class LTX2VideoDecoder(nn.Module):
_DEFAULT_NORM_NUM_GROUPS = 32
"""
Variational Autoencoder Decoder. Decodes latent representation into video frames.
The decoder upsamples latents through a series of upsampling operations (inverse of encoder).
Output dimensions: F = 8x(F'-1) + 1, H = 32xH', W = 32xW' for standard LTX Video configuration.
Upsampling blocks expand dimensions by 2x in specified dimensions:
- "compress_time": temporal only
- "compress_space": spatial only (H and W)
- "compress_all": all dimensions (F, H, W)
- "res_x" / "res_x_y" / "attn_res_x": no upsampling
Causal Mode:
causal=False (standard): Symmetric padding, allows future frame dependencies.
causal=True: Causal padding, each frame depends only on past/current frames.
First frame removed after temporal upsampling in both modes. Output shape unchanged.
Example: (B, 128, 5, 16, 16) -> (B, 3, 33, 512, 512) for both modes.
Args:
convolution_dimensions: The number of dimensions to use in convolutions (2D or 3D).
in_channels: The number of input channels (latent channels). Default is 128.
out_channels: The number of output channels. For RGB images, this is 3.
decoder_blocks: The list of blocks to construct the decoder. Each block is a tuple of (block_name, params)
where params is either an int (num_layers) or a dict with configuration.
patch_size: Final spatial expansion factor. For standard LTX Video, use 4 for 4x spatial expansion:
H -> Hx4, W -> Wx4. Should be a power of 2.
norm_layer: The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
causal: Whether to use causal convolutions. For standard LTX Video, use False for symmetric padding.
When True, uses causal padding (past/current frames only).
timestep_conditioning: Whether to condition the decoder on timestep for denoising.
"""
def __init__(
self,
convolution_dimensions: int = 3,
in_channels: int = 128,
out_channels: int = 3,
decoder_blocks: List[Tuple[str, int | dict]] = [], # noqa: B006
patch_size: int = 4,
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
causal: bool = False,
timestep_conditioning: bool = False,
decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
):
super().__init__()
# Spatiotemporal downscaling between decoded video space and VAE latents.
# According to the LTXV paper, the standard configuration downsamples
# video inputs by a factor of 8 in the temporal dimension and 32 in
# each spatial dimension (height and width). This parameter determines how
# many video frames and pixels correspond to a single latent cell.
decoder_blocks = [['res_x', {
'num_layers': 5,
'inject_noise': False
}], ['compress_all', {
'residual': True,
'multiplier': 2
}], ['res_x', {
'num_layers': 5,
'inject_noise': False
}], ['compress_all', {
'residual': True,
'multiplier': 2
}], ['res_x', {
'num_layers': 5,
'inject_noise': False
}], ['compress_all', {
'residual': True,
'multiplier': 2
}], ['res_x', {
'num_layers': 5,
'inject_noise': False
}]]
self.video_downscale_factors = SpatioTemporalScaleFactors(
time=8,
width=32,
height=32,
)
self.patch_size = patch_size
out_channels = out_channels * patch_size**2
self.causal = causal
self.timestep_conditioning = timestep_conditioning
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
# Per-channel statistics for denormalizing latents
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
# Noise and timestep parameters for decoder conditioning
self.decode_noise_scale = 0.025
self.decode_timestep = 0.05
# Compute initial feature_channels by going through blocks in reverse
# This determines the channel width at the start of the decoder
feature_channels = in_channels
for block_name, block_params in list(reversed(decoder_blocks)):
block_config = block_params if isinstance(block_params, dict) else {}
if block_name == "res_x_y":
feature_channels = feature_channels * block_config.get("multiplier", 2)
if block_name == "compress_all":
feature_channels = feature_channels * block_config.get("multiplier", 1)
self.conv_in = make_conv_nd(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=feature_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=decoder_spatial_padding_mode,
)
self.up_blocks = nn.ModuleList([])
for block_name, block_params in list(reversed(decoder_blocks)):
# Convert int to dict format for uniform handling
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block, feature_channels = _make_decoder_block(
block_name=block_name,
block_config=block_config,
in_channels=feature_channels,
convolution_dimensions=convolution_dimensions,
norm_layer=norm_layer,
timestep_conditioning=timestep_conditioning,
norm_num_groups=self._norm_num_groups,
spatial_padding_mode=decoder_spatial_padding_mode,
)
self.up_blocks.append(block)
if norm_layer == NormLayerType.GROUP_NORM:
self.conv_norm_out = nn.GroupNorm(num_channels=feature_channels, num_groups=self._norm_num_groups, eps=1e-6)
elif norm_layer == NormLayerType.PIXEL_NORM:
self.conv_norm_out = PixelNorm()
self.conv_act = nn.SiLU()
self.conv_out = make_conv_nd(
dims=convolution_dimensions,
in_channels=feature_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
causal=True,
spatial_padding_mode=decoder_spatial_padding_mode,
)
if timestep_conditioning:
self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0))
self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(embedding_dim=feature_channels * 2,
size_emb_dim=0)
self.last_scale_shift_table = nn.Parameter(torch.empty(2, feature_channels))
def forward(
self,
sample: torch.Tensor,
timestep: torch.Tensor | None = None,
generator: torch.Generator | None = None,
) -> torch.Tensor:
r"""
Decode latent representation into video frames.
Args:
sample: Latent tensor (B, 128, F', H', W').
timestep: Timestep for conditioning (if timestep_conditioning=True). Uses default 0.05 if None.
generator: Random generator for deterministic noise injection (if inject_noise=True in blocks).
Returns:
Decoded video (B, 3, F, H, W) where F = 8x(F'-1) + 1, H = 32xH', W = 32xW'.
Example: (B, 128, 5, 16, 16) -> (B, 3, 33, 512, 512).
Note: First frame is removed after temporal upsampling regardless of causal mode.
When causal=False, allows future frame dependencies in convolutions but maintains same output shape.
"""
batch_size = sample.shape[0]
# Add noise if timestep conditioning is enabled
if self.timestep_conditioning:
noise = (torch.randn(
sample.size(),
generator=generator,
dtype=sample.dtype,
device=sample.device,
) * self.decode_noise_scale)
sample = noise + (1.0 - self.decode_noise_scale) * sample
# Denormalize latents
sample = self.per_channel_statistics.un_normalize(sample)
# Use default decode_timestep if timestep not provided
if timestep is None and self.timestep_conditioning:
timestep = torch.full((batch_size,), self.decode_timestep, device=sample.device, dtype=sample.dtype)
sample = self.conv_in(sample, causal=self.causal)
scaled_timestep = None
if self.timestep_conditioning:
if timestep is None:
raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True")
scaled_timestep = timestep * self.timestep_scale_multiplier.to(sample)
for up_block in self.up_blocks:
if isinstance(up_block, UNetMidBlock3D):
block_kwargs = {
"causal": self.causal,
"timestep": scaled_timestep if self.timestep_conditioning else None,
"generator": generator,
}
sample = up_block(sample, **block_kwargs)
elif isinstance(up_block, ResnetBlock3D):
sample = up_block(sample, causal=self.causal, generator=generator)
else:
sample = up_block(sample, causal=self.causal)
sample = self.conv_norm_out(sample)
if self.timestep_conditioning:
embedded_timestep = self.last_time_embedder(
timestep=scaled_timestep.flatten(),
hidden_dtype=sample.dtype,
)
embedded_timestep = embedded_timestep.view(batch_size, embedded_timestep.shape[-1], 1, 1, 1)
ada_values = self.last_scale_shift_table[None, ..., None, None, None].to(
device=sample.device, dtype=sample.dtype) + embedded_timestep.reshape(
batch_size,
2,
-1,
embedded_timestep.shape[-3],
embedded_timestep.shape[-2],
embedded_timestep.shape[-1],
)
shift, scale = ada_values.unbind(dim=1)
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=self.causal)
# Final spatial expansion: reverse the initial patchify from encoder
# Moves pixels from channels back to spatial dimensions
# Example: (B, 48, F, 128, 128) -> (B, 3, F, 512, 512) with patch_size=4
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample
def _prepare_tiles(
self,
latent: torch.Tensor,
tiling_config: TilingConfig | None = None,
) -> List[Tile]:
splitters = [DEFAULT_SPLIT_OPERATION] * len(latent.shape)
mappers = [DEFAULT_MAPPING_OPERATION] * len(latent.shape)
if tiling_config is not None and tiling_config.spatial_config is not None:
cfg = tiling_config.spatial_config
long_side = max(latent.shape[3], latent.shape[4])
def enable_on_axis(axis_idx: int, factor: int) -> None:
size = cfg.tile_size_in_pixels // factor
overlap = cfg.tile_overlap_in_pixels // factor
axis_length = latent.shape[axis_idx]
lower_threshold = max(2, overlap + 1)
tile_size = max(lower_threshold, round(size * axis_length / long_side))
splitters[axis_idx] = split_in_spatial(tile_size, overlap)
mappers[axis_idx] = to_mapping_operation(map_spatial_slice, factor)
enable_on_axis(3, self.video_downscale_factors.height)
enable_on_axis(4, self.video_downscale_factors.width)
if tiling_config is not None and tiling_config.temporal_config is not None:
cfg = tiling_config.temporal_config
tile_size = cfg.tile_size_in_frames // self.video_downscale_factors.time
overlap = cfg.tile_overlap_in_frames // self.video_downscale_factors.time
splitters[2] = split_in_temporal(tile_size, overlap)
mappers[2] = to_mapping_operation(map_temporal_slice, self.video_downscale_factors.time)
return create_tiles(latent.shape, splitters, mappers)
def tiled_decode(
self,
latent: torch.Tensor,
tiling_config: TilingConfig | None = None,
timestep: torch.Tensor | None = None,
generator: torch.Generator | None = None,
) -> Iterator[torch.Tensor]:
"""
Decode a latent tensor into video frames using tiled processing.
Splits the latent tensor into tiles, decodes each tile individually,
and yields video chunks as they become available.
Args:
latent: Input latent tensor (B, C, F', H', W').
tiling_config: Tiling configuration for the latent tensor.
timestep: Optional timestep for decoder conditioning.
generator: Optional random generator for deterministic decoding.
Yields:
Video chunks (B, C, T, H, W) by temporal slices;
"""
# Calculate full video shape from latent shape to get spatial dimensions
full_video_shape = VideoLatentShape.from_torch_shape(latent.shape).upscale(self.video_downscale_factors)
tiles = self._prepare_tiles(latent, tiling_config)
temporal_groups = self._group_tiles_by_temporal_slice(tiles)
# State for temporal overlap handling
previous_chunk = None
previous_weights = None
previous_temporal_slice = None
for temporal_group_tiles in temporal_groups:
curr_temporal_slice = temporal_group_tiles[0].out_coords[2]
# Calculate the shape of the temporal buffer for this group of tiles.
# The temporal length depends on whether this is the first tile (starts at 0) or not.
# - First tile: (frames - 1) * scale + 1
# - Subsequent tiles: frames * scale
# This logic is handled by TemporalAxisMapping and reflected in out_coords.
temporal_tile_buffer_shape = full_video_shape._replace(frames=curr_temporal_slice.stop -
curr_temporal_slice.start,)
buffer = torch.zeros(
temporal_tile_buffer_shape.to_torch_shape(),
device=latent.device,
dtype=latent.dtype,
)
curr_weights = self._accumulate_temporal_group_into_buffer(
group_tiles=temporal_group_tiles,
buffer=buffer,
latent=latent,
timestep=timestep,
generator=generator,
)
# Blend with previous temporal chunk if it exists
if previous_chunk is not None:
# Check if current temporal slice overlaps with previous temporal slice
if previous_temporal_slice.stop > curr_temporal_slice.start:
overlap_len = previous_temporal_slice.stop - curr_temporal_slice.start
temporal_overlap_slice = slice(curr_temporal_slice.start - previous_temporal_slice.start, None)
# The overlap is already masked before it reaches this step. Each tile is accumulated into buffer
# with its trapezoidal mask, and curr_weights accumulates the same mask. In the overlap blend we add
# the masked values (buffer[...]) and the corresponding weights (curr_weights[...]) into the
# previous buffers, then later normalize by weights.
previous_chunk[:, :, temporal_overlap_slice, :, :] += buffer[:, :, slice(0, overlap_len), :, :]
previous_weights[:, :, temporal_overlap_slice, :, :] += curr_weights[:, :,
slice(0, overlap_len), :, :]
buffer[:, :, slice(0, overlap_len), :, :] = previous_chunk[:, :, temporal_overlap_slice, :, :]
curr_weights[:, :, slice(0, overlap_len), :, :] = previous_weights[:, :,
temporal_overlap_slice, :, :]
# Yield the non-overlapping part of the previous chunk
previous_weights = previous_weights.clamp(min=1e-8)
yield_len = curr_temporal_slice.start - previous_temporal_slice.start
yield (previous_chunk / previous_weights)[:, :, :yield_len, :, :]
# Update state for next iteration
previous_chunk = buffer
previous_weights = curr_weights
previous_temporal_slice = curr_temporal_slice
# Yield any remaining chunk
if previous_chunk is not None:
previous_weights = previous_weights.clamp(min=1e-8)
yield previous_chunk / previous_weights
def _group_tiles_by_temporal_slice(self, tiles: List[Tile]) -> List[List[Tile]]:
"""Group tiles by their temporal output slice."""
if not tiles:
return []
groups = []
current_slice = tiles[0].out_coords[2]
current_group = []
for tile in tiles:
tile_slice = tile.out_coords[2]
if tile_slice == current_slice:
current_group.append(tile)
else:
groups.append(current_group)
current_slice = tile_slice
current_group = [tile]
# Add the final group
if current_group:
groups.append(current_group)
return groups
def _accumulate_temporal_group_into_buffer(
self,
group_tiles: List[Tile],
buffer: torch.Tensor,
latent: torch.Tensor,
timestep: torch.Tensor | None,
generator: torch.Generator | None,
) -> torch.Tensor:
"""
Decode and accumulate all tiles of a temporal group into a local buffer.
The buffer is local to the group and always starts at time 0; temporal coordinates
are rebased by subtracting temporal_slice.start.
"""
temporal_slice = group_tiles[0].out_coords[2]
weights = torch.zeros_like(buffer)
for tile in group_tiles:
decoded_tile = self.forward(latent[tile.in_coords], timestep, generator)
mask = tile.blend_mask.to(device=buffer.device, dtype=buffer.dtype)
temporal_offset = tile.out_coords[2].start - temporal_slice.start
# Use the tile's output coordinate length, not the decoded tile's length,
# as the decoder may produce a different number of frames than expected
expected_temporal_len = tile.out_coords[2].stop - tile.out_coords[2].start
decoded_temporal_len = decoded_tile.shape[2]
# Ensure we don't exceed the buffer or decoded tile bounds
actual_temporal_len = min(expected_temporal_len, decoded_temporal_len, buffer.shape[2] - temporal_offset)
chunk_coords = (
slice(None), # batch
slice(None), # channels
slice(temporal_offset, temporal_offset + actual_temporal_len),
tile.out_coords[3], # height
tile.out_coords[4], # width
)
# Slice decoded_tile and mask to match the actual length we're writing
decoded_slice = decoded_tile[:, :, :actual_temporal_len, :, :]
mask_slice = mask[:, :, :actual_temporal_len, :, :] if mask.shape[2] > 1 else mask
buffer[chunk_coords] += decoded_slice * mask_slice
weights[chunk_coords] += mask_slice
return weights
def decode(
self,
latent: torch.Tensor,
tiled=False,
tile_size_in_pixels: Optional[int] = 512,
tile_overlap_in_pixels: Optional[int] = 128,
tile_size_in_frames: Optional[int] = 128,
tile_overlap_in_frames: Optional[int] = 24,
) -> torch.Tensor:
if tiled:
tiling_config = TilingConfig(
spatial_config=SpatialTilingConfig(
tile_size_in_pixels=tile_size_in_pixels,
tile_overlap_in_pixels=tile_overlap_in_pixels,
),
temporal_config=TemporalTilingConfig(
tile_size_in_frames=tile_size_in_frames,
tile_overlap_in_frames=tile_overlap_in_frames,
),
)
tiles = self.tiled_decode(latent, tiling_config)
return torch.cat(list(tiles), dim=2)
else:
return self.forward(latent)
def decode_video(
latent: torch.Tensor,
video_decoder: LTX2VideoDecoder,
tiling_config: TilingConfig | None = None,
generator: torch.Generator | None = None,
) -> Iterator[torch.Tensor]:
"""
Decode a video latent tensor with the given decoder.
Args:
latent: Tensor [c, f, h, w]
video_decoder: Decoder module.
tiling_config: Optional tiling settings.
generator: Optional random generator for deterministic decoding.
Yields:
Decoded chunk [f, h, w, c], uint8 in [0, 255].
"""
def convert_to_uint8(frames: torch.Tensor) -> torch.Tensor:
frames = (((frames + 1.0) / 2.0).clamp(0.0, 1.0) * 255.0).to(torch.uint8)
frames = rearrange(frames[0], "c f h w -> f h w c")
return frames
if tiling_config is not None:
for frames in video_decoder.tiled_decode(latent, tiling_config, generator=generator):
return convert_to_uint8(frames)
else:
decoded_video = video_decoder(latent, generator=generator)
return convert_to_uint8(decoded_video)
def get_video_chunks_number(num_frames: int, tiling_config: TilingConfig | None = None) -> int:
"""
Get the number of video chunks for a given number of frames and tiling configuration.
Args:
num_frames: Number of frames in the video.
tiling_config: Tiling configuration.
Returns:
Number of video chunks.
"""
if not tiling_config or not tiling_config.temporal_config:
return 1
cfg = tiling_config.temporal_config
frame_stride = cfg.tile_size_in_frames - cfg.tile_overlap_in_frames
return (num_frames - 1 + frame_stride - 1) // frame_stride
def split_in_spatial(size: int, overlap: int) -> SplitOperation:
def split(dimension_size: int) -> DimensionIntervals:
if dimension_size <= size:
return DEFAULT_SPLIT_OPERATION(dimension_size)
amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap)
starts = [i * (size - overlap) for i in range(amount)]
ends = [start + size for start in starts]
ends[-1] = dimension_size
left_ramps = [0] + [overlap] * (amount - 1)
right_ramps = [overlap] * (amount - 1) + [0]
return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps)
return split
def split_in_temporal(size: int, overlap: int) -> SplitOperation:
non_causal_split = split_in_spatial(size, overlap)
def split(dimension_size: int) -> DimensionIntervals:
if dimension_size <= size:
return DEFAULT_SPLIT_OPERATION(dimension_size)
intervals = non_causal_split(dimension_size)
starts = intervals.starts
starts[1:] = [s - 1 for s in starts[1:]]
left_ramps = intervals.left_ramps
left_ramps[1:] = [r + 1 for r in left_ramps[1:]]
return replace(intervals, starts=starts, left_ramps=left_ramps)
return split
def to_mapping_operation(
map_func: Callable[[int, int, int, int, int], Tuple[slice, torch.Tensor]],
scale: int,
) -> MappingOperation:
def map_op(intervals: DimensionIntervals) -> tuple[list[slice], list[torch.Tensor | None]]:
output_slices: list[slice] = []
masks_1d: list[torch.Tensor | None] = []
number_of_slices = len(intervals.starts)
for i in range(number_of_slices):
start = intervals.starts[i]
end = intervals.ends[i]
left_ramp = intervals.left_ramps[i]
right_ramp = intervals.right_ramps[i]
output_slice, mask_1d = map_func(start, end, left_ramp, right_ramp, scale)
output_slices.append(output_slice)
masks_1d.append(mask_1d)
return output_slices, masks_1d
return map_op
def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, torch.Tensor]:
start = begin * scale
stop = 1 + (end - 1) * scale
left_ramp = 1 + (left_ramp - 1) * scale
right_ramp = right_ramp * scale
return slice(start, stop), compute_trapezoidal_mask_1d(stop - start, left_ramp, right_ramp, True)
def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, torch.Tensor]:
start = begin * scale
stop = end * scale
left_ramp = left_ramp * scale
right_ramp = right_ramp * scale
return slice(start, stop), compute_trapezoidal_mask_1d(stop - start, left_ramp, right_ramp, False)