support ltx2 one-stage pipeline

This commit is contained in:
mi804
2026-01-29 16:30:15 +08:00
parent 8d303b47e9
commit b1a2782ad7
7 changed files with 1005 additions and 7 deletions

View File

@@ -337,3 +337,35 @@ class Patchifier(Protocol):
Returns:
Tensor containing patch coordinate metadata such as spatial or temporal intervals.
"""
def get_pixel_coords(
latent_coords: torch.Tensor,
scale_factors: SpatioTemporalScaleFactors,
causal_fix: bool = False,
) -> torch.Tensor:
"""
Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
Args:
latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
per axis.
causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
that treat frame zero differently still yield non-negative timestamps.
"""
# Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
broadcast_shape = [1] * latent_coords.ndim
broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
# Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
pixel_coords = latent_coords * scale_tensor
if causal_fix:
# VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
# Shift and clamp to keep the first-frame timestamps causal and non-negative.
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
return pixel_coords

View File

@@ -514,7 +514,7 @@ class Attention(torch.nn.Module):
out_pattern="b s n d",
attn_mask=mask
)
# Reshape back to original format
out = out.flatten(2, 3)
return self.to_out(out)
@@ -1398,7 +1398,7 @@ class LTXModel(torch.nn.Module):
x = proj_out(x)
return x
def forward(
def _forward(
self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig
) -> tuple[torch.Tensor, torch.Tensor]:
"""
@@ -1440,3 +1440,9 @@ class LTXModel(torch.nn.Module):
else None
)
return vx, ax
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps):
video = Modality(video_latents, video_timesteps, video_positions, video_context)
audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context)
vx, ax = self._forward(video=video, audio=audio, perturbations=None)
return vx, ax

View File

@@ -1,5 +1,6 @@
import itertools
import math
import einops
from dataclasses import replace, dataclass
from typing import Any, Callable, Iterator, List, NamedTuple, Tuple, Union, Optional
import torch
@@ -7,9 +8,138 @@ from einops import rearrange
from torch import nn
from torch.nn import functional as F
from enum import Enum
from .ltx2_common import PixelNorm, SpatioTemporalScaleFactors, VideoLatentShape
from .ltx2_common import PixelNorm, SpatioTemporalScaleFactors, VideoLatentShape, Patchifier, AudioLatentShape
from .ltx2_dit import PixArtAlphaCombinedTimestepSizeEmbeddings
VAE_SPATIAL_FACTOR = 32
VAE_TEMPORAL_FACTOR = 8
class VideoLatentPatchifier(Patchifier):
def __init__(self, patch_size: int):
# Patch sizes for video latents.
self._patch_size = (
1, # temporal dimension
patch_size, # height dimension
patch_size, # width dimension
)
@property
def patch_size(self) -> Tuple[int, int, int]:
return self._patch_size
def get_token_count(self, tgt_shape: VideoLatentShape) -> int:
return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size)
def patchify(
self,
latents: torch.Tensor,
) -> torch.Tensor:
latents = einops.rearrange(
latents,
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
p1=self._patch_size[0],
p2=self._patch_size[1],
p3=self._patch_size[2],
)
return latents
def unpatchify(
self,
latents: torch.Tensor,
output_shape: VideoLatentShape,
) -> torch.Tensor:
assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier"
patch_grid_frames = output_shape.frames // self._patch_size[0]
patch_grid_height = output_shape.height // self._patch_size[1]
patch_grid_width = output_shape.width // self._patch_size[2]
latents = einops.rearrange(
latents,
"b (f h w) (c p q) -> b c f (h p) (w q)",
f=patch_grid_frames,
h=patch_grid_height,
w=patch_grid_width,
p=self._patch_size[1],
q=self._patch_size[2],
)
return latents
def get_patch_grid_bounds(
self,
output_shape: AudioLatentShape | VideoLatentShape,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""
Return the per-dimension bounds [inclusive start, exclusive end) for every
patch produced by `patchify`. The bounds are expressed in the original
video grid coordinates: frame/time, height, and width.
The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where:
- axis 1 (size 3) enumerates (frame/time, height, width) dimensions
- axis 3 (size 2) stores `[start, end)` indices within each dimension
Args:
output_shape: Video grid description containing frames, height, and width.
device: Device of the latent tensor.
"""
if not isinstance(output_shape, VideoLatentShape):
raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates")
frames = output_shape.frames
height = output_shape.height
width = output_shape.width
batch_size = output_shape.batch
# Validate inputs to ensure positive dimensions
assert frames > 0, f"frames must be positive, got {frames}"
assert height > 0, f"height must be positive, got {height}"
assert width > 0, f"width must be positive, got {width}"
assert batch_size > 0, f"batch_size must be positive, got {batch_size}"
# Generate grid coordinates for each dimension (frame, height, width)
# We use torch.arange to create the starting coordinates for each patch.
# indexing='ij' ensures the dimensions are in the order (frame, height, width).
grid_coords = torch.meshgrid(
torch.arange(start=0, end=frames, step=self._patch_size[0], device=device),
torch.arange(start=0, end=height, step=self._patch_size[1], device=device),
torch.arange(start=0, end=width, step=self._patch_size[2], device=device),
indexing="ij",
)
# Stack the grid coordinates to create the start coordinates tensor.
# Shape becomes (3, grid_f, grid_h, grid_w)
patch_starts = torch.stack(grid_coords, dim=0)
# Create a tensor containing the size of a single patch:
# (frame_patch_size, height_patch_size, width_patch_size).
# Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates.
patch_size_delta = torch.tensor(
self._patch_size,
device=patch_starts.device,
dtype=patch_starts.dtype,
).view(3, 1, 1, 1)
# Calculate end coordinates: start + patch_size
# Shape becomes (3, grid_f, grid_h, grid_w)
patch_ends = patch_starts + patch_size_delta
# Stack start and end coordinates together along the last dimension
# Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end]
latent_coords = torch.stack((patch_starts, patch_ends), dim=-1)
# Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence.
# Final Shape: (batch_size, 3, num_patches, 2)
latent_coords = einops.repeat(
latent_coords,
"c f h w bounds -> b c (f h w) bounds",
b=batch_size,
bounds=2,
)
return latent_coords
class NormLayerType(Enum):
GROUP_NORM = "group_norm"
@@ -1339,6 +1469,185 @@ class LTX2VideoEncoder(nn.Module):
return self.per_channel_statistics.normalize(means)
def tiled_encode_video(
self,
video: torch.Tensor,
tile_size: int = 512,
tile_overlap: int = 128,
) -> torch.Tensor:
"""Encode video using spatial tiling for memory efficiency.
Splits the video into overlapping spatial tiles, encodes each tile separately,
and blends the results using linear feathering in the overlap regions.
Args:
video: Input tensor of shape [B, C, F, H, W]
tile_size: Tile size in pixels (must be divisible by 32)
tile_overlap: Overlap between tiles in pixels (must be divisible by 32)
Returns:
Encoded latent tensor [B, C_latent, F_latent, H_latent, W_latent]
"""
batch, _channels, frames, height, width = video.shape
device = video.device
dtype = video.dtype
# Validate tile parameters
if tile_size % VAE_SPATIAL_FACTOR != 0:
raise ValueError(f"tile_size must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_size}")
if tile_overlap % VAE_SPATIAL_FACTOR != 0:
raise ValueError(f"tile_overlap must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_overlap}")
if tile_overlap >= tile_size:
raise ValueError(f"tile_overlap ({tile_overlap}) must be less than tile_size ({tile_size})")
# If video fits in a single tile, use regular encoding
if height <= tile_size and width <= tile_size:
return self.forward(video)
# Calculate output dimensions
# VAE compresses: H -> H/32, W -> W/32, F -> 1 + (F-1)/8
output_height = height // VAE_SPATIAL_FACTOR
output_width = width // VAE_SPATIAL_FACTOR
output_frames = 1 + (frames - 1) // VAE_TEMPORAL_FACTOR
# Latent channels (128 for LTX-2)
# Get from a small test encode or assume 128
latent_channels = 128
# Initialize output and weight tensors
output = torch.zeros(
(batch, latent_channels, output_frames, output_height, output_width),
device=device,
dtype=dtype,
)
weights = torch.zeros(
(batch, 1, output_frames, output_height, output_width),
device=device,
dtype=dtype,
)
# Calculate tile positions with overlap
# Step size is tile_size - tile_overlap
step_h = tile_size - tile_overlap
step_w = tile_size - tile_overlap
h_positions = list(range(0, max(1, height - tile_overlap), step_h))
w_positions = list(range(0, max(1, width - tile_overlap), step_w))
# Ensure last tile covers the edge
if h_positions[-1] + tile_size < height:
h_positions.append(height - tile_size)
if w_positions[-1] + tile_size < width:
w_positions.append(width - tile_size)
# Remove duplicates and sort
h_positions = sorted(set(h_positions))
w_positions = sorted(set(w_positions))
# Overlap in latent space
overlap_out_h = tile_overlap // VAE_SPATIAL_FACTOR
overlap_out_w = tile_overlap // VAE_SPATIAL_FACTOR
# Process each tile
for h_pos in h_positions:
for w_pos in w_positions:
# Calculate tile boundaries in input space
h_start = max(0, h_pos)
w_start = max(0, w_pos)
h_end = min(h_start + tile_size, height)
w_end = min(w_start + tile_size, width)
# Ensure tile dimensions are divisible by VAE_SPATIAL_FACTOR
tile_h = ((h_end - h_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR
tile_w = ((w_end - w_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR
if tile_h < VAE_SPATIAL_FACTOR or tile_w < VAE_SPATIAL_FACTOR:
continue
# Adjust end positions
h_end = h_start + tile_h
w_end = w_start + tile_w
# Extract tile
tile = video[:, :, :, h_start:h_end, w_start:w_end]
# Encode tile
encoded_tile = self.forward(tile)
# Get actual encoded dimensions
_, _, tile_out_frames, tile_out_height, tile_out_width = encoded_tile.shape
# Calculate output positions
out_h_start = h_start // VAE_SPATIAL_FACTOR
out_w_start = w_start // VAE_SPATIAL_FACTOR
out_h_end = min(out_h_start + tile_out_height, output_height)
out_w_end = min(out_w_start + tile_out_width, output_width)
# Trim encoded tile if necessary
actual_tile_h = out_h_end - out_h_start
actual_tile_w = out_w_end - out_w_start
encoded_tile = encoded_tile[:, :, :, :actual_tile_h, :actual_tile_w]
# Create blending mask with linear feathering at edges
mask = torch.ones(
(1, 1, tile_out_frames, actual_tile_h, actual_tile_w),
device=device,
dtype=dtype,
)
# Apply feathering at edges (linear blend in overlap regions)
# Left edge
if h_pos > 0 and overlap_out_h > 0 and overlap_out_h < actual_tile_h:
fade_in = torch.linspace(0.0, 1.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1]
mask[:, :, :, :overlap_out_h, :] *= fade_in.view(1, 1, 1, -1, 1)
# Right edge (bottom in height dimension)
if h_end < height and overlap_out_h > 0 and overlap_out_h < actual_tile_h:
fade_out = torch.linspace(1.0, 0.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1]
mask[:, :, :, -overlap_out_h:, :] *= fade_out.view(1, 1, 1, -1, 1)
# Top edge (left in width dimension)
if w_pos > 0 and overlap_out_w > 0 and overlap_out_w < actual_tile_w:
fade_in = torch.linspace(0.0, 1.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1]
mask[:, :, :, :, :overlap_out_w] *= fade_in.view(1, 1, 1, 1, -1)
# Bottom edge (right in width dimension)
if w_end < width and overlap_out_w > 0 and overlap_out_w < actual_tile_w:
fade_out = torch.linspace(1.0, 0.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1]
mask[:, :, :, :, -overlap_out_w:] *= fade_out.view(1, 1, 1, 1, -1)
# Accumulate weighted results
output[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += encoded_tile * mask
weights[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += mask
# Normalize by weights (avoid division by zero)
output = output / (weights + 1e-8)
return output
def encode(
self,
video: torch.Tensor,
tiled=False,
tile_size_in_pixels: Optional[int] = 512,
tile_overlap_in_pixels: Optional[int] = 128,
**kwargs,
) -> torch.Tensor:
device = next(self.parameters()).device
vae_dtype = next(self.parameters()).dtype
if video.ndim == 4:
video = video.unsqueeze(0) # [C, F, H, W] -> [B, C, F, H, W]
video = video.to(device=device, dtype=vae_dtype)
# Choose encoding method based on tiling flag
if tiled:
latents = self.tiled_encode_video(
video=video,
tile_size=tile_size_in_pixels,
tile_overlap=tile_overlap_in_pixels,
)
else:
# Encode video - VAE expects [B, C, F, H, W], returns [B, C, F', H', W']
latents = self.forward(video)
return latents
def _make_decoder_block(
block_name: str,
block_config: dict[str, Any],
@@ -1850,6 +2159,30 @@ class LTX2VideoDecoder(nn.Module):
return weights
def decode(
self,
latent: torch.Tensor,
tiled=False,
tile_size_in_pixels: Optional[int] = 512,
tile_overlap_in_pixels: Optional[int] = 128,
tile_size_in_frames: Optional[int] = 128,
tile_overlap_in_frames: Optional[int] = 24,
) -> torch.Tensor:
if tiled:
tiling_config = TilingConfig(
spatial_config=SpatialTilingConfig(
tile_size_in_pixels=tile_size_in_pixels,
tile_overlap_in_pixels=tile_overlap_in_pixels,
),
temporal_config=TemporalTilingConfig(
tile_size_in_frames=tile_size_in_frames,
tile_overlap_in_frames=tile_overlap_in_frames,
),
)
tiles = self.tiled_decode(latent, tiling_config)
return torch.cat(list(tiles), dim=2)
else:
return self.forward(latent)
def decode_video(
latent: torch.Tensor,
@@ -1875,10 +2208,10 @@ def decode_video(
if tiling_config is not None:
for frames in video_decoder.tiled_decode(latent, tiling_config, generator=generator):
yield convert_to_uint8(frames)
return convert_to_uint8(frames)
else:
decoded_video = video_decoder(latent, generator=generator)
yield convert_to_uint8(decoded_video)
return convert_to_uint8(decoded_video)
def get_video_chunks_number(num_frames: int, tiling_config: TilingConfig | None = None) -> int: