mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 17:38:10 +00:00
Compare commits
4 Commits
compatibil
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e2a3a987da | ||
|
|
f7b9ae7d57 | ||
|
|
5d198287f0 | ||
|
|
5bccd60c80 |
@@ -86,7 +86,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
|
||||
|
||||
def enable_usp(self):
|
||||
from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward, usp_dit_forward
|
||||
from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward, usp_dit_forward, usp_vace_forward
|
||||
|
||||
for block in self.dit.blocks:
|
||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||
@@ -95,6 +95,14 @@ class WanVideoPipeline(BasePipeline):
|
||||
for block in self.dit2.blocks:
|
||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||
self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
|
||||
if self.vace is not None:
|
||||
for block in self.vace.vace_blocks:
|
||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||
self.vace.forward = types.MethodType(usp_vace_forward, self.vace)
|
||||
if self.vace2 is not None:
|
||||
for block in self.vace2.vace_blocks:
|
||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||
self.vace2.forward = types.MethodType(usp_vace_forward, self.vace2)
|
||||
self.sp_size = get_sequence_parallel_world_size()
|
||||
self.use_unified_sequence_parallel = True
|
||||
|
||||
@@ -1450,13 +1458,6 @@ def model_fn_wan_video(
|
||||
tea_cache_update = tea_cache.check(dit, x, t_mod)
|
||||
else:
|
||||
tea_cache_update = False
|
||||
|
||||
if vace_context is not None:
|
||||
vace_hints = vace(
|
||||
x, vace_context, context, t_mod, freqs,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload
|
||||
)
|
||||
|
||||
# WanToDance
|
||||
if hasattr(dit, "wantodance_enable_global") and dit.wantodance_enable_global:
|
||||
@@ -1519,6 +1520,13 @@ def model_fn_wan_video(
|
||||
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
|
||||
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
|
||||
x = chunks[get_sequence_parallel_rank()]
|
||||
|
||||
if vace_context is not None:
|
||||
vace_hints = vace(
|
||||
x, vace_context, context, t_mod, freqs,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload
|
||||
)
|
||||
if tea_cache_update:
|
||||
x = tea_cache.update(x)
|
||||
else:
|
||||
@@ -1561,9 +1569,6 @@ def model_fn_wan_video(
|
||||
# VACE
|
||||
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
||||
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
|
||||
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
||||
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||
current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0)
|
||||
x = x + current_vace_hint * vace_scale
|
||||
|
||||
# Animate
|
||||
|
||||
@@ -1 +1 @@
|
||||
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, get_sequence_parallel_world_size, initialize_usp, get_current_chunk, gather_all_chunks
|
||||
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, usp_vace_forward, get_sequence_parallel_world_size, initialize_usp, get_current_chunk, gather_all_chunks
|
||||
|
||||
@@ -117,6 +117,39 @@ def usp_dit_forward(self,
|
||||
return x
|
||||
|
||||
|
||||
def usp_vace_forward(
|
||||
self, x, vace_context, context, t_mod, freqs,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
):
|
||||
# Compute full sequence length from the sharded x
|
||||
full_seq_len = x.shape[1] * get_sequence_parallel_world_size()
|
||||
|
||||
# Embed vace_context via patch embedding
|
||||
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||
c = torch.cat([
|
||||
torch.cat([u, u.new_zeros(1, full_seq_len - u.size(1), u.size(2))],
|
||||
dim=1) for u in c
|
||||
])
|
||||
|
||||
# Chunk VACE context along sequence dim BEFORE processing through blocks
|
||||
c = torch.chunk(c, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||
|
||||
# Process through vace_blocks (self_attn already monkey-patched to usp_attn_forward)
|
||||
for block in self.vace_blocks:
|
||||
c = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
c, x, context, t_mod, freqs
|
||||
)
|
||||
|
||||
# Hints are already sharded per-rank
|
||||
hints = torch.unbind(c)[:-1]
|
||||
return hints
|
||||
|
||||
|
||||
def usp_attn_forward(self, x, freqs):
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(x))
|
||||
|
||||
@@ -16,7 +16,7 @@ For more information about installation, please refer to [Installation Dependenc
|
||||
|
||||
## Quick Start
|
||||
|
||||
Run the following code to quickly load the [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) model and perform inference. VRAM management has been enabled, and the framework will automatically control model parameter loading based on remaining VRAM. It can run with a minimum of 8GB VRAM.
|
||||
Run the following code to quickly load the [Lightricks/LTX-2.3](https://www.modelscope.cn/models/Lightricks/LTX-2.3) model and perform inference. VRAM management has been enabled, and the framework will automatically control model parameter loading based on remaining VRAM. It can run with a minimum of 8GB VRAM.
|
||||
|
||||
```python
|
||||
import torch
|
||||
@@ -24,88 +24,36 @@ from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelCo
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float8_e5m2,
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float8_e5m2,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e5m2,
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
"""
|
||||
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
|
||||
Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
|
||||
For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
|
||||
and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
|
||||
We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
|
||||
and avoid redundant memory usage when users only want to use part of the model.
|
||||
"""
|
||||
# use the repackaged modelconfig from "DiffSynth-Studio/LTX-2-Repackage" to avoid redundant model loading
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-distilled-lora-384.safetensors"),
|
||||
)
|
||||
|
||||
# use the following modelconfig if you want to initialize model from offical checkpoints from "Lightricks/LTX-2"
|
||||
# pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
# torch_dtype=torch.bfloat16,
|
||||
# device="cuda",
|
||||
# model_configs=[
|
||||
# ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||
# ],
|
||||
# tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
# stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||
# vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
# )
|
||||
|
||||
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
|
||||
negative_prompt = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
prompt = "Two cute orange cats, wearing boxing gloves, stand in a boxing ring and fight each other. They are punching each other fast and yelling: 'I will win!'"
|
||||
negative_prompt = pipe.default_negative_prompt["LTX-2.3"]
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
use_two_stage_pipeline=True,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_twostage.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=24000,
|
||||
height=1024, width=1536, num_frames=121,
|
||||
tiled=True, use_two_stage_pipeline=True,
|
||||
)
|
||||
write_video_audio_ltx2(video=video, audio=audio, output_path='video.mp4', fps=24, audio_sample_rate=pipe.audio_vocoder.output_sampling_rate)
|
||||
```
|
||||
|
||||
## Model Overview
|
||||
|
||||
@@ -16,7 +16,7 @@ pip install -e .
|
||||
|
||||
## 快速开始
|
||||
|
||||
运行以下代码可以快速加载 [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8GB 显存即可运行。
|
||||
运行以下代码可以快速加载 [Lightricks/LTX-2.3](https://www.modelscope.cn/models/Lightricks/LTX-2.3) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8GB 显存即可运行。
|
||||
|
||||
```python
|
||||
import torch
|
||||
@@ -24,88 +24,36 @@ from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelCo
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float8_e5m2,
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float8_e5m2,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e5m2,
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
"""
|
||||
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
|
||||
Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
|
||||
For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
|
||||
and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
|
||||
We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
|
||||
and avoid redundant memory usage when users only want to use part of the model.
|
||||
"""
|
||||
# use the repackaged modelconfig from "DiffSynth-Studio/LTX-2-Repackage" to avoid redundant model loading
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-distilled-lora-384.safetensors"),
|
||||
)
|
||||
|
||||
# use the following modelconfig if you want to initialize model from offical checkpoints from "Lightricks/LTX-2"
|
||||
# pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
# torch_dtype=torch.bfloat16,
|
||||
# device="cuda",
|
||||
# model_configs=[
|
||||
# ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||
# ],
|
||||
# tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
# stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||
# vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
# )
|
||||
|
||||
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
|
||||
negative_prompt = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
prompt = "Two cute orange cats, wearing boxing gloves, stand in a boxing ring and fight each other. They are punching each other fast and yelling: 'I will win!'"
|
||||
negative_prompt = pipe.default_negative_prompt["LTX-2.3"]
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
use_two_stage_pipeline=True,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_twostage.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=24000,
|
||||
height=1024, width=1536, num_frames=121,
|
||||
tiled=True, use_two_stage_pipeline=True,
|
||||
)
|
||||
write_video_audio_ltx2(video=video, audio=audio, output_path='video.mp4', fps=24, audio_sample_rate=pipe.audio_vocoder.output_sampling_rate)
|
||||
```
|
||||
|
||||
## 模型总览
|
||||
|
||||
Reference in New Issue
Block a user