mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support ltx2 one-stage pipeline
This commit is contained in:
@@ -4,13 +4,14 @@ from typing_extensions import Literal
|
||||
|
||||
class FlowMatchScheduler():
|
||||
|
||||
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image"] = "FLUX.1"):
|
||||
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2"] = "FLUX.1"):
|
||||
self.set_timesteps_fn = {
|
||||
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
|
||||
"Wan": FlowMatchScheduler.set_timesteps_wan,
|
||||
"Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image,
|
||||
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
|
||||
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
|
||||
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
|
||||
}.get(template, FlowMatchScheduler.set_timesteps_flux)
|
||||
self.num_train_timesteps = 1000
|
||||
|
||||
@@ -121,7 +122,30 @@ class FlowMatchScheduler():
|
||||
timestep_id = torch.argmin((timesteps - timestep).abs())
|
||||
timesteps[timestep_id] = timestep
|
||||
return sigmas, timesteps
|
||||
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, stretch=True, terminal=0.1):
|
||||
dynamic_shift_len = dynamic_shift_len or 4096
|
||||
sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image(
|
||||
image_seq_len=dynamic_shift_len,
|
||||
base_seq_len=1024,
|
||||
max_seq_len=4096,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
)
|
||||
num_train_timesteps = 1000
|
||||
sigma_min = 0.0
|
||||
sigma_max = 1.0
|
||||
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||
sigmas = math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1))
|
||||
# Shift terminal
|
||||
one_minus_z = 1.0 - sigmas
|
||||
scale_factor = one_minus_z[-1] / (1 - terminal)
|
||||
sigmas = 1.0 - (one_minus_z / scale_factor)
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
def set_training_weight(self):
|
||||
steps = 1000
|
||||
x = self.timesteps
|
||||
|
||||
@@ -337,3 +337,35 @@ class Patchifier(Protocol):
|
||||
Returns:
|
||||
Tensor containing patch coordinate metadata such as spatial or temporal intervals.
|
||||
"""
|
||||
|
||||
|
||||
def get_pixel_coords(
|
||||
latent_coords: torch.Tensor,
|
||||
scale_factors: SpatioTemporalScaleFactors,
|
||||
causal_fix: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
|
||||
each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
|
||||
Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
|
||||
Args:
|
||||
latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
|
||||
scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
|
||||
per axis.
|
||||
causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
|
||||
that treat frame zero differently still yield non-negative timestamps.
|
||||
"""
|
||||
# Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
|
||||
broadcast_shape = [1] * latent_coords.ndim
|
||||
broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
|
||||
scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
|
||||
|
||||
# Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
|
||||
pixel_coords = latent_coords * scale_tensor
|
||||
|
||||
if causal_fix:
|
||||
# VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
|
||||
# Shift and clamp to keep the first-frame timestamps causal and non-negative.
|
||||
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
|
||||
|
||||
return pixel_coords
|
||||
|
||||
@@ -514,7 +514,7 @@ class Attention(torch.nn.Module):
|
||||
out_pattern="b s n d",
|
||||
attn_mask=mask
|
||||
)
|
||||
|
||||
|
||||
# Reshape back to original format
|
||||
out = out.flatten(2, 3)
|
||||
return self.to_out(out)
|
||||
@@ -1398,7 +1398,7 @@ class LTXModel(torch.nn.Module):
|
||||
x = proj_out(x)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
def _forward(
|
||||
self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@@ -1440,3 +1440,9 @@ class LTXModel(torch.nn.Module):
|
||||
else None
|
||||
)
|
||||
return vx, ax
|
||||
|
||||
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps):
|
||||
video = Modality(video_latents, video_timesteps, video_positions, video_context)
|
||||
audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context)
|
||||
vx, ax = self._forward(video=video, audio=audio, perturbations=None)
|
||||
return vx, ax
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import itertools
|
||||
import math
|
||||
import einops
|
||||
from dataclasses import replace, dataclass
|
||||
from typing import Any, Callable, Iterator, List, NamedTuple, Tuple, Union, Optional
|
||||
import torch
|
||||
@@ -7,9 +8,138 @@ from einops import rearrange
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from enum import Enum
|
||||
from .ltx2_common import PixelNorm, SpatioTemporalScaleFactors, VideoLatentShape
|
||||
from .ltx2_common import PixelNorm, SpatioTemporalScaleFactors, VideoLatentShape, Patchifier, AudioLatentShape
|
||||
from .ltx2_dit import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
|
||||
VAE_SPATIAL_FACTOR = 32
|
||||
VAE_TEMPORAL_FACTOR = 8
|
||||
|
||||
|
||||
class VideoLatentPatchifier(Patchifier):
|
||||
def __init__(self, patch_size: int):
|
||||
# Patch sizes for video latents.
|
||||
self._patch_size = (
|
||||
1, # temporal dimension
|
||||
patch_size, # height dimension
|
||||
patch_size, # width dimension
|
||||
)
|
||||
|
||||
@property
|
||||
def patch_size(self) -> Tuple[int, int, int]:
|
||||
return self._patch_size
|
||||
|
||||
def get_token_count(self, tgt_shape: VideoLatentShape) -> int:
|
||||
return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size)
|
||||
|
||||
def patchify(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
latents = einops.rearrange(
|
||||
latents,
|
||||
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
||||
p1=self._patch_size[0],
|
||||
p2=self._patch_size[1],
|
||||
p3=self._patch_size[2],
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
def unpatchify(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
output_shape: VideoLatentShape,
|
||||
) -> torch.Tensor:
|
||||
assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier"
|
||||
|
||||
patch_grid_frames = output_shape.frames // self._patch_size[0]
|
||||
patch_grid_height = output_shape.height // self._patch_size[1]
|
||||
patch_grid_width = output_shape.width // self._patch_size[2]
|
||||
|
||||
latents = einops.rearrange(
|
||||
latents,
|
||||
"b (f h w) (c p q) -> b c f (h p) (w q)",
|
||||
f=patch_grid_frames,
|
||||
h=patch_grid_height,
|
||||
w=patch_grid_width,
|
||||
p=self._patch_size[1],
|
||||
q=self._patch_size[2],
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
def get_patch_grid_bounds(
|
||||
self,
|
||||
output_shape: AudioLatentShape | VideoLatentShape,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Return the per-dimension bounds [inclusive start, exclusive end) for every
|
||||
patch produced by `patchify`. The bounds are expressed in the original
|
||||
video grid coordinates: frame/time, height, and width.
|
||||
The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where:
|
||||
- axis 1 (size 3) enumerates (frame/time, height, width) dimensions
|
||||
- axis 3 (size 2) stores `[start, end)` indices within each dimension
|
||||
Args:
|
||||
output_shape: Video grid description containing frames, height, and width.
|
||||
device: Device of the latent tensor.
|
||||
"""
|
||||
if not isinstance(output_shape, VideoLatentShape):
|
||||
raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates")
|
||||
|
||||
frames = output_shape.frames
|
||||
height = output_shape.height
|
||||
width = output_shape.width
|
||||
batch_size = output_shape.batch
|
||||
|
||||
# Validate inputs to ensure positive dimensions
|
||||
assert frames > 0, f"frames must be positive, got {frames}"
|
||||
assert height > 0, f"height must be positive, got {height}"
|
||||
assert width > 0, f"width must be positive, got {width}"
|
||||
assert batch_size > 0, f"batch_size must be positive, got {batch_size}"
|
||||
|
||||
# Generate grid coordinates for each dimension (frame, height, width)
|
||||
# We use torch.arange to create the starting coordinates for each patch.
|
||||
# indexing='ij' ensures the dimensions are in the order (frame, height, width).
|
||||
grid_coords = torch.meshgrid(
|
||||
torch.arange(start=0, end=frames, step=self._patch_size[0], device=device),
|
||||
torch.arange(start=0, end=height, step=self._patch_size[1], device=device),
|
||||
torch.arange(start=0, end=width, step=self._patch_size[2], device=device),
|
||||
indexing="ij",
|
||||
)
|
||||
|
||||
# Stack the grid coordinates to create the start coordinates tensor.
|
||||
# Shape becomes (3, grid_f, grid_h, grid_w)
|
||||
patch_starts = torch.stack(grid_coords, dim=0)
|
||||
|
||||
# Create a tensor containing the size of a single patch:
|
||||
# (frame_patch_size, height_patch_size, width_patch_size).
|
||||
# Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates.
|
||||
patch_size_delta = torch.tensor(
|
||||
self._patch_size,
|
||||
device=patch_starts.device,
|
||||
dtype=patch_starts.dtype,
|
||||
).view(3, 1, 1, 1)
|
||||
|
||||
# Calculate end coordinates: start + patch_size
|
||||
# Shape becomes (3, grid_f, grid_h, grid_w)
|
||||
patch_ends = patch_starts + patch_size_delta
|
||||
|
||||
# Stack start and end coordinates together along the last dimension
|
||||
# Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end]
|
||||
latent_coords = torch.stack((patch_starts, patch_ends), dim=-1)
|
||||
|
||||
# Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence.
|
||||
# Final Shape: (batch_size, 3, num_patches, 2)
|
||||
latent_coords = einops.repeat(
|
||||
latent_coords,
|
||||
"c f h w bounds -> b c (f h w) bounds",
|
||||
b=batch_size,
|
||||
bounds=2,
|
||||
)
|
||||
|
||||
return latent_coords
|
||||
|
||||
|
||||
class NormLayerType(Enum):
|
||||
GROUP_NORM = "group_norm"
|
||||
@@ -1339,6 +1469,185 @@ class LTX2VideoEncoder(nn.Module):
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
|
||||
def tiled_encode_video(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
tile_size: int = 512,
|
||||
tile_overlap: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""Encode video using spatial tiling for memory efficiency.
|
||||
Splits the video into overlapping spatial tiles, encodes each tile separately,
|
||||
and blends the results using linear feathering in the overlap regions.
|
||||
Args:
|
||||
video: Input tensor of shape [B, C, F, H, W]
|
||||
tile_size: Tile size in pixels (must be divisible by 32)
|
||||
tile_overlap: Overlap between tiles in pixels (must be divisible by 32)
|
||||
Returns:
|
||||
Encoded latent tensor [B, C_latent, F_latent, H_latent, W_latent]
|
||||
"""
|
||||
batch, _channels, frames, height, width = video.shape
|
||||
device = video.device
|
||||
dtype = video.dtype
|
||||
|
||||
# Validate tile parameters
|
||||
if tile_size % VAE_SPATIAL_FACTOR != 0:
|
||||
raise ValueError(f"tile_size must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_size}")
|
||||
if tile_overlap % VAE_SPATIAL_FACTOR != 0:
|
||||
raise ValueError(f"tile_overlap must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_overlap}")
|
||||
if tile_overlap >= tile_size:
|
||||
raise ValueError(f"tile_overlap ({tile_overlap}) must be less than tile_size ({tile_size})")
|
||||
|
||||
# If video fits in a single tile, use regular encoding
|
||||
if height <= tile_size and width <= tile_size:
|
||||
return self.forward(video)
|
||||
|
||||
# Calculate output dimensions
|
||||
# VAE compresses: H -> H/32, W -> W/32, F -> 1 + (F-1)/8
|
||||
output_height = height // VAE_SPATIAL_FACTOR
|
||||
output_width = width // VAE_SPATIAL_FACTOR
|
||||
output_frames = 1 + (frames - 1) // VAE_TEMPORAL_FACTOR
|
||||
|
||||
# Latent channels (128 for LTX-2)
|
||||
# Get from a small test encode or assume 128
|
||||
latent_channels = 128
|
||||
|
||||
# Initialize output and weight tensors
|
||||
output = torch.zeros(
|
||||
(batch, latent_channels, output_frames, output_height, output_width),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
weights = torch.zeros(
|
||||
(batch, 1, output_frames, output_height, output_width),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Calculate tile positions with overlap
|
||||
# Step size is tile_size - tile_overlap
|
||||
step_h = tile_size - tile_overlap
|
||||
step_w = tile_size - tile_overlap
|
||||
|
||||
h_positions = list(range(0, max(1, height - tile_overlap), step_h))
|
||||
w_positions = list(range(0, max(1, width - tile_overlap), step_w))
|
||||
|
||||
# Ensure last tile covers the edge
|
||||
if h_positions[-1] + tile_size < height:
|
||||
h_positions.append(height - tile_size)
|
||||
if w_positions[-1] + tile_size < width:
|
||||
w_positions.append(width - tile_size)
|
||||
|
||||
# Remove duplicates and sort
|
||||
h_positions = sorted(set(h_positions))
|
||||
w_positions = sorted(set(w_positions))
|
||||
|
||||
# Overlap in latent space
|
||||
overlap_out_h = tile_overlap // VAE_SPATIAL_FACTOR
|
||||
overlap_out_w = tile_overlap // VAE_SPATIAL_FACTOR
|
||||
|
||||
# Process each tile
|
||||
for h_pos in h_positions:
|
||||
for w_pos in w_positions:
|
||||
# Calculate tile boundaries in input space
|
||||
h_start = max(0, h_pos)
|
||||
w_start = max(0, w_pos)
|
||||
h_end = min(h_start + tile_size, height)
|
||||
w_end = min(w_start + tile_size, width)
|
||||
|
||||
# Ensure tile dimensions are divisible by VAE_SPATIAL_FACTOR
|
||||
tile_h = ((h_end - h_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR
|
||||
tile_w = ((w_end - w_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR
|
||||
|
||||
if tile_h < VAE_SPATIAL_FACTOR or tile_w < VAE_SPATIAL_FACTOR:
|
||||
continue
|
||||
|
||||
# Adjust end positions
|
||||
h_end = h_start + tile_h
|
||||
w_end = w_start + tile_w
|
||||
|
||||
# Extract tile
|
||||
tile = video[:, :, :, h_start:h_end, w_start:w_end]
|
||||
|
||||
# Encode tile
|
||||
encoded_tile = self.forward(tile)
|
||||
|
||||
# Get actual encoded dimensions
|
||||
_, _, tile_out_frames, tile_out_height, tile_out_width = encoded_tile.shape
|
||||
|
||||
# Calculate output positions
|
||||
out_h_start = h_start // VAE_SPATIAL_FACTOR
|
||||
out_w_start = w_start // VAE_SPATIAL_FACTOR
|
||||
out_h_end = min(out_h_start + tile_out_height, output_height)
|
||||
out_w_end = min(out_w_start + tile_out_width, output_width)
|
||||
|
||||
# Trim encoded tile if necessary
|
||||
actual_tile_h = out_h_end - out_h_start
|
||||
actual_tile_w = out_w_end - out_w_start
|
||||
encoded_tile = encoded_tile[:, :, :, :actual_tile_h, :actual_tile_w]
|
||||
|
||||
# Create blending mask with linear feathering at edges
|
||||
mask = torch.ones(
|
||||
(1, 1, tile_out_frames, actual_tile_h, actual_tile_w),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Apply feathering at edges (linear blend in overlap regions)
|
||||
# Left edge
|
||||
if h_pos > 0 and overlap_out_h > 0 and overlap_out_h < actual_tile_h:
|
||||
fade_in = torch.linspace(0.0, 1.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1]
|
||||
mask[:, :, :, :overlap_out_h, :] *= fade_in.view(1, 1, 1, -1, 1)
|
||||
|
||||
# Right edge (bottom in height dimension)
|
||||
if h_end < height and overlap_out_h > 0 and overlap_out_h < actual_tile_h:
|
||||
fade_out = torch.linspace(1.0, 0.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1]
|
||||
mask[:, :, :, -overlap_out_h:, :] *= fade_out.view(1, 1, 1, -1, 1)
|
||||
|
||||
# Top edge (left in width dimension)
|
||||
if w_pos > 0 and overlap_out_w > 0 and overlap_out_w < actual_tile_w:
|
||||
fade_in = torch.linspace(0.0, 1.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1]
|
||||
mask[:, :, :, :, :overlap_out_w] *= fade_in.view(1, 1, 1, 1, -1)
|
||||
|
||||
# Bottom edge (right in width dimension)
|
||||
if w_end < width and overlap_out_w > 0 and overlap_out_w < actual_tile_w:
|
||||
fade_out = torch.linspace(1.0, 0.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1]
|
||||
mask[:, :, :, :, -overlap_out_w:] *= fade_out.view(1, 1, 1, 1, -1)
|
||||
|
||||
# Accumulate weighted results
|
||||
output[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += encoded_tile * mask
|
||||
weights[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += mask
|
||||
|
||||
# Normalize by weights (avoid division by zero)
|
||||
output = output / (weights + 1e-8)
|
||||
|
||||
return output
|
||||
|
||||
def encode(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
tiled=False,
|
||||
tile_size_in_pixels: Optional[int] = 512,
|
||||
tile_overlap_in_pixels: Optional[int] = 128,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
device = next(self.parameters()).device
|
||||
vae_dtype = next(self.parameters()).dtype
|
||||
if video.ndim == 4:
|
||||
video = video.unsqueeze(0) # [C, F, H, W] -> [B, C, F, H, W]
|
||||
video = video.to(device=device, dtype=vae_dtype)
|
||||
# Choose encoding method based on tiling flag
|
||||
if tiled:
|
||||
latents = self.tiled_encode_video(
|
||||
video=video,
|
||||
tile_size=tile_size_in_pixels,
|
||||
tile_overlap=tile_overlap_in_pixels,
|
||||
)
|
||||
else:
|
||||
# Encode video - VAE expects [B, C, F, H, W], returns [B, C, F', H', W']
|
||||
latents = self.forward(video)
|
||||
return latents
|
||||
|
||||
|
||||
def _make_decoder_block(
|
||||
block_name: str,
|
||||
block_config: dict[str, Any],
|
||||
@@ -1850,6 +2159,30 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
return weights
|
||||
|
||||
def decode(
|
||||
self,
|
||||
latent: torch.Tensor,
|
||||
tiled=False,
|
||||
tile_size_in_pixels: Optional[int] = 512,
|
||||
tile_overlap_in_pixels: Optional[int] = 128,
|
||||
tile_size_in_frames: Optional[int] = 128,
|
||||
tile_overlap_in_frames: Optional[int] = 24,
|
||||
) -> torch.Tensor:
|
||||
if tiled:
|
||||
tiling_config = TilingConfig(
|
||||
spatial_config=SpatialTilingConfig(
|
||||
tile_size_in_pixels=tile_size_in_pixels,
|
||||
tile_overlap_in_pixels=tile_overlap_in_pixels,
|
||||
),
|
||||
temporal_config=TemporalTilingConfig(
|
||||
tile_size_in_frames=tile_size_in_frames,
|
||||
tile_overlap_in_frames=tile_overlap_in_frames,
|
||||
),
|
||||
)
|
||||
tiles = self.tiled_decode(latent, tiling_config)
|
||||
return torch.cat(list(tiles), dim=2)
|
||||
else:
|
||||
return self.forward(latent)
|
||||
|
||||
def decode_video(
|
||||
latent: torch.Tensor,
|
||||
@@ -1875,10 +2208,10 @@ def decode_video(
|
||||
|
||||
if tiling_config is not None:
|
||||
for frames in video_decoder.tiled_decode(latent, tiling_config, generator=generator):
|
||||
yield convert_to_uint8(frames)
|
||||
return convert_to_uint8(frames)
|
||||
else:
|
||||
decoded_video = video_decoder(latent, generator=generator)
|
||||
yield convert_to_uint8(decoded_video)
|
||||
return convert_to_uint8(decoded_video)
|
||||
|
||||
|
||||
def get_video_chunks_number(num_frames: int, tiling_config: TilingConfig | None = None) -> int:
|
||||
|
||||
451
diffsynth/pipelines/ltx2_audio_video.py
Normal file
451
diffsynth/pipelines/ltx2_audio_video.py
Normal file
@@ -0,0 +1,451 @@
|
||||
import torch, types
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from einops import repeat
|
||||
from typing import Optional, Union
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from typing import Optional
|
||||
from typing_extensions import Literal
|
||||
from transformers import AutoImageProcessor, Gemma3Processor
|
||||
import einops
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
|
||||
from ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer
|
||||
from ..models.ltx2_dit import LTXModel
|
||||
from ..models.ltx2_video_vae import LTX2VideoEncoder, LTX2VideoDecoder, VideoLatentPatchifier
|
||||
from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier
|
||||
from ..models.ltx2_upsampler import LTX2LatentUpsampler
|
||||
from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS
|
||||
|
||||
|
||||
class LTX2AudioVideoPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device,
|
||||
torch_dtype=torch_dtype,
|
||||
height_division_factor=32,
|
||||
width_division_factor=32,
|
||||
time_division_factor=8,
|
||||
time_division_remainder=1,
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler("LTX-2")
|
||||
self.text_encoder: LTX2TextEncoder = None
|
||||
self.tokenizer: LTXVGemmaTokenizer = None
|
||||
self.processor: Gemma3Processor = None
|
||||
self.text_encoder_post_modules: LTX2TextEncoderPostModules = None
|
||||
self.dit: LTXModel = None
|
||||
self.video_vae_encoder: LTX2VideoEncoder = None
|
||||
self.video_vae_decoder: LTX2VideoDecoder = None
|
||||
self.audio_vae_encoder: LTX2AudioEncoder = None
|
||||
self.audio_vae_decoder: LTX2AudioDecoder = None
|
||||
self.audio_vocoder: LTX2Vocoder = None
|
||||
self.upsampler: LTX2LatentUpsampler = None
|
||||
|
||||
self.video_patchifier: VideoLatentPatchifier = VideoLatentPatchifier(patch_size=1)
|
||||
self.audio_patchifier: AudioPatchifier = AudioPatchifier(patch_size=1)
|
||||
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
LTX2AudioVideoUnit_PipelineChecker(),
|
||||
LTX2AudioVideoUnit_ShapeChecker(),
|
||||
LTX2AudioVideoUnit_PromptEmbedder(),
|
||||
LTX2AudioVideoUnit_NoiseInitializer(),
|
||||
LTX2AudioVideoUnit_InputVideoEmbedder(),
|
||||
]
|
||||
self.post_units = [
|
||||
LTX2AudioVideoPostUnit_UnPatchifier(),
|
||||
]
|
||||
self.model_fn = model_fn_ltx2
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = get_device_type(),
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
vram_limit: float = None,
|
||||
):
|
||||
# Initialize pipeline
|
||||
pipe = LTX2AudioVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
# Fetch models
|
||||
pipe.text_encoder = model_pool.fetch_model("ltx2_text_encoder")
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = LTXVGemmaTokenizer(tokenizer_path=tokenizer_config.path)
|
||||
image_processor = AutoImageProcessor.from_pretrained(tokenizer_config.path, local_files_only=True)
|
||||
pipe.processor = Gemma3Processor(image_processor=image_processor, tokenizer=pipe.tokenizer.tokenizer)
|
||||
|
||||
pipe.text_encoder_post_modules = model_pool.fetch_model("ltx2_text_encoder_post_modules")
|
||||
pipe.dit = model_pool.fetch_model("ltx2_dit")
|
||||
pipe.video_vae_encoder = model_pool.fetch_model("ltx2_video_vae_encoder")
|
||||
pipe.video_vae_decoder = model_pool.fetch_model("ltx2_video_vae_decoder")
|
||||
pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder")
|
||||
pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder")
|
||||
pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler")
|
||||
# Optional
|
||||
# pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = "",
|
||||
# Image-to-video
|
||||
input_image: Optional[Image.Image] = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Randomness
|
||||
seed: Optional[int] = None,
|
||||
rand_device: Optional[str] = "cpu",
|
||||
# Shape
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 768,
|
||||
num_frames=121,
|
||||
# Classifier-free guidance
|
||||
cfg_scale: Optional[float] = 3.0,
|
||||
cfg_merge: Optional[bool] = False,
|
||||
# Scheduler
|
||||
num_inference_steps: Optional[int] = 40,
|
||||
# VAE tiling
|
||||
tiled: Optional[bool] = True,
|
||||
tile_size_in_pixels: Optional[int] = 512,
|
||||
tile_overlap_in_pixels: Optional[int] = 128,
|
||||
tile_size_in_frames: Optional[int] = 128,
|
||||
tile_overlap_in_frames: Optional[int] = 24,
|
||||
# Two-Stage Pipeline
|
||||
use_two_stage: Optional[bool] = True,
|
||||
# progress_bar
|
||||
progress_bar_cmd=tqdm,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
|
||||
|
||||
# Inputs
|
||||
inputs_posi = {
|
||||
"prompt": prompt,
|
||||
}
|
||||
inputs_nega = {
|
||||
"negative_prompt": negative_prompt,
|
||||
}
|
||||
inputs_shared = {
|
||||
"input_image": input_image,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"height": height, "width": width, "num_frames": num_frames,
|
||||
"cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
|
||||
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
|
||||
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
|
||||
"use_two_stage": True
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
|
||||
# inputs_posi.update(torch.load("/mnt/nas1/zhanghong/project26/extern_codes/LTX-2/text_encodings.pt"))
|
||||
# inputs_nega.update(torch.load("/mnt/nas1/zhanghong/project26/extern_codes/LTX-2/negative_text_encodings.pt"))
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video, **inputs_shared)
|
||||
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared)
|
||||
|
||||
# post-denoising, pre-decoding processing logic
|
||||
for unit in self.post_units:
|
||||
inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
# Decode
|
||||
self.load_models_to_device(['video_vae_decoder'])
|
||||
video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels,
|
||||
tile_overlap_in_pixels, tile_size_in_frames, tile_overlap_in_frames)
|
||||
video = self.vae_output_to_video(video)
|
||||
self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder'])
|
||||
decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"])
|
||||
decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float()
|
||||
return video, decoded_audio
|
||||
|
||||
|
||||
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
|
||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||
self.clear_lora(verbose=0)
|
||||
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
|
||||
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
|
||||
if cfg_scale != 1.0:
|
||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||
self.clear_lora(verbose=0)
|
||||
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
||||
if isinstance(noise_pred_posi, tuple):
|
||||
noise_pred = tuple(
|
||||
n_nega + cfg_scale * (n_posi - n_nega)
|
||||
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
|
||||
)
|
||||
else:
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
return noise_pred
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(take_over=True)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
pass
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_ShapeChecker(PipelineUnit):
|
||||
"""
|
||||
# TODO: Adjust with two stage pipeline
|
||||
For two-stage pipelines, the resolution must be divisible by 64.
|
||||
For one-stage pipelines, the resolution must be divisible by 32.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "num_frames"),
|
||||
output_params=("height", "width", "num_frames"),
|
||||
)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames):
|
||||
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
|
||||
return {"height": height, "width": width, "num_frames": num_frames}
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
output_params=("video_context", "audio_context"),
|
||||
onload_model_names=("text_encoder", "text_encoder_post_modules"),
|
||||
)
|
||||
|
||||
def _convert_to_additive_mask(self, attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
return (attention_mask - 1).to(dtype).reshape(
|
||||
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(dtype).max
|
||||
|
||||
def _run_connectors(self, pipe, encoded_input: torch.Tensor,
|
||||
attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
connector_attention_mask = self._convert_to_additive_mask(attention_mask, encoded_input.dtype)
|
||||
|
||||
encoded, encoded_connector_attention_mask = pipe.text_encoder_post_modules.embeddings_connector(
|
||||
encoded_input,
|
||||
connector_attention_mask,
|
||||
)
|
||||
|
||||
# restore the mask values to int64
|
||||
attention_mask = (encoded_connector_attention_mask < 0.000001).to(torch.int64)
|
||||
attention_mask = attention_mask.reshape([encoded.shape[0], encoded.shape[1], 1])
|
||||
encoded = encoded * attention_mask
|
||||
|
||||
encoded_for_audio, _ = pipe.text_encoder_post_modules.audio_embeddings_connector(
|
||||
encoded_input, connector_attention_mask)
|
||||
|
||||
return encoded, encoded_for_audio, attention_mask.squeeze(-1)
|
||||
|
||||
def _norm_and_concat_padded_batch(
|
||||
self,
|
||||
encoded_text: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
padding_side: str = "right",
|
||||
) -> torch.Tensor:
|
||||
"""Normalize and flatten multi-layer hidden states, respecting padding.
|
||||
Performs per-batch, per-layer normalization using masked mean and range,
|
||||
then concatenates across the layer dimension.
|
||||
Args:
|
||||
encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].
|
||||
sequence_lengths: Number of valid (non-padded) tokens per batch item.
|
||||
padding_side: Whether padding is on "left" or "right".
|
||||
Returns:
|
||||
Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],
|
||||
with padded positions zeroed out.
|
||||
"""
|
||||
b, t, d, l = encoded_text.shape # noqa: E741
|
||||
device = encoded_text.device
|
||||
# Build mask: [B, T, 1, 1]
|
||||
token_indices = torch.arange(t, device=device)[None, :] # [1, T]
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [B, T]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = t - sequence_lengths[:, None] # [B, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = rearrange(mask, "b t -> b t 1 1")
|
||||
eps = 1e-6
|
||||
# Compute masked mean: [B, 1, 1, L]
|
||||
masked = encoded_text.masked_fill(~mask, 0.0)
|
||||
denom = (sequence_lengths * d).view(b, 1, 1, 1)
|
||||
mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps)
|
||||
# Compute masked min/max: [B, 1, 1, L]
|
||||
x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
range_ = x_max - x_min
|
||||
# Normalize only the valid tokens
|
||||
normed = 8 * (encoded_text - mean) / (range_ + eps)
|
||||
# concat to be [Batch, T, D * L] - this preserves the original structure
|
||||
normed = normed.reshape(b, t, -1) # [B, T, D * L]
|
||||
# Apply mask to preserve original padding (set padded positions to 0)
|
||||
mask_flattened = rearrange(mask, "b t 1 1 -> b t 1").expand(-1, -1, d * l)
|
||||
normed = normed.masked_fill(~mask_flattened, 0.0)
|
||||
|
||||
return normed
|
||||
|
||||
def _run_feature_extractor(self,
|
||||
pipe,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
padding_side: str = "right") -> torch.Tensor:
|
||||
encoded_text_features = torch.stack(hidden_states, dim=-1)
|
||||
encoded_text_features_dtype = encoded_text_features.dtype
|
||||
sequence_lengths = attention_mask.sum(dim=-1)
|
||||
normed_concated_encoded_text_features = self._norm_and_concat_padded_batch(encoded_text_features,
|
||||
sequence_lengths,
|
||||
padding_side=padding_side)
|
||||
|
||||
return pipe.text_encoder_post_modules.feature_extractor_linear(
|
||||
normed_concated_encoded_text_features.to(encoded_text_features_dtype))
|
||||
|
||||
def _preprocess_text(self,
|
||||
pipe,
|
||||
text: str,
|
||||
padding_side: str = "left") -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Encode a given string into feature tensors suitable for downstream tasks.
|
||||
Args:
|
||||
text (str): Input string to encode.
|
||||
Returns:
|
||||
tuple[torch.Tensor, dict[str, torch.Tensor]]: Encoded features and a dictionary with attention mask.
|
||||
"""
|
||||
token_pairs = pipe.tokenizer.tokenize_with_weights(text)["gemma"]
|
||||
input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.text_encoder.device)
|
||||
attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.text_encoder.device)
|
||||
outputs = pipe.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
||||
projected = self._run_feature_extractor(pipe,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
padding_side=padding_side)
|
||||
return projected, attention_mask
|
||||
|
||||
def encode_prompt(self, pipe, text, padding_side="left"):
|
||||
encoded_inputs, attention_mask = self._preprocess_text(pipe, text, padding_side)
|
||||
video_encoding, audio_encoding, attention_mask = self._run_connectors(pipe, encoded_inputs, attention_mask)
|
||||
return video_encoding, audio_encoding, attention_mask
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, prompt: str):
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
video_context, audio_context, _ = self.encode_prompt(pipe, prompt)
|
||||
return {"video_context": video_context, "audio_context": audio_context}
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "num_frames", "seed", "rand_device",),
|
||||
output_params=("video_noise", "audio_noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
|
||||
video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
|
||||
video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=pipe.video_vae_encoder.latent_channels)
|
||||
video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
|
||||
video_noise = pipe.video_patchifier.patchify(video_noise)
|
||||
|
||||
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device)
|
||||
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float()
|
||||
video_positions[:, 0, ...] = video_positions[:, 0, ...] / frame_rate
|
||||
video_positions = video_positions.to(pipe.torch_dtype)
|
||||
|
||||
audio_latent_shape = AudioLatentShape.from_video_pixel_shape(video_pixel_shape)
|
||||
audio_noise = pipe.generate_noise(audio_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
|
||||
audio_noise = pipe.audio_patchifier.patchify(audio_noise)
|
||||
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
|
||||
return {
|
||||
"video_noise": video_noise,
|
||||
"audio_noise": audio_noise,
|
||||
"video_positions": video_positions,
|
||||
"audio_positions": audio_positions,
|
||||
"video_latent_shape": video_latent_shape,
|
||||
"audio_latent_shape": audio_latent_shape
|
||||
}
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_video", "video_noise", "audio_noise", "tiled", "tile_size", "tile_stride"),
|
||||
output_params=("video_latents", "audio_latents"),
|
||||
onload_model_names=("video_vae_encoder")
|
||||
)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, audio_noise, tiled, tile_size, tile_stride):
|
||||
if input_video is None:
|
||||
return {"video_latents": video_noise, "audio_latents": audio_noise}
|
||||
else:
|
||||
# TODO: implement video-to-video
|
||||
raise NotImplementedError("Video-to-video not implemented yet.")
|
||||
|
||||
|
||||
class LTX2AudioVideoPostUnit_UnPatchifier(PipelineUnit):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("video_latent_shape", "audio_latent_shape", "video_latents", "audio_latents"),
|
||||
output_params=("video_latents", "audio_latents"),
|
||||
)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, video_latent_shape, audio_latent_shape, video_latents, audio_latents):
|
||||
video_latents = pipe.video_patchifier.unpatchify(video_latents, output_shape=video_latent_shape)
|
||||
audio_latents = pipe.audio_patchifier.unpatchify(audio_latents, output_shape=audio_latent_shape)
|
||||
return {"video_latents": video_latents, "audio_latents": audio_latents}
|
||||
|
||||
|
||||
def model_fn_ltx2(
|
||||
dit: LTXModel,
|
||||
video_latents=None,
|
||||
video_context=None,
|
||||
video_positions=None,
|
||||
audio_latents=None,
|
||||
audio_context=None,
|
||||
audio_positions=None,
|
||||
timestep=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
#TODO: support gradient checkpointing
|
||||
timestep = timestep.float() / 1000.
|
||||
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
|
||||
audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1)
|
||||
vx, ax = dit(
|
||||
video_latents=video_latents,
|
||||
video_positions=video_positions,
|
||||
video_context=video_context,
|
||||
video_timesteps=video_timesteps,
|
||||
audio_latents=audio_latents,
|
||||
audio_positions=audio_positions,
|
||||
audio_context=audio_context,
|
||||
audio_timesteps=audio_timesteps,
|
||||
)
|
||||
return vx, ax
|
||||
106
diffsynth/utils/data/media_io.py
Normal file
106
diffsynth/utils/data/media_io.py
Normal file
@@ -0,0 +1,106 @@
|
||||
|
||||
from fractions import Fraction
|
||||
import torch
|
||||
import av
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def _resample_audio(
|
||||
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
|
||||
) -> None:
|
||||
cc = audio_stream.codec_context
|
||||
|
||||
# Use the encoder's format/layout/rate as the *target*
|
||||
target_format = cc.format or "fltp" # AAC → usually fltp
|
||||
target_layout = cc.layout or "stereo"
|
||||
target_rate = cc.sample_rate or frame_in.sample_rate
|
||||
|
||||
audio_resampler = av.audio.resampler.AudioResampler(
|
||||
format=target_format,
|
||||
layout=target_layout,
|
||||
rate=target_rate,
|
||||
)
|
||||
|
||||
audio_next_pts = 0
|
||||
for rframe in audio_resampler.resample(frame_in):
|
||||
if rframe.pts is None:
|
||||
rframe.pts = audio_next_pts
|
||||
audio_next_pts += rframe.samples
|
||||
rframe.sample_rate = frame_in.sample_rate
|
||||
container.mux(audio_stream.encode(rframe))
|
||||
|
||||
# flush audio encoder
|
||||
for packet in audio_stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
|
||||
def _write_audio(
|
||||
container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int
|
||||
) -> None:
|
||||
if samples.ndim == 1:
|
||||
samples = samples[:, None]
|
||||
|
||||
if samples.shape[1] != 2 and samples.shape[0] == 2:
|
||||
samples = samples.T
|
||||
|
||||
if samples.shape[1] != 2:
|
||||
raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
|
||||
|
||||
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
|
||||
if samples.dtype != torch.int16:
|
||||
samples = torch.clip(samples, -1.0, 1.0)
|
||||
samples = (samples * 32767.0).to(torch.int16)
|
||||
|
||||
frame_in = av.AudioFrame.from_ndarray(
|
||||
samples.contiguous().reshape(1, -1).cpu().numpy(),
|
||||
format="s16",
|
||||
layout="stereo",
|
||||
)
|
||||
frame_in.sample_rate = audio_sample_rate
|
||||
|
||||
_resample_audio(container, audio_stream, frame_in)
|
||||
|
||||
|
||||
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
|
||||
"""
|
||||
Prepare the audio stream for writing.
|
||||
"""
|
||||
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
|
||||
audio_stream.codec_context.sample_rate = audio_sample_rate
|
||||
audio_stream.codec_context.layout = "stereo"
|
||||
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
|
||||
return audio_stream
|
||||
|
||||
def write_video_audio_ltx2(
|
||||
video: list[Image.Image],
|
||||
audio: torch.Tensor | None,
|
||||
output_path: str,
|
||||
fps: int = 24,
|
||||
audio_sample_rate: int | None = 24000,
|
||||
) -> None:
|
||||
|
||||
width, height = video[0].size
|
||||
container = av.open(output_path, mode="w")
|
||||
stream = container.add_stream("libx264", rate=int(fps))
|
||||
stream.width = width
|
||||
stream.height = height
|
||||
stream.pix_fmt = "yuv420p"
|
||||
|
||||
if audio is not None:
|
||||
if audio_sample_rate is None:
|
||||
raise ValueError("audio_sample_rate is required when audio is provided")
|
||||
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
|
||||
|
||||
for frame in tqdm(video, total=len(video)):
|
||||
frame = av.VideoFrame.from_image(frame)
|
||||
for packet in stream.encode(frame):
|
||||
container.mux(packet)
|
||||
|
||||
# Flush encoder
|
||||
for packet in stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
if audio is not None:
|
||||
_write_audio(container, audio_stream, audio, audio_sample_rate)
|
||||
|
||||
container.close()
|
||||
46
examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py
Normal file
46
examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io import write_video_audio_ltx2
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
)
|
||||
prompt = """
|
||||
INT. OVEN – DAY. Static camera from inside the oven, looking outward through the slightly fogged glass door. Warm golden light glows around freshly baked cookies. The baker’s face fills the frame, eyes wide with focus, his breath fogging the glass as he leans in. Subtle reflections move across the glass as steam rises.
|
||||
Baker (whispering dramatically): “Today… I achieve perfection.”
|
||||
He leans even closer, nose nearly touching the glass.
|
||||
"""
|
||||
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
height, width, num_frames = 512, 768, 121
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=False,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_onestage.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=24000,
|
||||
)
|
||||
Reference in New Issue
Block a user