Compare commits

..

1 Commits

Author SHA1 Message Date
Artiprocher
0d466231d1 compatibility patch 2026-03-23 11:24:19 +08:00
33 changed files with 148 additions and 3318 deletions

View File

@@ -32,7 +32,7 @@ We believe that a well-developed open-source code framework can lower the thresh
> DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update.
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
- **March 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available.
- **January 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available.
- **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/en/Model_Details/LTX-2.md) and [code](/examples/ltx2/).

View File

@@ -33,7 +33,7 @@ DiffSynth 目前包括两个开源项目:
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 和 [mi804](https://github.com/mi804) 负责因此新功能的开发进展会比较缓慢issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
- **2026年3月19日** 新增对 [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) 和 [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/Wan.md)和[示例代码](/examples/mova/)现已可用。
- **2026年1月19日** 新增对 [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) 和 [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/Wan.md)和[示例代码](/examples/mova/)现已可用。
- **2026年3月12日** 我们新增了 [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) 音视频生成模型的支持模型支持的功能包括文生音视频、图生音视频、IC-LoRA控制、音频生视频、音视频局部Inpainting框架支持完整的推理和训练功能。详细信息请参考 [文档](/docs/zh/Model_Details/LTX-2.md) 和 [示例代码](/examples/ltx2/)。

View File

@@ -884,40 +884,4 @@ mova_series = [
"model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge",
},
]
ace_step_series = [
{
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors")
"model_hash": "3509bea17b0e8cffc3dd4a15cc7899d0",
"model_name": "ace_step_text_encoder",
"model_class": "diffsynth.models.ace_step_text_encoder.AceStepTextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_text_encoder.AceStepTextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "51420834e54474986a7f4be0e4d6f687",
"model_name": "ace_step_vae",
"model_class": "diffsynth.models.ace_step_vae.AceStepVAE",
"extra_kwargs": {
"encoder_hidden_size": 128,
"downsampling_ratios": [2, 4, 4, 6, 10],
"channel_multiples": [1, 2, 4, 8, 16],
"decoder_channels": 128,
"decoder_input_channels": 64,
"audio_channels": 2,
"sampling_rate": 48000
}
},
{
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors")
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
"model_name": "ace_step_dit",
"model_class": "diffsynth.models.ace_step_dit.AceStepConditionGenerationModelWrapper",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.AceStepDiTStateDictConverter",
"extra_kwargs": {
"config_path": "models/ACE-Step/Ace-Step1.5/acestep-v15-turbo"
}
},
]
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series + mova_series + ace_step_series
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series + mova_series

View File

@@ -339,38 +339,6 @@ class BasePipeline(torch.nn.Module):
noise_pred = noise_pred_posi
return noise_pred
def compile_pipeline(self, mode: str = "default", dynamic: bool = True, fullgraph: bool = False, compile_models: list = None, **kwargs):
"""
compile the pipeline with torch.compile. The models that will be compiled are determined by the `compilable_models` attribute of the pipeline.
If a model has `_repeated_blocks` attribute, we will compile these blocks with regional compilation. Otherwise, we will compile the whole model.
See https://docs.pytorch.org/docs/stable/generated/torch.compile.html#torch.compile for details about compilation arguments.
Args:
mode: The compilation mode, which will be passed to `torch.compile`, options are "default", "reduce-overhead", "max-autotune" and "max-autotune-no-cudagraphs. Default to "default".
dynamic: Whether to enable dynamic graph compilation to support dynamic input shapes, which will be passed to `torch.compile`. Default to True (recommended).
fullgraph: Whether to use full graph compilation, which will be passed to `torch.compile`. Default to False (recommended).
compile_models: The list of model names to be compiled. If None, we will compile the models in `pipeline.compilable_models`. Default to None.
**kwargs: Other arguments for `torch.compile`.
"""
compile_models = compile_models or getattr(self, "compilable_models", [])
if len(compile_models) == 0:
print("No compilable models in the pipeline. Skip compilation.")
return
for name in compile_models:
model = getattr(self, name, None)
if model is None:
print(f"Model '{name}' not found in the pipeline.")
continue
repeated_blocks = getattr(model, "_repeated_blocks", None)
# regional compilation for repeated blocks.
if repeated_blocks is not None:
for submod in model.modules():
if submod.__class__.__name__ in repeated_blocks:
submod.compile(mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs)
# compile the whole model.
else:
model.compile(mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs)
print(f"{name} is compiled with mode={mode}, dynamic={dynamic}, fullgraph={fullgraph}.")
class PipelineUnitGraph:
def __init__(self):

File diff suppressed because it is too large Load Diff

View File

@@ -1,38 +0,0 @@
from transformers import Qwen3Model, Qwen3Config
import torch
class AceStepTextEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
config = Qwen3Config(**{
"architectures": ["Qwen3Model"],
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151643,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 3072,
"max_position_embeddings": 32768,
"max_window_layers": 28,
"model_type": "qwen3",
"num_attention_heads": 16,
"num_hidden_layers": 28,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_scaling": None,
"rope_theta": 1000000,
"sliding_window": None,
"tie_word_embeddings": True,
"torch_dtype": "bfloat16",
"use_cache": True,
"use_sliding_window": False,
"vocab_size": 151669
})
self.model = Qwen3Model(config)
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)

View File

@@ -1,416 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
class Snake1d(nn.Module):
"""
A 1-dimensional Snake activation function module.
"""
def __init__(self, hidden_dim, logscale=True):
super().__init__()
self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1))
self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1))
self.alpha.requires_grad = True
self.beta.requires_grad = True
self.logscale = logscale
def forward(self, hidden_states):
shape = hidden_states.shape
alpha = self.alpha if not self.logscale else torch.exp(self.alpha)
beta = self.beta if not self.logscale else torch.exp(self.beta)
hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2)
hidden_states = hidden_states.reshape(shape)
return hidden_states
class OobleckResidualUnit(nn.Module):
"""
A residual unit composed of Snake1d and weight-normalized Conv1d layers with dilations.
"""
def __init__(self, dimension: int = 16, dilation: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.snake1 = Snake1d(dimension)
self.conv1 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad))
self.snake2 = Snake1d(dimension)
self.conv2 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=1))
def forward(self, hidden_state):
output_tensor = hidden_state
output_tensor = self.conv1(self.snake1(output_tensor))
output_tensor = self.conv2(self.snake2(output_tensor))
padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2
if padding > 0:
hidden_state = hidden_state[..., padding:-padding]
output_tensor = hidden_state + output_tensor
return output_tensor
class OobleckEncoderBlock(nn.Module):
"""Encoder block used in Oobleck encoder."""
def __init__(self, input_dim, output_dim, stride: int = 1):
super().__init__()
self.res_unit1 = OobleckResidualUnit(input_dim, dilation=1)
self.res_unit2 = OobleckResidualUnit(input_dim, dilation=3)
self.res_unit3 = OobleckResidualUnit(input_dim, dilation=9)
self.snake1 = Snake1d(input_dim)
self.conv1 = weight_norm(
nn.Conv1d(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2))
)
def forward(self, hidden_state):
hidden_state = self.res_unit1(hidden_state)
hidden_state = self.res_unit2(hidden_state)
hidden_state = self.snake1(self.res_unit3(hidden_state))
hidden_state = self.conv1(hidden_state)
return hidden_state
class OobleckDecoderBlock(nn.Module):
"""Decoder block used in Oobleck decoder."""
def __init__(self, input_dim, output_dim, stride: int = 1):
super().__init__()
self.snake1 = Snake1d(input_dim)
self.conv_t1 = weight_norm(
nn.ConvTranspose1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
)
)
self.res_unit1 = OobleckResidualUnit(output_dim, dilation=1)
self.res_unit2 = OobleckResidualUnit(output_dim, dilation=3)
self.res_unit3 = OobleckResidualUnit(output_dim, dilation=9)
def forward(self, hidden_state):
hidden_state = self.snake1(hidden_state)
hidden_state = self.conv_t1(hidden_state)
hidden_state = self.res_unit1(hidden_state)
hidden_state = self.res_unit2(hidden_state)
hidden_state = self.res_unit3(hidden_state)
return hidden_state
class OobleckDiagonalGaussianDistribution(object):
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
self.mean, self.scale = parameters.chunk(2, dim=1)
self.std = nn.functional.softplus(self.scale) + 1e-4
self.var = self.std * self.std
self.logvar = torch.log(self.var)
self.deterministic = deterministic
def sample(self, generator: torch.Generator = None) -> torch.Tensor:
device = self.parameters.device
dtype = self.parameters.dtype
sample = torch.randn(self.mean.shape, generator=generator, device=device, dtype=dtype)
x = self.mean + self.std * sample
return x
def kl(self, other: "OobleckDiagonalGaussianDistribution" = None) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return (self.mean * self.mean + self.var - self.logvar - 1.0).sum(1).mean()
else:
normalized_diff = torch.pow(self.mean - other.mean, 2) / other.var
var_ratio = self.var / other.var
logvar_diff = self.logvar - other.logvar
kl = normalized_diff + var_ratio + logvar_diff - 1
kl = kl.sum(1).mean()
return kl
def mode(self) -> torch.Tensor:
return self.mean
@dataclass
class AutoencoderOobleckOutput:
"""
Output of AutoencoderOobleck encoding method.
Args:
latent_dist (`OobleckDiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and standard deviation of
`OobleckDiagonalGaussianDistribution`. `OobleckDiagonalGaussianDistribution` allows for sampling latents
from the distribution.
"""
latent_dist: "OobleckDiagonalGaussianDistribution"
@dataclass
class OobleckDecoderOutput:
r"""
Output of decoding method.
Args:
sample (`torch.Tensor` of shape `(batch_size, audio_channels, sequence_length)`):
The decoded output sample from the last layer of the model.
"""
sample: torch.Tensor
class OobleckEncoder(nn.Module):
"""Oobleck Encoder"""
def __init__(self, encoder_hidden_size, audio_channels, downsampling_ratios, channel_multiples):
super().__init__()
strides = downsampling_ratios
channel_multiples = [1] + channel_multiples
# Create first convolution
self.conv1 = weight_norm(nn.Conv1d(audio_channels, encoder_hidden_size, kernel_size=7, padding=3))
self.block = []
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride_index, stride in enumerate(strides):
self.block += [
OobleckEncoderBlock(
input_dim=encoder_hidden_size * channel_multiples[stride_index],
output_dim=encoder_hidden_size * channel_multiples[stride_index + 1],
stride=stride,
)
]
self.block = nn.ModuleList(self.block)
d_model = encoder_hidden_size * channel_multiples[-1]
self.snake1 = Snake1d(d_model)
self.conv2 = weight_norm(nn.Conv1d(d_model, encoder_hidden_size, kernel_size=3, padding=1))
def forward(self, hidden_state):
hidden_state = self.conv1(hidden_state)
for module in self.block:
hidden_state = module(hidden_state)
hidden_state = self.snake1(hidden_state)
hidden_state = self.conv2(hidden_state)
return hidden_state
class OobleckDecoder(nn.Module):
"""Oobleck Decoder"""
def __init__(self, channels, input_channels, audio_channels, upsampling_ratios, channel_multiples):
super().__init__()
strides = upsampling_ratios
channel_multiples = [1] + channel_multiples
# Add first conv layer
self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * channel_multiples[-1], kernel_size=7, padding=3))
# Add upsampling + MRF blocks
block = []
for stride_index, stride in enumerate(strides):
block += [
OobleckDecoderBlock(
input_dim=channels * channel_multiples[len(strides) - stride_index],
output_dim=channels * channel_multiples[len(strides) - stride_index - 1],
stride=stride,
)
]
self.block = nn.ModuleList(block)
output_dim = channels
self.snake1 = Snake1d(output_dim)
self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, kernel_size=7, padding=3, bias=False))
def forward(self, hidden_state):
hidden_state = self.conv1(hidden_state)
for layer in self.block:
hidden_state = layer(hidden_state)
hidden_state = self.snake1(hidden_state)
hidden_state = self.conv2(hidden_state)
return hidden_state
class AceStepVAE(nn.Module):
r"""
An autoencoder for encoding waveforms into latents and decoding latent representations into waveforms. First
introduced in Stable Audio.
Parameters:
encoder_hidden_size (`int`, *optional*, defaults to 128):
Intermediate representation dimension for the encoder.
downsampling_ratios (`list[int]`, *optional*, defaults to `[2, 4, 4, 8, 8]`):
Ratios for downsampling in the encoder. These are used in reverse order for upsampling in the decoder.
channel_multiples (`list[int]`, *optional*, defaults to `[1, 2, 4, 8, 16]`):
Multiples used to determine the hidden sizes of the hidden layers.
decoder_channels (`int`, *optional*, defaults to 128):
Intermediate representation dimension for the decoder.
decoder_input_channels (`int`, *optional*, defaults to 64):
Input dimension for the decoder. Corresponds to the latent dimension.
audio_channels (`int`, *optional*, defaults to 2):
Number of channels in the audio data. Either 1 for mono or 2 for stereo.
sampling_rate (`int`, *optional*, defaults to 44100):
The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
"""
def __init__(
self,
encoder_hidden_size=128,
downsampling_ratios=[2, 4, 4, 8, 8],
channel_multiples=[1, 2, 4, 8, 16],
decoder_channels=128,
decoder_input_channels=64,
audio_channels=2,
sampling_rate=44100,
):
super().__init__()
self.encoder_hidden_size = encoder_hidden_size
self.downsampling_ratios = downsampling_ratios
self.decoder_channels = decoder_channels
self.upsampling_ratios = downsampling_ratios[::-1]
self.hop_length = int(np.prod(downsampling_ratios))
self.sampling_rate = sampling_rate
self.encoder = OobleckEncoder(
encoder_hidden_size=encoder_hidden_size,
audio_channels=audio_channels,
downsampling_ratios=downsampling_ratios,
channel_multiples=channel_multiples,
)
self.decoder = OobleckDecoder(
channels=decoder_channels,
input_channels=decoder_input_channels,
audio_channels=audio_channels,
upsampling_ratios=self.upsampling_ratios,
channel_multiples=channel_multiples,
)
self.use_slicing = False
def encode(self, x: torch.Tensor, return_dict: bool = True):
"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self.encoder(x)
posterior = OobleckDiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderOobleckOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True):
dec = self.decoder(z)
if not return_dict:
return (dec,)
return OobleckDecoderOutput(sample=dec)
def decode(self, z: torch.FloatTensor, return_dict: bool = True, generator=None):
"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.OobleckDecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.OobleckDecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.OobleckDecoderOutput`] is returned, otherwise a plain `tuple`
is returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return OobleckDecoderOutput(sample=decoded)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: torch.Generator = None,
):
r"""
Args:
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`OobleckDecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return OobleckDecoderOutput(sample=dec)

View File

@@ -1270,9 +1270,6 @@ class LLMAdapter(nn.Module):
class AnimaDiT(MiniTrainDIT):
_repeated_blocks = ["Block"]
def __init__(self):
kwargs = {'image_model': 'anima', 'max_img_h': 240, 'max_img_w': 240, 'max_frames': 128, 'in_channels': 16, 'out_channels': 16, 'patch_spatial': 2, 'patch_temporal': 1, 'model_channels': 2048, 'concat_padding_mask': True, 'crossattn_emb_channels': 1024, 'pos_emb_cls': 'rope3d', 'pos_emb_learnable': True, 'pos_emb_interpolation': 'crop', 'min_fps': 1, 'max_fps': 30, 'use_adaln_lora': True, 'adaln_lora_dim': 256, 'num_blocks': 28, 'num_heads': 16, 'extra_per_block_abs_pos_emb': False, 'rope_h_extrapolation_ratio': 4.0, 'rope_w_extrapolation_ratio': 4.0, 'rope_t_extrapolation_ratio': 1.0, 'extra_h_extrapolation_ratio': 1.0, 'extra_w_extrapolation_ratio': 1.0, 'extra_t_extrapolation_ratio': 1.0, 'rope_enable_fps_modulation': False, 'dtype': torch.bfloat16, 'device': None, 'operations': torch.nn}
super().__init__(**kwargs)

View File

@@ -879,9 +879,6 @@ class Flux2Modulation(nn.Module):
class Flux2DiT(torch.nn.Module):
_repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
def __init__(
self,
patch_size: int = 1,

View File

@@ -275,9 +275,6 @@ class AdaLayerNormContinuous(torch.nn.Module):
class FluxDiT(torch.nn.Module):
_repeated_blocks = ["FluxJointTransformerBlock", "FluxSingleTransformerBlock"]
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])

View File

@@ -1280,7 +1280,6 @@ class LTXModel(torch.nn.Module):
LTX model transformer implementation.
This class implements the transformer blocks for the LTX model.
"""
_repeated_blocks = ["BasicAVTransformerBlock"]
def __init__( # noqa: PLR0913
self,

View File

@@ -549,9 +549,6 @@ class QwenImageTransformerBlock(nn.Module):
class QwenImageDiT(torch.nn.Module):
_repeated_blocks = ["QwenImageTransformerBlock"]
def __init__(
self,
num_layers: int = 60,

View File

@@ -336,9 +336,6 @@ class WanToDanceInjector(nn.Module):
class WanModel(torch.nn.Module):
_repeated_blocks = ["DiTBlock"]
def __init__(
self,
dim: int,

View File

@@ -326,7 +326,6 @@ class RopeEmbedder:
class ZImageDiT(nn.Module):
_supports_gradient_checkpointing = True
_no_split_modules = ["ZImageTransformerBlock"]
_repeated_blocks = ["ZImageTransformerBlock"]
def __init__(
self,

View File

@@ -1,217 +0,0 @@
import torch, math
from PIL import Image
from typing import Union
from tqdm import tqdm
from einops import rearrange
import numpy as np
from math import prod
from transformers import AutoTokenizer
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
from ..utils.lora.merge import merge_lora
from ..core.device.npu_compatible_device import get_device_type
from ..core import ModelConfig
from ..diffusion.base_pipeline import BasePipeline
from ..models.ace_step_text_encoder import AceStepTextEncoder
from ..models.ace_step_vae import AceStepVAE
from ..models.ace_step_dit import AceStepConditionGenerationModelWrapper
class AceStepAudioPipeline(BasePipeline):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.text_encoder: AceStepTextEncoder = None
self.dit: AceStepConditionGenerationModelWrapper = None
self.vae: AceStepVAE = None
self.scheduler = FlowMatchScheduler()
self.tokenizer: AutoTokenizer = None
self.in_iteration_models = ("dit",)
self.units = []
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B"),
vram_limit: float = None,
):
# Initialize pipeline
pipe = AceStepAudioPipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
pipe.text_encoder = model_pool.fetch_model("ace_step_text_encoder")
pipe.dit = model_pool.fetch_model("ace_step_dit")
pipe.vae = model_pool.fetch_model("ace_step_vae")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
caption: str,
lyrics: str = "",
duration: float = 160,
bpm: int = None,
keyscale: str = "",
timesignature: str = "",
vocal_language: str = "zh",
instrumental: bool = False,
inference_steps: int = 8,
guidance_scale: float = 3.0,
seed: int = None,
):
# Format text prompt with metadata
text_prompt = self._format_text_prompt(caption, bpm, keyscale, timesignature, duration)
lyrics_text = self._format_lyrics(lyrics, vocal_language, instrumental)
# Tokenize
text_inputs = self.tokenizer(
text_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512,
).to(self.device)
lyrics_inputs = self.tokenizer(
lyrics_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048,
).to(self.device)
# Encode text and lyrics
text_outputs = self.text_encoder(
input_ids=text_inputs["input_ids"],
attention_mask=text_inputs["attention_mask"],
)
lyrics_outputs = self.text_encoder(
input_ids=lyrics_inputs["input_ids"],
attention_mask=lyrics_inputs["attention_mask"],
)
# Get hidden states
text_hidden_states = text_outputs.last_hidden_state
lyric_hidden_states = lyrics_outputs.last_hidden_state
# Prepare generation parameters
latent_frames = int(duration * 46.875) # 48000 / 1024 ≈ 46.875 Hz
# For text2music task, use silence_latent as src_latents
# silence_latent will be tokenized/detokenized to get lm_hints_25Hz (127 dims)
# which will be used as context for generation
if self.silence_latent is not None:
# Slice or pad silence_latent to match latent_frames
if self.silence_latent.shape[1] >= latent_frames:
src_latents = self.silence_latent[:, :latent_frames, :].to(device=self.device, dtype=self.torch_dtype)
else:
# Pad with zeros if silence_latent is shorter
pad_len = latent_frames - self.silence_latent.shape[1]
src_latents = torch.cat([
self.silence_latent.to(device=self.device, dtype=self.torch_dtype),
torch.zeros(1, pad_len, self.src_latent_channels, device=self.device, dtype=self.torch_dtype)
], dim=1)
else:
# Fallback: create random latents if silence_latent is not loaded
src_latents = torch.randn(1, latent_frames, self.src_latent_channels,
device=self.device, dtype=self.torch_dtype)
# Create attention mask
attention_mask = torch.ones(1, latent_frames, device=self.device, dtype=self.torch_dtype)
# Use silence_latent for the silence_latent parameter as well
silence_latent = src_latents
# Chunk masks and is_covers (for text2music, these are all zeros)
# chunk_masks shape: [batch, latent_frames, 1]
chunk_masks = torch.zeros(1, latent_frames, 1, device=self.device, dtype=self.torch_dtype)
is_covers = torch.zeros(1, device=self.device, dtype=self.torch_dtype)
# Reference audio (empty for text2music)
# For text2music mode, we need empty reference audio
# refer_audio_acoustic_hidden_states_packed: [batch, num_segments, hidden_dim]
# refer_audio_order_mask: [num_segments] - indicates which batch each segment belongs to
refer_audio_acoustic_hidden_states_packed = torch.zeros(1, 1, 64, device=self.device, dtype=self.torch_dtype)
refer_audio_order_mask = torch.zeros(1, device=self.device, dtype=torch.long) # 1-d tensor
# Generate audio latents using DiT model
generation_result = self.dit.model.generate_audio(
text_hidden_states=text_hidden_states,
text_attention_mask=text_inputs["attention_mask"],
lyric_hidden_states=lyric_hidden_states,
lyric_attention_mask=lyrics_inputs["attention_mask"],
refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
refer_audio_order_mask=refer_audio_order_mask,
src_latents=src_latents,
chunk_masks=chunk_masks,
is_covers=is_covers,
silence_latent=silence_latent,
attention_mask=attention_mask,
seed=seed if seed is not None else 42,
fix_nfe=inference_steps,
shift=guidance_scale,
)
# Extract target latents from result dictionary
generated_latents = generation_result["target_latents"]
# Decode latents to audio
# generated_latents shape: [batch, latent_frames, 64]
# VAE expects: [batch, latent_frames, 64]
audio_output = self.vae.decode(generated_latents, return_dict=True)
audio = audio_output.sample
# Post-process audio
audio = self._postprocess_audio(audio)
self.load_models_to_device([])
return audio
def _format_text_prompt(self, caption, bpm, keyscale, timesignature, duration):
"""Format text prompt with metadata"""
prompt = "# Instruction\nFill the audio semantic mask based on the given conditions:\n\n"
prompt += f"# Caption\n{caption}\n\n"
prompt += "# Metas\n"
if bpm:
prompt += f"- bpm: {bpm}\n"
if timesignature:
prompt += f"- timesignature: {timesignature}\n"
if keyscale:
prompt += f"- keyscale: {keyscale}\n"
prompt += f"- duration: {int(duration)} seconds\n"
prompt += "<|endoftext|>"
return prompt
def _format_lyrics(self, lyrics, vocal_language, instrumental):
"""Format lyrics with language"""
if instrumental or not lyrics:
lyrics = "[Instrumental]"
lyrics_text = f"# Languages\n{vocal_language}\n\n# Lyric\n{lyrics}<|endoftext|>"
return lyrics_text
def _postprocess_audio(self, audio):
"""Post-process audio tensor"""
# Ensure audio is on CPU and in float32
audio = audio.to(device="cpu", dtype=torch.float32)
# Normalize to [-1, 1]
max_val = torch.abs(audio).max()
if max_val > 0:
audio = audio / max_val
return audio

View File

@@ -39,7 +39,6 @@ class AnimaImagePipeline(BasePipeline):
AnimaUnit_PromptEmbedder(),
]
self.model_fn = model_fn_anima
self.compilable_models = ["dit"]
@staticmethod

View File

@@ -42,7 +42,6 @@ class Flux2ImagePipeline(BasePipeline):
Flux2Unit_ImageIDs(),
]
self.model_fn = model_fn_flux2
self.compilable_models = ["dit"]
@staticmethod

View File

@@ -103,7 +103,6 @@ class FluxImagePipeline(BasePipeline):
FluxImageUnit_LoRAEncode(),
]
self.model_fn = model_fn_flux_image
self.compilable_models = ["dit"]
self.lora_loader = FluxLoRALoader
def enable_lora_merger(self):

View File

@@ -76,7 +76,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
LTX2AudioVideoUnit_SetScheduleStage2(),
]
self.model_fn = model_fn_ltx2
self.compilable_models = ["dit"]
self.default_negative_prompt = {
"LTX-2": (

View File

@@ -52,7 +52,6 @@ class MovaAudioVideoPipeline(BasePipeline):
MovaAudioVideoUnit_UnifiedSequenceParallel(),
]
self.model_fn = model_fn_mova_audio_video
self.compilable_models = ["video_dit", "video_dit2", "audio_dit"]
def enable_usp(self):
from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward

View File

@@ -56,7 +56,6 @@ class QwenImagePipeline(BasePipeline):
QwenImageUnit_BlockwiseControlNet(),
]
self.model_fn = model_fn_qwen_image
self.compilable_models = ["dit"]
@staticmethod

View File

@@ -83,11 +83,10 @@ class WanVideoPipeline(BasePipeline):
WanVideoPostUnit_S2V(),
]
self.model_fn = model_fn_wan_video
self.compilable_models = ["dit", "dit2"]
def enable_usp(self):
from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward, usp_dit_forward, usp_vace_forward
from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward, usp_dit_forward
for block in self.dit.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
@@ -96,14 +95,6 @@ class WanVideoPipeline(BasePipeline):
for block in self.dit2.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
if self.vace is not None:
for block in self.vace.vace_blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.vace.forward = types.MethodType(usp_vace_forward, self.vace)
if self.vace2 is not None:
for block in self.vace2.vace_blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.vace2.forward = types.MethodType(usp_vace_forward, self.vace2)
self.sp_size = get_sequence_parallel_world_size()
self.use_unified_sequence_parallel = True
@@ -1460,6 +1451,13 @@ def model_fn_wan_video(
else:
tea_cache_update = False
if vace_context is not None:
vace_hints = vace(
x, vace_context, context, t_mod, freqs,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload
)
# WanToDance
if hasattr(dit, "wantodance_enable_global") and dit.wantodance_enable_global:
if wantodance_refimage_feature is not None:
@@ -1521,13 +1519,6 @@ def model_fn_wan_video(
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
x = chunks[get_sequence_parallel_rank()]
if vace_context is not None:
vace_hints = vace(
x, vace_context, context, t_mod, freqs,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload
)
if tea_cache_update:
x = tea_cache.update(x)
else:
@@ -1570,6 +1561,9 @@ def model_fn_wan_video(
# VACE
if vace_context is not None and block_id in vace.vace_layers_mapping:
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0)
x = x + current_vace_hint * vace_scale
# Animate

View File

@@ -54,7 +54,6 @@ class ZImagePipeline(BasePipeline):
ZImageUnit_PAIControlNet(),
]
self.model_fn = model_fn_z_image
self.compilable_models = ["dit"]
@staticmethod

View File

@@ -1,15 +0,0 @@
def AceStepDiTStateDictConverter(state_dict):
"""
Convert ACE-Step DiT state dict to add 'model.' prefix for wrapper class.
The wrapper class has self.model = AceStepConditionGenerationModel(config),
so all keys need to be prefixed with 'model.'
"""
state_dict_ = {}
keys = state_dict.keys() if hasattr(state_dict, 'keys') else state_dict
for k in keys:
v = state_dict[k]
if not k.startswith("model."):
k = "model." + k
state_dict_[k] = v
return state_dict_

View File

@@ -1,19 +0,0 @@
def AceStepTextEncoderStateDictConverter(state_dict):
"""
将 ACE-Step Text Encoder 权重添加 model. 前缀
Args:
state_dict: 原始的 state dict可能是 dict 或 DiskMap
Returns:
转换后的 state dict所有 key 添加 "model." 前缀
"""
state_dict_ = {}
# 处理 DiskMap 或普通 dict
keys = state_dict.keys() if hasattr(state_dict, 'keys') else state_dict
for k in keys:
v = state_dict[k]
if not k.startswith("model."):
k = "model." + k
state_dict_[k] = v
return state_dict_

View File

@@ -1 +1 @@
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, usp_vace_forward, get_sequence_parallel_world_size, initialize_usp, get_current_chunk, gather_all_chunks
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, get_sequence_parallel_world_size, initialize_usp, get_current_chunk, gather_all_chunks

View File

@@ -117,39 +117,6 @@ def usp_dit_forward(self,
return x
def usp_vace_forward(
self, x, vace_context, context, t_mod, freqs,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
):
# Compute full sequence length from the sharded x
full_seq_len = x.shape[1] * get_sequence_parallel_world_size()
# Embed vace_context via patch embedding
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = torch.cat([
torch.cat([u, u.new_zeros(1, full_seq_len - u.size(1), u.size(2))],
dim=1) for u in c
])
# Chunk VACE context along sequence dim BEFORE processing through blocks
c = torch.chunk(c, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
# Process through vace_blocks (self_attn already monkey-patched to usp_attn_forward)
for block in self.vace_blocks:
c = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
c, x, context, t_mod, freqs
)
# Hints are already sharded per-rank
hints = torch.unbind(c)[:-1]
return hints
def usp_attn_forward(self, x, freqs):
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))

View File

@@ -16,7 +16,7 @@ For more information about installation, please refer to [Installation Dependenc
## Quick Start
Run the following code to quickly load the [Lightricks/LTX-2.3](https://www.modelscope.cn/models/Lightricks/LTX-2.3) model and perform inference. VRAM management has been enabled, and the framework will automatically control model parameter loading based on remaining VRAM. It can run with a minimum of 8GB VRAM.
Run the following code to quickly load the [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) model and perform inference. VRAM management has been enabled, and the framework will automatically control model parameter loading based on remaining VRAM. It can run with a minimum of 8GB VRAM.
```python
import torch
@@ -24,36 +24,88 @@ from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelCo
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_dtype": torch.float8_e5m2,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"onload_dtype": torch.float8_e5m2,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e5m2,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
"""
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
and avoid redundant memory usage when users only want to use part of the model.
"""
# use the repackaged modelconfig from "DiffSynth-Studio/LTX-2-Repackage" to avoid redundant model loading
pipe = LTX2AudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-spatial-upscaler-x2-1.0.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-distilled-lora-384.safetensors"),
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "Two cute orange cats, wearing boxing gloves, stand in a boxing ring and fight each other. They are punching each other fast and yelling: 'I will win!'"
negative_prompt = pipe.default_negative_prompt["LTX-2.3"]
# use the following modelconfig if you want to initialize model from offical checkpoints from "Lightricks/LTX-2"
# pipe = LTX2AudioVideoPipeline.from_pretrained(
# torch_dtype=torch.bfloat16,
# device="cuda",
# model_configs=[
# ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
# ],
# tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
# stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
# vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
# )
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
negative_prompt = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)
height, width, num_frames = 512 * 2, 768 * 2, 121
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
seed=43,
height=1024, width=1536, num_frames=121,
tiled=True, use_two_stage_pipeline=True,
height=height,
width=width,
num_frames=num_frames,
tiled=True,
use_two_stage_pipeline=True,
)
write_video_audio_ltx2(
video=video,
audio=audio,
output_path='ltx2_twostage.mp4',
fps=24,
audio_sample_rate=24000,
)
write_video_audio_ltx2(video=video, audio=audio, output_path='video.mp4', fps=24, audio_sample_rate=pipe.audio_vocoder.output_sampling_rate)
```
## Model Overview

View File

@@ -16,7 +16,7 @@ pip install -e .
## 快速开始
运行以下代码可以快速加载 [Lightricks/LTX-2.3](https://www.modelscope.cn/models/Lightricks/LTX-2.3) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8GB 显存即可运行。
运行以下代码可以快速加载 [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8GB 显存即可运行。
```python
import torch
@@ -24,36 +24,88 @@ from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelCo
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_dtype": torch.float8_e5m2,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"onload_dtype": torch.float8_e5m2,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e5m2,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
"""
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
and avoid redundant memory usage when users only want to use part of the model.
"""
# use the repackaged modelconfig from "DiffSynth-Studio/LTX-2-Repackage" to avoid redundant model loading
pipe = LTX2AudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-spatial-upscaler-x2-1.0.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-distilled-lora-384.safetensors"),
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "Two cute orange cats, wearing boxing gloves, stand in a boxing ring and fight each other. They are punching each other fast and yelling: 'I will win!'"
negative_prompt = pipe.default_negative_prompt["LTX-2.3"]
# use the following modelconfig if you want to initialize model from offical checkpoints from "Lightricks/LTX-2"
# pipe = LTX2AudioVideoPipeline.from_pretrained(
# torch_dtype=torch.bfloat16,
# device="cuda",
# model_configs=[
# ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
# ],
# tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
# stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
# vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
# )
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
negative_prompt = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)
height, width, num_frames = 512 * 2, 768 * 2, 121
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
seed=43,
height=1024, width=1536, num_frames=121,
tiled=True, use_two_stage_pipeline=True,
height=height,
width=width,
num_frames=num_frames,
tiled=True,
use_two_stage_pipeline=True,
)
write_video_audio_ltx2(
video=video,
audio=audio,
output_path='ltx2_twostage.mp4',
fps=24,
audio_sample_rate=24000,
)
write_video_audio_ltx2(video=video, audio=audio, output_path='video.mp4', fps=24, audio_sample_rate=pipe.audio_vocoder.output_sampling_rate)
```
## 模型总览

View File

@@ -1,14 +0,0 @@
from diffsynth.pipelines.ace_step_audio import AceStepAudioPipeline, ModelConfig
import torch
pipe = AceStepAudioPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B"),
)

View File

@@ -1,6 +1,6 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
from PIL import Image
vram_config = {
"offload_dtype": torch.bfloat16,
@@ -25,8 +25,3 @@ pipe = Flux2ImagePipeline.from_pretrained(
prompt = "Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text `BFL Diffusers` on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom."
image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50)
image.save("image_FLUX.2-dev.jpg")
prompt = "Transform the image into Japanese anime style"
edit_image = [Image.open("image_FLUX.2-dev.jpg")]
image = pipe(prompt, seed=42, rand_device="cuda", edit_image=edit_image, num_inference_steps=50, embedded_guidance=2.5)
image.save("image_FLUX.2-dev_edit.jpg")

View File

@@ -1,6 +1,5 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
from PIL import Image
vram_config = {
"offload_dtype": "disk",
@@ -26,8 +25,3 @@ pipe = Flux2ImagePipeline.from_pretrained(
prompt = "High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene."
image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50)
image.save("image.jpg")
prompt = "Transform the image into Japanese anime style"
edit_image = [Image.open("image.jpg")]
image = pipe(prompt, seed=42, rand_device="cuda", edit_image=edit_image, num_inference_steps=50, embedded_guidance=2.5)
image.save("image_edit.jpg")

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "diffsynth"
version = "2.0.7"
version = "2.0.6"
description = "Enjoy the magic of Diffusion models!"
authors = [{name = "ModelScope Team"}]
license = {text = "Apache-2.0"}