mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
add video_vae and dit for ltx-2
This commit is contained in:
@@ -591,4 +591,28 @@ z_image_series = [
|
||||
},
|
||||
]
|
||||
|
||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series
|
||||
ltx2_series = [
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
"model_name": "ltx2_dit",
|
||||
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
"model_name": "ltx2_video_vae_encoder",
|
||||
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
"model_name": "ltx2_video_vae_decoder",
|
||||
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
||||
},
|
||||
]
|
||||
|
||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series
|
||||
|
||||
@@ -52,7 +52,7 @@ def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="
|
||||
if k_pattern != required_in_pattern:
|
||||
k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims)
|
||||
if v_pattern != required_in_pattern:
|
||||
v = rearrange(v, f"{q_pattern} -> {required_in_pattern}", **dims)
|
||||
v = rearrange(v, f"{v_pattern} -> {required_in_pattern}", **dims)
|
||||
return q, k, v
|
||||
|
||||
|
||||
|
||||
253
diffsynth/models/ltx2_common.py
Normal file
253
diffsynth/models/ltx2_common.py
Normal file
@@ -0,0 +1,253 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import NamedTuple
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
class VideoPixelShape(NamedTuple):
|
||||
"""
|
||||
Shape of the tensor representing the video pixel array. Assumes BGR channel format.
|
||||
"""
|
||||
|
||||
batch: int
|
||||
frames: int
|
||||
height: int
|
||||
width: int
|
||||
fps: float
|
||||
|
||||
|
||||
class SpatioTemporalScaleFactors(NamedTuple):
|
||||
"""
|
||||
Describes the spatiotemporal downscaling between decoded video space and
|
||||
the corresponding VAE latent grid.
|
||||
"""
|
||||
|
||||
time: int
|
||||
width: int
|
||||
height: int
|
||||
|
||||
@classmethod
|
||||
def default(cls) -> "SpatioTemporalScaleFactors":
|
||||
return cls(time=8, width=32, height=32)
|
||||
|
||||
|
||||
VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default()
|
||||
|
||||
|
||||
class VideoLatentShape(NamedTuple):
|
||||
"""
|
||||
Shape of the tensor representing video in VAE latent space.
|
||||
The latent representation is a 5D tensor with dimensions ordered as
|
||||
(batch, channels, frames, height, width). Spatial and temporal dimensions
|
||||
are downscaled relative to pixel space according to the VAE's scale factors.
|
||||
"""
|
||||
|
||||
batch: int
|
||||
channels: int
|
||||
frames: int
|
||||
height: int
|
||||
width: int
|
||||
|
||||
def to_torch_shape(self) -> torch.Size:
|
||||
return torch.Size([self.batch, self.channels, self.frames, self.height, self.width])
|
||||
|
||||
@staticmethod
|
||||
def from_torch_shape(shape: torch.Size) -> "VideoLatentShape":
|
||||
return VideoLatentShape(
|
||||
batch=shape[0],
|
||||
channels=shape[1],
|
||||
frames=shape[2],
|
||||
height=shape[3],
|
||||
width=shape[4],
|
||||
)
|
||||
|
||||
def mask_shape(self) -> "VideoLatentShape":
|
||||
return self._replace(channels=1)
|
||||
|
||||
@staticmethod
|
||||
def from_pixel_shape(
|
||||
shape: VideoPixelShape,
|
||||
latent_channels: int = 128,
|
||||
scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS,
|
||||
) -> "VideoLatentShape":
|
||||
frames = (shape.frames - 1) // scale_factors[0] + 1
|
||||
height = shape.height // scale_factors[1]
|
||||
width = shape.width // scale_factors[2]
|
||||
|
||||
return VideoLatentShape(
|
||||
batch=shape.batch,
|
||||
channels=latent_channels,
|
||||
frames=frames,
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
|
||||
def upscale(self, scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS) -> "VideoLatentShape":
|
||||
return self._replace(
|
||||
channels=3,
|
||||
frames=(self.frames - 1) * scale_factors.time + 1,
|
||||
height=self.height * scale_factors.height,
|
||||
width=self.width * scale_factors.width,
|
||||
)
|
||||
|
||||
|
||||
class AudioLatentShape(NamedTuple):
|
||||
"""
|
||||
Shape of audio in VAE latent space: (batch, channels, frames, mel_bins).
|
||||
mel_bins is the number of frequency bins from the mel-spectrogram encoding.
|
||||
"""
|
||||
|
||||
batch: int
|
||||
channels: int
|
||||
frames: int
|
||||
mel_bins: int
|
||||
|
||||
def to_torch_shape(self) -> torch.Size:
|
||||
return torch.Size([self.batch, self.channels, self.frames, self.mel_bins])
|
||||
|
||||
def mask_shape(self) -> "AudioLatentShape":
|
||||
return self._replace(channels=1, mel_bins=1)
|
||||
|
||||
@staticmethod
|
||||
def from_torch_shape(shape: torch.Size) -> "AudioLatentShape":
|
||||
return AudioLatentShape(
|
||||
batch=shape[0],
|
||||
channels=shape[1],
|
||||
frames=shape[2],
|
||||
mel_bins=shape[3],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_duration(
|
||||
batch: int,
|
||||
duration: float,
|
||||
channels: int = 8,
|
||||
mel_bins: int = 16,
|
||||
sample_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
audio_latent_downsample_factor: int = 4,
|
||||
) -> "AudioLatentShape":
|
||||
latents_per_second = float(sample_rate) / float(hop_length) / float(audio_latent_downsample_factor)
|
||||
|
||||
return AudioLatentShape(
|
||||
batch=batch,
|
||||
channels=channels,
|
||||
frames=round(duration * latents_per_second),
|
||||
mel_bins=mel_bins,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_video_pixel_shape(
|
||||
shape: VideoPixelShape,
|
||||
channels: int = 8,
|
||||
mel_bins: int = 16,
|
||||
sample_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
audio_latent_downsample_factor: int = 4,
|
||||
) -> "AudioLatentShape":
|
||||
return AudioLatentShape.from_duration(
|
||||
batch=shape.batch,
|
||||
duration=float(shape.frames) / float(shape.fps),
|
||||
channels=channels,
|
||||
mel_bins=mel_bins,
|
||||
sample_rate=sample_rate,
|
||||
hop_length=hop_length,
|
||||
audio_latent_downsample_factor=audio_latent_downsample_factor,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LatentState:
|
||||
"""
|
||||
State of latents during the diffusion denoising process.
|
||||
Attributes:
|
||||
latent: The current noisy latent tensor being denoised.
|
||||
denoise_mask: Mask encoding the denoising strength for each token (1 = full denoising, 0 = no denoising).
|
||||
positions: Positional indices for each latent element, used for positional embeddings.
|
||||
clean_latent: Initial state of the latent before denoising, may include conditioning latents.
|
||||
"""
|
||||
|
||||
latent: torch.Tensor
|
||||
denoise_mask: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
clean_latent: torch.Tensor
|
||||
|
||||
def clone(self) -> "LatentState":
|
||||
return LatentState(
|
||||
latent=self.latent.clone(),
|
||||
denoise_mask=self.denoise_mask.clone(),
|
||||
positions=self.positions.clone(),
|
||||
clean_latent=self.clean_latent.clone(),
|
||||
)
|
||||
|
||||
|
||||
class PixelNorm(nn.Module):
|
||||
"""
|
||||
Per-pixel (per-location) RMS normalization layer.
|
||||
For each element along the chosen dimension, this layer normalizes the tensor
|
||||
by the root-mean-square of its values across that dimension:
|
||||
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
|
||||
"""
|
||||
Args:
|
||||
dim: Dimension along which to compute the RMS (typically channels).
|
||||
eps: Small constant added for numerical stability.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply RMS normalization along the configured dimension.
|
||||
"""
|
||||
# Compute mean of squared values along `dim`, keep dimensions for broadcasting.
|
||||
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
|
||||
# Normalize by the root-mean-square (RMS).
|
||||
rms = torch.sqrt(mean_sq + self.eps)
|
||||
return x / rms
|
||||
|
||||
|
||||
def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor:
|
||||
"""Root-mean-square (RMS) normalize `x` over its last dimension.
|
||||
Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized
|
||||
shape and forwards `weight` and `eps`.
|
||||
"""
|
||||
return torch.nn.functional.rms_norm(x, (x.shape[-1],), weight=weight, eps=eps)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Modality:
|
||||
"""
|
||||
Input data for a single modality (video or audio) in the transformer.
|
||||
Bundles the latent tokens, timestep embeddings, positional information,
|
||||
and text conditioning context for processing by the diffusion transformer.
|
||||
"""
|
||||
|
||||
latent: (
|
||||
torch.Tensor
|
||||
) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
|
||||
timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps
|
||||
positions: (
|
||||
torch.Tensor
|
||||
) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens
|
||||
context: torch.Tensor
|
||||
enabled: bool = True
|
||||
context_mask: torch.Tensor | None = None
|
||||
|
||||
|
||||
def to_denoised(
|
||||
sample: torch.Tensor,
|
||||
velocity: torch.Tensor,
|
||||
sigma: float | torch.Tensor,
|
||||
calc_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert the sample and its denoising velocity to denoised sample.
|
||||
Returns:
|
||||
Denoised sample
|
||||
"""
|
||||
if isinstance(sigma, torch.Tensor):
|
||||
sigma = sigma.to(calc_dtype)
|
||||
return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype)
|
||||
1442
diffsynth/models/ltx2_dit.py
Normal file
1442
diffsynth/models/ltx2_dit.py
Normal file
File diff suppressed because it is too large
Load Diff
1969
diffsynth/models/ltx2_video_vae.py
Normal file
1969
diffsynth/models/ltx2_video_vae.py
Normal file
File diff suppressed because it is too large
Load Diff
9
diffsynth/utils/state_dict_converters/ltx2_dit.py
Normal file
9
diffsynth/utils/state_dict_converters/ltx2_dit.py
Normal file
@@ -0,0 +1,9 @@
|
||||
def LTXModelStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("model.diffusion_model."):
|
||||
new_name = name.replace("model.diffusion_model.", "")
|
||||
if new_name.startswith("audio_embeddings_connector.") or new_name.startswith("video_embeddings_connector."):
|
||||
continue
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
return state_dict_
|
||||
22
diffsynth/utils/state_dict_converters/ltx2_video_vae.py
Normal file
22
diffsynth/utils/state_dict_converters/ltx2_video_vae.py
Normal file
@@ -0,0 +1,22 @@
|
||||
def LTX2VideoEncoderStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("vae.encoder."):
|
||||
new_name = name.replace("vae.encoder.", "")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
elif name.startswith("vae.per_channel_statistics."):
|
||||
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
return state_dict_
|
||||
|
||||
|
||||
def LTX2VideoDecoderStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("vae.decoder."):
|
||||
new_name = name.replace("vae.decoder.", "")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
elif name.startswith("vae.per_channel_statistics."):
|
||||
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
return state_dict_
|
||||
22
diffsynth/utils/test/load_model.py
Normal file
22
diffsynth/utils/test/load_model.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
from diffsynth.models.model_loader import ModelPool
|
||||
from diffsynth.core.loader import ModelConfig
|
||||
|
||||
|
||||
def test_model_loading(model_name,
|
||||
model_config: ModelConfig,
|
||||
vram_limit: float = None,
|
||||
device="cpu",
|
||||
torch_dtype=torch.bfloat16):
|
||||
model_pool = ModelPool()
|
||||
model_config.download_if_necessary()
|
||||
vram_config = model_config.vram_config()
|
||||
vram_config["computation_dtype"] = torch_dtype
|
||||
vram_config["computation_device"] = device
|
||||
model_pool.auto_load_model(
|
||||
model_config.path,
|
||||
vram_config=vram_config,
|
||||
vram_limit=vram_limit,
|
||||
clear_parameters=model_config.clear_parameters,
|
||||
)
|
||||
return model_pool.fetch_model(model_name)
|
||||
Reference in New Issue
Block a user