mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-21 08:08:13 +00:00
support ltx2 one-stage pipeline
This commit is contained in:
@@ -4,13 +4,14 @@ from typing_extensions import Literal
|
||||
|
||||
class FlowMatchScheduler():
|
||||
|
||||
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image"] = "FLUX.1"):
|
||||
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2"] = "FLUX.1"):
|
||||
self.set_timesteps_fn = {
|
||||
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
|
||||
"Wan": FlowMatchScheduler.set_timesteps_wan,
|
||||
"Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image,
|
||||
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
|
||||
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
|
||||
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
|
||||
}.get(template, FlowMatchScheduler.set_timesteps_flux)
|
||||
self.num_train_timesteps = 1000
|
||||
|
||||
@@ -121,7 +122,30 @@ class FlowMatchScheduler():
|
||||
timestep_id = torch.argmin((timesteps - timestep).abs())
|
||||
timesteps[timestep_id] = timestep
|
||||
return sigmas, timesteps
|
||||
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, stretch=True, terminal=0.1):
|
||||
dynamic_shift_len = dynamic_shift_len or 4096
|
||||
sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image(
|
||||
image_seq_len=dynamic_shift_len,
|
||||
base_seq_len=1024,
|
||||
max_seq_len=4096,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
)
|
||||
num_train_timesteps = 1000
|
||||
sigma_min = 0.0
|
||||
sigma_max = 1.0
|
||||
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||
sigmas = math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1))
|
||||
# Shift terminal
|
||||
one_minus_z = 1.0 - sigmas
|
||||
scale_factor = one_minus_z[-1] / (1 - terminal)
|
||||
sigmas = 1.0 - (one_minus_z / scale_factor)
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
def set_training_weight(self):
|
||||
steps = 1000
|
||||
x = self.timesteps
|
||||
|
||||
@@ -337,3 +337,35 @@ class Patchifier(Protocol):
|
||||
Returns:
|
||||
Tensor containing patch coordinate metadata such as spatial or temporal intervals.
|
||||
"""
|
||||
|
||||
|
||||
def get_pixel_coords(
|
||||
latent_coords: torch.Tensor,
|
||||
scale_factors: SpatioTemporalScaleFactors,
|
||||
causal_fix: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
|
||||
each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
|
||||
Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
|
||||
Args:
|
||||
latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
|
||||
scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
|
||||
per axis.
|
||||
causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
|
||||
that treat frame zero differently still yield non-negative timestamps.
|
||||
"""
|
||||
# Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
|
||||
broadcast_shape = [1] * latent_coords.ndim
|
||||
broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
|
||||
scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
|
||||
|
||||
# Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
|
||||
pixel_coords = latent_coords * scale_tensor
|
||||
|
||||
if causal_fix:
|
||||
# VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
|
||||
# Shift and clamp to keep the first-frame timestamps causal and non-negative.
|
||||
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
|
||||
|
||||
return pixel_coords
|
||||
|
||||
@@ -514,7 +514,7 @@ class Attention(torch.nn.Module):
|
||||
out_pattern="b s n d",
|
||||
attn_mask=mask
|
||||
)
|
||||
|
||||
|
||||
# Reshape back to original format
|
||||
out = out.flatten(2, 3)
|
||||
return self.to_out(out)
|
||||
@@ -1398,7 +1398,7 @@ class LTXModel(torch.nn.Module):
|
||||
x = proj_out(x)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
def _forward(
|
||||
self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@@ -1440,3 +1440,9 @@ class LTXModel(torch.nn.Module):
|
||||
else None
|
||||
)
|
||||
return vx, ax
|
||||
|
||||
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps):
|
||||
video = Modality(video_latents, video_timesteps, video_positions, video_context)
|
||||
audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context)
|
||||
vx, ax = self._forward(video=video, audio=audio, perturbations=None)
|
||||
return vx, ax
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import itertools
|
||||
import math
|
||||
import einops
|
||||
from dataclasses import replace, dataclass
|
||||
from typing import Any, Callable, Iterator, List, NamedTuple, Tuple, Union, Optional
|
||||
import torch
|
||||
@@ -7,9 +8,138 @@ from einops import rearrange
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from enum import Enum
|
||||
from .ltx2_common import PixelNorm, SpatioTemporalScaleFactors, VideoLatentShape
|
||||
from .ltx2_common import PixelNorm, SpatioTemporalScaleFactors, VideoLatentShape, Patchifier, AudioLatentShape
|
||||
from .ltx2_dit import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
|
||||
VAE_SPATIAL_FACTOR = 32
|
||||
VAE_TEMPORAL_FACTOR = 8
|
||||
|
||||
|
||||
class VideoLatentPatchifier(Patchifier):
|
||||
def __init__(self, patch_size: int):
|
||||
# Patch sizes for video latents.
|
||||
self._patch_size = (
|
||||
1, # temporal dimension
|
||||
patch_size, # height dimension
|
||||
patch_size, # width dimension
|
||||
)
|
||||
|
||||
@property
|
||||
def patch_size(self) -> Tuple[int, int, int]:
|
||||
return self._patch_size
|
||||
|
||||
def get_token_count(self, tgt_shape: VideoLatentShape) -> int:
|
||||
return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size)
|
||||
|
||||
def patchify(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
latents = einops.rearrange(
|
||||
latents,
|
||||
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
||||
p1=self._patch_size[0],
|
||||
p2=self._patch_size[1],
|
||||
p3=self._patch_size[2],
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
def unpatchify(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
output_shape: VideoLatentShape,
|
||||
) -> torch.Tensor:
|
||||
assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier"
|
||||
|
||||
patch_grid_frames = output_shape.frames // self._patch_size[0]
|
||||
patch_grid_height = output_shape.height // self._patch_size[1]
|
||||
patch_grid_width = output_shape.width // self._patch_size[2]
|
||||
|
||||
latents = einops.rearrange(
|
||||
latents,
|
||||
"b (f h w) (c p q) -> b c f (h p) (w q)",
|
||||
f=patch_grid_frames,
|
||||
h=patch_grid_height,
|
||||
w=patch_grid_width,
|
||||
p=self._patch_size[1],
|
||||
q=self._patch_size[2],
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
def get_patch_grid_bounds(
|
||||
self,
|
||||
output_shape: AudioLatentShape | VideoLatentShape,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Return the per-dimension bounds [inclusive start, exclusive end) for every
|
||||
patch produced by `patchify`. The bounds are expressed in the original
|
||||
video grid coordinates: frame/time, height, and width.
|
||||
The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where:
|
||||
- axis 1 (size 3) enumerates (frame/time, height, width) dimensions
|
||||
- axis 3 (size 2) stores `[start, end)` indices within each dimension
|
||||
Args:
|
||||
output_shape: Video grid description containing frames, height, and width.
|
||||
device: Device of the latent tensor.
|
||||
"""
|
||||
if not isinstance(output_shape, VideoLatentShape):
|
||||
raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates")
|
||||
|
||||
frames = output_shape.frames
|
||||
height = output_shape.height
|
||||
width = output_shape.width
|
||||
batch_size = output_shape.batch
|
||||
|
||||
# Validate inputs to ensure positive dimensions
|
||||
assert frames > 0, f"frames must be positive, got {frames}"
|
||||
assert height > 0, f"height must be positive, got {height}"
|
||||
assert width > 0, f"width must be positive, got {width}"
|
||||
assert batch_size > 0, f"batch_size must be positive, got {batch_size}"
|
||||
|
||||
# Generate grid coordinates for each dimension (frame, height, width)
|
||||
# We use torch.arange to create the starting coordinates for each patch.
|
||||
# indexing='ij' ensures the dimensions are in the order (frame, height, width).
|
||||
grid_coords = torch.meshgrid(
|
||||
torch.arange(start=0, end=frames, step=self._patch_size[0], device=device),
|
||||
torch.arange(start=0, end=height, step=self._patch_size[1], device=device),
|
||||
torch.arange(start=0, end=width, step=self._patch_size[2], device=device),
|
||||
indexing="ij",
|
||||
)
|
||||
|
||||
# Stack the grid coordinates to create the start coordinates tensor.
|
||||
# Shape becomes (3, grid_f, grid_h, grid_w)
|
||||
patch_starts = torch.stack(grid_coords, dim=0)
|
||||
|
||||
# Create a tensor containing the size of a single patch:
|
||||
# (frame_patch_size, height_patch_size, width_patch_size).
|
||||
# Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates.
|
||||
patch_size_delta = torch.tensor(
|
||||
self._patch_size,
|
||||
device=patch_starts.device,
|
||||
dtype=patch_starts.dtype,
|
||||
).view(3, 1, 1, 1)
|
||||
|
||||
# Calculate end coordinates: start + patch_size
|
||||
# Shape becomes (3, grid_f, grid_h, grid_w)
|
||||
patch_ends = patch_starts + patch_size_delta
|
||||
|
||||
# Stack start and end coordinates together along the last dimension
|
||||
# Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end]
|
||||
latent_coords = torch.stack((patch_starts, patch_ends), dim=-1)
|
||||
|
||||
# Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence.
|
||||
# Final Shape: (batch_size, 3, num_patches, 2)
|
||||
latent_coords = einops.repeat(
|
||||
latent_coords,
|
||||
"c f h w bounds -> b c (f h w) bounds",
|
||||
b=batch_size,
|
||||
bounds=2,
|
||||
)
|
||||
|
||||
return latent_coords
|
||||
|
||||
|
||||
class NormLayerType(Enum):
|
||||
GROUP_NORM = "group_norm"
|
||||
@@ -1339,6 +1469,185 @@ class LTX2VideoEncoder(nn.Module):
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
|
||||
def tiled_encode_video(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
tile_size: int = 512,
|
||||
tile_overlap: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""Encode video using spatial tiling for memory efficiency.
|
||||
Splits the video into overlapping spatial tiles, encodes each tile separately,
|
||||
and blends the results using linear feathering in the overlap regions.
|
||||
Args:
|
||||
video: Input tensor of shape [B, C, F, H, W]
|
||||
tile_size: Tile size in pixels (must be divisible by 32)
|
||||
tile_overlap: Overlap between tiles in pixels (must be divisible by 32)
|
||||
Returns:
|
||||
Encoded latent tensor [B, C_latent, F_latent, H_latent, W_latent]
|
||||
"""
|
||||
batch, _channels, frames, height, width = video.shape
|
||||
device = video.device
|
||||
dtype = video.dtype
|
||||
|
||||
# Validate tile parameters
|
||||
if tile_size % VAE_SPATIAL_FACTOR != 0:
|
||||
raise ValueError(f"tile_size must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_size}")
|
||||
if tile_overlap % VAE_SPATIAL_FACTOR != 0:
|
||||
raise ValueError(f"tile_overlap must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_overlap}")
|
||||
if tile_overlap >= tile_size:
|
||||
raise ValueError(f"tile_overlap ({tile_overlap}) must be less than tile_size ({tile_size})")
|
||||
|
||||
# If video fits in a single tile, use regular encoding
|
||||
if height <= tile_size and width <= tile_size:
|
||||
return self.forward(video)
|
||||
|
||||
# Calculate output dimensions
|
||||
# VAE compresses: H -> H/32, W -> W/32, F -> 1 + (F-1)/8
|
||||
output_height = height // VAE_SPATIAL_FACTOR
|
||||
output_width = width // VAE_SPATIAL_FACTOR
|
||||
output_frames = 1 + (frames - 1) // VAE_TEMPORAL_FACTOR
|
||||
|
||||
# Latent channels (128 for LTX-2)
|
||||
# Get from a small test encode or assume 128
|
||||
latent_channels = 128
|
||||
|
||||
# Initialize output and weight tensors
|
||||
output = torch.zeros(
|
||||
(batch, latent_channels, output_frames, output_height, output_width),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
weights = torch.zeros(
|
||||
(batch, 1, output_frames, output_height, output_width),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Calculate tile positions with overlap
|
||||
# Step size is tile_size - tile_overlap
|
||||
step_h = tile_size - tile_overlap
|
||||
step_w = tile_size - tile_overlap
|
||||
|
||||
h_positions = list(range(0, max(1, height - tile_overlap), step_h))
|
||||
w_positions = list(range(0, max(1, width - tile_overlap), step_w))
|
||||
|
||||
# Ensure last tile covers the edge
|
||||
if h_positions[-1] + tile_size < height:
|
||||
h_positions.append(height - tile_size)
|
||||
if w_positions[-1] + tile_size < width:
|
||||
w_positions.append(width - tile_size)
|
||||
|
||||
# Remove duplicates and sort
|
||||
h_positions = sorted(set(h_positions))
|
||||
w_positions = sorted(set(w_positions))
|
||||
|
||||
# Overlap in latent space
|
||||
overlap_out_h = tile_overlap // VAE_SPATIAL_FACTOR
|
||||
overlap_out_w = tile_overlap // VAE_SPATIAL_FACTOR
|
||||
|
||||
# Process each tile
|
||||
for h_pos in h_positions:
|
||||
for w_pos in w_positions:
|
||||
# Calculate tile boundaries in input space
|
||||
h_start = max(0, h_pos)
|
||||
w_start = max(0, w_pos)
|
||||
h_end = min(h_start + tile_size, height)
|
||||
w_end = min(w_start + tile_size, width)
|
||||
|
||||
# Ensure tile dimensions are divisible by VAE_SPATIAL_FACTOR
|
||||
tile_h = ((h_end - h_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR
|
||||
tile_w = ((w_end - w_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR
|
||||
|
||||
if tile_h < VAE_SPATIAL_FACTOR or tile_w < VAE_SPATIAL_FACTOR:
|
||||
continue
|
||||
|
||||
# Adjust end positions
|
||||
h_end = h_start + tile_h
|
||||
w_end = w_start + tile_w
|
||||
|
||||
# Extract tile
|
||||
tile = video[:, :, :, h_start:h_end, w_start:w_end]
|
||||
|
||||
# Encode tile
|
||||
encoded_tile = self.forward(tile)
|
||||
|
||||
# Get actual encoded dimensions
|
||||
_, _, tile_out_frames, tile_out_height, tile_out_width = encoded_tile.shape
|
||||
|
||||
# Calculate output positions
|
||||
out_h_start = h_start // VAE_SPATIAL_FACTOR
|
||||
out_w_start = w_start // VAE_SPATIAL_FACTOR
|
||||
out_h_end = min(out_h_start + tile_out_height, output_height)
|
||||
out_w_end = min(out_w_start + tile_out_width, output_width)
|
||||
|
||||
# Trim encoded tile if necessary
|
||||
actual_tile_h = out_h_end - out_h_start
|
||||
actual_tile_w = out_w_end - out_w_start
|
||||
encoded_tile = encoded_tile[:, :, :, :actual_tile_h, :actual_tile_w]
|
||||
|
||||
# Create blending mask with linear feathering at edges
|
||||
mask = torch.ones(
|
||||
(1, 1, tile_out_frames, actual_tile_h, actual_tile_w),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Apply feathering at edges (linear blend in overlap regions)
|
||||
# Left edge
|
||||
if h_pos > 0 and overlap_out_h > 0 and overlap_out_h < actual_tile_h:
|
||||
fade_in = torch.linspace(0.0, 1.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1]
|
||||
mask[:, :, :, :overlap_out_h, :] *= fade_in.view(1, 1, 1, -1, 1)
|
||||
|
||||
# Right edge (bottom in height dimension)
|
||||
if h_end < height and overlap_out_h > 0 and overlap_out_h < actual_tile_h:
|
||||
fade_out = torch.linspace(1.0, 0.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1]
|
||||
mask[:, :, :, -overlap_out_h:, :] *= fade_out.view(1, 1, 1, -1, 1)
|
||||
|
||||
# Top edge (left in width dimension)
|
||||
if w_pos > 0 and overlap_out_w > 0 and overlap_out_w < actual_tile_w:
|
||||
fade_in = torch.linspace(0.0, 1.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1]
|
||||
mask[:, :, :, :, :overlap_out_w] *= fade_in.view(1, 1, 1, 1, -1)
|
||||
|
||||
# Bottom edge (right in width dimension)
|
||||
if w_end < width and overlap_out_w > 0 and overlap_out_w < actual_tile_w:
|
||||
fade_out = torch.linspace(1.0, 0.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1]
|
||||
mask[:, :, :, :, -overlap_out_w:] *= fade_out.view(1, 1, 1, 1, -1)
|
||||
|
||||
# Accumulate weighted results
|
||||
output[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += encoded_tile * mask
|
||||
weights[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += mask
|
||||
|
||||
# Normalize by weights (avoid division by zero)
|
||||
output = output / (weights + 1e-8)
|
||||
|
||||
return output
|
||||
|
||||
def encode(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
tiled=False,
|
||||
tile_size_in_pixels: Optional[int] = 512,
|
||||
tile_overlap_in_pixels: Optional[int] = 128,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
device = next(self.parameters()).device
|
||||
vae_dtype = next(self.parameters()).dtype
|
||||
if video.ndim == 4:
|
||||
video = video.unsqueeze(0) # [C, F, H, W] -> [B, C, F, H, W]
|
||||
video = video.to(device=device, dtype=vae_dtype)
|
||||
# Choose encoding method based on tiling flag
|
||||
if tiled:
|
||||
latents = self.tiled_encode_video(
|
||||
video=video,
|
||||
tile_size=tile_size_in_pixels,
|
||||
tile_overlap=tile_overlap_in_pixels,
|
||||
)
|
||||
else:
|
||||
# Encode video - VAE expects [B, C, F, H, W], returns [B, C, F', H', W']
|
||||
latents = self.forward(video)
|
||||
return latents
|
||||
|
||||
|
||||
def _make_decoder_block(
|
||||
block_name: str,
|
||||
block_config: dict[str, Any],
|
||||
@@ -1850,6 +2159,30 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
return weights
|
||||
|
||||
def decode(
|
||||
self,
|
||||
latent: torch.Tensor,
|
||||
tiled=False,
|
||||
tile_size_in_pixels: Optional[int] = 512,
|
||||
tile_overlap_in_pixels: Optional[int] = 128,
|
||||
tile_size_in_frames: Optional[int] = 128,
|
||||
tile_overlap_in_frames: Optional[int] = 24,
|
||||
) -> torch.Tensor:
|
||||
if tiled:
|
||||
tiling_config = TilingConfig(
|
||||
spatial_config=SpatialTilingConfig(
|
||||
tile_size_in_pixels=tile_size_in_pixels,
|
||||
tile_overlap_in_pixels=tile_overlap_in_pixels,
|
||||
),
|
||||
temporal_config=TemporalTilingConfig(
|
||||
tile_size_in_frames=tile_size_in_frames,
|
||||
tile_overlap_in_frames=tile_overlap_in_frames,
|
||||
),
|
||||
)
|
||||
tiles = self.tiled_decode(latent, tiling_config)
|
||||
return torch.cat(list(tiles), dim=2)
|
||||
else:
|
||||
return self.forward(latent)
|
||||
|
||||
def decode_video(
|
||||
latent: torch.Tensor,
|
||||
@@ -1875,10 +2208,10 @@ def decode_video(
|
||||
|
||||
if tiling_config is not None:
|
||||
for frames in video_decoder.tiled_decode(latent, tiling_config, generator=generator):
|
||||
yield convert_to_uint8(frames)
|
||||
return convert_to_uint8(frames)
|
||||
else:
|
||||
decoded_video = video_decoder(latent, generator=generator)
|
||||
yield convert_to_uint8(decoded_video)
|
||||
return convert_to_uint8(decoded_video)
|
||||
|
||||
|
||||
def get_video_chunks_number(num_frames: int, tiling_config: TilingConfig | None = None) -> int:
|
||||
|
||||
451
diffsynth/pipelines/ltx2_audio_video.py
Normal file
451
diffsynth/pipelines/ltx2_audio_video.py
Normal file
@@ -0,0 +1,451 @@
|
||||
import torch, types
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from einops import repeat
|
||||
from typing import Optional, Union
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from typing import Optional
|
||||
from typing_extensions import Literal
|
||||
from transformers import AutoImageProcessor, Gemma3Processor
|
||||
import einops
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
|
||||
from ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer
|
||||
from ..models.ltx2_dit import LTXModel
|
||||
from ..models.ltx2_video_vae import LTX2VideoEncoder, LTX2VideoDecoder, VideoLatentPatchifier
|
||||
from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier
|
||||
from ..models.ltx2_upsampler import LTX2LatentUpsampler
|
||||
from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS
|
||||
|
||||
|
||||
class LTX2AudioVideoPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device,
|
||||
torch_dtype=torch_dtype,
|
||||
height_division_factor=32,
|
||||
width_division_factor=32,
|
||||
time_division_factor=8,
|
||||
time_division_remainder=1,
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler("LTX-2")
|
||||
self.text_encoder: LTX2TextEncoder = None
|
||||
self.tokenizer: LTXVGemmaTokenizer = None
|
||||
self.processor: Gemma3Processor = None
|
||||
self.text_encoder_post_modules: LTX2TextEncoderPostModules = None
|
||||
self.dit: LTXModel = None
|
||||
self.video_vae_encoder: LTX2VideoEncoder = None
|
||||
self.video_vae_decoder: LTX2VideoDecoder = None
|
||||
self.audio_vae_encoder: LTX2AudioEncoder = None
|
||||
self.audio_vae_decoder: LTX2AudioDecoder = None
|
||||
self.audio_vocoder: LTX2Vocoder = None
|
||||
self.upsampler: LTX2LatentUpsampler = None
|
||||
|
||||
self.video_patchifier: VideoLatentPatchifier = VideoLatentPatchifier(patch_size=1)
|
||||
self.audio_patchifier: AudioPatchifier = AudioPatchifier(patch_size=1)
|
||||
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
LTX2AudioVideoUnit_PipelineChecker(),
|
||||
LTX2AudioVideoUnit_ShapeChecker(),
|
||||
LTX2AudioVideoUnit_PromptEmbedder(),
|
||||
LTX2AudioVideoUnit_NoiseInitializer(),
|
||||
LTX2AudioVideoUnit_InputVideoEmbedder(),
|
||||
]
|
||||
self.post_units = [
|
||||
LTX2AudioVideoPostUnit_UnPatchifier(),
|
||||
]
|
||||
self.model_fn = model_fn_ltx2
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = get_device_type(),
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
vram_limit: float = None,
|
||||
):
|
||||
# Initialize pipeline
|
||||
pipe = LTX2AudioVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
# Fetch models
|
||||
pipe.text_encoder = model_pool.fetch_model("ltx2_text_encoder")
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = LTXVGemmaTokenizer(tokenizer_path=tokenizer_config.path)
|
||||
image_processor = AutoImageProcessor.from_pretrained(tokenizer_config.path, local_files_only=True)
|
||||
pipe.processor = Gemma3Processor(image_processor=image_processor, tokenizer=pipe.tokenizer.tokenizer)
|
||||
|
||||
pipe.text_encoder_post_modules = model_pool.fetch_model("ltx2_text_encoder_post_modules")
|
||||
pipe.dit = model_pool.fetch_model("ltx2_dit")
|
||||
pipe.video_vae_encoder = model_pool.fetch_model("ltx2_video_vae_encoder")
|
||||
pipe.video_vae_decoder = model_pool.fetch_model("ltx2_video_vae_decoder")
|
||||
pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder")
|
||||
pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder")
|
||||
pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler")
|
||||
# Optional
|
||||
# pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = "",
|
||||
# Image-to-video
|
||||
input_image: Optional[Image.Image] = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Randomness
|
||||
seed: Optional[int] = None,
|
||||
rand_device: Optional[str] = "cpu",
|
||||
# Shape
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 768,
|
||||
num_frames=121,
|
||||
# Classifier-free guidance
|
||||
cfg_scale: Optional[float] = 3.0,
|
||||
cfg_merge: Optional[bool] = False,
|
||||
# Scheduler
|
||||
num_inference_steps: Optional[int] = 40,
|
||||
# VAE tiling
|
||||
tiled: Optional[bool] = True,
|
||||
tile_size_in_pixels: Optional[int] = 512,
|
||||
tile_overlap_in_pixels: Optional[int] = 128,
|
||||
tile_size_in_frames: Optional[int] = 128,
|
||||
tile_overlap_in_frames: Optional[int] = 24,
|
||||
# Two-Stage Pipeline
|
||||
use_two_stage: Optional[bool] = True,
|
||||
# progress_bar
|
||||
progress_bar_cmd=tqdm,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
|
||||
|
||||
# Inputs
|
||||
inputs_posi = {
|
||||
"prompt": prompt,
|
||||
}
|
||||
inputs_nega = {
|
||||
"negative_prompt": negative_prompt,
|
||||
}
|
||||
inputs_shared = {
|
||||
"input_image": input_image,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"height": height, "width": width, "num_frames": num_frames,
|
||||
"cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
|
||||
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
|
||||
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
|
||||
"use_two_stage": True
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
|
||||
# inputs_posi.update(torch.load("/mnt/nas1/zhanghong/project26/extern_codes/LTX-2/text_encodings.pt"))
|
||||
# inputs_nega.update(torch.load("/mnt/nas1/zhanghong/project26/extern_codes/LTX-2/negative_text_encodings.pt"))
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video, **inputs_shared)
|
||||
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared)
|
||||
|
||||
# post-denoising, pre-decoding processing logic
|
||||
for unit in self.post_units:
|
||||
inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
# Decode
|
||||
self.load_models_to_device(['video_vae_decoder'])
|
||||
video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels,
|
||||
tile_overlap_in_pixels, tile_size_in_frames, tile_overlap_in_frames)
|
||||
video = self.vae_output_to_video(video)
|
||||
self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder'])
|
||||
decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"])
|
||||
decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float()
|
||||
return video, decoded_audio
|
||||
|
||||
|
||||
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
|
||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||
self.clear_lora(verbose=0)
|
||||
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
|
||||
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
|
||||
if cfg_scale != 1.0:
|
||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||
self.clear_lora(verbose=0)
|
||||
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
||||
if isinstance(noise_pred_posi, tuple):
|
||||
noise_pred = tuple(
|
||||
n_nega + cfg_scale * (n_posi - n_nega)
|
||||
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
|
||||
)
|
||||
else:
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
return noise_pred
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(take_over=True)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
pass
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_ShapeChecker(PipelineUnit):
|
||||
"""
|
||||
# TODO: Adjust with two stage pipeline
|
||||
For two-stage pipelines, the resolution must be divisible by 64.
|
||||
For one-stage pipelines, the resolution must be divisible by 32.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "num_frames"),
|
||||
output_params=("height", "width", "num_frames"),
|
||||
)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames):
|
||||
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
|
||||
return {"height": height, "width": width, "num_frames": num_frames}
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
output_params=("video_context", "audio_context"),
|
||||
onload_model_names=("text_encoder", "text_encoder_post_modules"),
|
||||
)
|
||||
|
||||
def _convert_to_additive_mask(self, attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
return (attention_mask - 1).to(dtype).reshape(
|
||||
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(dtype).max
|
||||
|
||||
def _run_connectors(self, pipe, encoded_input: torch.Tensor,
|
||||
attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
connector_attention_mask = self._convert_to_additive_mask(attention_mask, encoded_input.dtype)
|
||||
|
||||
encoded, encoded_connector_attention_mask = pipe.text_encoder_post_modules.embeddings_connector(
|
||||
encoded_input,
|
||||
connector_attention_mask,
|
||||
)
|
||||
|
||||
# restore the mask values to int64
|
||||
attention_mask = (encoded_connector_attention_mask < 0.000001).to(torch.int64)
|
||||
attention_mask = attention_mask.reshape([encoded.shape[0], encoded.shape[1], 1])
|
||||
encoded = encoded * attention_mask
|
||||
|
||||
encoded_for_audio, _ = pipe.text_encoder_post_modules.audio_embeddings_connector(
|
||||
encoded_input, connector_attention_mask)
|
||||
|
||||
return encoded, encoded_for_audio, attention_mask.squeeze(-1)
|
||||
|
||||
def _norm_and_concat_padded_batch(
|
||||
self,
|
||||
encoded_text: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
padding_side: str = "right",
|
||||
) -> torch.Tensor:
|
||||
"""Normalize and flatten multi-layer hidden states, respecting padding.
|
||||
Performs per-batch, per-layer normalization using masked mean and range,
|
||||
then concatenates across the layer dimension.
|
||||
Args:
|
||||
encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].
|
||||
sequence_lengths: Number of valid (non-padded) tokens per batch item.
|
||||
padding_side: Whether padding is on "left" or "right".
|
||||
Returns:
|
||||
Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],
|
||||
with padded positions zeroed out.
|
||||
"""
|
||||
b, t, d, l = encoded_text.shape # noqa: E741
|
||||
device = encoded_text.device
|
||||
# Build mask: [B, T, 1, 1]
|
||||
token_indices = torch.arange(t, device=device)[None, :] # [1, T]
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [B, T]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = t - sequence_lengths[:, None] # [B, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = rearrange(mask, "b t -> b t 1 1")
|
||||
eps = 1e-6
|
||||
# Compute masked mean: [B, 1, 1, L]
|
||||
masked = encoded_text.masked_fill(~mask, 0.0)
|
||||
denom = (sequence_lengths * d).view(b, 1, 1, 1)
|
||||
mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps)
|
||||
# Compute masked min/max: [B, 1, 1, L]
|
||||
x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
range_ = x_max - x_min
|
||||
# Normalize only the valid tokens
|
||||
normed = 8 * (encoded_text - mean) / (range_ + eps)
|
||||
# concat to be [Batch, T, D * L] - this preserves the original structure
|
||||
normed = normed.reshape(b, t, -1) # [B, T, D * L]
|
||||
# Apply mask to preserve original padding (set padded positions to 0)
|
||||
mask_flattened = rearrange(mask, "b t 1 1 -> b t 1").expand(-1, -1, d * l)
|
||||
normed = normed.masked_fill(~mask_flattened, 0.0)
|
||||
|
||||
return normed
|
||||
|
||||
def _run_feature_extractor(self,
|
||||
pipe,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
padding_side: str = "right") -> torch.Tensor:
|
||||
encoded_text_features = torch.stack(hidden_states, dim=-1)
|
||||
encoded_text_features_dtype = encoded_text_features.dtype
|
||||
sequence_lengths = attention_mask.sum(dim=-1)
|
||||
normed_concated_encoded_text_features = self._norm_and_concat_padded_batch(encoded_text_features,
|
||||
sequence_lengths,
|
||||
padding_side=padding_side)
|
||||
|
||||
return pipe.text_encoder_post_modules.feature_extractor_linear(
|
||||
normed_concated_encoded_text_features.to(encoded_text_features_dtype))
|
||||
|
||||
def _preprocess_text(self,
|
||||
pipe,
|
||||
text: str,
|
||||
padding_side: str = "left") -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Encode a given string into feature tensors suitable for downstream tasks.
|
||||
Args:
|
||||
text (str): Input string to encode.
|
||||
Returns:
|
||||
tuple[torch.Tensor, dict[str, torch.Tensor]]: Encoded features and a dictionary with attention mask.
|
||||
"""
|
||||
token_pairs = pipe.tokenizer.tokenize_with_weights(text)["gemma"]
|
||||
input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.text_encoder.device)
|
||||
attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.text_encoder.device)
|
||||
outputs = pipe.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
||||
projected = self._run_feature_extractor(pipe,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
padding_side=padding_side)
|
||||
return projected, attention_mask
|
||||
|
||||
def encode_prompt(self, pipe, text, padding_side="left"):
|
||||
encoded_inputs, attention_mask = self._preprocess_text(pipe, text, padding_side)
|
||||
video_encoding, audio_encoding, attention_mask = self._run_connectors(pipe, encoded_inputs, attention_mask)
|
||||
return video_encoding, audio_encoding, attention_mask
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, prompt: str):
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
video_context, audio_context, _ = self.encode_prompt(pipe, prompt)
|
||||
return {"video_context": video_context, "audio_context": audio_context}
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "num_frames", "seed", "rand_device",),
|
||||
output_params=("video_noise", "audio_noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
|
||||
video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
|
||||
video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=pipe.video_vae_encoder.latent_channels)
|
||||
video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
|
||||
video_noise = pipe.video_patchifier.patchify(video_noise)
|
||||
|
||||
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device)
|
||||
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float()
|
||||
video_positions[:, 0, ...] = video_positions[:, 0, ...] / frame_rate
|
||||
video_positions = video_positions.to(pipe.torch_dtype)
|
||||
|
||||
audio_latent_shape = AudioLatentShape.from_video_pixel_shape(video_pixel_shape)
|
||||
audio_noise = pipe.generate_noise(audio_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
|
||||
audio_noise = pipe.audio_patchifier.patchify(audio_noise)
|
||||
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
|
||||
return {
|
||||
"video_noise": video_noise,
|
||||
"audio_noise": audio_noise,
|
||||
"video_positions": video_positions,
|
||||
"audio_positions": audio_positions,
|
||||
"video_latent_shape": video_latent_shape,
|
||||
"audio_latent_shape": audio_latent_shape
|
||||
}
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_video", "video_noise", "audio_noise", "tiled", "tile_size", "tile_stride"),
|
||||
output_params=("video_latents", "audio_latents"),
|
||||
onload_model_names=("video_vae_encoder")
|
||||
)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, audio_noise, tiled, tile_size, tile_stride):
|
||||
if input_video is None:
|
||||
return {"video_latents": video_noise, "audio_latents": audio_noise}
|
||||
else:
|
||||
# TODO: implement video-to-video
|
||||
raise NotImplementedError("Video-to-video not implemented yet.")
|
||||
|
||||
|
||||
class LTX2AudioVideoPostUnit_UnPatchifier(PipelineUnit):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("video_latent_shape", "audio_latent_shape", "video_latents", "audio_latents"),
|
||||
output_params=("video_latents", "audio_latents"),
|
||||
)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, video_latent_shape, audio_latent_shape, video_latents, audio_latents):
|
||||
video_latents = pipe.video_patchifier.unpatchify(video_latents, output_shape=video_latent_shape)
|
||||
audio_latents = pipe.audio_patchifier.unpatchify(audio_latents, output_shape=audio_latent_shape)
|
||||
return {"video_latents": video_latents, "audio_latents": audio_latents}
|
||||
|
||||
|
||||
def model_fn_ltx2(
|
||||
dit: LTXModel,
|
||||
video_latents=None,
|
||||
video_context=None,
|
||||
video_positions=None,
|
||||
audio_latents=None,
|
||||
audio_context=None,
|
||||
audio_positions=None,
|
||||
timestep=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
#TODO: support gradient checkpointing
|
||||
timestep = timestep.float() / 1000.
|
||||
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
|
||||
audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1)
|
||||
vx, ax = dit(
|
||||
video_latents=video_latents,
|
||||
video_positions=video_positions,
|
||||
video_context=video_context,
|
||||
video_timesteps=video_timesteps,
|
||||
audio_latents=audio_latents,
|
||||
audio_positions=audio_positions,
|
||||
audio_context=audio_context,
|
||||
audio_timesteps=audio_timesteps,
|
||||
)
|
||||
return vx, ax
|
||||
106
diffsynth/utils/data/media_io.py
Normal file
106
diffsynth/utils/data/media_io.py
Normal file
@@ -0,0 +1,106 @@
|
||||
|
||||
from fractions import Fraction
|
||||
import torch
|
||||
import av
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def _resample_audio(
|
||||
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
|
||||
) -> None:
|
||||
cc = audio_stream.codec_context
|
||||
|
||||
# Use the encoder's format/layout/rate as the *target*
|
||||
target_format = cc.format or "fltp" # AAC → usually fltp
|
||||
target_layout = cc.layout or "stereo"
|
||||
target_rate = cc.sample_rate or frame_in.sample_rate
|
||||
|
||||
audio_resampler = av.audio.resampler.AudioResampler(
|
||||
format=target_format,
|
||||
layout=target_layout,
|
||||
rate=target_rate,
|
||||
)
|
||||
|
||||
audio_next_pts = 0
|
||||
for rframe in audio_resampler.resample(frame_in):
|
||||
if rframe.pts is None:
|
||||
rframe.pts = audio_next_pts
|
||||
audio_next_pts += rframe.samples
|
||||
rframe.sample_rate = frame_in.sample_rate
|
||||
container.mux(audio_stream.encode(rframe))
|
||||
|
||||
# flush audio encoder
|
||||
for packet in audio_stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
|
||||
def _write_audio(
|
||||
container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int
|
||||
) -> None:
|
||||
if samples.ndim == 1:
|
||||
samples = samples[:, None]
|
||||
|
||||
if samples.shape[1] != 2 and samples.shape[0] == 2:
|
||||
samples = samples.T
|
||||
|
||||
if samples.shape[1] != 2:
|
||||
raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
|
||||
|
||||
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
|
||||
if samples.dtype != torch.int16:
|
||||
samples = torch.clip(samples, -1.0, 1.0)
|
||||
samples = (samples * 32767.0).to(torch.int16)
|
||||
|
||||
frame_in = av.AudioFrame.from_ndarray(
|
||||
samples.contiguous().reshape(1, -1).cpu().numpy(),
|
||||
format="s16",
|
||||
layout="stereo",
|
||||
)
|
||||
frame_in.sample_rate = audio_sample_rate
|
||||
|
||||
_resample_audio(container, audio_stream, frame_in)
|
||||
|
||||
|
||||
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
|
||||
"""
|
||||
Prepare the audio stream for writing.
|
||||
"""
|
||||
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
|
||||
audio_stream.codec_context.sample_rate = audio_sample_rate
|
||||
audio_stream.codec_context.layout = "stereo"
|
||||
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
|
||||
return audio_stream
|
||||
|
||||
def write_video_audio_ltx2(
|
||||
video: list[Image.Image],
|
||||
audio: torch.Tensor | None,
|
||||
output_path: str,
|
||||
fps: int = 24,
|
||||
audio_sample_rate: int | None = 24000,
|
||||
) -> None:
|
||||
|
||||
width, height = video[0].size
|
||||
container = av.open(output_path, mode="w")
|
||||
stream = container.add_stream("libx264", rate=int(fps))
|
||||
stream.width = width
|
||||
stream.height = height
|
||||
stream.pix_fmt = "yuv420p"
|
||||
|
||||
if audio is not None:
|
||||
if audio_sample_rate is None:
|
||||
raise ValueError("audio_sample_rate is required when audio is provided")
|
||||
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
|
||||
|
||||
for frame in tqdm(video, total=len(video)):
|
||||
frame = av.VideoFrame.from_image(frame)
|
||||
for packet in stream.encode(frame):
|
||||
container.mux(packet)
|
||||
|
||||
# Flush encoder
|
||||
for packet in stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
if audio is not None:
|
||||
_write_audio(container, audio_stream, audio, audio_sample_rate)
|
||||
|
||||
container.close()
|
||||
Reference in New Issue
Block a user