support ltx2 one-stage pipeline

This commit is contained in:
mi804
2026-01-29 16:30:15 +08:00
parent 8d303b47e9
commit b1a2782ad7
7 changed files with 1005 additions and 7 deletions

View File

@@ -4,13 +4,14 @@ from typing_extensions import Literal
class FlowMatchScheduler():
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image"] = "FLUX.1"):
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2"] = "FLUX.1"):
self.set_timesteps_fn = {
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
"Wan": FlowMatchScheduler.set_timesteps_wan,
"Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image,
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
}.get(template, FlowMatchScheduler.set_timesteps_flux)
self.num_train_timesteps = 1000
@@ -121,7 +122,30 @@ class FlowMatchScheduler():
timestep_id = torch.argmin((timesteps - timestep).abs())
timesteps[timestep_id] = timestep
return sigmas, timesteps
@staticmethod
def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, stretch=True, terminal=0.1):
dynamic_shift_len = dynamic_shift_len or 4096
sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image(
image_seq_len=dynamic_shift_len,
base_seq_len=1024,
max_seq_len=4096,
base_shift=0.95,
max_shift=2.05,
)
num_train_timesteps = 1000
sigma_min = 0.0
sigma_max = 1.0
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
sigmas = math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1))
# Shift terminal
one_minus_z = 1.0 - sigmas
scale_factor = one_minus_z[-1] / (1 - terminal)
sigmas = 1.0 - (one_minus_z / scale_factor)
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
def set_training_weight(self):
steps = 1000
x = self.timesteps

View File

@@ -337,3 +337,35 @@ class Patchifier(Protocol):
Returns:
Tensor containing patch coordinate metadata such as spatial or temporal intervals.
"""
def get_pixel_coords(
latent_coords: torch.Tensor,
scale_factors: SpatioTemporalScaleFactors,
causal_fix: bool = False,
) -> torch.Tensor:
"""
Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
Args:
latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
per axis.
causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
that treat frame zero differently still yield non-negative timestamps.
"""
# Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
broadcast_shape = [1] * latent_coords.ndim
broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
# Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
pixel_coords = latent_coords * scale_tensor
if causal_fix:
# VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
# Shift and clamp to keep the first-frame timestamps causal and non-negative.
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
return pixel_coords

View File

@@ -514,7 +514,7 @@ class Attention(torch.nn.Module):
out_pattern="b s n d",
attn_mask=mask
)
# Reshape back to original format
out = out.flatten(2, 3)
return self.to_out(out)
@@ -1398,7 +1398,7 @@ class LTXModel(torch.nn.Module):
x = proj_out(x)
return x
def forward(
def _forward(
self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig
) -> tuple[torch.Tensor, torch.Tensor]:
"""
@@ -1440,3 +1440,9 @@ class LTXModel(torch.nn.Module):
else None
)
return vx, ax
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps):
video = Modality(video_latents, video_timesteps, video_positions, video_context)
audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context)
vx, ax = self._forward(video=video, audio=audio, perturbations=None)
return vx, ax

View File

@@ -1,5 +1,6 @@
import itertools
import math
import einops
from dataclasses import replace, dataclass
from typing import Any, Callable, Iterator, List, NamedTuple, Tuple, Union, Optional
import torch
@@ -7,9 +8,138 @@ from einops import rearrange
from torch import nn
from torch.nn import functional as F
from enum import Enum
from .ltx2_common import PixelNorm, SpatioTemporalScaleFactors, VideoLatentShape
from .ltx2_common import PixelNorm, SpatioTemporalScaleFactors, VideoLatentShape, Patchifier, AudioLatentShape
from .ltx2_dit import PixArtAlphaCombinedTimestepSizeEmbeddings
VAE_SPATIAL_FACTOR = 32
VAE_TEMPORAL_FACTOR = 8
class VideoLatentPatchifier(Patchifier):
def __init__(self, patch_size: int):
# Patch sizes for video latents.
self._patch_size = (
1, # temporal dimension
patch_size, # height dimension
patch_size, # width dimension
)
@property
def patch_size(self) -> Tuple[int, int, int]:
return self._patch_size
def get_token_count(self, tgt_shape: VideoLatentShape) -> int:
return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size)
def patchify(
self,
latents: torch.Tensor,
) -> torch.Tensor:
latents = einops.rearrange(
latents,
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
p1=self._patch_size[0],
p2=self._patch_size[1],
p3=self._patch_size[2],
)
return latents
def unpatchify(
self,
latents: torch.Tensor,
output_shape: VideoLatentShape,
) -> torch.Tensor:
assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier"
patch_grid_frames = output_shape.frames // self._patch_size[0]
patch_grid_height = output_shape.height // self._patch_size[1]
patch_grid_width = output_shape.width // self._patch_size[2]
latents = einops.rearrange(
latents,
"b (f h w) (c p q) -> b c f (h p) (w q)",
f=patch_grid_frames,
h=patch_grid_height,
w=patch_grid_width,
p=self._patch_size[1],
q=self._patch_size[2],
)
return latents
def get_patch_grid_bounds(
self,
output_shape: AudioLatentShape | VideoLatentShape,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""
Return the per-dimension bounds [inclusive start, exclusive end) for every
patch produced by `patchify`. The bounds are expressed in the original
video grid coordinates: frame/time, height, and width.
The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where:
- axis 1 (size 3) enumerates (frame/time, height, width) dimensions
- axis 3 (size 2) stores `[start, end)` indices within each dimension
Args:
output_shape: Video grid description containing frames, height, and width.
device: Device of the latent tensor.
"""
if not isinstance(output_shape, VideoLatentShape):
raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates")
frames = output_shape.frames
height = output_shape.height
width = output_shape.width
batch_size = output_shape.batch
# Validate inputs to ensure positive dimensions
assert frames > 0, f"frames must be positive, got {frames}"
assert height > 0, f"height must be positive, got {height}"
assert width > 0, f"width must be positive, got {width}"
assert batch_size > 0, f"batch_size must be positive, got {batch_size}"
# Generate grid coordinates for each dimension (frame, height, width)
# We use torch.arange to create the starting coordinates for each patch.
# indexing='ij' ensures the dimensions are in the order (frame, height, width).
grid_coords = torch.meshgrid(
torch.arange(start=0, end=frames, step=self._patch_size[0], device=device),
torch.arange(start=0, end=height, step=self._patch_size[1], device=device),
torch.arange(start=0, end=width, step=self._patch_size[2], device=device),
indexing="ij",
)
# Stack the grid coordinates to create the start coordinates tensor.
# Shape becomes (3, grid_f, grid_h, grid_w)
patch_starts = torch.stack(grid_coords, dim=0)
# Create a tensor containing the size of a single patch:
# (frame_patch_size, height_patch_size, width_patch_size).
# Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates.
patch_size_delta = torch.tensor(
self._patch_size,
device=patch_starts.device,
dtype=patch_starts.dtype,
).view(3, 1, 1, 1)
# Calculate end coordinates: start + patch_size
# Shape becomes (3, grid_f, grid_h, grid_w)
patch_ends = patch_starts + patch_size_delta
# Stack start and end coordinates together along the last dimension
# Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end]
latent_coords = torch.stack((patch_starts, patch_ends), dim=-1)
# Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence.
# Final Shape: (batch_size, 3, num_patches, 2)
latent_coords = einops.repeat(
latent_coords,
"c f h w bounds -> b c (f h w) bounds",
b=batch_size,
bounds=2,
)
return latent_coords
class NormLayerType(Enum):
GROUP_NORM = "group_norm"
@@ -1339,6 +1469,185 @@ class LTX2VideoEncoder(nn.Module):
return self.per_channel_statistics.normalize(means)
def tiled_encode_video(
self,
video: torch.Tensor,
tile_size: int = 512,
tile_overlap: int = 128,
) -> torch.Tensor:
"""Encode video using spatial tiling for memory efficiency.
Splits the video into overlapping spatial tiles, encodes each tile separately,
and blends the results using linear feathering in the overlap regions.
Args:
video: Input tensor of shape [B, C, F, H, W]
tile_size: Tile size in pixels (must be divisible by 32)
tile_overlap: Overlap between tiles in pixels (must be divisible by 32)
Returns:
Encoded latent tensor [B, C_latent, F_latent, H_latent, W_latent]
"""
batch, _channels, frames, height, width = video.shape
device = video.device
dtype = video.dtype
# Validate tile parameters
if tile_size % VAE_SPATIAL_FACTOR != 0:
raise ValueError(f"tile_size must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_size}")
if tile_overlap % VAE_SPATIAL_FACTOR != 0:
raise ValueError(f"tile_overlap must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_overlap}")
if tile_overlap >= tile_size:
raise ValueError(f"tile_overlap ({tile_overlap}) must be less than tile_size ({tile_size})")
# If video fits in a single tile, use regular encoding
if height <= tile_size and width <= tile_size:
return self.forward(video)
# Calculate output dimensions
# VAE compresses: H -> H/32, W -> W/32, F -> 1 + (F-1)/8
output_height = height // VAE_SPATIAL_FACTOR
output_width = width // VAE_SPATIAL_FACTOR
output_frames = 1 + (frames - 1) // VAE_TEMPORAL_FACTOR
# Latent channels (128 for LTX-2)
# Get from a small test encode or assume 128
latent_channels = 128
# Initialize output and weight tensors
output = torch.zeros(
(batch, latent_channels, output_frames, output_height, output_width),
device=device,
dtype=dtype,
)
weights = torch.zeros(
(batch, 1, output_frames, output_height, output_width),
device=device,
dtype=dtype,
)
# Calculate tile positions with overlap
# Step size is tile_size - tile_overlap
step_h = tile_size - tile_overlap
step_w = tile_size - tile_overlap
h_positions = list(range(0, max(1, height - tile_overlap), step_h))
w_positions = list(range(0, max(1, width - tile_overlap), step_w))
# Ensure last tile covers the edge
if h_positions[-1] + tile_size < height:
h_positions.append(height - tile_size)
if w_positions[-1] + tile_size < width:
w_positions.append(width - tile_size)
# Remove duplicates and sort
h_positions = sorted(set(h_positions))
w_positions = sorted(set(w_positions))
# Overlap in latent space
overlap_out_h = tile_overlap // VAE_SPATIAL_FACTOR
overlap_out_w = tile_overlap // VAE_SPATIAL_FACTOR
# Process each tile
for h_pos in h_positions:
for w_pos in w_positions:
# Calculate tile boundaries in input space
h_start = max(0, h_pos)
w_start = max(0, w_pos)
h_end = min(h_start + tile_size, height)
w_end = min(w_start + tile_size, width)
# Ensure tile dimensions are divisible by VAE_SPATIAL_FACTOR
tile_h = ((h_end - h_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR
tile_w = ((w_end - w_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR
if tile_h < VAE_SPATIAL_FACTOR or tile_w < VAE_SPATIAL_FACTOR:
continue
# Adjust end positions
h_end = h_start + tile_h
w_end = w_start + tile_w
# Extract tile
tile = video[:, :, :, h_start:h_end, w_start:w_end]
# Encode tile
encoded_tile = self.forward(tile)
# Get actual encoded dimensions
_, _, tile_out_frames, tile_out_height, tile_out_width = encoded_tile.shape
# Calculate output positions
out_h_start = h_start // VAE_SPATIAL_FACTOR
out_w_start = w_start // VAE_SPATIAL_FACTOR
out_h_end = min(out_h_start + tile_out_height, output_height)
out_w_end = min(out_w_start + tile_out_width, output_width)
# Trim encoded tile if necessary
actual_tile_h = out_h_end - out_h_start
actual_tile_w = out_w_end - out_w_start
encoded_tile = encoded_tile[:, :, :, :actual_tile_h, :actual_tile_w]
# Create blending mask with linear feathering at edges
mask = torch.ones(
(1, 1, tile_out_frames, actual_tile_h, actual_tile_w),
device=device,
dtype=dtype,
)
# Apply feathering at edges (linear blend in overlap regions)
# Left edge
if h_pos > 0 and overlap_out_h > 0 and overlap_out_h < actual_tile_h:
fade_in = torch.linspace(0.0, 1.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1]
mask[:, :, :, :overlap_out_h, :] *= fade_in.view(1, 1, 1, -1, 1)
# Right edge (bottom in height dimension)
if h_end < height and overlap_out_h > 0 and overlap_out_h < actual_tile_h:
fade_out = torch.linspace(1.0, 0.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1]
mask[:, :, :, -overlap_out_h:, :] *= fade_out.view(1, 1, 1, -1, 1)
# Top edge (left in width dimension)
if w_pos > 0 and overlap_out_w > 0 and overlap_out_w < actual_tile_w:
fade_in = torch.linspace(0.0, 1.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1]
mask[:, :, :, :, :overlap_out_w] *= fade_in.view(1, 1, 1, 1, -1)
# Bottom edge (right in width dimension)
if w_end < width and overlap_out_w > 0 and overlap_out_w < actual_tile_w:
fade_out = torch.linspace(1.0, 0.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1]
mask[:, :, :, :, -overlap_out_w:] *= fade_out.view(1, 1, 1, 1, -1)
# Accumulate weighted results
output[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += encoded_tile * mask
weights[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += mask
# Normalize by weights (avoid division by zero)
output = output / (weights + 1e-8)
return output
def encode(
self,
video: torch.Tensor,
tiled=False,
tile_size_in_pixels: Optional[int] = 512,
tile_overlap_in_pixels: Optional[int] = 128,
**kwargs,
) -> torch.Tensor:
device = next(self.parameters()).device
vae_dtype = next(self.parameters()).dtype
if video.ndim == 4:
video = video.unsqueeze(0) # [C, F, H, W] -> [B, C, F, H, W]
video = video.to(device=device, dtype=vae_dtype)
# Choose encoding method based on tiling flag
if tiled:
latents = self.tiled_encode_video(
video=video,
tile_size=tile_size_in_pixels,
tile_overlap=tile_overlap_in_pixels,
)
else:
# Encode video - VAE expects [B, C, F, H, W], returns [B, C, F', H', W']
latents = self.forward(video)
return latents
def _make_decoder_block(
block_name: str,
block_config: dict[str, Any],
@@ -1850,6 +2159,30 @@ class LTX2VideoDecoder(nn.Module):
return weights
def decode(
self,
latent: torch.Tensor,
tiled=False,
tile_size_in_pixels: Optional[int] = 512,
tile_overlap_in_pixels: Optional[int] = 128,
tile_size_in_frames: Optional[int] = 128,
tile_overlap_in_frames: Optional[int] = 24,
) -> torch.Tensor:
if tiled:
tiling_config = TilingConfig(
spatial_config=SpatialTilingConfig(
tile_size_in_pixels=tile_size_in_pixels,
tile_overlap_in_pixels=tile_overlap_in_pixels,
),
temporal_config=TemporalTilingConfig(
tile_size_in_frames=tile_size_in_frames,
tile_overlap_in_frames=tile_overlap_in_frames,
),
)
tiles = self.tiled_decode(latent, tiling_config)
return torch.cat(list(tiles), dim=2)
else:
return self.forward(latent)
def decode_video(
latent: torch.Tensor,
@@ -1875,10 +2208,10 @@ def decode_video(
if tiling_config is not None:
for frames in video_decoder.tiled_decode(latent, tiling_config, generator=generator):
yield convert_to_uint8(frames)
return convert_to_uint8(frames)
else:
decoded_video = video_decoder(latent, generator=generator)
yield convert_to_uint8(decoded_video)
return convert_to_uint8(decoded_video)
def get_video_chunks_number(num_frames: int, tiling_config: TilingConfig | None = None) -> int:

View File

@@ -0,0 +1,451 @@
import torch, types
import numpy as np
from PIL import Image
from einops import repeat
from typing import Optional, Union
from einops import rearrange
import numpy as np
from PIL import Image
from tqdm import tqdm
from typing import Optional
from typing_extensions import Literal
from transformers import AutoImageProcessor, Gemma3Processor
import einops
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer
from ..models.ltx2_dit import LTXModel
from ..models.ltx2_video_vae import LTX2VideoEncoder, LTX2VideoDecoder, VideoLatentPatchifier
from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier
from ..models.ltx2_upsampler import LTX2LatentUpsampler
from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS
class LTX2AudioVideoPipeline(BasePipeline):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device,
torch_dtype=torch_dtype,
height_division_factor=32,
width_division_factor=32,
time_division_factor=8,
time_division_remainder=1,
)
self.scheduler = FlowMatchScheduler("LTX-2")
self.text_encoder: LTX2TextEncoder = None
self.tokenizer: LTXVGemmaTokenizer = None
self.processor: Gemma3Processor = None
self.text_encoder_post_modules: LTX2TextEncoderPostModules = None
self.dit: LTXModel = None
self.video_vae_encoder: LTX2VideoEncoder = None
self.video_vae_decoder: LTX2VideoDecoder = None
self.audio_vae_encoder: LTX2AudioEncoder = None
self.audio_vae_decoder: LTX2AudioDecoder = None
self.audio_vocoder: LTX2Vocoder = None
self.upsampler: LTX2LatentUpsampler = None
self.video_patchifier: VideoLatentPatchifier = VideoLatentPatchifier(patch_size=1)
self.audio_patchifier: AudioPatchifier = AudioPatchifier(patch_size=1)
self.in_iteration_models = ("dit",)
self.units = [
LTX2AudioVideoUnit_PipelineChecker(),
LTX2AudioVideoUnit_ShapeChecker(),
LTX2AudioVideoUnit_PromptEmbedder(),
LTX2AudioVideoUnit_NoiseInitializer(),
LTX2AudioVideoUnit_InputVideoEmbedder(),
]
self.post_units = [
LTX2AudioVideoPostUnit_UnPatchifier(),
]
self.model_fn = model_fn_ltx2
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
vram_limit: float = None,
):
# Initialize pipeline
pipe = LTX2AudioVideoPipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
pipe.text_encoder = model_pool.fetch_model("ltx2_text_encoder")
tokenizer_config.download_if_necessary()
pipe.tokenizer = LTXVGemmaTokenizer(tokenizer_path=tokenizer_config.path)
image_processor = AutoImageProcessor.from_pretrained(tokenizer_config.path, local_files_only=True)
pipe.processor = Gemma3Processor(image_processor=image_processor, tokenizer=pipe.tokenizer.tokenizer)
pipe.text_encoder_post_modules = model_pool.fetch_model("ltx2_text_encoder_post_modules")
pipe.dit = model_pool.fetch_model("ltx2_dit")
pipe.video_vae_encoder = model_pool.fetch_model("ltx2_video_vae_encoder")
pipe.video_vae_decoder = model_pool.fetch_model("ltx2_video_vae_decoder")
pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder")
pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder")
pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler")
# Optional
# pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: Optional[str] = "",
# Image-to-video
input_image: Optional[Image.Image] = None,
denoising_strength: float = 1.0,
# Randomness
seed: Optional[int] = None,
rand_device: Optional[str] = "cpu",
# Shape
height: Optional[int] = 512,
width: Optional[int] = 768,
num_frames=121,
# Classifier-free guidance
cfg_scale: Optional[float] = 3.0,
cfg_merge: Optional[bool] = False,
# Scheduler
num_inference_steps: Optional[int] = 40,
# VAE tiling
tiled: Optional[bool] = True,
tile_size_in_pixels: Optional[int] = 512,
tile_overlap_in_pixels: Optional[int] = 128,
tile_size_in_frames: Optional[int] = 128,
tile_overlap_in_frames: Optional[int] = 24,
# Two-Stage Pipeline
use_two_stage: Optional[bool] = True,
# progress_bar
progress_bar_cmd=tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
# Inputs
inputs_posi = {
"prompt": prompt,
}
inputs_nega = {
"negative_prompt": negative_prompt,
}
inputs_shared = {
"input_image": input_image,
"seed": seed, "rand_device": rand_device,
"height": height, "width": width, "num_frames": num_frames,
"cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
"use_two_stage": True
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# inputs_posi.update(torch.load("/mnt/nas1/zhanghong/project26/extern_codes/LTX-2/text_encodings.pt"))
# inputs_nega.update(torch.load("/mnt/nas1/zhanghong/project26/extern_codes/LTX-2/negative_text_encodings.pt"))
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video, **inputs_shared)
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared)
# post-denoising, pre-decoding processing logic
for unit in self.post_units:
inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Decode
self.load_models_to_device(['video_vae_decoder'])
video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels,
tile_overlap_in_pixels, tile_size_in_frames, tile_overlap_in_frames)
video = self.vae_output_to_video(video)
self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder'])
decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"])
decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float()
return video, decoded_audio
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
if cfg_scale != 1.0:
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
if isinstance(noise_pred_posi, tuple):
noise_pred = tuple(
n_nega + cfg_scale * (n_posi - n_nega)
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
)
else:
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
return noise_pred
class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
def __init__(self):
super().__init__(take_over=True)
def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
pass
return inputs_shared, inputs_posi, inputs_nega
class LTX2AudioVideoUnit_ShapeChecker(PipelineUnit):
"""
# TODO: Adjust with two stage pipeline
For two-stage pipelines, the resolution must be divisible by 64.
For one-stage pipelines, the resolution must be divisible by 32.
"""
def __init__(self):
super().__init__(
input_params=("height", "width", "num_frames"),
output_params=("height", "width", "num_frames"),
)
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames):
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
return {"height": height, "width": width, "num_frames": num_frames}
class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("video_context", "audio_context"),
onload_model_names=("text_encoder", "text_encoder_post_modules"),
)
def _convert_to_additive_mask(self, attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return (attention_mask - 1).to(dtype).reshape(
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(dtype).max
def _run_connectors(self, pipe, encoded_input: torch.Tensor,
attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
connector_attention_mask = self._convert_to_additive_mask(attention_mask, encoded_input.dtype)
encoded, encoded_connector_attention_mask = pipe.text_encoder_post_modules.embeddings_connector(
encoded_input,
connector_attention_mask,
)
# restore the mask values to int64
attention_mask = (encoded_connector_attention_mask < 0.000001).to(torch.int64)
attention_mask = attention_mask.reshape([encoded.shape[0], encoded.shape[1], 1])
encoded = encoded * attention_mask
encoded_for_audio, _ = pipe.text_encoder_post_modules.audio_embeddings_connector(
encoded_input, connector_attention_mask)
return encoded, encoded_for_audio, attention_mask.squeeze(-1)
def _norm_and_concat_padded_batch(
self,
encoded_text: torch.Tensor,
sequence_lengths: torch.Tensor,
padding_side: str = "right",
) -> torch.Tensor:
"""Normalize and flatten multi-layer hidden states, respecting padding.
Performs per-batch, per-layer normalization using masked mean and range,
then concatenates across the layer dimension.
Args:
encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].
sequence_lengths: Number of valid (non-padded) tokens per batch item.
padding_side: Whether padding is on "left" or "right".
Returns:
Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],
with padded positions zeroed out.
"""
b, t, d, l = encoded_text.shape # noqa: E741
device = encoded_text.device
# Build mask: [B, T, 1, 1]
token_indices = torch.arange(t, device=device)[None, :] # [1, T]
if padding_side == "right":
# For right padding, valid tokens are from 0 to sequence_length-1
mask = token_indices < sequence_lengths[:, None] # [B, T]
elif padding_side == "left":
# For left padding, valid tokens are from (T - sequence_length) to T-1
start_indices = t - sequence_lengths[:, None] # [B, 1]
mask = token_indices >= start_indices # [B, T]
else:
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
mask = rearrange(mask, "b t -> b t 1 1")
eps = 1e-6
# Compute masked mean: [B, 1, 1, L]
masked = encoded_text.masked_fill(~mask, 0.0)
denom = (sequence_lengths * d).view(b, 1, 1, 1)
mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps)
# Compute masked min/max: [B, 1, 1, L]
x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
range_ = x_max - x_min
# Normalize only the valid tokens
normed = 8 * (encoded_text - mean) / (range_ + eps)
# concat to be [Batch, T, D * L] - this preserves the original structure
normed = normed.reshape(b, t, -1) # [B, T, D * L]
# Apply mask to preserve original padding (set padded positions to 0)
mask_flattened = rearrange(mask, "b t 1 1 -> b t 1").expand(-1, -1, d * l)
normed = normed.masked_fill(~mask_flattened, 0.0)
return normed
def _run_feature_extractor(self,
pipe,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
padding_side: str = "right") -> torch.Tensor:
encoded_text_features = torch.stack(hidden_states, dim=-1)
encoded_text_features_dtype = encoded_text_features.dtype
sequence_lengths = attention_mask.sum(dim=-1)
normed_concated_encoded_text_features = self._norm_and_concat_padded_batch(encoded_text_features,
sequence_lengths,
padding_side=padding_side)
return pipe.text_encoder_post_modules.feature_extractor_linear(
normed_concated_encoded_text_features.to(encoded_text_features_dtype))
def _preprocess_text(self,
pipe,
text: str,
padding_side: str = "left") -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""
Encode a given string into feature tensors suitable for downstream tasks.
Args:
text (str): Input string to encode.
Returns:
tuple[torch.Tensor, dict[str, torch.Tensor]]: Encoded features and a dictionary with attention mask.
"""
token_pairs = pipe.tokenizer.tokenize_with_weights(text)["gemma"]
input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.text_encoder.device)
attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.text_encoder.device)
outputs = pipe.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
projected = self._run_feature_extractor(pipe,
hidden_states=outputs.hidden_states,
attention_mask=attention_mask,
padding_side=padding_side)
return projected, attention_mask
def encode_prompt(self, pipe, text, padding_side="left"):
encoded_inputs, attention_mask = self._preprocess_text(pipe, text, padding_side)
video_encoding, audio_encoding, attention_mask = self._run_connectors(pipe, encoded_inputs, attention_mask)
return video_encoding, audio_encoding, attention_mask
def process(self, pipe: LTX2AudioVideoPipeline, prompt: str):
pipe.load_models_to_device(self.onload_model_names)
video_context, audio_context, _ = self.encode_prompt(pipe, prompt)
return {"video_context": video_context, "audio_context": audio_context}
class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "num_frames", "seed", "rand_device",),
output_params=("video_noise", "audio_noise",),
)
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=pipe.video_vae_encoder.latent_channels)
video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
video_noise = pipe.video_patchifier.patchify(video_noise)
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device)
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float()
video_positions[:, 0, ...] = video_positions[:, 0, ...] / frame_rate
video_positions = video_positions.to(pipe.torch_dtype)
audio_latent_shape = AudioLatentShape.from_video_pixel_shape(video_pixel_shape)
audio_noise = pipe.generate_noise(audio_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
audio_noise = pipe.audio_patchifier.patchify(audio_noise)
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
return {
"video_noise": video_noise,
"audio_noise": audio_noise,
"video_positions": video_positions,
"audio_positions": audio_positions,
"video_latent_shape": video_latent_shape,
"audio_latent_shape": audio_latent_shape
}
class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_video", "video_noise", "audio_noise", "tiled", "tile_size", "tile_stride"),
output_params=("video_latents", "audio_latents"),
onload_model_names=("video_vae_encoder")
)
def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, audio_noise, tiled, tile_size, tile_stride):
if input_video is None:
return {"video_latents": video_noise, "audio_latents": audio_noise}
else:
# TODO: implement video-to-video
raise NotImplementedError("Video-to-video not implemented yet.")
class LTX2AudioVideoPostUnit_UnPatchifier(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("video_latent_shape", "audio_latent_shape", "video_latents", "audio_latents"),
output_params=("video_latents", "audio_latents"),
)
def process(self, pipe: LTX2AudioVideoPipeline, video_latent_shape, audio_latent_shape, video_latents, audio_latents):
video_latents = pipe.video_patchifier.unpatchify(video_latents, output_shape=video_latent_shape)
audio_latents = pipe.audio_patchifier.unpatchify(audio_latents, output_shape=audio_latent_shape)
return {"video_latents": video_latents, "audio_latents": audio_latents}
def model_fn_ltx2(
dit: LTXModel,
video_latents=None,
video_context=None,
video_positions=None,
audio_latents=None,
audio_context=None,
audio_positions=None,
timestep=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
#TODO: support gradient checkpointing
timestep = timestep.float() / 1000.
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1)
vx, ax = dit(
video_latents=video_latents,
video_positions=video_positions,
video_context=video_context,
video_timesteps=video_timesteps,
audio_latents=audio_latents,
audio_positions=audio_positions,
audio_context=audio_context,
audio_timesteps=audio_timesteps,
)
return vx, ax

View File

@@ -0,0 +1,106 @@
from fractions import Fraction
import torch
import av
from tqdm import tqdm
def _resample_audio(
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
) -> None:
cc = audio_stream.codec_context
# Use the encoder's format/layout/rate as the *target*
target_format = cc.format or "fltp" # AAC → usually fltp
target_layout = cc.layout or "stereo"
target_rate = cc.sample_rate or frame_in.sample_rate
audio_resampler = av.audio.resampler.AudioResampler(
format=target_format,
layout=target_layout,
rate=target_rate,
)
audio_next_pts = 0
for rframe in audio_resampler.resample(frame_in):
if rframe.pts is None:
rframe.pts = audio_next_pts
audio_next_pts += rframe.samples
rframe.sample_rate = frame_in.sample_rate
container.mux(audio_stream.encode(rframe))
# flush audio encoder
for packet in audio_stream.encode():
container.mux(packet)
def _write_audio(
container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int
) -> None:
if samples.ndim == 1:
samples = samples[:, None]
if samples.shape[1] != 2 and samples.shape[0] == 2:
samples = samples.T
if samples.shape[1] != 2:
raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
if samples.dtype != torch.int16:
samples = torch.clip(samples, -1.0, 1.0)
samples = (samples * 32767.0).to(torch.int16)
frame_in = av.AudioFrame.from_ndarray(
samples.contiguous().reshape(1, -1).cpu().numpy(),
format="s16",
layout="stereo",
)
frame_in.sample_rate = audio_sample_rate
_resample_audio(container, audio_stream, frame_in)
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
"""
Prepare the audio stream for writing.
"""
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
audio_stream.codec_context.sample_rate = audio_sample_rate
audio_stream.codec_context.layout = "stereo"
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
return audio_stream
def write_video_audio_ltx2(
video: list[Image.Image],
audio: torch.Tensor | None,
output_path: str,
fps: int = 24,
audio_sample_rate: int | None = 24000,
) -> None:
width, height = video[0].size
container = av.open(output_path, mode="w")
stream = container.add_stream("libx264", rate=int(fps))
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
if audio is not None:
if audio_sample_rate is None:
raise ValueError("audio_sample_rate is required when audio is provided")
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
for frame in tqdm(video, total=len(video)):
frame = av.VideoFrame.from_image(frame)
for packet in stream.encode(frame):
container.mux(packet)
# Flush encoder
for packet in stream.encode():
container.mux(packet)
if audio is not None:
_write_audio(container, audio_stream, audio, audio_sample_rate)
container.close()

View File

@@ -0,0 +1,46 @@
import torch
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
from diffsynth.utils.data.media_io import write_video_audio_ltx2
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = LTX2AudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
)
prompt = """
INT. OVEN DAY. Static camera from inside the oven, looking outward through the slightly fogged glass door. Warm golden light glows around freshly baked cookies. The bakers face fills the frame, eyes wide with focus, his breath fogging the glass as he leans in. Subtle reflections move across the glass as steam rises.
Baker (whispering dramatically): “Today… I achieve perfection.”
He leans even closer, nose nearly touching the glass.
"""
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
height, width, num_frames = 512, 768, 121
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
seed=43,
height=height,
width=width,
num_frames=num_frames,
tiled=False,
)
write_video_audio_ltx2(
video=video,
audio=audio,
output_path='ltx2_onestage.mp4',
fps=24,
audio_sample_rate=24000,
)