mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support ltx2 one-stage pipeline
This commit is contained in:
@@ -337,3 +337,35 @@ class Patchifier(Protocol):
|
||||
Returns:
|
||||
Tensor containing patch coordinate metadata such as spatial or temporal intervals.
|
||||
"""
|
||||
|
||||
|
||||
def get_pixel_coords(
|
||||
latent_coords: torch.Tensor,
|
||||
scale_factors: SpatioTemporalScaleFactors,
|
||||
causal_fix: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
|
||||
each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
|
||||
Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
|
||||
Args:
|
||||
latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
|
||||
scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
|
||||
per axis.
|
||||
causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
|
||||
that treat frame zero differently still yield non-negative timestamps.
|
||||
"""
|
||||
# Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
|
||||
broadcast_shape = [1] * latent_coords.ndim
|
||||
broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
|
||||
scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
|
||||
|
||||
# Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
|
||||
pixel_coords = latent_coords * scale_tensor
|
||||
|
||||
if causal_fix:
|
||||
# VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
|
||||
# Shift and clamp to keep the first-frame timestamps causal and non-negative.
|
||||
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
|
||||
|
||||
return pixel_coords
|
||||
|
||||
@@ -514,7 +514,7 @@ class Attention(torch.nn.Module):
|
||||
out_pattern="b s n d",
|
||||
attn_mask=mask
|
||||
)
|
||||
|
||||
|
||||
# Reshape back to original format
|
||||
out = out.flatten(2, 3)
|
||||
return self.to_out(out)
|
||||
@@ -1398,7 +1398,7 @@ class LTXModel(torch.nn.Module):
|
||||
x = proj_out(x)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
def _forward(
|
||||
self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@@ -1440,3 +1440,9 @@ class LTXModel(torch.nn.Module):
|
||||
else None
|
||||
)
|
||||
return vx, ax
|
||||
|
||||
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps):
|
||||
video = Modality(video_latents, video_timesteps, video_positions, video_context)
|
||||
audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context)
|
||||
vx, ax = self._forward(video=video, audio=audio, perturbations=None)
|
||||
return vx, ax
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import itertools
|
||||
import math
|
||||
import einops
|
||||
from dataclasses import replace, dataclass
|
||||
from typing import Any, Callable, Iterator, List, NamedTuple, Tuple, Union, Optional
|
||||
import torch
|
||||
@@ -7,9 +8,138 @@ from einops import rearrange
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from enum import Enum
|
||||
from .ltx2_common import PixelNorm, SpatioTemporalScaleFactors, VideoLatentShape
|
||||
from .ltx2_common import PixelNorm, SpatioTemporalScaleFactors, VideoLatentShape, Patchifier, AudioLatentShape
|
||||
from .ltx2_dit import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
|
||||
VAE_SPATIAL_FACTOR = 32
|
||||
VAE_TEMPORAL_FACTOR = 8
|
||||
|
||||
|
||||
class VideoLatentPatchifier(Patchifier):
|
||||
def __init__(self, patch_size: int):
|
||||
# Patch sizes for video latents.
|
||||
self._patch_size = (
|
||||
1, # temporal dimension
|
||||
patch_size, # height dimension
|
||||
patch_size, # width dimension
|
||||
)
|
||||
|
||||
@property
|
||||
def patch_size(self) -> Tuple[int, int, int]:
|
||||
return self._patch_size
|
||||
|
||||
def get_token_count(self, tgt_shape: VideoLatentShape) -> int:
|
||||
return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size)
|
||||
|
||||
def patchify(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
latents = einops.rearrange(
|
||||
latents,
|
||||
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
||||
p1=self._patch_size[0],
|
||||
p2=self._patch_size[1],
|
||||
p3=self._patch_size[2],
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
def unpatchify(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
output_shape: VideoLatentShape,
|
||||
) -> torch.Tensor:
|
||||
assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier"
|
||||
|
||||
patch_grid_frames = output_shape.frames // self._patch_size[0]
|
||||
patch_grid_height = output_shape.height // self._patch_size[1]
|
||||
patch_grid_width = output_shape.width // self._patch_size[2]
|
||||
|
||||
latents = einops.rearrange(
|
||||
latents,
|
||||
"b (f h w) (c p q) -> b c f (h p) (w q)",
|
||||
f=patch_grid_frames,
|
||||
h=patch_grid_height,
|
||||
w=patch_grid_width,
|
||||
p=self._patch_size[1],
|
||||
q=self._patch_size[2],
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
def get_patch_grid_bounds(
|
||||
self,
|
||||
output_shape: AudioLatentShape | VideoLatentShape,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Return the per-dimension bounds [inclusive start, exclusive end) for every
|
||||
patch produced by `patchify`. The bounds are expressed in the original
|
||||
video grid coordinates: frame/time, height, and width.
|
||||
The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where:
|
||||
- axis 1 (size 3) enumerates (frame/time, height, width) dimensions
|
||||
- axis 3 (size 2) stores `[start, end)` indices within each dimension
|
||||
Args:
|
||||
output_shape: Video grid description containing frames, height, and width.
|
||||
device: Device of the latent tensor.
|
||||
"""
|
||||
if not isinstance(output_shape, VideoLatentShape):
|
||||
raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates")
|
||||
|
||||
frames = output_shape.frames
|
||||
height = output_shape.height
|
||||
width = output_shape.width
|
||||
batch_size = output_shape.batch
|
||||
|
||||
# Validate inputs to ensure positive dimensions
|
||||
assert frames > 0, f"frames must be positive, got {frames}"
|
||||
assert height > 0, f"height must be positive, got {height}"
|
||||
assert width > 0, f"width must be positive, got {width}"
|
||||
assert batch_size > 0, f"batch_size must be positive, got {batch_size}"
|
||||
|
||||
# Generate grid coordinates for each dimension (frame, height, width)
|
||||
# We use torch.arange to create the starting coordinates for each patch.
|
||||
# indexing='ij' ensures the dimensions are in the order (frame, height, width).
|
||||
grid_coords = torch.meshgrid(
|
||||
torch.arange(start=0, end=frames, step=self._patch_size[0], device=device),
|
||||
torch.arange(start=0, end=height, step=self._patch_size[1], device=device),
|
||||
torch.arange(start=0, end=width, step=self._patch_size[2], device=device),
|
||||
indexing="ij",
|
||||
)
|
||||
|
||||
# Stack the grid coordinates to create the start coordinates tensor.
|
||||
# Shape becomes (3, grid_f, grid_h, grid_w)
|
||||
patch_starts = torch.stack(grid_coords, dim=0)
|
||||
|
||||
# Create a tensor containing the size of a single patch:
|
||||
# (frame_patch_size, height_patch_size, width_patch_size).
|
||||
# Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates.
|
||||
patch_size_delta = torch.tensor(
|
||||
self._patch_size,
|
||||
device=patch_starts.device,
|
||||
dtype=patch_starts.dtype,
|
||||
).view(3, 1, 1, 1)
|
||||
|
||||
# Calculate end coordinates: start + patch_size
|
||||
# Shape becomes (3, grid_f, grid_h, grid_w)
|
||||
patch_ends = patch_starts + patch_size_delta
|
||||
|
||||
# Stack start and end coordinates together along the last dimension
|
||||
# Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end]
|
||||
latent_coords = torch.stack((patch_starts, patch_ends), dim=-1)
|
||||
|
||||
# Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence.
|
||||
# Final Shape: (batch_size, 3, num_patches, 2)
|
||||
latent_coords = einops.repeat(
|
||||
latent_coords,
|
||||
"c f h w bounds -> b c (f h w) bounds",
|
||||
b=batch_size,
|
||||
bounds=2,
|
||||
)
|
||||
|
||||
return latent_coords
|
||||
|
||||
|
||||
class NormLayerType(Enum):
|
||||
GROUP_NORM = "group_norm"
|
||||
@@ -1339,6 +1469,185 @@ class LTX2VideoEncoder(nn.Module):
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
|
||||
def tiled_encode_video(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
tile_size: int = 512,
|
||||
tile_overlap: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""Encode video using spatial tiling for memory efficiency.
|
||||
Splits the video into overlapping spatial tiles, encodes each tile separately,
|
||||
and blends the results using linear feathering in the overlap regions.
|
||||
Args:
|
||||
video: Input tensor of shape [B, C, F, H, W]
|
||||
tile_size: Tile size in pixels (must be divisible by 32)
|
||||
tile_overlap: Overlap between tiles in pixels (must be divisible by 32)
|
||||
Returns:
|
||||
Encoded latent tensor [B, C_latent, F_latent, H_latent, W_latent]
|
||||
"""
|
||||
batch, _channels, frames, height, width = video.shape
|
||||
device = video.device
|
||||
dtype = video.dtype
|
||||
|
||||
# Validate tile parameters
|
||||
if tile_size % VAE_SPATIAL_FACTOR != 0:
|
||||
raise ValueError(f"tile_size must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_size}")
|
||||
if tile_overlap % VAE_SPATIAL_FACTOR != 0:
|
||||
raise ValueError(f"tile_overlap must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_overlap}")
|
||||
if tile_overlap >= tile_size:
|
||||
raise ValueError(f"tile_overlap ({tile_overlap}) must be less than tile_size ({tile_size})")
|
||||
|
||||
# If video fits in a single tile, use regular encoding
|
||||
if height <= tile_size and width <= tile_size:
|
||||
return self.forward(video)
|
||||
|
||||
# Calculate output dimensions
|
||||
# VAE compresses: H -> H/32, W -> W/32, F -> 1 + (F-1)/8
|
||||
output_height = height // VAE_SPATIAL_FACTOR
|
||||
output_width = width // VAE_SPATIAL_FACTOR
|
||||
output_frames = 1 + (frames - 1) // VAE_TEMPORAL_FACTOR
|
||||
|
||||
# Latent channels (128 for LTX-2)
|
||||
# Get from a small test encode or assume 128
|
||||
latent_channels = 128
|
||||
|
||||
# Initialize output and weight tensors
|
||||
output = torch.zeros(
|
||||
(batch, latent_channels, output_frames, output_height, output_width),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
weights = torch.zeros(
|
||||
(batch, 1, output_frames, output_height, output_width),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Calculate tile positions with overlap
|
||||
# Step size is tile_size - tile_overlap
|
||||
step_h = tile_size - tile_overlap
|
||||
step_w = tile_size - tile_overlap
|
||||
|
||||
h_positions = list(range(0, max(1, height - tile_overlap), step_h))
|
||||
w_positions = list(range(0, max(1, width - tile_overlap), step_w))
|
||||
|
||||
# Ensure last tile covers the edge
|
||||
if h_positions[-1] + tile_size < height:
|
||||
h_positions.append(height - tile_size)
|
||||
if w_positions[-1] + tile_size < width:
|
||||
w_positions.append(width - tile_size)
|
||||
|
||||
# Remove duplicates and sort
|
||||
h_positions = sorted(set(h_positions))
|
||||
w_positions = sorted(set(w_positions))
|
||||
|
||||
# Overlap in latent space
|
||||
overlap_out_h = tile_overlap // VAE_SPATIAL_FACTOR
|
||||
overlap_out_w = tile_overlap // VAE_SPATIAL_FACTOR
|
||||
|
||||
# Process each tile
|
||||
for h_pos in h_positions:
|
||||
for w_pos in w_positions:
|
||||
# Calculate tile boundaries in input space
|
||||
h_start = max(0, h_pos)
|
||||
w_start = max(0, w_pos)
|
||||
h_end = min(h_start + tile_size, height)
|
||||
w_end = min(w_start + tile_size, width)
|
||||
|
||||
# Ensure tile dimensions are divisible by VAE_SPATIAL_FACTOR
|
||||
tile_h = ((h_end - h_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR
|
||||
tile_w = ((w_end - w_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR
|
||||
|
||||
if tile_h < VAE_SPATIAL_FACTOR or tile_w < VAE_SPATIAL_FACTOR:
|
||||
continue
|
||||
|
||||
# Adjust end positions
|
||||
h_end = h_start + tile_h
|
||||
w_end = w_start + tile_w
|
||||
|
||||
# Extract tile
|
||||
tile = video[:, :, :, h_start:h_end, w_start:w_end]
|
||||
|
||||
# Encode tile
|
||||
encoded_tile = self.forward(tile)
|
||||
|
||||
# Get actual encoded dimensions
|
||||
_, _, tile_out_frames, tile_out_height, tile_out_width = encoded_tile.shape
|
||||
|
||||
# Calculate output positions
|
||||
out_h_start = h_start // VAE_SPATIAL_FACTOR
|
||||
out_w_start = w_start // VAE_SPATIAL_FACTOR
|
||||
out_h_end = min(out_h_start + tile_out_height, output_height)
|
||||
out_w_end = min(out_w_start + tile_out_width, output_width)
|
||||
|
||||
# Trim encoded tile if necessary
|
||||
actual_tile_h = out_h_end - out_h_start
|
||||
actual_tile_w = out_w_end - out_w_start
|
||||
encoded_tile = encoded_tile[:, :, :, :actual_tile_h, :actual_tile_w]
|
||||
|
||||
# Create blending mask with linear feathering at edges
|
||||
mask = torch.ones(
|
||||
(1, 1, tile_out_frames, actual_tile_h, actual_tile_w),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Apply feathering at edges (linear blend in overlap regions)
|
||||
# Left edge
|
||||
if h_pos > 0 and overlap_out_h > 0 and overlap_out_h < actual_tile_h:
|
||||
fade_in = torch.linspace(0.0, 1.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1]
|
||||
mask[:, :, :, :overlap_out_h, :] *= fade_in.view(1, 1, 1, -1, 1)
|
||||
|
||||
# Right edge (bottom in height dimension)
|
||||
if h_end < height and overlap_out_h > 0 and overlap_out_h < actual_tile_h:
|
||||
fade_out = torch.linspace(1.0, 0.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1]
|
||||
mask[:, :, :, -overlap_out_h:, :] *= fade_out.view(1, 1, 1, -1, 1)
|
||||
|
||||
# Top edge (left in width dimension)
|
||||
if w_pos > 0 and overlap_out_w > 0 and overlap_out_w < actual_tile_w:
|
||||
fade_in = torch.linspace(0.0, 1.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1]
|
||||
mask[:, :, :, :, :overlap_out_w] *= fade_in.view(1, 1, 1, 1, -1)
|
||||
|
||||
# Bottom edge (right in width dimension)
|
||||
if w_end < width and overlap_out_w > 0 and overlap_out_w < actual_tile_w:
|
||||
fade_out = torch.linspace(1.0, 0.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1]
|
||||
mask[:, :, :, :, -overlap_out_w:] *= fade_out.view(1, 1, 1, 1, -1)
|
||||
|
||||
# Accumulate weighted results
|
||||
output[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += encoded_tile * mask
|
||||
weights[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += mask
|
||||
|
||||
# Normalize by weights (avoid division by zero)
|
||||
output = output / (weights + 1e-8)
|
||||
|
||||
return output
|
||||
|
||||
def encode(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
tiled=False,
|
||||
tile_size_in_pixels: Optional[int] = 512,
|
||||
tile_overlap_in_pixels: Optional[int] = 128,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
device = next(self.parameters()).device
|
||||
vae_dtype = next(self.parameters()).dtype
|
||||
if video.ndim == 4:
|
||||
video = video.unsqueeze(0) # [C, F, H, W] -> [B, C, F, H, W]
|
||||
video = video.to(device=device, dtype=vae_dtype)
|
||||
# Choose encoding method based on tiling flag
|
||||
if tiled:
|
||||
latents = self.tiled_encode_video(
|
||||
video=video,
|
||||
tile_size=tile_size_in_pixels,
|
||||
tile_overlap=tile_overlap_in_pixels,
|
||||
)
|
||||
else:
|
||||
# Encode video - VAE expects [B, C, F, H, W], returns [B, C, F', H', W']
|
||||
latents = self.forward(video)
|
||||
return latents
|
||||
|
||||
|
||||
def _make_decoder_block(
|
||||
block_name: str,
|
||||
block_config: dict[str, Any],
|
||||
@@ -1850,6 +2159,30 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
return weights
|
||||
|
||||
def decode(
|
||||
self,
|
||||
latent: torch.Tensor,
|
||||
tiled=False,
|
||||
tile_size_in_pixels: Optional[int] = 512,
|
||||
tile_overlap_in_pixels: Optional[int] = 128,
|
||||
tile_size_in_frames: Optional[int] = 128,
|
||||
tile_overlap_in_frames: Optional[int] = 24,
|
||||
) -> torch.Tensor:
|
||||
if tiled:
|
||||
tiling_config = TilingConfig(
|
||||
spatial_config=SpatialTilingConfig(
|
||||
tile_size_in_pixels=tile_size_in_pixels,
|
||||
tile_overlap_in_pixels=tile_overlap_in_pixels,
|
||||
),
|
||||
temporal_config=TemporalTilingConfig(
|
||||
tile_size_in_frames=tile_size_in_frames,
|
||||
tile_overlap_in_frames=tile_overlap_in_frames,
|
||||
),
|
||||
)
|
||||
tiles = self.tiled_decode(latent, tiling_config)
|
||||
return torch.cat(list(tiles), dim=2)
|
||||
else:
|
||||
return self.forward(latent)
|
||||
|
||||
def decode_video(
|
||||
latent: torch.Tensor,
|
||||
@@ -1875,10 +2208,10 @@ def decode_video(
|
||||
|
||||
if tiling_config is not None:
|
||||
for frames in video_decoder.tiled_decode(latent, tiling_config, generator=generator):
|
||||
yield convert_to_uint8(frames)
|
||||
return convert_to_uint8(frames)
|
||||
else:
|
||||
decoded_video = video_decoder(latent, generator=generator)
|
||||
yield convert_to_uint8(decoded_video)
|
||||
return convert_to_uint8(decoded_video)
|
||||
|
||||
|
||||
def get_video_chunks_number(num_frames: int, tiling_config: TilingConfig | None = None) -> int:
|
||||
|
||||
Reference in New Issue
Block a user