add video_vae and dit for ltx-2

This commit is contained in:
mi804
2026-01-27 19:34:09 +08:00
parent ffb7a138f7
commit 00da4b6c4f
8 changed files with 3743 additions and 2 deletions

View File

@@ -0,0 +1,253 @@
from dataclasses import dataclass
from typing import NamedTuple
import torch
from torch import nn
class VideoPixelShape(NamedTuple):
"""
Shape of the tensor representing the video pixel array. Assumes BGR channel format.
"""
batch: int
frames: int
height: int
width: int
fps: float
class SpatioTemporalScaleFactors(NamedTuple):
"""
Describes the spatiotemporal downscaling between decoded video space and
the corresponding VAE latent grid.
"""
time: int
width: int
height: int
@classmethod
def default(cls) -> "SpatioTemporalScaleFactors":
return cls(time=8, width=32, height=32)
VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default()
class VideoLatentShape(NamedTuple):
"""
Shape of the tensor representing video in VAE latent space.
The latent representation is a 5D tensor with dimensions ordered as
(batch, channels, frames, height, width). Spatial and temporal dimensions
are downscaled relative to pixel space according to the VAE's scale factors.
"""
batch: int
channels: int
frames: int
height: int
width: int
def to_torch_shape(self) -> torch.Size:
return torch.Size([self.batch, self.channels, self.frames, self.height, self.width])
@staticmethod
def from_torch_shape(shape: torch.Size) -> "VideoLatentShape":
return VideoLatentShape(
batch=shape[0],
channels=shape[1],
frames=shape[2],
height=shape[3],
width=shape[4],
)
def mask_shape(self) -> "VideoLatentShape":
return self._replace(channels=1)
@staticmethod
def from_pixel_shape(
shape: VideoPixelShape,
latent_channels: int = 128,
scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS,
) -> "VideoLatentShape":
frames = (shape.frames - 1) // scale_factors[0] + 1
height = shape.height // scale_factors[1]
width = shape.width // scale_factors[2]
return VideoLatentShape(
batch=shape.batch,
channels=latent_channels,
frames=frames,
height=height,
width=width,
)
def upscale(self, scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS) -> "VideoLatentShape":
return self._replace(
channels=3,
frames=(self.frames - 1) * scale_factors.time + 1,
height=self.height * scale_factors.height,
width=self.width * scale_factors.width,
)
class AudioLatentShape(NamedTuple):
"""
Shape of audio in VAE latent space: (batch, channels, frames, mel_bins).
mel_bins is the number of frequency bins from the mel-spectrogram encoding.
"""
batch: int
channels: int
frames: int
mel_bins: int
def to_torch_shape(self) -> torch.Size:
return torch.Size([self.batch, self.channels, self.frames, self.mel_bins])
def mask_shape(self) -> "AudioLatentShape":
return self._replace(channels=1, mel_bins=1)
@staticmethod
def from_torch_shape(shape: torch.Size) -> "AudioLatentShape":
return AudioLatentShape(
batch=shape[0],
channels=shape[1],
frames=shape[2],
mel_bins=shape[3],
)
@staticmethod
def from_duration(
batch: int,
duration: float,
channels: int = 8,
mel_bins: int = 16,
sample_rate: int = 16000,
hop_length: int = 160,
audio_latent_downsample_factor: int = 4,
) -> "AudioLatentShape":
latents_per_second = float(sample_rate) / float(hop_length) / float(audio_latent_downsample_factor)
return AudioLatentShape(
batch=batch,
channels=channels,
frames=round(duration * latents_per_second),
mel_bins=mel_bins,
)
@staticmethod
def from_video_pixel_shape(
shape: VideoPixelShape,
channels: int = 8,
mel_bins: int = 16,
sample_rate: int = 16000,
hop_length: int = 160,
audio_latent_downsample_factor: int = 4,
) -> "AudioLatentShape":
return AudioLatentShape.from_duration(
batch=shape.batch,
duration=float(shape.frames) / float(shape.fps),
channels=channels,
mel_bins=mel_bins,
sample_rate=sample_rate,
hop_length=hop_length,
audio_latent_downsample_factor=audio_latent_downsample_factor,
)
@dataclass(frozen=True)
class LatentState:
"""
State of latents during the diffusion denoising process.
Attributes:
latent: The current noisy latent tensor being denoised.
denoise_mask: Mask encoding the denoising strength for each token (1 = full denoising, 0 = no denoising).
positions: Positional indices for each latent element, used for positional embeddings.
clean_latent: Initial state of the latent before denoising, may include conditioning latents.
"""
latent: torch.Tensor
denoise_mask: torch.Tensor
positions: torch.Tensor
clean_latent: torch.Tensor
def clone(self) -> "LatentState":
return LatentState(
latent=self.latent.clone(),
denoise_mask=self.denoise_mask.clone(),
positions=self.positions.clone(),
clean_latent=self.clean_latent.clone(),
)
class PixelNorm(nn.Module):
"""
Per-pixel (per-location) RMS normalization layer.
For each element along the chosen dimension, this layer normalizes the tensor
by the root-mean-square of its values across that dimension:
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
"""
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
"""
Args:
dim: Dimension along which to compute the RMS (typically channels).
eps: Small constant added for numerical stability.
"""
super().__init__()
self.dim = dim
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply RMS normalization along the configured dimension.
"""
# Compute mean of squared values along `dim`, keep dimensions for broadcasting.
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
# Normalize by the root-mean-square (RMS).
rms = torch.sqrt(mean_sq + self.eps)
return x / rms
def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor:
"""Root-mean-square (RMS) normalize `x` over its last dimension.
Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized
shape and forwards `weight` and `eps`.
"""
return torch.nn.functional.rms_norm(x, (x.shape[-1],), weight=weight, eps=eps)
@dataclass(frozen=True)
class Modality:
"""
Input data for a single modality (video or audio) in the transformer.
Bundles the latent tokens, timestep embeddings, positional information,
and text conditioning context for processing by the diffusion transformer.
"""
latent: (
torch.Tensor
) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps
positions: (
torch.Tensor
) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens
context: torch.Tensor
enabled: bool = True
context_mask: torch.Tensor | None = None
def to_denoised(
sample: torch.Tensor,
velocity: torch.Tensor,
sigma: float | torch.Tensor,
calc_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Convert the sample and its denoising velocity to denoised sample.
Returns:
Denoised sample
"""
if isinstance(sigma, torch.Tensor):
sigma = sigma.to(calc_dtype)
return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype)

1442
diffsynth/models/ltx2_dit.py Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff