Compare commits

..

1 Commits
sd ... zero3

Author SHA1 Message Date
Artiprocher
6fe897883b resolve conflicts 2026-02-04 17:07:30 +08:00
505 changed files with 1240 additions and 34437 deletions

1
.gitignore vendored
View File

@@ -2,7 +2,6 @@
/models
/scripts
/diffusers
/.vscode
*.pkl
*.safetensors
*.pth

486
README.md
View File

@@ -7,14 +7,11 @@
[![open issues](https://isitmaintained.com/badge/open/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/issues)
[![GitHub pull-requests](https://img.shields.io/github/issues-pr/modelscope/DiffSynth-Studio.svg)](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
[![GitHub latest commit](https://badgen.net/github/last-commit/modelscope/DiffSynth-Studio)](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
[![Discord](https://badgen.net//discord/members/Mm9suEeUDc)](https://discord.gg/Mm9suEeUDc)
[切换到中文版](./README_zh.md)
## Introduction
> DiffSynth-Studio Documentation: [中文版](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/)、[English version](https://diffsynth-studio-doc.readthedocs.io/en/latest/)
Welcome to the magical world of Diffusion models! DiffSynth-Studio is an open-source Diffusion model engine developed and maintained by the [ModelScope Community](https://www.modelscope.cn/). We hope to foster technological innovation through framework construction, aggregate the power of the open-source community, and explore the boundaries of generative model technology!
DiffSynth currently includes two open-source projects:
@@ -26,34 +23,15 @@ DiffSynth currently includes two open-source projects:
* ModelScope AIGC Zone (for Chinese users): https://modelscope.cn/aigc/home
* ModelScope Civision (for global users): https://modelscope.ai/civision/home
> DiffSynth-Studio Documentation: [中文版](/docs/zh/README.md)、[English version](/docs/en/README.md)
We believe that a well-developed open-source code framework can lower the threshold for technical exploration. We have achieved many [interesting technologies](#innovative-achievements) based on this codebase. Perhaps you also have many wild ideas, and with DiffSynth-Studio, you can quickly realize these ideas. For this reason, we have prepared detailed documentation for developers. We hope that through these documents, developers can understand the principles of Diffusion models, and we look forward to expanding the boundaries of technology together with you.
## Update History
> DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update.
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
- **April 14, 2026** JoyAI-Image open-sourced, welcome a new member to the image editing model family! Support includes instruction-guided image editing, low VRAM inference, and training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/JoyAI-Image.md) and [example code](/examples/joyai_image/).
- **March 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available.
- **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/en/Model_Details/LTX-2.md) and [code](/examples/ltx2/).
- **March 3, 2026**: We released the [DiffSynth-Studio/Qwen-Image-Layered-Control-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control-V2) model, which is an updated version of Qwen-Image-Layered-Control. In addition to the originally supported text-guided functionality, it adds brush-controlled layer separation capabilities.
- **March 2, 2026** Added support for [Anima](https://modelscope.cn/models/circlestone-labs/Anima). For details, please refer to the [documentation](docs/en/Model_Details/Anima.md). This is an interesting anime-style image generation model. We look forward to its future updates.
<details>
<summary>More</summary>
- **February 26, 2026** Added full and lora training support for the LTX-2 audio-video generation model. See the [documentation](/docs/en/Model_Details/LTX-2.md) for details.
- **February 10, 2026** Added inference support for the LTX-2 audio-video generation model. See the [documentation](/docs/en/Model_Details/LTX-2.md) for details. Support for model training will be implemented in the future.
- **February 2, 2026** The first document of the Research Tutorial series is now available, guiding you through training a small 0.1B text-to-image model from scratch. For details, see the [documentation](/docs/en/Research_Tutorial/train_from_scratch.md) and [model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel). We hope DiffSynth-Studio can evolve into a more powerful training framework for Diffusion models.
- **January 27, 2026**: [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) is released, and our [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) model is released concurrently. You can use it in [ModelScope Studios](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L). For details, see the [documentation](/docs/zh/Model_Details/Z-Image.md).
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
- **January 19, 2026**: Added support for [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) and [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/FLUX2.md) and [example code](/examples/flux2/) are now available.
@@ -74,6 +52,9 @@ We believe that a well-developed open-source code framework can lower the thresh
- [Differential LoRA Training](/docs/zh/Training/Differential_LoRA.md): This is a training technique we used in [ArtAug](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), now available for LoRA training of any model.
- [FP8 Training](/docs/zh/Training/FP8_Precision.md): FP8 can be applied to any non-training model during training, i.e., models with gradients turned off or gradients that only affect LoRA weights.
<details>
<summary>More</summary>
- **November 4, 2025** Supported the [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) model, which is trained based on Wan 2.1 and supports generating corresponding actions based on reference videos.
- **October 30, 2025** Supported the [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) model, which supports text-to-video, image-to-video, and video continuation. This model uses the Wan framework for inference and training in this project.
@@ -288,14 +269,9 @@ image.save("image.jpg")
Example code for Z-Image is available at: [/examples/z_image/](/examples/z_image/)
|Model ID|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
</details>
@@ -355,60 +331,6 @@ Example code for FLUX.2 is available at: [/examples/flux2/](/examples/flux2/)
</details>
#### Anima: [/docs/en/Model_Details/Anima.md](/docs/en/Model_Details/Anima.md)
<details>
<summary>Quick Start</summary>
Run the following code to quickly load the [circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 8GB VRAM.
```python
from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": "disk",
"onload_device": "disk",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AnimaImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors", **vram_config),
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors", **vram_config),
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait."
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
image = pipe(prompt, seed=0, num_inference_steps=50)
image.save("image.jpg")
```
</details>
<details>
<summary>Examples</summary>
Example code for Anima is located at: [/examples/anima/](/examples/anima/)
| Model ID | Inference | Low VRAM Inference | Full Training | Validation after Full Training | LoRA Training | Validation after LoRA Training |
|-|-|-|-|-|-|-|
|[circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima)|[code](/examples/anima/model_inference/anima-preview.py)|[code](/examples/anima/model_inference_low_vram/anima-preview.py)|[code](/examples/anima/model_training/full/anima-preview.sh)|[code](/examples/anima/model_training/validate_full/anima-preview.py)|[code](/examples/anima/model_training/lora/anima-preview.sh)|[code](/examples/anima/model_training/validate_lora/anima-preview.py)|
</details>
#### Qwen-Image: [/docs/en/Model_Details/Qwen-Image.md](/docs/en/Model_Details/Qwen-Image.md)
<details>
@@ -488,12 +410,8 @@ Example code for Qwen-Image is available at: [/examples/qwen_image/](/examples/q
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[FireRedTeam/FireRed-Image-Edit-1.0](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.0)|[code](/examples/qwen_image/model_inference/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.0.sh)|[code](/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.0.sh)|[code](/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.0.py)|
|[FireRedTeam/FireRed-Image-Edit-1.1](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.1)|[code](/examples/qwen_image/model_inference/FireRed-Image-Edit-1.1.py)|[code](/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.1.py)|[code](/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.1.sh)|[code](/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.1.py)|[code](/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.1.sh)|[code](/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.1.py)|
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control-V2.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control-V2.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
@@ -600,283 +518,10 @@ Example code for FLUX.1 is available at: [/examples/flux/](/examples/flux/)
</details>
#### ERNIE-Image: [/docs/en/Model_Details/ERNIE-Image.md](/docs/en/Model_Details/ERNIE-Image.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 3GB VRAM.
```python
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = ErnieImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device='cuda',
model_configs=[
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="一只黑白相间的中华田园犬",
negative_prompt="",
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
cfg_scale=4.0,
)
image.save("output.jpg")
```
</details>
<details>
<summary>Examples</summary>
Example code for ERNIE-Image is available at: [/examples/ernie_image/](/examples/ernie_image/)
| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|
|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
</details>
#### JoyAI-Image: [/docs/en/Model_Details/JoyAI-Image.md](/docs/en/Model_Details/JoyAI-Image.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 4GB VRAM.
```python
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
import torch
from PIL import Image
from modelscope import dataset_snapshot_download
# Download dataset
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
)
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = JoyAIImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
],
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
# Use first sample from dataset
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
prompt = "将裙子改为粉色"
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
output = pipe(
prompt=prompt,
edit_image=edit_image,
height=1024,
width=1024,
seed=0,
num_inference_steps=30,
cfg_scale=5.0,
)
output.save("output_joyai_edit_low_vram.png")
```
</details>
<details>
<summary>Examples</summary>
Example code for JoyAI-Image is available at: [/examples/joyai_image/](/examples/joyai_image/)
| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|
|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)|
</details>
### Video Synthesis
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
#### LTX-2: [/docs/en/Model_Details/LTX-2.md](/docs/en/Model_Details/LTX-2.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8GB of VRAM.
```python
import torch
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
vram_config = {
"offload_dtype": torch.float8_e5m2,
"offload_device": "cpu",
"onload_dtype": torch.float8_e5m2,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e5m2,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
"""
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
and avoid redundant memory usage when users only want to use part of the model.
"""
# use the repackaged modelconfig from "DiffSynth-Studio/LTX-2-Repackage" to avoid redundant model loading
pipe = LTX2AudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
# use the following modelconfig if you want to initialize model from offical checkpoints from "Lightricks/LTX-2"
# pipe = LTX2AudioVideoPipeline.from_pretrained(
# torch_dtype=torch.bfloat16,
# device="cuda",
# model_configs=[
# ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
# ],
# tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
# stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
# vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
# )
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
negative_prompt = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)
height, width, num_frames = 512 * 2, 768 * 2, 121
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
seed=43,
height=height,
width=width,
num_frames=num_frames,
tiled=True,
use_two_stage_pipeline=True,
)
write_video_audio_ltx2(
video=video,
audio=audio,
output_path='ltx2_twostage.mp4',
fps=24,
audio_sample_rate=24000,
)
```
</details>
<details>
<summary>Examples</summary>
Example code for LTX-2 is available at: [/examples/ltx2/](/examples/ltx2/)
| Model ID | Extra Args | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|-|
|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2.3-I2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2.3-I2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2.3-I2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-I2AV.py)|
|[Lightricks/LTX-2.3: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2.3: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-DistilledPipeline.py)|-|-|-|-|
|[Lightricks/LTX-2.3: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py)|
|[Lightricks/LTX-2.3: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2.3: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-DistilledPipeline.py)|-|-|-|-|
|[Lightricks/LTX-2.3: A2V](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`retake_audio`,`audio_sample_rate`,`retake_audio_regions`|[code](/examples/ltx2/model_inference/LTX-2.3-A2V-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-A2V-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2.3: Retake](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`retake_video`,`retake_video_regions`,`retake_audio`,`audio_sample_rate`,`retake_audio_regions`|[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage-Retake.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage-Retake.py)|-|-|-|-|
|[Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py)|
|[Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control](https://www.modelscope.cn/models/Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Motion-Track-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Motion-Track-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py)|
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)|
|[Lightricks/LTX-2-19b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|
|[Lightricks/LTX-2-19b-IC-LoRA-Detailer](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Detailer)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Detailer.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Static](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Static)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py)|-|-|-|-|
</details>
#### Wan: [/docs/en/Model_Details/Wan.md](/docs/en/Model_Details/Wan.md)
<details>
@@ -976,43 +621,39 @@ graph LR;
Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
| Model ID | Extra Inputs | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|-|-|-|-|-|-|-|-|
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B_multi_clips.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|[openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference_low_vram/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-360P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-360P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py)|
|[openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference_low_vram/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-720P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-720P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py)|
|[Wan-AI/WanToDance-14B (global model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py)|
|[Wan-AI/WanToDance-14B (local model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-local.py)|
| Model ID | Extra Args | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](/examples/wanvideo/model_inference/LongCat-Video.py)|[code](/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
</details>
@@ -1020,37 +661,6 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements.
<details>
<summary>Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation</summary>
- Paper: [Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation
](https://arxiv.org/abs/2602.03208)
- Sample Code: [/docs/en/Research_Tutorial/inference_time_scaling.md](/docs/en/Research_Tutorial/inference_time_scaling.md)
|FLUX.1-dev|FLUX.1-dev + SES|Qwen-Image|Qwen-Image + SES|
|-|-|-|-|
|![Image](https://github.com/user-attachments/assets/5be15dc6-2805-4822-b04c-2573fc0f45f0)|![Image](https://github.com/user-attachments/assets/e71b8c20-1629-41d9-b0ff-185805c1da4e)|![Image](https://github.com/user-attachments/assets/7a73c968-133a-4545-9aa2-205533861cd4)|![Image](https://github.com/user-attachments/assets/c8390b22-14fe-48a0-a6e6-d6556d31235e)|
</details>
<details>
<summary>VIRAL: Visual In-Context Reasoning via Analogy in Diffusion Transformers</summary>
- Paper: [VIRAL: Visual In-Context Reasoning via Analogy in Diffusion Transformers
](https://arxiv.org/abs/2602.03210)
- Sample code: [/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py)
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)
|Example 1|Example 2|Query|Output|
|-|-|-|-|
|![Image](https://github.com/user-attachments/assets/380d2670-47bf-41cd-b5c9-37110cc4a943)|![Image](https://github.com/user-attachments/assets/7ceaf345-0992-46e6-b38f-394c2065b165)|![Image](https://github.com/user-attachments/assets/f7c26c21-6894-4d9e-b570-f1d44ca7c1de)|![Image](https://github.com/user-attachments/assets/c2bebe3b-5984-41ba-94bf-9509f6a8a990)|
</details>
<details>
<summary>AttriCtrl: Attribute Intensity Control for Image Generation Models</summary>
@@ -1061,7 +671,7 @@ DiffSynth-Studio is not just an engineered model framework, but also an incubato
|brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9|
|-|-|-|-|-|
|![Image](https://github.com/user-attachments/assets/e74b32a5-5b2e-4c87-9df8-487c0f8366b7)|![Image](https://github.com/user-attachments/assets/bfe8bec2-9e55-493d-9a26-7e9cce28e03d)|![Image](https://github.com/user-attachments/assets/b099dfe3-ff1f-4b96-894c-d48bbe92db7a)|![Image](https://github.com/user-attachments/assets/0a6b2982-deab-4b0d-91ad-888782de01c9)|![Image](https://github.com/user-attachments/assets/fcecb755-7d03-4020-b83a-13ad2b38705c)|
|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.5.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.7.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.9.jpg)|
</details>
@@ -1076,10 +686,10 @@ DiffSynth-Studio is not just an engineered model framework, but also an incubato
||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)|
|-|-|-|-|-|
|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |![Image](https://github.com/user-attachments/assets/01c54d5a-4f00-4c2e-982a-4ec0a4c6a6e3)|![Image](https://github.com/user-attachments/assets/e6621457-b9f1-437c-bcc8-3e12e41646de)|![Image](https://github.com/user-attachments/assets/4b7f721f-a2e5-416c-af2c-b53ef236c321)|![Image](https://github.com/user-attachments/assets/802d554e-0402-482c-9f28-87605f8fe318)|
|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |![Image](https://github.com/user-attachments/assets/e6621457-b9f1-437c-bcc8-3e12e41646de)|![Image](https://github.com/user-attachments/assets/43720a9f-aa27-4918-947d-545389375d46)|![Image](https://github.com/user-attachments/assets/418c725b-6d35-41f4-b18f-c7e3867cc142)|![Image](https://github.com/user-attachments/assets/8c8f22fa-9643-4019-b6d7-396d8b7fed9a)|
|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |![Image](https://github.com/user-attachments/assets/4b7f721f-a2e5-416c-af2c-b53ef236c321)|![Image](https://github.com/user-attachments/assets/418c725b-6d35-41f4-b18f-c7e3867cc142)|![Image](https://github.com/user-attachments/assets/041a3f9a-c7b4-4311-8582-cb71a7226d80)|![Image](https://github.com/user-attachments/assets/b54ebaa4-31a7-4536-a2c1-496adba0c013)|
|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |![Image](https://github.com/user-attachments/assets/802d554e-0402-482c-9f28-87605f8fe318)|![Image](https://github.com/user-attachments/assets/8c8f22fa-9643-4019-b6d7-396d8b7fed9a)|![Image](https://github.com/user-attachments/assets/b54ebaa4-31a7-4536-a2c1-496adba0c013)|![Image](https://github.com/user-attachments/assets/a640fd54-3192-49a0-9281-b43d9ba64f09)|
|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_0.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|
|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|
|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|
|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_3_3.jpg)|
</details>
@@ -1170,9 +780,3 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-47
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
</details>
## Contact Us
|Discordhttps://discord.gg/Mm9suEeUDc|
|-|
|<img width="160" height="160" alt="Image" src="https://github.com/user-attachments/assets/29bdc97b-e35d-4fea-88d6-32e35182e458" />|

View File

@@ -7,14 +7,11 @@
[![open issues](https://isitmaintained.com/badge/open/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/issues)
[![GitHub pull-requests](https://img.shields.io/github/issues-pr/modelscope/DiffSynth-Studio.svg)](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
[![GitHub latest commit](https://badgen.net/github/last-commit/modelscope/DiffSynth-Studio)](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
[![Discord](https://badgen.net//discord/members/Mm9suEeUDc)](https://discord.gg/Mm9suEeUDc)
[Switch to English](./README.md)
## 简介
> DiffSynth-Studio 文档:[中文版](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/)、[English version](https://diffsynth-studio-doc.readthedocs.io/en/latest/)
欢迎来到 Diffusion 模型的魔法世界DiffSynth-Studio 是由[魔搭社区](https://www.modelscope.cn/)团队开发和维护的开源 Diffusion 模型引擎。我们期望以框架建设孵化技术创新,凝聚开源社区的力量,探索生成式模型技术的边界!
DiffSynth 目前包括两个开源项目:
@@ -26,34 +23,15 @@ DiffSynth 目前包括两个开源项目:
* 魔搭社区 AIGC 专区 (面向中国用户): https://modelscope.cn/aigc/home
* ModelScope Civision (for global users): https://modelscope.ai/civision/home
> DiffSynth-Studio 文档:[中文版](/docs/zh/README.md)、[English version](/docs/en/README.md)
我们相信,一个完善的开源代码框架能够降低技术探索的门槛,我们基于这个代码库搞出了不少[有意思的技术](#创新成果)。或许你也有许多天马行空的构想,借助 DiffSynth-Studio你可以快速实现这些想法。为此我们为开发者准备了详细的文档我们希望通过这些文档帮助开发者理解 Diffusion 模型的原理,更期待与你一同拓展技术的边界。
## 更新历史
> DiffSynth-Studio 经历了大版本更新,部分旧功能已停止维护,如需使用旧版功能,请切换到大版本更新前的[最后一个历史版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3)。
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 和 [mi804](https://github.com/mi804) 负责因此新功能的开发进展会比较缓慢issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
- **2026年4月14日** JoyAI-Image 开源,欢迎加入图像编辑模型家族!支持指令引导的图像编辑推理、低显存推理和训练能力。详情请参考[文档](/docs/zh/Model_Details/JoyAI-Image.md)和[示例代码](/examples/joyai_image/)。
- **2026年3月19日** 新增对 [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) 和 [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/Wan.md)和[示例代码](/examples/mova/)现已可用。
- **2026年3月12日** 我们新增了 [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) 音视频生成模型的支持模型支持的功能包括文生音视频、图生音视频、IC-LoRA控制、音频生视频、音视频局部Inpainting框架支持完整的推理和训练功能。详细信息请参考 [文档](/docs/zh/Model_Details/LTX-2.md) 和 [示例代码](/examples/ltx2/)。
- **2026年3月3日** 我们发布了 [DiffSynth-Studio/Qwen-Image-Layered-Control-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control-V2) 模型,这是 Qwen-Image-Layered-Control 的更新版本。除了原本就支持的文本引导功能,新增了画笔控制的图层拆分能力。
- **2026年3月2日** 新增对[Anima](https://modelscope.cn/models/circlestone-labs/Anima)的支持,详见[文档](docs/zh/Model_Details/Anima.md)。这是一个有趣的动漫风格图像生成模型,我们期待其后续的模型更新。
<details>
<summary>更多</summary>
- **2026年2月26日** 新增对[LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2)音视频生成模型全量微调与LoRA训练支持详见[文档](docs/zh/Model_Details/LTX-2.md)。
- **2026年2月10日** 新增对[LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2)音视频生成模型的推理支持,详见[文档](docs/zh/Model_Details/LTX-2.md),后续将推进模型训练的支持。
- **2026年2月2日** Research Tutorial 的第一篇文档上线,带你从零开始训练一个 0.1B 的小型文生图模型,详见[文档](/docs/zh/Research_Tutorial/train_from_scratch.md)、[模型](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel),我们希望 DiffSynth-Studio 能够成为一个更强大的 Diffusion 模型训练框架。
- **2026年1月27日** [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) 发布,我们的 [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) 模型同步发布,在[魔搭创空间](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L)可直接体验,详见[文档](/docs/zh/Model_Details/Z-Image.md)。
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责因此新功能的开发进展会比较缓慢issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
- **2026年1月19日** 新增对 [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 和 [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/FLUX2.md)和[示例代码](/examples/flux2/)现已可用。
@@ -74,6 +52,9 @@ DiffSynth 目前包括两个开源项目:
- [差分 LoRA 训练](/docs/zh/Training/Differential_LoRA.md):这是我们曾在 [ArtAug](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) 中使用的训练技术,目前已可用于任意模型的 LoRA 训练。
- [FP8 训练](/docs/zh/Training/FP8_Precision.md)FP8 在训练中支持应用到任意非训练模型,即梯度关闭或者梯度仅影响 LoRA 权重的模型。
<details>
<summary>更多</summary>
- **2025年11月4日** 支持了 [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) 模型,该模型基于 Wan 2.1 训练,支持根据参考视频生成相应的动作。
- **2025年10月30日** 支持了 [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) 模型,该模型支持文生视频、图生视频、视频续写。这个模型在本项目中沿用 Wan 的框架进行推理和训练。
@@ -290,12 +271,7 @@ Z-Image 的示例代码位于:[/examples/z_image/](/examples/z_image/)
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
</details>
@@ -355,60 +331,6 @@ FLUX.2 的示例代码位于:[/examples/flux2/](/examples/flux2/)
</details>
#### Anima: [/docs/zh/Model_Details/Anima.md](/docs/zh/Model_Details/Anima.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
```python
from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": "disk",
"onload_device": "disk",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AnimaImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors", **vram_config),
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors", **vram_config),
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait."
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
image = pipe(prompt, seed=0, num_inference_steps=50)
image.save("image.jpg")
```
</details>
<details>
<summary>示例代码</summary>
Anima 的示例代码位于:[/examples/anima/](/examples/anima/)
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima)|[code](/examples/anima/model_inference/anima-preview.py)|[code](/examples/anima/model_inference_low_vram/anima-preview.py)|[code](/examples/anima/model_training/full/anima-preview.sh)|[code](/examples/anima/model_training/validate_full/anima-preview.py)|[code](/examples/anima/model_training/lora/anima-preview.sh)|[code](/examples/anima/model_training/validate_lora/anima-preview.py)|
</details>
#### Qwen-Image: [/docs/zh/Model_Details/Qwen-Image.md](/docs/zh/Model_Details/Qwen-Image.md)
<details>
@@ -488,12 +410,8 @@ Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[FireRedTeam/FireRed-Image-Edit-1.0](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.0)|[code](/examples/qwen_image/model_inference/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.0.sh)|[code](/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.0.sh)|[code](/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.0.py)|
|[FireRedTeam/FireRed-Image-Edit-1.1](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.1)|[code](/examples/qwen_image/model_inference/FireRed-Image-Edit-1.1.py)|[code](/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.1.py)|[code](/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.1.sh)|[code](/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.1.py)|[code](/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.1.sh)|[code](/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.1.py)|
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control-V2.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control-V2.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
@@ -600,283 +518,10 @@ FLUX.1 的示例代码位于:[/examples/flux/](/examples/flux/)
</details>
#### ERNIE-Image: [/docs/zh/Model_Details/ERNIE-Image.md](/docs/zh/Model_Details/ERNIE-Image.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
```python
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = ErnieImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device='cuda',
model_configs=[
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="一只黑白相间的中华田园犬",
negative_prompt="",
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
cfg_scale=4.0,
)
image.save("output.jpg")
```
</details>
<details>
<summary>示例代码</summary>
ERNIE-Image 的示例代码位于:[/examples/ernie_image/](/examples/ernie_image/)
| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 |
|-|-|-|-|-|-|-|
|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
</details>
#### JoyAI-Image: [/docs/zh/Model_Details/JoyAI-Image.md](/docs/zh/Model_Details/JoyAI-Image.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 4G 显存即可运行。
```python
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
import torch
from PIL import Image
from modelscope import dataset_snapshot_download
# Download dataset
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
)
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = JoyAIImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
],
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
# Use first sample from dataset
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
prompt = "将裙子改为粉色"
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
output = pipe(
prompt=prompt,
edit_image=edit_image,
height=1024,
width=1024,
seed=0,
num_inference_steps=30,
cfg_scale=5.0,
)
output.save("output_joyai_edit_low_vram.png")
```
</details>
<details>
<summary>示例代码</summary>
JoyAI-Image 的示例代码位于:[/examples/joyai_image/](/examples/joyai_image/)
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)|
</details>
### 视频生成模型
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
#### LTX-2: [/docs/zh/Model_Details/LTX-2.md](/docs/zh/Model_Details/LTX-2.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8GB 显存即可运行。
```python
import torch
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
vram_config = {
"offload_dtype": torch.float8_e5m2,
"offload_device": "cpu",
"onload_dtype": torch.float8_e5m2,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e5m2,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
"""
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
and avoid redundant memory usage when users only want to use part of the model.
"""
# use the repackaged modelconfig from "DiffSynth-Studio/LTX-2-Repackage" to avoid redundant model loading
pipe = LTX2AudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
# use the following modelconfig if you want to initialize model from offical checkpoints from "Lightricks/LTX-2"
# pipe = LTX2AudioVideoPipeline.from_pretrained(
# torch_dtype=torch.bfloat16,
# device="cuda",
# model_configs=[
# ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
# ],
# tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
# stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
# vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
# )
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
negative_prompt = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)
height, width, num_frames = 512 * 2, 768 * 2, 121
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
seed=43,
height=height,
width=width,
num_frames=num_frames,
tiled=True,
use_two_stage_pipeline=True,
)
write_video_audio_ltx2(
video=video,
audio=audio,
output_path='ltx2_twostage.mp4',
fps=24,
audio_sample_rate=24000,
)
```
</details>
<details>
<summary>示例代码</summary>
LTX-2 的示例代码位于:[/examples/ltx2/](/examples/ltx2/)
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|-|
|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2.3-I2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2.3-I2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2.3-I2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-I2AV.py)|
|[Lightricks/LTX-2.3: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2.3: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-DistilledPipeline.py)|-|-|-|-|
|[Lightricks/LTX-2.3: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py)|
|[Lightricks/LTX-2.3: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2.3: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-DistilledPipeline.py)|-|-|-|-|
|[Lightricks/LTX-2.3: A2V](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`retake_audio`,`audio_sample_rate`,`retake_audio_regions`|[code](/examples/ltx2/model_inference/LTX-2.3-A2V-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-A2V-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2.3: Retake](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`retake_video`,`retake_video_regions`,`retake_audio`,`audio_sample_rate`,`retake_audio_regions`|[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage-Retake.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage-Retake.py)|-|-|-|-|
|[Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py)|
|[Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control](https://www.modelscope.cn/models/Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Motion-Track-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Motion-Track-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py)|
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)|
|[Lightricks/LTX-2-19b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|
|[Lightricks/LTX-2-19b-IC-LoRA-Detailer](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Detailer)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Detailer.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Static](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Static)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py)|-|-|-|-|
</details>
#### Wan: [/docs/zh/Model_Details/Wan.md](/docs/zh/Model_Details/Wan.md)
<details>
@@ -976,43 +621,39 @@ graph LR;
Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/)
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|-|
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B_multi_clips.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|[openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference_low_vram/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-360P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-360P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py)|
|[openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference_low_vram/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-720P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-720P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py)|
|[Wan-AI/WanToDance-14B (global model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py)|
|[Wan-AI/WanToDance-14B (local model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-local.py)|
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](/examples/wanvideo/model_inference/LongCat-Video.py)|[code](/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
</details>
@@ -1020,37 +661,6 @@ Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/)
DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。
<details>
<summary>Spectral Evolution Search: 用于奖励对齐图像生成的高效推理阶段缩放</summary>
- 论文:[Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation
](https://arxiv.org/abs/2602.03208)
- 代码样例:[/docs/en/Research_Tutorial/inference_time_scaling.md](/docs/en/Research_Tutorial/inference_time_scaling.md)
|FLUX.1-dev|FLUX.1-dev + SES|Qwen-Image|Qwen-Image + SES|
|-|-|-|-|
|![Image](https://github.com/user-attachments/assets/5be15dc6-2805-4822-b04c-2573fc0f45f0)|![Image](https://github.com/user-attachments/assets/e71b8c20-1629-41d9-b0ff-185805c1da4e)|![Image](https://github.com/user-attachments/assets/7a73c968-133a-4545-9aa2-205533861cd4)|![Image](https://github.com/user-attachments/assets/c8390b22-14fe-48a0-a6e6-d6556d31235e)|
</details>
<details>
<summary>VIRAL基于DiT模型的类比视觉上下文推理</summary>
- 论文:[VIRAL: Visual In-Context Reasoning via Analogy in Diffusion Transformers
](https://arxiv.org/abs/2602.03210)
- 代码样例:[/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py)
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)
|Example 1|Example 2|Query|Output|
|-|-|-|-|
|![Image](https://github.com/user-attachments/assets/380d2670-47bf-41cd-b5c9-37110cc4a943)|![Image](https://github.com/user-attachments/assets/7ceaf345-0992-46e6-b38f-394c2065b165)|![Image](https://github.com/user-attachments/assets/f7c26c21-6894-4d9e-b570-f1d44ca7c1de)|![Image](https://github.com/user-attachments/assets/c2bebe3b-5984-41ba-94bf-9509f6a8a990)|
</details>
<details>
<summary>AttriCtrl: 图像生成模型的属性强度控制</summary>
@@ -1062,7 +672,7 @@ DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果
|brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9|
|-|-|-|-|-|
|![Image](https://github.com/user-attachments/assets/e74b32a5-5b2e-4c87-9df8-487c0f8366b7)|![Image](https://github.com/user-attachments/assets/bfe8bec2-9e55-493d-9a26-7e9cce28e03d)|![Image](https://github.com/user-attachments/assets/b099dfe3-ff1f-4b96-894c-d48bbe92db7a)|![Image](https://github.com/user-attachments/assets/0a6b2982-deab-4b0d-91ad-888782de01c9)|![Image](https://github.com/user-attachments/assets/fcecb755-7d03-4020-b83a-13ad2b38705c)|
|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.5.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.7.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.9.jpg)|
</details>
@@ -1078,10 +688,10 @@ DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果
||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)|
|-|-|-|-|-|
|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |![Image](https://github.com/user-attachments/assets/01c54d5a-4f00-4c2e-982a-4ec0a4c6a6e3)|![Image](https://github.com/user-attachments/assets/e6621457-b9f1-437c-bcc8-3e12e41646de)|![Image](https://github.com/user-attachments/assets/4b7f721f-a2e5-416c-af2c-b53ef236c321)|![Image](https://github.com/user-attachments/assets/802d554e-0402-482c-9f28-87605f8fe318)|
|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |![Image](https://github.com/user-attachments/assets/e6621457-b9f1-437c-bcc8-3e12e41646de)|![Image](https://github.com/user-attachments/assets/43720a9f-aa27-4918-947d-545389375d46)|![Image](https://github.com/user-attachments/assets/418c725b-6d35-41f4-b18f-c7e3867cc142)|![Image](https://github.com/user-attachments/assets/8c8f22fa-9643-4019-b6d7-396d8b7fed9a)|
|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |![Image](https://github.com/user-attachments/assets/4b7f721f-a2e5-416c-af2c-b53ef236c321)|![Image](https://github.com/user-attachments/assets/418c725b-6d35-41f4-b18f-c7e3867cc142)|![Image](https://github.com/user-attachments/assets/041a3f9a-c7b4-4311-8582-cb71a7226d80)|![Image](https://github.com/user-attachments/assets/b54ebaa4-31a7-4536-a2c1-496adba0c013)|
|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |![Image](https://github.com/user-attachments/assets/802d554e-0402-482c-9f28-87605f8fe318)|![Image](https://github.com/user-attachments/assets/8c8f22fa-9643-4019-b6d7-396d8b7fed9a)|![Image](https://github.com/user-attachments/assets/b54ebaa4-31a7-4536-a2c1-496adba0c013)|![Image](https://github.com/user-attachments/assets/a640fd54-3192-49a0-9281-b43d9ba64f09)|
|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_0.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|
|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|
|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|
|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_3_3.jpg)|
</details>
@@ -1172,9 +782,3 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-47
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
</details>
## 联系我们
|Discordhttps://discord.gg/Mm9suEeUDc|
|-|
|<img width="160" height="160" alt="Image" src="https://github.com/user-attachments/assets/29bdc97b-e35d-4fea-88d6-32e35182e458" />|

View File

@@ -1,2 +1,2 @@
from .model_configs import MODEL_CONFIGS
from .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS, VERSION_CHECKER_MAPS
from .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS

View File

@@ -307,13 +307,6 @@ wan_series = [
"model_class": "diffsynth.models.wav2vec.WanS2VAudioEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="global_model.safetensors")
"model_hash": "eb18873fc0ba77b541eb7b62dbcd2059",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'wantodance_enable_music_inject': True, 'wantodance_music_inject_layers': [0, 4, 8, 12, 16, 20, 24, 27], 'wantodance_enable_refimage': True, 'has_ref_conv': True, 'wantodance_enable_refface': False, 'wantodance_enable_global': True, 'wantodance_enable_dynamicfps': True, 'wantodance_enable_unimodel': True}
},
]
flux_series = [
@@ -541,22 +534,6 @@ flux2_series = [
},
]
ernie_image_series = [
{
# Example: ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
"model_hash": "584c13713849f1af4e03d5f1858b8b7b",
"model_name": "ernie_image_dit",
"model_class": "diffsynth.models.ernie_image_dit.ErnieImageDiT",
},
{
# Example: ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors")
"model_hash": "404ed9f40796a38dd34c1620f1920207",
"model_name": "ernie_image_text_encoder",
"model_class": "diffsynth.models.ernie_image_text_encoder.ErnieImageTextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ernie_image_text_encoder.ErnieImageTextEncoderStateDictConverter",
},
]
z_image_series = [
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors")
@@ -612,370 +589,6 @@ z_image_series = [
"model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
"extra_kwargs": {"compress_dim": 128},
},
{
# Example: ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors")
"model_hash": "1392adecee344136041e70553f875f31",
"model_name": "z_image_text_encoder",
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
"extra_kwargs": {"model_size": "0.6B"},
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
},
{
# To ensure compatibility with the `model.diffusion_model` prefix introduced by other frameworks.
"model_hash": "8cf241a0d32f93d5de368502a086852f",
"model_name": "z_image_dit",
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_dit.ZImageDiTStateDictConverter",
},
]
"""
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
and avoid redundant memory usage when users only want to use part of the model.
"""
ltx2_series = [
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_dit",
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors")
"model_hash": "c567aaa37d5ed7454c73aa6024458661",
"model_name": "ltx2_dit",
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_video_vae_encoder",
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors")
"model_hash": "7f7e904a53260ec0351b05f32153754b",
"model_name": "ltx2_video_vae_encoder",
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_video_vae_decoder",
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors")
"model_hash": "dc6029ca2825147872b45e35a2dc3a97",
"model_name": "ltx2_video_vae_decoder",
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_audio_vae_decoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors")
"model_hash": "7d7823dde8f1ea0b50fb07ac329dd4cb",
"model_name": "ltx2_audio_vae_decoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_audio_vocoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors")
"model_hash": "f471360f6b24bef702ab73133d9f8bb9",
"model_name": "ltx2_audio_vocoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_audio_vae_encoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_encoder.safetensors")
"model_hash": "29338f3b95e7e312a3460a482e4f4554",
"model_name": "ltx2_audio_vae_encoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_text_encoder_post_modules",
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors")
"model_hash": "981629689c8be92a712ab3c5eb4fc3f6",
"model_name": "ltx2_text_encoder_post_modules",
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
},
{
# Example: ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors")
"model_hash": "33917f31c4a79196171154cca39f165e",
"model_name": "ltx2_text_encoder",
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "c79c458c6e99e0e14d47e676761732d2",
"model_name": "ltx2_latent_upsampler",
"model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
"model_name": "ltx2_dit",
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
"extra_kwargs": {"apply_gated_attention": True, "cross_attention_adaln": True, "caption_channels": None},
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
"model_name": "ltx2_video_vae_encoder",
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
"extra_kwargs": {"encoder_version": "ltx-2.3"},
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
"model_name": "ltx2_video_vae_decoder",
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
"extra_kwargs": {"decoder_version": "ltx-2.3"},
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
"model_name": "ltx2_audio_vae_decoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
"model_name": "ltx2_audio_vocoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2VocoderWithBWE",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
"model_name": "ltx2_audio_vae_encoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
"model_name": "ltx2_text_encoder_post_modules",
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
"extra_kwargs": {"separated_audio_video": True, "embedding_dim_gemma": 3840, "num_layers_gemma": 49, "video_attention_heads": 32, "video_attention_head_dim": 128, "audio_attention_heads": 32, "audio_attention_head_dim": 64, "num_connector_layers": 8, "apply_gated_attention": True},
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
"model_hash": "aed408774d694a2452f69936c32febb5",
"model_name": "ltx2_latent_upsampler",
"model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
"extra_kwargs": {"rational_resampler": False},
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="transformer.safetensors")
"model_hash": "1c55afad76ed33c112a2978550b524d1",
"model_name": "ltx2_dit",
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
"extra_kwargs": {"apply_gated_attention": True, "cross_attention_adaln": True, "caption_channels": None},
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="video_vae_encoder.safetensors")
"model_hash": "eecdc07c2ec30863b8a2b8b2134036cf",
"model_name": "ltx2_video_vae_encoder",
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
"extra_kwargs": {"encoder_version": "ltx-2.3"},
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="video_vae_decoder.safetensors")
"model_hash": "deda2f542e17ee25bc8c38fd605316ea",
"model_name": "ltx2_video_vae_decoder",
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
"extra_kwargs": {"decoder_version": "ltx-2.3"},
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors")
"model_hash": "7d7823dde8f1ea0b50fb07ac329dd4cb",
"model_name": "ltx2_audio_vae_decoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vae_encoder.safetensors")
"model_hash": "29338f3b95e7e312a3460a482e4f4554",
"model_name": "ltx2_audio_vae_encoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors")
"model_hash": "cd436c99e69ec5c80f050f0944f02a15",
"model_name": "ltx2_audio_vocoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2VocoderWithBWE",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors")
"model_hash": "05da2aab1c4b061f72c426311c165a43",
"model_name": "ltx2_text_encoder_post_modules",
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
"extra_kwargs": {"separated_audio_video": True, "embedding_dim_gemma": 3840, "num_layers_gemma": 49, "video_attention_heads": 32, "video_attention_head_dim": 128, "audio_attention_heads": 32, "audio_attention_head_dim": 64, "num_connector_layers": 8, "apply_gated_attention": True},
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
},
]
anima_series = [
{
# Example: ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors")
"model_hash": "a9995952c2d8e63cf82e115005eb61b9",
"model_name": "z_image_text_encoder",
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
"extra_kwargs": {"model_size": "0.6B"},
},
{
# Example: ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors")
"model_hash": "417673936471e79e31ed4d186d7a3f4a",
"model_name": "anima_dit",
"model_class": "diffsynth.models.anima_dit.AnimaDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.anima_dit.AnimaDiTStateDictConverter",
}
]
mova_series = [
# Example: ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors")
{
"model_hash": "8c57e12790e2c45a64817e0ce28cde2f",
"model_name": "mova_audio_dit",
"model_class": "diffsynth.models.mova_audio_dit.MovaAudioDit",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1], 'in_dim': 128, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 128, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
},
# Example: ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors")
{
"model_hash": "418517fb2b4e919d2cac8f314fcf82ac",
"model_name": "mova_audio_vae",
"model_class": "diffsynth.models.mova_audio_vae.DacVAE",
},
# Example: ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors")
{
"model_hash": "d1139dbbc8b4ab53cf4b4243d57bbceb",
"model_name": "mova_dual_tower_bridge",
"model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge",
},
]
stable_diffusion_xl_series = [
{
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors")
"model_hash": "142b114f67f5ab3a6d83fb5788f12ded",
"model_name": "stable_diffusion_xl_unet",
"model_class": "diffsynth.models.stable_diffusion_xl_unet.SDXLUNet2DConditionModel",
"extra_kwargs": {
"attention_head_dim": [5, 10, 20],
"transformer_layers_per_block": [1, 2, 10],
"use_linear_projection": True,
"addition_embed_type": "text_time",
"addition_time_embed_dim": 256,
"projection_class_embeddings_input_dim": 2816,
},
},
{
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors")
"model_hash": "98cc34ccc5b54ae0e56bdea8688dcd5a",
"model_name": "stable_diffusion_xl_text_encoder",
"model_class": "diffsynth.models.stable_diffusion_xl_text_encoder.SDXLTextEncoder2",
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_xl_text_encoder.SDXLTextEncoder2StateDictConverter",
},
{
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors")
"model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
"model_name": "stable_diffusion_text_encoder",
"model_class": "diffsynth.models.stable_diffusion_text_encoder.SDTextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_text_encoder.SDTextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "13115dd45a6e1c39860f91ab073b8a78",
"model_name": "stable_diffusion_xl_vae",
"model_class": "diffsynth.models.stable_diffusion_vae.StableDiffusionVAE",
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_vae.SDVAEStateDictConverter",
"extra_kwargs": {"scaling_factor": 0.13025, "sample_size": 1024, "force_upcast": True},
},
]
stable_diffusion_series = [
{
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors")
"model_hash": "ffd1737ae9df7fd43f5fbed653bdad67",
"model_name": "stable_diffusion_text_encoder",
"model_class": "diffsynth.models.stable_diffusion_text_encoder.SDTextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_text_encoder.SDTextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "f86d5683ed32433be8ca69969c67ba69",
"model_name": "stable_diffusion_vae",
"model_class": "diffsynth.models.stable_diffusion_vae.StableDiffusionVAE",
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_vae.SDVAEStateDictConverter",
},
{
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors")
"model_hash": "025a4b86a84829399d89f613e580757b",
"model_name": "stable_diffusion_unet",
"model_class": "diffsynth.models.stable_diffusion_unet.UNet2DConditionModel",
},
]
joyai_image_series = [
{
# Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth")
"model_hash": "56592ddfd7d0249d3aa527d24161a863",
"model_name": "joyai_image_dit",
"model_class": "diffsynth.models.joyai_image_dit.JoyAIImageDiT",
},
{
# Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model-*.safetensors")
"model_hash": "2d11bf14bba8b4e87477c8199a895403",
"model_name": "joyai_image_text_encoder",
"model_class": "diffsynth.models.joyai_image_text_encoder.JoyAIImageTextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.joyai_image_text_encoder.JoyAIImageTextEncoderStateDictConverter",
},
]
MODEL_CONFIGS = stable_diffusion_xl_series + stable_diffusion_series + qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series

View File

@@ -210,142 +210,4 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.ltx2_dit.LTXModel": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler": {
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_video_vae.LTX2VideoEncoder": {
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_video_vae.LTX2VideoDecoder": {
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder": {
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_audio_vae.LTX2Vocoder": {
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.ltx2_text_encoder.Embeddings1DConnector": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_text_encoder.LTX2TextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"transformers.models.gemma3.modeling_gemma3.Gemma3MultiModalProjector": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.anima_dit.AnimaDiT": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.mova_audio_dit.MovaAudioDit": {
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.mova_audio_vae.DacVAE": {
"diffsynth.models.mova_audio_vae.Snake1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ernie_image_dit.ErnieImageDiT": {
"diffsynth.models.ernie_image_dit.ErnieImageRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ernie_image_text_encoder.ErnieImageTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.ministral3.modeling_ministral3.Ministral3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.joyai_image_dit.Transformer3DModel": {
"diffsynth.models.joyai_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.joyai_image_dit.ModulateWan": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.joyai_image_text_encoder.JoyAIImageTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLVisionModel": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.stable_diffusion_unet.UNet2DConditionModel": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.stable_diffusion_vae.StableDiffusionVAE": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.stable_diffusion_vae.Upsample2D": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.stable_diffusion_vae.Downsample2D": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.stable_diffusion_text_encoder.SDTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.clip.modeling_clip.CLIPTextTransformer": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.clip.modeling_clip.CLIPEncoderLayer": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.clip.modeling_clip.CLIPAttention": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.stable_diffusion_xl_unet.SDXLUNet2DConditionModel": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.stable_diffusion_xl_text_encoder.SDXLTextEncoder2": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.clip.modeling_clip.CLIPTextTransformer": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.clip.modeling_clip.CLIPEncoderLayer": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.clip.modeling_clip.CLIPAttention": "diffsynth.core.vram.layers.AutoWrappedModule",
},
}
def QwenImageTextEncoder_Module_Map_Updater():
current = VRAM_MANAGEMENT_MODULE_MAPS["diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder"]
from packaging import version
import transformers
if version.parse(transformers.__version__) >= version.parse("5.2.0"):
# The Qwen2RMSNorm in transformers 5.2.0+ has been renamed to Qwen2_5_VLRMSNorm, so we need to update the module map accordingly
current.pop("transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm", None)
current["transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRMSNorm"] = "diffsynth.core.vram.layers.AutoWrappedModule"
return current
VERSION_CHECKER_MAPS = {
"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": QwenImageTextEncoder_Module_Map_Updater,
}

View File

@@ -52,7 +52,7 @@ def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="
if k_pattern != required_in_pattern:
k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims)
if v_pattern != required_in_pattern:
v = rearrange(v, f"{v_pattern} -> {required_in_pattern}", **dims)
v = rearrange(v, f"{q_pattern} -> {required_in_pattern}", **dims)
return q, k, v

View File

@@ -1,8 +1,6 @@
import math, warnings
import torch, torchvision, imageio, os
import imageio.v3 as iio
from PIL import Image
import torchaudio
class DataProcessingPipeline:
@@ -107,59 +105,27 @@ class ToList(DataProcessingOperator):
return [data]
class FrameSamplerByRateMixin:
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_rate=24, fix_frame_rate=False):
class LoadVideo(DataProcessingOperator):
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
self.num_frames = num_frames
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
self.frame_rate = frame_rate
self.fix_frame_rate = fix_frame_rate
def get_reader(self, data: str):
return imageio.get_reader(data)
def get_available_num_frames(self, reader):
if not self.fix_frame_rate:
return reader.count_frames()
meta_data = reader.get_meta_data()
total_original_frames = int(reader.count_frames())
duration = meta_data["duration"] if "duration" in meta_data else total_original_frames / meta_data['fps']
total_available_frames = math.floor(duration * self.frame_rate)
return int(total_available_frames)
# frame_processor is build in the video loader for high efficiency.
self.frame_processor = frame_processor
def get_num_frames(self, reader):
num_frames = self.num_frames
total_frames = self.get_available_num_frames(reader)
if int(total_frames) < num_frames:
num_frames = total_frames
if int(reader.count_frames()) < num_frames:
num_frames = int(reader.count_frames())
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
num_frames -= 1
return num_frames
def map_single_frame_id(self, new_sequence_id: int, raw_frame_rate: float, total_raw_frames: int) -> int:
if not self.fix_frame_rate:
return new_sequence_id
target_time_in_seconds = new_sequence_id / self.frame_rate
raw_frame_index_float = target_time_in_seconds * raw_frame_rate
frame_id = int(round(raw_frame_index_float))
frame_id = min(frame_id, total_raw_frames - 1)
return frame_id
class LoadVideo(DataProcessingOperator, FrameSamplerByRateMixin):
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x, frame_rate=24, fix_frame_rate=False):
FrameSamplerByRateMixin.__init__(self, num_frames, time_division_factor, time_division_remainder, frame_rate, fix_frame_rate)
# frame_processor is build in the video loader for high efficiency.
self.frame_processor = frame_processor
def __call__(self, data: str):
reader = self.get_reader(data)
raw_frame_rate = reader.get_meta_data()['fps']
reader = imageio.get_reader(data)
num_frames = self.get_num_frames(reader)
total_raw_frames = reader.count_frames()
frames = []
for frame_id in range(num_frames):
frame_id = self.map_single_frame_id(frame_id, raw_frame_rate, total_raw_frames)
frame = reader.get_data(frame_id)
frame = Image.fromarray(frame)
frame = self.frame_processor(frame)
@@ -183,7 +149,7 @@ class LoadGIF(DataProcessingOperator):
self.time_division_remainder = time_division_remainder
# frame_processor is build in the video loader for high efficiency.
self.frame_processor = frame_processor
def get_num_frames(self, path):
num_frames = self.num_frames
images = iio.imread(path, mode="RGB")
@@ -252,27 +218,3 @@ class LoadAudio(DataProcessingOperator):
import librosa
input_audio, sample_rate = librosa.load(data, sr=self.sr)
return input_audio
class LoadAudioWithTorchaudio(DataProcessingOperator, FrameSamplerByRateMixin):
def __init__(self, num_frames=121, time_division_factor=8, time_division_remainder=1, frame_rate=24, fix_frame_rate=True):
FrameSamplerByRateMixin.__init__(self, num_frames, time_division_factor, time_division_remainder, frame_rate, fix_frame_rate)
def __call__(self, data: str):
try:
reader = self.get_reader(data)
num_frames = self.get_num_frames(reader)
duration = num_frames / self.frame_rate
waveform, sample_rate = torchaudio.load(data)
target_samples = int(duration * sample_rate)
current_samples = waveform.shape[-1]
if current_samples > target_samples:
waveform = waveform[..., :target_samples]
elif current_samples < target_samples:
padding = target_samples - current_samples
waveform = torch.nn.functional.pad(waveform, (0, padding))
return waveform, sample_rate
except:
warnings.warn(f"Cannot load audio in {data}. The audio will be `None`.")
return None

View File

@@ -42,7 +42,6 @@ class UnifiedDataset(torch.utils.data.Dataset):
max_pixels=1920*1080, height=None, width=None,
height_division_factor=16, width_division_factor=16,
num_frames=81, time_division_factor=4, time_division_remainder=1,
frame_rate=24, fix_frame_rate=False,
):
return RouteByType(operator_map=[
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
@@ -54,7 +53,6 @@ class UnifiedDataset(torch.utils.data.Dataset):
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
num_frames, time_division_factor, time_division_remainder,
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
frame_rate=frame_rate, fix_frame_rate=fix_frame_rate,
)),
])),
])

View File

@@ -1,32 +1,12 @@
import torch
try:
import deepspeed
_HAS_DEEPSPEED = True
except ModuleNotFoundError:
_HAS_DEEPSPEED = False
def create_custom_forward(module):
def custom_forward(*inputs, **kwargs):
return module(*inputs, **kwargs)
return custom_forward
def create_custom_forward_use_reentrant(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
def judge_args_requires_grad(*args):
for arg in args:
if isinstance(arg, torch.Tensor) and arg.requires_grad:
return True
return False
def gradient_checkpoint_forward(
model,
use_gradient_checkpointing,
@@ -34,17 +14,6 @@ def gradient_checkpoint_forward(
*args,
**kwargs,
):
if use_gradient_checkpointing and _HAS_DEEPSPEED and deepspeed.checkpointing.is_configured():
all_args = args + tuple(kwargs.values())
if not judge_args_requires_grad(*all_args):
# get the first grad_enabled tensor from un_checkpointed forward
model_output = model(*args, **kwargs)
else:
model_output = deepspeed.checkpointing.checkpoint(
create_custom_forward_use_reentrant(model),
*all_args,
)
return model_output
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
model_output = torch.utils.checkpoint.checkpoint(

View File

@@ -1,5 +1,5 @@
import torch, glob, os
from typing import Optional, Union, Dict
from typing import Optional, Union
from dataclasses import dataclass
from modelscope import snapshot_download
from huggingface_hub import snapshot_download as hf_snapshot_download
@@ -23,14 +23,13 @@ class ModelConfig:
computation_device: Optional[Union[str, torch.device]] = None
computation_dtype: Optional[torch.dtype] = None
clear_parameters: bool = False
state_dict: Dict[str, torch.Tensor] = None
def check_input(self):
if self.path is None and self.model_id is None:
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
def parse_original_file_pattern(self):
if self.origin_file_pattern in [None, "", "./"]:
if self.origin_file_pattern is None or self.origin_file_pattern == "":
return "*"
elif self.origin_file_pattern.endswith("/"):
return self.origin_file_pattern + "*"
@@ -99,7 +98,7 @@ class ModelConfig:
if self.require_downloading():
self.download()
if self.path is None:
if self.origin_file_pattern in [None, "", "./"]:
if self.origin_file_pattern is None or self.origin_file_pattern == "":
self.path = os.path.join(self.local_model_path, self.model_id)
else:
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))

View File

@@ -2,25 +2,16 @@ from safetensors import safe_open
import torch, hashlib
def load_state_dict(file_path, torch_dtype=None, device="cpu", pin_memory=False, verbose=0):
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
if isinstance(file_path, list):
state_dict = {}
for file_path_ in file_path:
state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose))
state_dict.update(load_state_dict(file_path_, torch_dtype, device))
return state_dict
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
else:
if verbose >= 1:
print(f"Loading file [started]: {file_path}")
if file_path.endswith(".safetensors"):
state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
else:
state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
# If load state dict in CPU memory, `pin_memory=True` will make `model.to("cuda")` faster.
if pin_memory:
for i in state_dict:
state_dict[i] = state_dict[i].pin_memory()
if verbose >= 1:
print(f"Loading file [done]: {file_path}")
return state_dict
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):

View File

@@ -3,7 +3,6 @@ from ..vram.disk_map import DiskMap
from ..vram.layers import enable_vram_management
from .file import load_state_dict
import torch
from contextlib import contextmanager
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils import ContextManagers
@@ -20,7 +19,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
dtype = [d for d in dtypes if d != "disk"][0]
if vram_config["offload_device"] != "disk":
if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype)
state_dict = DiskMap(path, device, torch_dtype=dtype)
if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict)
else:
@@ -35,9 +34,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
# Sometimes a model file contains multiple models,
# and DiskMap can load only the parameters of a single model,
# avoiding the need to load all parameters in the file.
if state_dict is not None:
pass
elif use_disk_map:
if use_disk_map:
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
else:
state_dict = load_state_dict(path, torch_dtype, device)

View File

@@ -1,30 +0,0 @@
import torch
from ..device.npu_compatible_device import get_device_type
try:
import torch_npu
except:
pass
def rms_norm_forward_npu(self, hidden_states):
"npu rms fused operator for RMSNorm.forward from diffsynth\models\general_modules.py"
if hidden_states.dtype != self.weight.dtype:
hidden_states = hidden_states.to(self.weight.dtype)
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.eps)[0]
def rms_norm_forward_transformers_npu(self, hidden_states):
"npu rms fused operator for transformers"
if hidden_states.dtype != self.weight.dtype:
hidden_states = hidden_states.to(self.weight.dtype)
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
def rotary_emb_Zimage_npu(self, x_in: torch.Tensor, freqs_cis: torch.Tensor):
"npu rope fused operator for Zimage"
with torch.amp.autocast(get_device_type(), enabled=False):
freqs_cis = freqs_cis.unsqueeze(2)
cos, sin = torch.chunk(torch.view_as_real(freqs_cis), 2, dim=-1)
cos = cos.expand(-1, -1, -1, -1, 2).flatten(-2)
sin = sin.expand(-1, -1, -1, -1, 2).flatten(-2)
return torch_npu.npu_rotary_mul(x_in, cos, sin, rotary_mode="interleave").to(x_in)

View File

@@ -417,7 +417,7 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
def lora_forward(self, x, out):
if self.lora_merger is None:
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
out = out + x @ lora_A.T.to(device=x.device, dtype=x.dtype) @ lora_B.T.to(device=x.device, dtype=x.dtype)
out = out + x @ lora_A.T @ lora_B.T
else:
lora_output = []
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):

View File

@@ -94,23 +94,20 @@ class BasePipeline(torch.nn.Module):
return self
def check_resize_height_width(self, height, width, num_frames=None, verbose=1):
def check_resize_height_width(self, height, width, num_frames=None):
# Shape check
if height % self.height_division_factor != 0:
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
if verbose > 0:
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
if width % self.width_division_factor != 0:
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
if verbose > 0:
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
if num_frames is None:
return height, width
else:
if num_frames % self.time_division_factor != self.time_division_remainder:
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
if verbose > 0:
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
return height, width, num_frames
@@ -147,12 +144,6 @@ class BasePipeline(torch.nn.Module):
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
return video
def output_audio_format_check(self, audio_output):
# output standard foramt: [C, T], output dtype: float()
# remove batch dim
if audio_output.ndim == 3:
audio_output = audio_output.squeeze(0)
return audio_output.float()
def load_models_to_device(self, model_names):
if self.vram_management_enabled:
@@ -305,7 +296,6 @@ class BasePipeline(torch.nn.Module):
vram_config=vram_config,
vram_limit=vram_limit,
clear_parameters=model_config.clear_parameters,
state_dict=model_config.state_dict,
)
return model_pool
@@ -327,50 +317,11 @@ class BasePipeline(torch.nn.Module):
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
if isinstance(noise_pred_posi, tuple):
# Separately handling different output types of latents, eg. video and audio latents.
noise_pred = tuple(
n_nega + cfg_scale * (n_posi - n_nega)
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
)
else:
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
return noise_pred
def compile_pipeline(self, mode: str = "default", dynamic: bool = True, fullgraph: bool = False, compile_models: list = None, **kwargs):
"""
compile the pipeline with torch.compile. The models that will be compiled are determined by the `compilable_models` attribute of the pipeline.
If a model has `_repeated_blocks` attribute, we will compile these blocks with regional compilation. Otherwise, we will compile the whole model.
See https://docs.pytorch.org/docs/stable/generated/torch.compile.html#torch.compile for details about compilation arguments.
Args:
mode: The compilation mode, which will be passed to `torch.compile`, options are "default", "reduce-overhead", "max-autotune" and "max-autotune-no-cudagraphs. Default to "default".
dynamic: Whether to enable dynamic graph compilation to support dynamic input shapes, which will be passed to `torch.compile`. Default to True (recommended).
fullgraph: Whether to use full graph compilation, which will be passed to `torch.compile`. Default to False (recommended).
compile_models: The list of model names to be compiled. If None, we will compile the models in `pipeline.compilable_models`. Default to None.
**kwargs: Other arguments for `torch.compile`.
"""
compile_models = compile_models or getattr(self, "compilable_models", [])
if len(compile_models) == 0:
print("No compilable models in the pipeline. Skip compilation.")
return
for name in compile_models:
model = getattr(self, name, None)
if model is None:
print(f"Model '{name}' not found in the pipeline.")
continue
repeated_blocks = getattr(model, "_repeated_blocks", None)
# regional compilation for repeated blocks.
if repeated_blocks is not None:
for submod in model.modules():
if submod.__class__.__name__ in repeated_blocks:
submod.compile(mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs)
# compile the whole model.
else:
model.compile(mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs)
print(f"{name} is compiled with mode={mode}, dynamic={dynamic}, fullgraph={fullgraph}.")
class PipelineUnitGraph:
def __init__(self):

View File

@@ -1,107 +0,0 @@
import torch, math
class DDIMScheduler():
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon", rescale_zero_terminal_snr=False):
self.num_train_timesteps = num_train_timesteps
if beta_schedule == "scaled_linear":
betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
elif beta_schedule == "linear":
betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
else:
raise NotImplementedError(f"{beta_schedule} is not implemented")
self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0)
if rescale_zero_terminal_snr:
self.alphas_cumprod = self.rescale_zero_terminal_snr(self.alphas_cumprod)
self.alphas_cumprod = self.alphas_cumprod.tolist()
self.set_timesteps(10)
self.prediction_type = prediction_type
self.training = False
def rescale_zero_terminal_snr(self, alphas_cumprod):
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt.square() # Revert sqrt
return alphas_bar
def set_timesteps(self, num_inference_steps, denoising_strength=1.0, training=False, **kwargs):
# The timesteps are aligned to 999...0, which is different from other implementations,
# but I think this implementation is more reasonable in theory.
max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
num_inference_steps = min(num_inference_steps, max_timestep + 1)
if num_inference_steps == 1:
self.timesteps = torch.Tensor([max_timestep])
else:
step_length = max_timestep / (num_inference_steps - 1)
self.timesteps = torch.Tensor([round(max_timestep - i*step_length) for i in range(num_inference_steps)])
self.training = training
def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
if self.prediction_type == "epsilon":
weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
prev_sample = sample * weight_x + model_output * weight_e
elif self.prediction_type == "v_prediction":
weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(alpha_prod_t * (1 - alpha_prod_t_prev))
weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt((1 - alpha_prod_t) * (1 - alpha_prod_t_prev))
prev_sample = sample * weight_x + model_output * weight_e
else:
raise NotImplementedError(f"{self.prediction_type} is not implemented")
return prev_sample
def step(self, model_output, timestep, sample, to_final=False):
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
if to_final or timestep_id + 1 >= len(self.timesteps):
alpha_prod_t_prev = 1.0
else:
timestep_prev = int(self.timesteps[timestep_id + 1])
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
def return_to_timestep(self, timestep, sample, sample_stablized):
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t)
return noise_pred
def add_noise(self, original_samples, noise, timestep):
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def training_target(self, sample, noise, timestep):
if self.prediction_type == "epsilon":
return noise
else:
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return target
def training_weight(self, timestep):
return 1.0

View File

@@ -4,16 +4,13 @@ from typing_extensions import Literal
class FlowMatchScheduler():
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning", "ERNIE-Image"] = "FLUX.1"):
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image"] = "FLUX.1"):
self.set_timesteps_fn = {
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
"Wan": FlowMatchScheduler.set_timesteps_wan,
"Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image,
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
"Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
"ERNIE-Image": FlowMatchScheduler.set_timesteps_ernie_image,
}.get(template, FlowMatchScheduler.set_timesteps_flux)
self.num_train_timesteps = 1000
@@ -73,28 +70,6 @@ class FlowMatchScheduler():
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def set_timesteps_qwen_image_lightning(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
sigma_min = 0.0
sigma_max = 1.0
num_train_timesteps = 1000
base_shift = math.log(3)
max_shift = math.log(3)
# Sigmas
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
# Mu
if exponential_shift_mu is not None:
mu = exponential_shift_mu
elif dynamic_shift_len is not None:
mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len, base_shift=base_shift, max_shift=max_shift)
else:
mu = 0.8
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
# Timesteps
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def compute_empirical_mu(image_seq_len, num_steps):
a1, b1 = 8.73809524e-05, 1.89833333
@@ -130,18 +105,6 @@ class FlowMatchScheduler():
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def set_timesteps_ernie_image(num_inference_steps=50, denoising_strength=1.0, shift=3.0):
sigma_min = 0.0
sigma_max = 1.0
num_train_timesteps = 1000
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
if shift is not None and shift != 1.0:
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
sigma_min = 0.0
@@ -158,47 +121,7 @@ class FlowMatchScheduler():
timestep_id = torch.argmin((timesteps - timestep).abs())
timesteps[timestep_id] = timestep
return sigmas, timesteps
@staticmethod
def set_timesteps_joyai_image(num_inference_steps=100, denoising_strength=1.0, shift=None):
sigma_min = 0.0
sigma_max = 1.0
shift = 4.0 if shift is None else shift
num_train_timesteps = 1000
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None):
num_train_timesteps = 1000
if special_case == "stage2":
sigmas = torch.Tensor([0.909375, 0.725, 0.421875])
elif special_case == "ditilled_stage1":
sigmas = torch.Tensor([1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875])
else:
dynamic_shift_len = dynamic_shift_len or 4096
sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image(
image_seq_len=dynamic_shift_len,
base_seq_len=1024,
max_seq_len=4096,
base_shift=0.95,
max_shift=2.05,
)
sigma_min = 0.0
sigma_max = 1.0
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
sigmas = math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1))
# Shift terminal
one_minus_z = 1.0 - sigmas
scale_factor = one_minus_z[-1] / (1 - terminal)
sigmas = 1.0 - (one_minus_z / scale_factor)
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
def set_training_weight(self):
steps = 1000
x = self.timesteps
@@ -210,7 +133,7 @@ class FlowMatchScheduler():
bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
self.linear_timesteps_weights = bsmntw_weighing
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
self.sigmas, self.timesteps = self.set_timesteps_fn(
num_inference_steps=num_inference_steps,

View File

@@ -13,51 +13,14 @@ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
if "first_frame_latents" in inputs:
inputs["latents"][:, :, 0:1] = inputs["first_frame_latents"]
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
if "first_frame_latents" in inputs:
noise_pred = noise_pred[:, :, 1:]
training_target = training_target[:, :, 1:]
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * pipe.scheduler.training_weight(timestep)
return loss
def FlowMatchSFTAudioVideoLoss(pipe: BasePipeline, **inputs):
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
# video
noise = torch.randn_like(inputs["input_latents"])
inputs["video_latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
# audio
if inputs.get("audio_input_latents") is not None:
audio_noise = torch.randn_like(inputs["audio_input_latents"])
inputs["audio_latents"] = pipe.scheduler.add_noise(inputs["audio_input_latents"], audio_noise, timestep)
training_target_audio = pipe.scheduler.training_target(inputs["audio_input_latents"], audio_noise, timestep)
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
noise_pred, noise_pred_audio = pipe.model_fn(**models, **inputs, timestep=timestep)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * pipe.scheduler.training_weight(timestep)
if inputs.get("audio_input_latents") is not None:
loss_audio = torch.nn.functional.mse_loss(noise_pred_audio.float(), training_target_audio.float())
loss_audio = loss_audio * pipe.scheduler.training_weight(timestep)
loss = loss + loss_audio
return loss
def DirectDistillLoss(pipe: BasePipeline, **inputs):
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
pipe.scheduler.training = True
@@ -121,9 +84,7 @@ class TrajectoryImitationLoss(torch.nn.Module):
progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs())
latents_ = trajectory_teacher[progress_id_teacher]
denom = sigma_ - sigma
denom = torch.sign(denom) * torch.clamp(denom.abs(), min=1e-6)
target = (latents_ - inputs_shared["latents"]) / denom
target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma)
loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep)
return loss

View File

@@ -29,19 +29,19 @@ def launch_training_task(
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
model.to(device=accelerator.device)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
initialize_deepspeed_gradient_checkpointing(accelerator)
for epoch_id in range(num_epochs):
for data in tqdm(dataloader):
with accelerator.accumulate(model):
optimizer.zero_grad()
if dataset.load_from_cache:
loss = model({}, inputs=data)
else:
loss = model(data)
accelerator.backward(loss)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
scheduler.step()
if save_steps is None:
model_logger.on_epoch_end(accelerator, model, epoch_id)
model_logger.on_training_end(accelerator, model, save_steps)
@@ -70,19 +70,3 @@ def launch_data_process_task(
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
data = model(data)
torch.save(data, save_path)
def initialize_deepspeed_gradient_checkpointing(accelerator: Accelerator):
if getattr(accelerator.state, "deepspeed_plugin", None) is not None:
ds_config = accelerator.state.deepspeed_plugin.deepspeed_config
if "activation_checkpointing" in ds_config:
import deepspeed
act_config = ds_config["activation_checkpointing"]
deepspeed.checkpointing.configure(
mpu_=None,
partition_activations=act_config.get("partition_activations", False),
checkpoint_in_cpu=act_config.get("cpu_checkpointing", False),
contiguous_checkpointing=act_config.get("contiguous_memory_optimization", False)
)
else:
print("Do not find activation_checkpointing config in deepspeed config, skip initializing deepspeed gradient checkpointing.")

View File

@@ -1,32 +1,9 @@
import torch, json, os, inspect
import torch, json, os
from ..core import ModelConfig, load_state_dict
from ..utils.controlnet import ControlNetInput
from .base_pipeline import PipelineUnit
from peft import LoraConfig, inject_adapter_in_model
class GeneralUnit_RemoveCache(PipelineUnit):
def __init__(self, required_params=tuple(), force_remove_params_shared=tuple(), force_remove_params_posi=tuple(), force_remove_params_nega=tuple()):
super().__init__(take_over=True)
self.required_params = required_params
self.force_remove_params_shared = force_remove_params_shared
self.force_remove_params_posi = force_remove_params_posi
self.force_remove_params_nega = force_remove_params_nega
def process_params(self, inputs, required_params, force_remove_params):
inputs_ = {}
for name, param in inputs.items():
if name in required_params and name not in force_remove_params:
inputs_[name] = param
return inputs_
def process(self, pipe, inputs_shared, inputs_posi, inputs_nega):
inputs_shared = self.process_params(inputs_shared, self.required_params, self.force_remove_params_shared)
inputs_posi = self.process_params(inputs_posi, self.required_params, self.force_remove_params_posi)
inputs_nega = self.process_params(inputs_nega, self.required_params, self.force_remove_params_nega)
return inputs_shared, inputs_posi, inputs_nega
class DiffusionTrainingModule(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -254,30 +231,14 @@ class DiffusionTrainingModule(torch.nn.Module):
setattr(pipe, lora_base_model, model)
def split_pipeline_units(
self, task, pipe,
trainable_models=None, lora_base_model=None,
# TODO: set `remove_unnecessary_params` to `True` by default
remove_unnecessary_params=False,
# TODO: move `loss_required_params` to `loss.py`
loss_required_params=("input_latents", "max_timestep_boundary", "min_timestep_boundary", "first_frame_latents", "video_latents", "audio_input_latents", "num_inference_steps"),
force_remove_params_shared=tuple(),
force_remove_params_posi=tuple(),
force_remove_params_nega=tuple(),
):
def split_pipeline_units(self, task, pipe, trainable_models=None, lora_base_model=None):
models_require_backward = []
if trainable_models is not None:
models_require_backward += trainable_models.split(",")
if lora_base_model is not None:
models_require_backward += [lora_base_model]
if task.endswith(":data_process"):
other_units, pipe.units = pipe.split_pipeline_units(models_require_backward)
if remove_unnecessary_params:
required_params = list(loss_required_params) + [i for i in inspect.signature(self.pipe.model_fn).parameters]
for unit in other_units:
required_params.extend(unit.fetch_input_params())
required_params = sorted(list(set(required_params)))
pipe.units.append(GeneralUnit_RemoveCache(required_params, force_remove_params_shared, force_remove_params_posi, force_remove_params_nega))
_, pipe.units = pipe.split_pipeline_units(models_require_backward)
elif task.endswith(":train"):
pipe.units, _ = pipe.split_pipeline_units(models_require_backward)
return pipe

File diff suppressed because it is too large Load Diff

View File

@@ -1,362 +0,0 @@
"""
Ernie-Image DiT for DiffSynth-Studio.
Refactored from diffusers ErnieImageTransformer2DModel to use DiffSynth core modules.
Default parameters from actual checkpoint config.json (PaddlePaddle/ERNIE-Image transformer).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
from ..core.attention import attention_forward
from ..core.gradient import gradient_checkpoint_forward
from .flux2_dit import Timesteps, TimestepEmbedding
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta ** scale)
out = torch.einsum("...n,d->...nd", pos, omega)
return out.float()
class ErnieImageEmbedND3(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: Tuple[int, int, int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = list(axes_dim)
def forward(self, ids: torch.Tensor) -> torch.Tensor:
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1)
emb = emb.unsqueeze(2)
return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1)
class ErnieImagePatchEmbedDynamic(nn.Module):
def __init__(self, in_channels: int, embed_dim: int, patch_size: int):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
batch_size, dim, height, width = x.shape
return x.reshape(batch_size, dim, height * width).transpose(1, 2).contiguous()
class ErnieImageSingleStreamAttnProcessor:
def __call__(
self,
attn: "ErnieImageAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
rot_dim = freqs_cis.shape[-1]
x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:]
cos_ = torch.cos(freqs_cis).to(x.dtype)
sin_ = torch.sin(freqs_cis).to(x.dtype)
x1, x2 = x.chunk(2, dim=-1)
x_rotated = torch.cat((-x2, x1), dim=-1)
return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1)
if freqs_cis is not None:
query = apply_rotary_emb(query, freqs_cis)
key = apply_rotary_emb(key, freqs_cis)
if attention_mask is not None and attention_mask.ndim == 2:
attention_mask = attention_mask[:, None, None, :]
hidden_states = attention_forward(
query, key, value,
q_pattern="b s n d",
k_pattern="b s n d",
v_pattern="b s n d",
out_pattern="b s n d",
attn_mask=attention_mask,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
output = attn.to_out[0](hidden_states)
return output
class ErnieImageAttention(nn.Module):
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
qk_norm: str = "rms_norm",
out_bias: bool = True,
eps: float = 1e-5,
out_dim: int = None,
elementwise_affine: bool = True,
):
super().__init__()
self.head_dim = dim_head
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.out_dim = out_dim if out_dim is not None else query_dim
self.heads = out_dim // dim_head if out_dim is not None else heads
self.use_bias = bias
self.dropout = dropout
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
if qk_norm == "layer_norm":
self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
elif qk_norm == "rms_norm":
self.norm_q = nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
else:
raise ValueError(
f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'rms_norm'."
)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.processor = ErnieImageSingleStreamAttnProcessor()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.processor(self, hidden_states, attention_mask, image_rotary_emb)
class ErnieImageFeedForward(nn.Module):
def __init__(self, hidden_size: int, ffn_hidden_size: int):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
self.up_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
self.linear_fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x)))
class ErnieImageRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
hidden_states = hidden_states * self.weight
return hidden_states.to(input_dtype)
class ErnieImageSharedAdaLNBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
ffn_hidden_size: int,
eps: float = 1e-6,
qk_layernorm: bool = True,
):
super().__init__()
self.adaLN_sa_ln = ErnieImageRMSNorm(hidden_size, eps=eps)
self.self_attention = ErnieImageAttention(
query_dim=hidden_size,
dim_head=hidden_size // num_heads,
heads=num_heads,
qk_norm="rms_norm" if qk_layernorm else None,
eps=eps,
bias=False,
out_bias=False,
)
self.adaLN_mlp_ln = ErnieImageRMSNorm(hidden_size, eps=eps)
self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size)
def forward(
self,
x: torch.Tensor,
rotary_pos_emb: torch.Tensor,
temb: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb
residual = x
x = self.adaLN_sa_ln(x)
x = (x.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype)
x_bsh = x.permute(1, 0, 2)
attn_out = self.self_attention(x_bsh, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
attn_out = attn_out.permute(1, 0, 2)
x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype)
residual = x
x = self.adaLN_mlp_ln(x)
x = (x.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype)
return residual + (gate_mlp.float() * self.mlp(x).float()).to(x.dtype)
class ErnieImageAdaLNContinuous(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=eps)
self.linear = nn.Linear(hidden_size, hidden_size * 2)
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
scale, shift = self.linear(conditioning).chunk(2, dim=-1)
x = self.norm(x)
x = x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0)
return x
class ErnieImageDiT(nn.Module):
"""
Ernie-Image DiT model for DiffSynth-Studio.
Architecture: SharedAdaLN + RoPE 3D + Joint Image-Text Attention.
Internal format: [S, B, H] for transformer blocks, [B, S, H] for attention.
"""
def __init__(
self,
hidden_size: int = 4096,
num_attention_heads: int = 32,
num_layers: int = 36,
ffn_hidden_size: int = 12288,
in_channels: int = 128,
out_channels: int = 128,
patch_size: int = 1,
text_in_dim: int = 3072,
rope_theta: int = 256,
rope_axes_dim: Tuple[int, int, int] = (32, 48, 48),
eps: float = 1e-6,
qk_layernorm: bool = True,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_attention_heads
self.head_dim = hidden_size // num_attention_heads
self.num_layers = num_layers
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels
self.text_in_dim = text_in_dim
self.x_embedder = ErnieImagePatchEmbedDynamic(in_channels, hidden_size, patch_size)
self.text_proj = nn.Linear(text_in_dim, hidden_size, bias=False) if text_in_dim != hidden_size else None
self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False, downscale_freq_shift=0)
self.time_embedding = TimestepEmbedding(hidden_size, hidden_size)
self.pos_embed = ErnieImageEmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size))
nn.init.zeros_(self.adaLN_modulation[-1].weight)
nn.init.zeros_(self.adaLN_modulation[-1].bias)
self.layers = nn.ModuleList([
ErnieImageSharedAdaLNBlock(hidden_size, num_attention_heads, ffn_hidden_size, eps, qk_layernorm=qk_layernorm)
for _ in range(num_layers)
])
self.final_norm = ErnieImageAdaLNContinuous(hidden_size, eps)
self.final_linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels)
nn.init.zeros_(self.final_linear.weight)
nn.init.zeros_(self.final_linear.bias)
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
text_bth: torch.Tensor,
text_lens: torch.Tensor,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
) -> torch.Tensor:
device, dtype = hidden_states.device, hidden_states.dtype
B, C, H, W = hidden_states.shape
p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size
N_img = Hp * Wp
img_sbh = self.x_embedder(hidden_states).transpose(0, 1).contiguous()
if self.text_proj is not None and text_bth.numel() > 0:
text_bth = self.text_proj(text_bth)
Tmax = text_bth.shape[1]
text_sbh = text_bth.transpose(0, 1).contiguous()
x = torch.cat([img_sbh, text_sbh], dim=0)
S = x.shape[0]
text_ids = torch.cat([
torch.arange(Tmax, device=device, dtype=torch.float32).view(1, Tmax, 1).expand(B, -1, -1),
torch.zeros((B, Tmax, 2), device=device)
], dim=-1) if Tmax > 0 else torch.zeros((B, 0, 3), device=device)
grid_yx = torch.stack(
torch.meshgrid(torch.arange(Hp, device=device, dtype=torch.float32),
torch.arange(Wp, device=device, dtype=torch.float32), indexing="ij"),
dim=-1
).reshape(-1, 2)
image_ids = torch.cat([
text_lens.float().view(B, 1, 1).expand(-1, N_img, -1),
grid_yx.view(1, N_img, 2).expand(B, -1, -1)
], dim=-1)
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1))
valid_text = torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) if Tmax > 0 else torch.zeros((B, 0), device=device, dtype=torch.bool)
attention_mask = torch.cat([
torch.ones((B, N_img), device=device, dtype=torch.bool),
valid_text
], dim=1)[:, None, None, :]
sample = self.time_proj(timestep.to(dtype))
sample = sample.to(self.time_embedding.linear_1.weight.dtype)
c = self.time_embedding(sample)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
t.unsqueeze(0).expand(S, -1, -1).contiguous()
for t in self.adaLN_modulation(c).chunk(6, dim=-1)
]
for layer in self.layers:
temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp]
if torch.is_grad_enabled() and use_gradient_checkpointing:
x = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x,
rotary_pos_emb,
temb,
attention_mask,
)
else:
x = layer(x, rotary_pos_emb, temb, attention_mask)
x = self.final_norm(x, c).type_as(x)
patches = self.final_linear(x)[:N_img].transpose(0, 1).contiguous()
output = patches.view(B, Hp, Wp, p, p, self.out_channels).permute(0, 5, 1, 3, 2, 4).contiguous().view(B, self.out_channels, H, W)
return output

View File

@@ -1,76 +0,0 @@
"""
Ernie-Image TextEncoder for DiffSynth-Studio.
Wraps transformers Ministral3Model to output text embeddings.
Pattern: lazy import + manual config dict + torch.nn.Module wrapper.
Only loads the text (language) model, ignoring vision components.
"""
import torch
class ErnieImageTextEncoder(torch.nn.Module):
"""
Text encoder using Ministral3Model (transformers).
Only the text_config portion of the full Mistral3Model checkpoint.
Uses the base model (no lm_head) since the checkpoint only has embeddings.
"""
def __init__(self):
super().__init__()
from transformers import Ministral3Config, Ministral3Model
text_config = {
"attention_dropout": 0.0,
"bos_token_id": 1,
"dtype": "bfloat16",
"eos_token_id": 2,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 3072,
"initializer_range": 0.02,
"intermediate_size": 9216,
"max_position_embeddings": 262144,
"model_type": "ministral3",
"num_attention_heads": 32,
"num_hidden_layers": 26,
"num_key_value_heads": 8,
"pad_token_id": 11,
"rms_norm_eps": 1e-05,
"rope_parameters": {
"beta_fast": 32.0,
"beta_slow": 1.0,
"factor": 16.0,
"llama_4_scaling_beta": 0.1,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 16384,
"rope_theta": 1000000.0,
"rope_type": "yarn",
"type": "yarn",
},
"sliding_window": None,
"tie_word_embeddings": True,
"use_cache": True,
"vocab_size": 131072,
}
config = Ministral3Config(**text_config)
self.model = Ministral3Model(config)
self.config = config
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
**kwargs,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_hidden_states=True,
return_dict=True,
**kwargs,
)
return (outputs.hidden_states,)

View File

@@ -407,7 +407,6 @@ class Flux2AttnProcessor:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype)
hidden_states = attention_forward(
query,
key,
@@ -537,7 +536,6 @@ class Flux2ParallelSelfAttnProcessor:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype)
hidden_states = attention_forward(
query,
key,
@@ -879,9 +877,6 @@ class Flux2Modulation(nn.Module):
class Flux2DiT(torch.nn.Module):
_repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
def __init__(
self,
patch_size: int = 1,

View File

@@ -275,9 +275,6 @@ class AdaLayerNormContinuous(torch.nn.Module):
class FluxDiT(torch.nn.Module):
_repeated_blocks = ["FluxJointTransformerBlock", "FluxSingleTransformerBlock"]
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])

View File

@@ -1,636 +0,0 @@
import math
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from ..core.attention import attention_forward
from ..core.gradient import gradient_checkpoint_forward
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
) -> torch.Tensor:
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
emb = scale * emb
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
return get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
)
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
self.act = nn.SiLU()
time_embed_dim_out = out_dim if out_dim is not None else time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
self.post_act = nn.SiLU() if post_act_fn == "silu" else None
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class PixArtAlphaTextProjection(nn.Module):
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
super().__init__()
if out_features is None:
out_features = hidden_size
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
if act_fn == "gelu_tanh":
self.act_1 = nn.GELU(approximate="tanh")
elif act_fn == "silu":
self.act_1 = nn.SiLU()
else:
self.act_1 = nn.GELU(approximate="tanh")
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
self.approximate = approximate
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.proj(hidden_states)
hidden_states = F.gelu(hidden_states, approximate=self.approximate)
return hidden_states
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
inner_dim=None,
bias: bool = True,
):
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
# Build activation + projection matching diffusers pattern
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim, bias=bias)
elif activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
else:
act_fn = GELU(dim, inner_dim, bias=bias)
self.net = nn.ModuleList([])
self.net.append(act_fn)
self.net.append(nn.Dropout(dropout))
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
def _to_tuple(x, dim=2):
if isinstance(x, int):
return (x,) * dim
elif len(x) == dim:
return x
else:
raise ValueError(f"Expected length {dim} or int, but got {x}")
def get_meshgrid_nd(start, *args, dim=2):
if len(args) == 0:
num = _to_tuple(start, dim=dim)
start = (0,) * dim
stop = num
elif len(args) == 1:
start = _to_tuple(start, dim=dim)
stop = _to_tuple(args[0], dim=dim)
num = [stop[i] - start[i] for i in range(dim)]
elif len(args) == 2:
start = _to_tuple(start, dim=dim)
stop = _to_tuple(args[0], dim=dim)
num = _to_tuple(args[1], dim=dim)
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
axis_grid = []
for i in range(dim):
a, b, n = start[i], stop[i], num[i]
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
axis_grid.append(g)
grid = torch.meshgrid(*axis_grid, indexing="ij")
grid = torch.stack(grid, dim=0)
return grid
def reshape_for_broadcast(freqs_cis, x, head_first=False):
ndim = x.ndim
assert 0 <= 1 < ndim
if isinstance(freqs_cis, tuple):
if head_first:
assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1])
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
else:
assert freqs_cis[0].shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
else:
if head_first:
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
else:
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def rotate_half(x):
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(xq, xk, freqs_cis, head_first=False):
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)
cos, sin = cos.to(xq.device), sin.to(xq.device)
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
return xq_out, xk_out
def get_1d_rotary_pos_embed(dim, pos, theta=10000.0, use_real=False, theta_rescale_factor=1.0, interpolation_factor=1.0):
if isinstance(pos, int):
pos = torch.arange(pos).float()
if theta_rescale_factor != 1.0:
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
freqs = torch.outer(pos * interpolation_factor, freqs)
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1)
freqs_sin = freqs.sin().repeat_interleave(2, dim=1)
return freqs_cos, freqs_sin
else:
return torch.polar(torch.ones_like(freqs), freqs)
def get_nd_rotary_pos_embed(rope_dim_list, start, *args, theta=10000.0, use_real=False,
txt_rope_size=None, theta_rescale_factor=1.0, interpolation_factor=1.0):
grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list))
if isinstance(theta_rescale_factor, (int, float)):
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
if isinstance(interpolation_factor, (int, float)):
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
embs = []
for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(
rope_dim_list[i], grid[i].reshape(-1), theta,
use_real=use_real, theta_rescale_factor=theta_rescale_factor[i],
interpolation_factor=interpolation_factor[i],
)
embs.append(emb)
if use_real:
vis_emb = (torch.cat([emb[0] for emb in embs], dim=1), torch.cat([emb[1] for emb in embs], dim=1))
else:
vis_emb = torch.cat(embs, dim=1)
if txt_rope_size is not None:
embs_txt = []
vis_max_ids = grid.view(-1).max().item()
grid_txt = torch.arange(txt_rope_size) + vis_max_ids + 1
for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(
rope_dim_list[i], grid_txt, theta,
use_real=use_real, theta_rescale_factor=theta_rescale_factor[i],
interpolation_factor=interpolation_factor[i],
)
embs_txt.append(emb)
if use_real:
txt_emb = (torch.cat([emb[0] for emb in embs_txt], dim=1), torch.cat([emb[1] for emb in embs_txt], dim=1))
else:
txt_emb = torch.cat(embs_txt, dim=1)
else:
txt_emb = None
return vis_emb, txt_emb
class ModulateWan(nn.Module):
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None):
super().__init__()
self.factor = factor
factory_kwargs = {"dtype": dtype, "device": device}
self.modulate_table = nn.Parameter(
torch.zeros(1, factor, hidden_size, **factory_kwargs) / hidden_size**0.5,
requires_grad=True
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(x.shape) != 3:
x = x.unsqueeze(1)
return [o.squeeze(1) for o in (self.modulate_table + x).chunk(self.factor, dim=1)]
def modulate(x, shift=None, scale=None):
if scale is None and shift is None:
return x
elif shift is None:
return x * (1 + scale.unsqueeze(1))
elif scale is None:
return x + shift.unsqueeze(1)
else:
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def apply_gate(x, gate=None, tanh=False):
if gate is None:
return x
if tanh:
return x * gate.unsqueeze(1).tanh()
else:
return x * gate.unsqueeze(1)
def load_modulation(modulate_type: str, hidden_size: int, factor: int, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
if modulate_type == 'wanx':
return ModulateWan(hidden_size, factor, **factory_kwargs)
raise ValueError(f"Unknown modulation type: {modulate_type}. Only 'wanx' is supported.")
class RMSNorm(nn.Module):
def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
class MMDoubleStreamBlock(nn.Module):
"""
A multimodal dit block with separate modulation for
text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
(Flux.1): https://github.com/black-forest-labs/flux
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float,
mlp_act_type: str = "gelu_tanh",
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
dit_modulation_type: Optional[str] = "wanx",
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.dit_modulation_type = dit_modulation_type
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.img_mod = load_modulation(
modulate_type=self.dit_modulation_type,
hidden_size=hidden_size, factor=6, **factory_kwargs,
)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs)
self.img_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
self.img_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.img_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate")
self.txt_mod = load_modulation(
modulate_type=self.dit_modulation_type,
hidden_size=hidden_size, factor=6, **factory_kwargs,
)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs)
self.txt_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
self.txt_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.txt_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate")
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
vis_freqs_cis: tuple = None,
txt_freqs_cis: tuple = None,
attn_kwargs: Optional[dict] = {},
) -> Tuple[torch.Tensor, torch.Tensor]:
(
img_mod1_shift, img_mod1_scale, img_mod1_gate,
img_mod2_shift, img_mod2_scale, img_mod2_gate,
) = self.img_mod(vec)
(
txt_mod1_shift, txt_mod1_scale, txt_mod1_gate,
txt_mod2_shift, txt_mod2_scale, txt_mod2_gate,
) = self.txt_mod(vec)
img_modulated = self.img_norm1(img)
img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
if vis_freqs_cis is not None:
img_qq, img_kk = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False)
img_q, img_k = img_qq, img_kk
txt_modulated = self.txt_norm1(txt)
txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
if txt_freqs_cis is not None:
raise NotImplementedError("RoPE text is not supported for inference")
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)
# Use DiffSynth unified attention
attn_out = attention_forward(
q, k, v,
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
)
attn_out = attn_out.flatten(2, 3)
img_attn, txt_attn = attn_out[:, : img.shape[1]], attn_out[:, img.shape[1]:]
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
img = img + apply_gate(
self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
gate=img_mod2_gate,
)
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
txt = txt + apply_gate(
self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
gate=txt_mod2_gate,
)
return img, txt
class WanTimeTextImageEmbedding(nn.Module):
def __init__(
self,
dim: int,
time_freq_dim: int,
time_proj_dim: int,
text_embed_dim: int,
image_embed_dim: Optional[int] = None,
pos_embed_seq_len: Optional[int] = None,
):
super().__init__()
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
self.act_fn = nn.SiLU()
self.time_proj = nn.Linear(dim, time_proj_dim)
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor):
timestep = self.timesteps_proj(timestep)
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
timestep_proj = self.time_proj(self.act_fn(temb))
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
return temb, timestep_proj, encoder_hidden_states
class JoyAIImageDiT(nn.Module):
_supports_gradient_checkpointing = True
def __init__(
self,
patch_size: list = [1, 2, 2],
in_channels: int = 16,
out_channels: int = 16,
hidden_size: int = 4096,
heads_num: int = 32,
text_states_dim: int = 4096,
mlp_width_ratio: float = 4.0,
mm_double_blocks_depth: int = 40,
rope_dim_list: List[int] = [16, 56, 56],
rope_type: str = 'rope',
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
dit_modulation_type: str = "wanx",
theta: int = 10000,
):
super().__init__()
self.out_channels = out_channels or in_channels
self.patch_size = patch_size
self.hidden_size = hidden_size
self.heads_num = heads_num
self.rope_dim_list = rope_dim_list
self.dit_modulation_type = dit_modulation_type
self.mm_double_blocks_depth = mm_double_blocks_depth
self.rope_type = rope_type
self.theta = theta
factory_kwargs = {"device": device, "dtype": dtype}
if hidden_size % heads_num != 0:
raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}")
self.img_in = nn.Conv3d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
self.condition_embedder = WanTimeTextImageEmbedding(
dim=hidden_size,
time_freq_dim=256,
time_proj_dim=hidden_size * 6,
text_embed_dim=text_states_dim,
)
self.double_blocks = nn.ModuleList([
MMDoubleStreamBlock(
self.hidden_size, self.heads_num,
mlp_width_ratio=mlp_width_ratio,
dit_modulation_type=self.dit_modulation_type,
**factory_kwargs,
)
for _ in range(mm_double_blocks_depth)
])
self.norm_out = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(hidden_size, self.out_channels * math.prod(patch_size), **factory_kwargs)
def get_rotary_pos_embed(self, vis_rope_size, txt_rope_size=None):
target_ndim = 3
if len(vis_rope_size) != target_ndim:
vis_rope_size = [1] * (target_ndim - len(vis_rope_size)) + vis_rope_size
head_dim = self.hidden_size // self.heads_num
rope_dim_list = self.rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim
vis_freqs, txt_freqs = get_nd_rotary_pos_embed(
rope_dim_list, vis_rope_size,
txt_rope_size=txt_rope_size if self.rope_type == 'mrope' else None,
theta=self.theta, use_real=True, theta_rescale_factor=1,
)
return vis_freqs, txt_freqs
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
encoder_hidden_states_mask: torch.Tensor = None,
return_dict: bool = True,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
is_multi_item = (len(hidden_states.shape) == 6)
num_items = 0
if is_multi_item:
num_items = hidden_states.shape[1]
if num_items > 1:
assert self.patch_size[0] == 1, "For multi-item input, patch_size[0] must be 1"
hidden_states = torch.cat([hidden_states[:, -1:], hidden_states[:, :-1]], dim=1)
hidden_states = rearrange(hidden_states, 'b n c t h w -> b c (n t) h w')
batch_size, _, ot, oh, ow = hidden_states.shape
tt, th, tw = ot // self.patch_size[0], oh // self.patch_size[1], ow // self.patch_size[2]
if encoder_hidden_states_mask is None:
encoder_hidden_states_mask = torch.ones(
(encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]),
dtype=torch.bool,
).to(encoder_hidden_states.device)
img = self.img_in(hidden_states).flatten(2).transpose(1, 2)
temb, vec, txt = self.condition_embedder(timestep, encoder_hidden_states)
if vec.shape[-1] > self.hidden_size:
vec = vec.unflatten(1, (6, -1))
txt_seq_len = txt.shape[1]
img_seq_len = img.shape[1]
vis_freqs_cis, txt_freqs_cis = self.get_rotary_pos_embed(
vis_rope_size=(tt, th, tw),
txt_rope_size=txt_seq_len if self.rope_type == 'mrope' else None,
)
for block in self.double_blocks:
img, txt = gradient_checkpoint_forward(
block,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
img=img, txt=txt, vec=vec,
vis_freqs_cis=vis_freqs_cis, txt_freqs_cis=txt_freqs_cis,
attn_kwargs={},
)
img_len = img.shape[1]
x = torch.cat((img, txt), 1)
img = x[:, :img_len, ...]
img = self.proj_out(self.norm_out(img))
img = self.unpatchify(img, tt, th, tw)
if is_multi_item:
img = rearrange(img, 'b c (n t) h w -> b n c t h w', n=num_items)
if num_items > 1:
img = torch.cat([img[:, 1:], img[:, :1]], dim=1)
return img
def unpatchify(self, x, t, h, w):
c = self.out_channels
pt, ph, pw = self.patch_size
assert t * h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
x = torch.einsum("nthwopqc->nctohpwq", x)
return x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))

View File

@@ -1,82 +0,0 @@
import torch
from typing import Optional
class JoyAIImageTextEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration
config = Qwen3VLConfig(
text_config={
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 12288,
"max_position_embeddings": 262144,
"model_type": "qwen3_vl_text",
"num_attention_heads": 32,
"num_hidden_layers": 36,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-6,
"rope_scaling": {
"mrope_interleaved": True,
"mrope_section": [24, 20, 20],
"rope_type": "default",
},
"rope_theta": 5000000,
"use_cache": True,
"vocab_size": 151936,
},
vision_config={
"deepstack_visual_indexes": [8, 16, 24],
"depth": 27,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"in_channels": 3,
"initializer_range": 0.02,
"intermediate_size": 4304,
"model_type": "qwen3_vl",
"num_heads": 16,
"num_position_embeddings": 2304,
"out_hidden_size": 4096,
"patch_size": 16,
"spatial_merge_size": 2,
"temporal_patch_size": 2,
},
image_token_id=151655,
video_token_id=151656,
vision_start_token_id=151652,
vision_end_token_id=151653,
tie_word_embeddings=False,
)
self.model = Qwen3VLForConditionalGeneration(config)
self.config = config
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
**kwargs,
):
pre_norm_output = [None]
def hook_fn(module, args, kwargs_output=None):
pre_norm_output[0] = args[0]
self.model.model.language_model.norm.register_forward_hook(hook_fn)
_ = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
attention_mask=attention_mask,
output_hidden_states=True,
**kwargs,
)
return pre_norm_output[0]

File diff suppressed because it is too large Load Diff

View File

@@ -1,388 +0,0 @@
from dataclasses import dataclass
from typing import NamedTuple, Protocol, Tuple
import torch
from torch import nn
from enum import Enum
class VideoPixelShape(NamedTuple):
"""
Shape of the tensor representing the video pixel array. Assumes BGR channel format.
"""
batch: int
frames: int
height: int
width: int
fps: float
class SpatioTemporalScaleFactors(NamedTuple):
"""
Describes the spatiotemporal downscaling between decoded video space and
the corresponding VAE latent grid.
"""
time: int
width: int
height: int
@classmethod
def default(cls) -> "SpatioTemporalScaleFactors":
return cls(time=8, width=32, height=32)
VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default()
class VideoLatentShape(NamedTuple):
"""
Shape of the tensor representing video in VAE latent space.
The latent representation is a 5D tensor with dimensions ordered as
(batch, channels, frames, height, width). Spatial and temporal dimensions
are downscaled relative to pixel space according to the VAE's scale factors.
"""
batch: int
channels: int
frames: int
height: int
width: int
def to_torch_shape(self) -> torch.Size:
return torch.Size([self.batch, self.channels, self.frames, self.height, self.width])
@staticmethod
def from_torch_shape(shape: torch.Size) -> "VideoLatentShape":
return VideoLatentShape(
batch=shape[0],
channels=shape[1],
frames=shape[2],
height=shape[3],
width=shape[4],
)
def mask_shape(self) -> "VideoLatentShape":
return self._replace(channels=1)
@staticmethod
def from_pixel_shape(
shape: VideoPixelShape,
latent_channels: int = 128,
scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS,
) -> "VideoLatentShape":
frames = (shape.frames - 1) // scale_factors[0] + 1
height = shape.height // scale_factors[1]
width = shape.width // scale_factors[2]
return VideoLatentShape(
batch=shape.batch,
channels=latent_channels,
frames=frames,
height=height,
width=width,
)
def upscale(self, scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS) -> "VideoLatentShape":
return self._replace(
channels=3,
frames=(self.frames - 1) * scale_factors.time + 1,
height=self.height * scale_factors.height,
width=self.width * scale_factors.width,
)
class AudioLatentShape(NamedTuple):
"""
Shape of audio in VAE latent space: (batch, channels, frames, mel_bins).
mel_bins is the number of frequency bins from the mel-spectrogram encoding.
"""
batch: int
channels: int
frames: int
mel_bins: int
def to_torch_shape(self) -> torch.Size:
return torch.Size([self.batch, self.channels, self.frames, self.mel_bins])
def mask_shape(self) -> "AudioLatentShape":
return self._replace(channels=1, mel_bins=1)
@staticmethod
def from_torch_shape(shape: torch.Size) -> "AudioLatentShape":
return AudioLatentShape(
batch=shape[0],
channels=shape[1],
frames=shape[2],
mel_bins=shape[3],
)
@staticmethod
def from_duration(
batch: int,
duration: float,
channels: int = 8,
mel_bins: int = 16,
sample_rate: int = 16000,
hop_length: int = 160,
audio_latent_downsample_factor: int = 4,
) -> "AudioLatentShape":
latents_per_second = float(sample_rate) / float(hop_length) / float(audio_latent_downsample_factor)
return AudioLatentShape(
batch=batch,
channels=channels,
frames=round(duration * latents_per_second),
mel_bins=mel_bins,
)
@staticmethod
def from_video_pixel_shape(
shape: VideoPixelShape,
channels: int = 8,
mel_bins: int = 16,
sample_rate: int = 16000,
hop_length: int = 160,
audio_latent_downsample_factor: int = 4,
) -> "AudioLatentShape":
return AudioLatentShape.from_duration(
batch=shape.batch,
duration=float(shape.frames) / float(shape.fps),
channels=channels,
mel_bins=mel_bins,
sample_rate=sample_rate,
hop_length=hop_length,
audio_latent_downsample_factor=audio_latent_downsample_factor,
)
@dataclass(frozen=True)
class LatentState:
"""
State of latents during the diffusion denoising process.
Attributes:
latent: The current noisy latent tensor being denoised.
denoise_mask: Mask encoding the denoising strength for each token (1 = full denoising, 0 = no denoising).
positions: Positional indices for each latent element, used for positional embeddings.
clean_latent: Initial state of the latent before denoising, may include conditioning latents.
"""
latent: torch.Tensor
denoise_mask: torch.Tensor
positions: torch.Tensor
clean_latent: torch.Tensor
def clone(self) -> "LatentState":
return LatentState(
latent=self.latent.clone(),
denoise_mask=self.denoise_mask.clone(),
positions=self.positions.clone(),
clean_latent=self.clean_latent.clone(),
)
class NormType(Enum):
"""Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm)."""
GROUP = "group"
PIXEL = "pixel"
class PixelNorm(nn.Module):
"""
Per-pixel (per-location) RMS normalization layer.
For each element along the chosen dimension, this layer normalizes the tensor
by the root-mean-square of its values across that dimension:
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
"""
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
"""
Args:
dim: Dimension along which to compute the RMS (typically channels).
eps: Small constant added for numerical stability.
"""
super().__init__()
self.dim = dim
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply RMS normalization along the configured dimension.
"""
# Compute mean of squared values along `dim`, keep dimensions for broadcasting.
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
# Normalize by the root-mean-square (RMS).
rms = torch.sqrt(mean_sq + self.eps)
return x / rms
def build_normalization_layer(
in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP
) -> nn.Module:
"""
Create a normalization layer based on the normalization type.
Args:
in_channels: Number of input channels
num_groups: Number of groups for group normalization
normtype: Type of normalization: "group" or "pixel"
Returns:
A normalization layer
"""
if normtype == NormType.GROUP:
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if normtype == NormType.PIXEL:
return PixelNorm(dim=1, eps=1e-6)
raise ValueError(f"Invalid normalization type: {normtype}")
def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor:
"""Root-mean-square (RMS) normalize `x` over its last dimension.
Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized
shape and forwards `weight` and `eps`.
"""
return torch.nn.functional.rms_norm(x, (x.shape[-1],), weight=weight, eps=eps)
@dataclass(frozen=True)
class Modality:
"""
Input data for a single modality (video or audio) in the transformer.
Bundles the latent tokens, timestep embeddings, positional information,
and text conditioning context for processing by the diffusion transformer.
Attributes:
latent: Patchified latent tokens, shape ``(B, T, D)`` where *B* is
the batch size, *T* is the total number of tokens (noisy +
conditioning), and *D* is the input dimension.
timesteps: Per-token timestep embeddings, shape ``(B, T)``.
positions: Positional coordinates, shape ``(B, 3, T)`` for video
(time, height, width) or ``(B, 1, T)`` for audio.
context: Text conditioning embeddings from the prompt encoder.
enabled: Whether this modality is active in the current forward pass.
context_mask: Optional mask for the text context tokens.
attention_mask: Optional 2-D self-attention mask, shape ``(B, T, T)``.
Values in ``[0, 1]`` where ``1`` = full attention and ``0`` = no
attention. ``None`` means unrestricted (full) attention between
all tokens. Built incrementally by conditioning items; see
:class:`~ltx_core.conditioning.types.attention_strength_wrapper.ConditioningItemAttentionStrengthWrapper`.
"""
latent: (
torch.Tensor
) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
sigma: torch.Tensor # Shape: (B,). Current sigma value, used for cross-attention timestep calculation.
timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps
positions: (
torch.Tensor
) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens
context: torch.Tensor
enabled: bool = True
context_mask: torch.Tensor | None = None
attention_mask: torch.Tensor | None = None
def to_denoised(
sample: torch.Tensor,
velocity: torch.Tensor,
sigma: float | torch.Tensor,
calc_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Convert the sample and its denoising velocity to denoised sample.
Returns:
Denoised sample
"""
if isinstance(sigma, torch.Tensor):
sigma = sigma.to(calc_dtype)
return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype)
class Patchifier(Protocol):
"""
Protocol for patchifiers that convert latent tensors into patches and assemble them back.
"""
def patchify(
self,
latents: torch.Tensor,
) -> torch.Tensor:
...
"""
Convert latent tensors into flattened patch tokens.
Args:
latents: Latent tensor to patchify.
Returns:
Flattened patch tokens tensor.
"""
def unpatchify(
self,
latents: torch.Tensor,
output_shape: AudioLatentShape | VideoLatentShape,
) -> torch.Tensor:
"""
Converts latent tensors between spatio-temporal formats and flattened sequence representations.
Args:
latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`.
output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or
VideoLatentShape.
Returns:
Dense latent tensor restored from the flattened representation.
"""
@property
def patch_size(self) -> Tuple[int, int, int]:
...
"""
Returns the patch size as a tuple of (temporal, height, width) dimensions
"""
def get_patch_grid_bounds(
self,
output_shape: AudioLatentShape | VideoLatentShape,
device: torch.device | None = None,
) -> torch.Tensor:
...
"""
Compute metadata describing where each latent patch resides within the
grid specified by `output_shape`.
Args:
output_shape: Target grid layout for the patches.
device: Target device for the returned tensor.
Returns:
Tensor containing patch coordinate metadata such as spatial or temporal intervals.
"""
def get_pixel_coords(
latent_coords: torch.Tensor,
scale_factors: SpatioTemporalScaleFactors,
causal_fix: bool = False,
) -> torch.Tensor:
"""
Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
Args:
latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
per axis.
causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
that treat frame zero differently still yield non-negative timestamps.
"""
# Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
broadcast_shape = [1] * latent_coords.ndim
broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
# Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
pixel_coords = latent_coords * scale_tensor
if causal_fix:
# VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
# Shift and clamp to keep the first-frame timestamps causal and non-negative.
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
return pixel_coords

File diff suppressed because it is too large Load Diff

View File

@@ -1,549 +0,0 @@
import math
import torch
import torch.nn as nn
from einops import rearrange
from transformers import Gemma3ForConditionalGeneration, Gemma3Config, AutoTokenizer
from .ltx2_dit import (LTXRopeType, generate_freq_grid_np, generate_freq_grid_pytorch, precompute_freqs_cis, Attention,
FeedForward)
from .ltx2_common import rms_norm
class LTX2TextEncoder(Gemma3ForConditionalGeneration):
def __init__(self):
config = Gemma3Config(
**{
"architectures": ["Gemma3ForConditionalGeneration"],
"boi_token_index": 255999,
"dtype": "bfloat16",
"eoi_token_index": 256000,
"eos_token_id": [1, 106],
"image_token_index": 262144,
"initializer_range": 0.02,
"mm_tokens_per_image": 256,
"model_type": "gemma3",
"text_config": {
"_sliding_window_pattern": 6,
"attention_bias": False,
"attention_dropout": 0.0,
"attn_logit_softcapping": None,
"cache_implementation": "hybrid",
"dtype": "bfloat16",
"final_logit_softcapping": None,
"head_dim": 256,
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 3840,
"initializer_range": 0.02,
"intermediate_size": 15360,
"layer_types": [
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention"
],
"max_position_embeddings": 131072,
"model_type": "gemma3_text",
"num_attention_heads": 16,
"num_hidden_layers": 48,
"num_key_value_heads": 8,
"query_pre_attn_scalar": 256,
"rms_norm_eps": 1e-06,
"rope_local_base_freq": 10000,
"rope_scaling": {
"factor": 8.0,
"rope_type": "linear"
},
"rope_theta": 1000000,
"sliding_window": 1024,
"sliding_window_pattern": 6,
"use_bidirectional_attention": False,
"use_cache": True,
"vocab_size": 262208
},
"transformers_version": "4.57.3",
"vision_config": {
"attention_dropout": 0.0,
"dtype": "bfloat16",
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": 896,
"intermediate_size": 4304,
"layer_norm_eps": 1e-06,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 27,
"patch_size": 14,
"vision_use_head": False
}
})
super().__init__(config)
class LTXVGemmaTokenizer:
"""
Tokenizer wrapper for Gemma models compatible with LTXV processes.
This class wraps HuggingFace's `AutoTokenizer` for use with Gemma text encoders,
ensuring correct settings and output formatting for downstream consumption.
"""
def __init__(self, tokenizer_path: str, max_length: int = 1024):
"""
Initialize the tokenizer.
Args:
tokenizer_path (str): Path to the pretrained tokenizer files or model directory.
max_length (int, optional): Max sequence length for encoding. Defaults to 256.
"""
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, local_files_only=True, model_max_length=max_length
)
# Gemma expects left padding for chat-style prompts; for plain text it doesn't matter much.
self.tokenizer.padding_side = "left"
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.max_length = max_length
def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> dict[str, list[tuple[int, int]]]:
"""
Tokenize the given text and return token IDs and attention weights.
Args:
text (str): The input string to tokenize.
return_word_ids (bool, optional): If True, includes the token's position (index) in the output tuples.
If False (default), omits the indices.
Returns:
dict[str, list[tuple[int, int]]] OR dict[str, list[tuple[int, int, int]]]:
A dictionary with a "gemma" key mapping to:
- a list of (token_id, attention_mask) tuples if return_word_ids is False;
- a list of (token_id, attention_mask, index) tuples if return_word_ids is True.
Example:
>>> tokenizer = LTXVGemmaTokenizer("path/to/tokenizer", max_length=8)
>>> tokenizer.tokenize_with_weights("hello world")
{'gemma': [(1234, 1), (5678, 1), (2, 0), ...]}
"""
text = text.strip()
encoded = self.tokenizer(
text,
padding="max_length",
max_length=self.max_length,
truncation=True,
return_tensors="pt",
)
input_ids = encoded.input_ids
attention_mask = encoded.attention_mask
tuples = [
(token_id, attn, i) for i, (token_id, attn) in enumerate(zip(input_ids[0], attention_mask[0], strict=True))
]
out = {"gemma": tuples}
if not return_word_ids:
# Return only (token_id, attention_mask) pairs, omitting token position
out = {k: [(t, w) for t, w, _ in v] for k, v in out.items()}
return out
class GemmaFeaturesExtractorProjLinear(nn.Module):
"""
Feature extractor module for Gemma models.
This module applies a single linear projection to the input tensor.
It expects a flattened feature tensor of shape (batch_size, 3840*49).
The linear layer maps this to a (batch_size, 3840) embedding.
Attributes:
aggregate_embed (nn.Linear): Linear projection layer.
"""
def __init__(self) -> None:
"""
Initialize the GemmaFeaturesExtractorProjLinear module.
The input dimension is expected to be 3840 * 49, and the output is 3840.
"""
super().__init__()
self.aggregate_embed = nn.Linear(3840 * 49, 3840, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
padding_side: str = "left",
) -> tuple[torch.Tensor, torch.Tensor | None]:
encoded = torch.stack(hidden_states, dim=-1) if isinstance(hidden_states, (list, tuple)) else hidden_states
dtype = encoded.dtype
sequence_lengths = attention_mask.sum(dim=-1)
normed = _norm_and_concat_padded_batch(encoded, sequence_lengths, padding_side)
features = self.aggregate_embed(normed.to(dtype))
return features, features
class GemmaSeperatedFeaturesExtractorProjLinear(nn.Module):
"""22B: per-token RMS norm → rescale → dual aggregate embeds"""
def __init__(
self,
num_layers: int,
embedding_dim: int,
video_inner_dim: int,
audio_inner_dim: int,
):
super().__init__()
in_dim = embedding_dim * num_layers
self.video_aggregate_embed = torch.nn.Linear(in_dim, video_inner_dim, bias=True)
self.audio_aggregate_embed = torch.nn.Linear(in_dim, audio_inner_dim, bias=True)
self.embedding_dim = embedding_dim
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
padding_side: str = "left", # noqa: ARG002
) -> tuple[torch.Tensor, torch.Tensor | None]:
encoded = torch.stack(hidden_states, dim=-1) if isinstance(hidden_states, (list, tuple)) else hidden_states
normed = norm_and_concat_per_token_rms(encoded, attention_mask)
normed = normed.to(encoded.dtype)
v_dim = self.video_aggregate_embed.out_features
video = self.video_aggregate_embed(_rescale_norm(normed, v_dim, self.embedding_dim))
audio = None
if self.audio_aggregate_embed is not None:
a_dim = self.audio_aggregate_embed.out_features
audio = self.audio_aggregate_embed(_rescale_norm(normed, a_dim, self.embedding_dim))
return video, audio
class _BasicTransformerBlock1D(nn.Module):
def __init__(
self,
dim: int,
heads: int,
dim_head: int,
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
apply_gated_attention: bool = False,
):
super().__init__()
self.attn1 = Attention(
query_dim=dim,
heads=heads,
dim_head=dim_head,
rope_type=rope_type,
apply_gated_attention=apply_gated_attention,
)
self.ff = FeedForward(
dim,
dim_out=dim,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
pe: torch.Tensor | None = None,
) -> torch.Tensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Normalization Before Self-Attention
norm_hidden_states = rms_norm(hidden_states)
norm_hidden_states = norm_hidden_states.squeeze(1)
# 2. Self-Attention
attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe)
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 3. Normalization before Feed-Forward
norm_hidden_states = rms_norm(hidden_states)
# 4. Feed-forward
ff_output = self.ff(norm_hidden_states)
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class Embeddings1DConnector(nn.Module):
"""
Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or
other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can
substitute padded positions with learnable registers. The module is highly configurable for head size, number of
layers, and register usage.
Args:
attention_head_dim (int): Dimension of each attention head (default=128).
num_attention_heads (int): Number of attention heads (default=30).
num_layers (int): Number of transformer layers (default=2).
positional_embedding_theta (float): Scaling factor for position embedding (default=10000.0).
positional_embedding_max_pos (list[int] | None): Max positions for positional embeddings (default=[1]).
causal_temporal_positioning (bool): If True, uses causal attention (default=False).
num_learnable_registers (int | None): Number of learnable registers to replace padded tokens. If None, disables
register replacement. (default=128)
rope_type (LTXRopeType): The RoPE variant to use (default=DEFAULT_ROPE_TYPE).
double_precision_rope (bool): Use double precision rope calculation (default=False).
"""
_supports_gradient_checkpointing = True
def __init__(
self,
attention_head_dim: int = 128,
num_attention_heads: int = 30,
num_layers: int = 2,
positional_embedding_theta: float = 10000.0,
positional_embedding_max_pos: list[int] | None = [4096],
causal_temporal_positioning: bool = False,
num_learnable_registers: int | None = 128,
rope_type: LTXRopeType = LTXRopeType.SPLIT,
double_precision_rope: bool = True,
apply_gated_attention: bool = False,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.inner_dim = num_attention_heads * attention_head_dim
self.causal_temporal_positioning = causal_temporal_positioning
self.positional_embedding_theta = positional_embedding_theta
self.positional_embedding_max_pos = (
positional_embedding_max_pos if positional_embedding_max_pos is not None else [1]
)
self.rope_type = rope_type
self.double_precision_rope = double_precision_rope
self.transformer_1d_blocks = nn.ModuleList(
[
_BasicTransformerBlock1D(
dim=self.inner_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
rope_type=rope_type,
apply_gated_attention=apply_gated_attention,
)
for _ in range(num_layers)
]
)
self.num_learnable_registers = num_learnable_registers
if self.num_learnable_registers:
self.learnable_registers = nn.Parameter(
torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0
)
def _replace_padded_with_learnable_registers(
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.shape[1] % self.num_learnable_registers == 0, (
f"Hidden states sequence length {hidden_states.shape[1]} must be divisible by num_learnable_registers "
f"{self.num_learnable_registers}."
)
num_registers_duplications = hidden_states.shape[1] // self.num_learnable_registers
learnable_registers = torch.tile(self.learnable_registers, (num_registers_duplications, 1))
attention_mask_binary = (attention_mask.squeeze(1).squeeze(1).unsqueeze(-1) >= -9000.0).int()
non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :]
non_zero_nums = non_zero_hidden_states.shape[1]
pad_length = hidden_states.shape[1] - non_zero_nums
adjusted_hidden_states = nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0)
flipped_mask = torch.flip(attention_mask_binary, dims=[1])
hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers
attention_mask = torch.full_like(
attention_mask,
0.0,
dtype=attention_mask.dtype,
device=attention_mask.device,
)
return hidden_states, attention_mask
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass of Embeddings1DConnector.
Args:
hidden_states (torch.Tensor): Input tensor of embeddings (shape [batch, seq_len, feature_dim]).
attention_mask (torch.Tensor|None): Optional mask for valid tokens (shape compatible with hidden_states).
Returns:
tuple[torch.Tensor, torch.Tensor]: Processed features and the corresponding (possibly modified) mask.
"""
if self.num_learnable_registers:
hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask)
indices_grid = torch.arange(hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device)
indices_grid = indices_grid[None, None, :]
freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch
freqs_cis = precompute_freqs_cis(
indices_grid=indices_grid,
dim=self.inner_dim,
out_dtype=hidden_states.dtype,
theta=self.positional_embedding_theta,
max_pos=self.positional_embedding_max_pos,
num_attention_heads=self.num_attention_heads,
rope_type=self.rope_type,
freq_grid_generator=freq_grid_generator,
)
for block in self.transformer_1d_blocks:
hidden_states = block(hidden_states, attention_mask=attention_mask, pe=freqs_cis)
hidden_states = rms_norm(hidden_states)
return hidden_states, attention_mask
class LTX2TextEncoderPostModules(nn.Module):
def __init__(
self,
separated_audio_video: bool = False,
embedding_dim_gemma: int = 3840,
num_layers_gemma: int = 49,
video_attention_heads: int = 32,
video_attention_head_dim: int = 128,
audio_attention_heads: int = 32,
audio_attention_head_dim: int = 64,
num_connector_layers: int = 2,
apply_gated_attention: bool = False,
):
super().__init__()
if not separated_audio_video:
self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear()
self.embeddings_connector = Embeddings1DConnector()
self.audio_embeddings_connector = Embeddings1DConnector()
else:
# LTX-2.3
self.feature_extractor_linear = GemmaSeperatedFeaturesExtractorProjLinear(
num_layers_gemma, embedding_dim_gemma, video_attention_heads * video_attention_head_dim,
audio_attention_heads * audio_attention_head_dim)
self.embeddings_connector = Embeddings1DConnector(
attention_head_dim=video_attention_head_dim,
num_attention_heads=video_attention_heads,
num_layers=num_connector_layers,
apply_gated_attention=apply_gated_attention,
)
self.audio_embeddings_connector = Embeddings1DConnector(
attention_head_dim=audio_attention_head_dim,
num_attention_heads=audio_attention_heads,
num_layers=num_connector_layers,
apply_gated_attention=apply_gated_attention,
)
def create_embeddings(
self,
video_features: torch.Tensor,
audio_features: torch.Tensor | None,
additive_attention_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]:
video_encoded, video_mask = self.embeddings_connector(video_features, additive_attention_mask)
video_encoded, binary_mask = _to_binary_mask(video_encoded, video_mask)
audio_encoded, _ = self.audio_embeddings_connector(audio_features, additive_attention_mask)
return video_encoded, audio_encoded, binary_mask
def process_hidden_states(
self,
hidden_states: tuple[torch.Tensor, ...],
attention_mask: torch.Tensor,
padding_side: str = "left",
):
video_feats, audio_feats = self.feature_extractor_linear(hidden_states, attention_mask, padding_side)
additive_mask = _convert_to_additive_mask(attention_mask, video_feats.dtype)
video_enc, audio_enc, binary_mask = self.create_embeddings(video_feats, audio_feats, additive_mask)
return video_enc, audio_enc, binary_mask
def _norm_and_concat_padded_batch(
encoded_text: torch.Tensor,
sequence_lengths: torch.Tensor,
padding_side: str = "right",
) -> torch.Tensor:
"""Normalize and flatten multi-layer hidden states, respecting padding.
Performs per-batch, per-layer normalization using masked mean and range,
then concatenates across the layer dimension.
Args:
encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].
sequence_lengths: Number of valid (non-padded) tokens per batch item.
padding_side: Whether padding is on "left" or "right".
Returns:
Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],
with padded positions zeroed out.
"""
b, t, d, l = encoded_text.shape # noqa: E741
device = encoded_text.device
# Build mask: [B, T, 1, 1]
token_indices = torch.arange(t, device=device)[None, :] # [1, T]
if padding_side == "right":
# For right padding, valid tokens are from 0 to sequence_length-1
mask = token_indices < sequence_lengths[:, None] # [B, T]
elif padding_side == "left":
# For left padding, valid tokens are from (T - sequence_length) to T-1
start_indices = t - sequence_lengths[:, None] # [B, 1]
mask = token_indices >= start_indices # [B, T]
else:
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
mask = rearrange(mask, "b t -> b t 1 1")
eps = 1e-6
# Compute masked mean: [B, 1, 1, L]
masked = encoded_text.masked_fill(~mask, 0.0)
denom = (sequence_lengths * d).view(b, 1, 1, 1)
mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps)
# Compute masked min/max: [B, 1, 1, L]
x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
range_ = x_max - x_min
# Normalize only the valid tokens
normed = 8 * (encoded_text - mean) / (range_ + eps)
# concat to be [Batch, T, D * L] - this preserves the original structure
normed = normed.reshape(b, t, -1) # [B, T, D * L]
# Apply mask to preserve original padding (set padded positions to 0)
mask_flattened = rearrange(mask, "b t 1 1 -> b t 1").expand(-1, -1, d * l)
normed = normed.masked_fill(~mask_flattened, 0.0)
return normed
def _convert_to_additive_mask(attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return (attention_mask - 1).to(dtype).reshape(
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(dtype).max
def _to_binary_mask(encoded: torch.Tensor, encoded_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert connector output mask to binary mask and apply to encoded tensor."""
binary_mask = (encoded_mask < 0.000001).to(torch.int64)
binary_mask = binary_mask.reshape([encoded.shape[0], encoded.shape[1], 1])
encoded = encoded * binary_mask
return encoded, binary_mask
def norm_and_concat_per_token_rms(
encoded_text: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
"""Per-token RMSNorm normalization for V2 models.
Args:
encoded_text: [B, T, D, L]
attention_mask: [B, T] binary mask
Returns:
[B, T, D*L] normalized tensor with padding zeroed out.
"""
B, T, D, L = encoded_text.shape # noqa: N806
variance = torch.mean(encoded_text**2, dim=2, keepdim=True) # [B,T,1,L]
normed = encoded_text * torch.rsqrt(variance + 1e-6)
normed = normed.reshape(B, T, D * L)
mask_3d = attention_mask.bool().unsqueeze(-1) # [B, T, 1]
return torch.where(mask_3d, normed, torch.zeros_like(normed))
def _rescale_norm(x: torch.Tensor, target_dim: int, source_dim: int) -> torch.Tensor:
"""Rescale normalization: x * sqrt(target_dim / source_dim)."""
return x * math.sqrt(target_dim / source_dim)

View File

@@ -1,313 +0,0 @@
import math
from typing import Optional, Tuple
import torch
from einops import rearrange
import torch.nn.functional as F
from .ltx2_video_vae import LTX2VideoEncoder
class PixelShuffleND(torch.nn.Module):
"""
N-dimensional pixel shuffle operation for upsampling tensors.
Args:
dims (int): Number of dimensions to apply pixel shuffle to.
- 1: Temporal (e.g., frames)
- 2: Spatial (e.g., height and width)
- 3: Spatiotemporal (e.g., depth, height, width)
upscale_factors (tuple[int, int, int], optional): Upscaling factors for each dimension.
For dims=1, only the first value is used.
For dims=2, the first two values are used.
For dims=3, all three values are used.
The input tensor is rearranged so that the channel dimension is split into
smaller channels and upscaling factors, and the upscaling factors are moved
into the corresponding spatial/temporal dimensions.
Note:
This operation is equivalent to the patchifier operation in for the models. Consider
using this class instead.
"""
def __init__(self, dims: int, upscale_factors: tuple[int, int, int] = (2, 2, 2)):
super().__init__()
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
self.dims = dims
self.upscale_factors = upscale_factors
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.dims == 3:
return rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.upscale_factors[0],
p2=self.upscale_factors[1],
p3=self.upscale_factors[2],
)
elif self.dims == 2:
return rearrange(
x,
"b (c p1 p2) h w -> b c (h p1) (w p2)",
p1=self.upscale_factors[0],
p2=self.upscale_factors[1],
)
elif self.dims == 1:
return rearrange(
x,
"b (c p1) f h w -> b c (f p1) h w",
p1=self.upscale_factors[0],
)
else:
raise ValueError(f"Unsupported dims: {self.dims}")
class ResBlock(torch.nn.Module):
"""
Residual block with two convolutional layers, group normalization, and SiLU activation.
Args:
channels (int): Number of input and output channels.
mid_channels (Optional[int]): Number of channels in the intermediate convolution layer. Defaults to `channels`
if not specified.
dims (int): Dimensionality of the convolution (2 for Conv2d, 3 for Conv3d). Defaults to 3.
"""
def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3):
super().__init__()
if mid_channels is None:
mid_channels = channels
conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
self.conv1 = conv(channels, mid_channels, kernel_size=3, padding=1)
self.norm1 = torch.nn.GroupNorm(32, mid_channels)
self.conv2 = conv(mid_channels, channels, kernel_size=3, padding=1)
self.norm2 = torch.nn.GroupNorm(32, channels)
self.activation = torch.nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.conv1(x)
x = self.norm1(x)
x = self.activation(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.activation(x + residual)
return x
class BlurDownsample(torch.nn.Module):
"""
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel.
Applies only on H,W. Works for dims=2 or dims=3 (per-frame).
"""
def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None:
super().__init__()
assert dims in (2, 3)
assert isinstance(stride, int)
assert stride >= 1
assert kernel_size >= 3
assert kernel_size % 2 == 1
self.dims = dims
self.stride = stride
self.kernel_size = kernel_size
# 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from
# the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and
# provides a smooth approximation of a Gaussian filter (often called a "binomial filter").
# The 2D kernel is constructed as the outer product and normalized.
k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)])
k2d = k[:, None] @ k[None, :]
k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size)
self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.stride == 1:
return x
if self.dims == 2:
return self._apply_2d(x)
else:
# dims == 3: apply per-frame on H,W
b, _, f, _, _ = x.shape
x = rearrange(x, "b c f h w -> (b f) c h w")
x = self._apply_2d(x)
h2, w2 = x.shape[-2:]
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2)
return x
def _apply_2d(self, x2d: torch.Tensor) -> torch.Tensor:
c = x2d.shape[1]
weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
x2d = F.conv2d(x2d, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
return x2d
def _rational_for_scale(scale: float) -> Tuple[int, int]:
mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)}
if float(scale) not in mapping:
raise ValueError(f"Unsupported scale {scale}. Choose from {list(mapping.keys())}")
return mapping[float(scale)]
class SpatialRationalResampler(torch.nn.Module):
"""
Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
downsample by 'den' using fixed blur + stride. Operates on H,W only.
For dims==3, work per-frame for spatial scaling (temporal axis untouched).
Args:
mid_channels (`int`): Number of intermediate channels for the convolution layer
scale (`float`): Spatial scaling factor. Supported values are:
- 0.75: Downsample by 3/4 (reduce spatial size)
- 1.5: Upsample by 3/2 (increase spatial size)
- 2.0: Upsample by 2x (double spatial size)
- 4.0: Upsample by 4x (quadruple spatial size)
Any other value will raise a ValueError.
"""
def __init__(self, mid_channels: int, scale: float):
super().__init__()
self.scale = float(scale)
self.num, self.den = _rational_for_scale(self.scale)
self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1)
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
self.blur_down = BlurDownsample(dims=2, stride=self.den)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, _, f, _, _ = x.shape
x = rearrange(x, "b c f h w -> (b f) c h w")
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.blur_down(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
return x
class LTX2LatentUpsampler(torch.nn.Module):
"""
Model to upsample VAE latents spatially and/or temporally.
Args:
in_channels (`int`): Number of channels in the input latent
mid_channels (`int`): Number of channels in the middle layers
num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
dims (`int`): Number of dimensions for convolutions (2 or 3)
spatial_upsample (`bool`): Whether to spatially upsample the latent
temporal_upsample (`bool`): Whether to temporally upsample the latent
spatial_scale (`float`): Scale factor for spatial upsampling
rational_resampler (`bool`): Whether to use a rational resampler for spatial upsampling
"""
def __init__(
self,
in_channels: int = 128,
mid_channels: int = 1024,
num_blocks_per_stage: int = 4,
dims: int = 3,
spatial_upsample: bool = True,
temporal_upsample: bool = False,
spatial_scale: float = 2.0,
rational_resampler: bool = True,
):
super().__init__()
self.in_channels = in_channels
self.mid_channels = mid_channels
self.num_blocks_per_stage = num_blocks_per_stage
self.dims = dims
self.spatial_upsample = spatial_upsample
self.temporal_upsample = temporal_upsample
self.spatial_scale = float(spatial_scale)
self.rational_resampler = rational_resampler
conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
self.initial_conv = conv(in_channels, mid_channels, kernel_size=3, padding=1)
self.initial_norm = torch.nn.GroupNorm(32, mid_channels)
self.initial_activation = torch.nn.SiLU()
self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])
if spatial_upsample and temporal_upsample:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(3),
)
elif spatial_upsample:
if rational_resampler:
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=self.spatial_scale)
else:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(2),
)
elif temporal_upsample:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(1),
)
else:
raise ValueError("Either spatial_upsample or temporal_upsample must be True")
self.post_upsample_res_blocks = torch.nn.ModuleList(
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
)
self.final_conv = conv(mid_channels, in_channels, kernel_size=3, padding=1)
def forward(self, latent: torch.Tensor) -> torch.Tensor:
b, _, f, _, _ = latent.shape
if self.dims == 2:
x = rearrange(latent, "b c f h w -> (b f) c h w")
x = self.initial_conv(x)
x = self.initial_norm(x)
x = self.initial_activation(x)
for block in self.res_blocks:
x = block(x)
x = self.upsampler(x)
for block in self.post_upsample_res_blocks:
x = block(x)
x = self.final_conv(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
else:
x = self.initial_conv(latent)
x = self.initial_norm(x)
x = self.initial_activation(x)
for block in self.res_blocks:
x = block(x)
if self.temporal_upsample:
x = self.upsampler(x)
# remove the first frame after upsampling.
# This is done because the first frame encodes one pixel frame.
x = x[:, :, 1:, :, :]
elif isinstance(self.upsampler, SpatialRationalResampler):
x = self.upsampler(x)
else:
x = rearrange(x, "b c f h w -> (b f) c h w")
x = self.upsampler(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
for block in self.post_upsample_res_blocks:
x = block(x)
x = self.final_conv(x)
return x
def upsample_video(latent: torch.Tensor, video_encoder: LTX2VideoEncoder, upsampler: "LTX2LatentUpsampler") -> torch.Tensor:
"""
Apply upsampling to the latent representation using the provided upsampler,
with normalization and un-normalization based on the video encoder's per-channel statistics.
Args:
latent: Input latent tensor of shape [B, C, F, H, W].
video_encoder: VideoEncoder with per_channel_statistics for normalization.
upsampler: LTX2LatentUpsampler module to perform upsampling.
Returns:
torch.Tensor: Upsampled and re-normalized latent tensor.
"""
latent = video_encoder.per_channel_statistics.un_normalize(latent)
latent = upsampler(latent)
latent = video_encoder.per_channel_statistics.normalize(latent)
return latent

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
from ..core.loader import load_model, hash_model_file
from ..core.vram import AutoWrappedModule
from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS, VERSION_CHECKER_MAPS
from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS
import importlib, json, torch
@@ -22,15 +22,14 @@ class ModelPool:
def fetch_module_map(self, model_class, vram_config):
if self.need_to_enable_vram_management(vram_config):
if model_class in VRAM_MANAGEMENT_MODULE_MAPS:
vram_module_map = VRAM_MANAGEMENT_MODULE_MAPS[model_class] if model_class not in VERSION_CHECKER_MAPS else VERSION_CHECKER_MAPS[model_class]()
module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in vram_module_map.items()}
module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in VRAM_MANAGEMENT_MODULE_MAPS[model_class].items()}
else:
module_map = {self.import_model_class(model_class): AutoWrappedModule}
else:
module_map = None
return module_map
def load_model_file(self, config, path, vram_config, vram_limit=None, state_dict=None):
def load_model_file(self, config, path, vram_config, vram_limit=None):
model_class = self.import_model_class(config["model_class"])
model_config = config.get("extra_kwargs", {})
if "state_dict_converter" in config:
@@ -44,7 +43,6 @@ class ModelPool:
state_dict_converter,
use_disk_map=True,
vram_config=vram_config, module_map=module_map, vram_limit=vram_limit,
state_dict=state_dict,
)
return model
@@ -61,7 +59,7 @@ class ModelPool:
}
return vram_config
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False, state_dict=None):
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False):
print(f"Loading models from: {json.dumps(path, indent=4)}")
if vram_config is None:
vram_config = self.default_vram_config()
@@ -69,7 +67,7 @@ class ModelPool:
loaded = False
for config in MODEL_CONFIGS:
if config["model_hash"] == model_hash:
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit, state_dict=state_dict)
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit)
if clear_parameters: self.clear_parameters(model)
self.model.append(model)
model_name = config["model_name"]

View File

@@ -1,57 +0,0 @@
import torch
import torch.nn as nn
from .wan_video_dit import WanModel, precompute_freqs_cis, sinusoidal_embedding_1d
from einops import rearrange
from ..core import gradient_checkpoint_forward
def precompute_freqs_cis_1d(dim: int, end: int = 16384, theta: float = 10000.0):
f_freqs_cis = precompute_freqs_cis(dim, end, theta)
return f_freqs_cis.chunk(3, dim=-1)
class MovaAudioDit(WanModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
head_dim = kwargs.get("dim", 1536) // kwargs.get("num_heads", 12)
self.freqs = precompute_freqs_cis_1d(head_dim)
self.patch_embedding = nn.Conv1d(
kwargs.get("in_dim", 128), kwargs.get("dim", 1536), kernel_size=[1], stride=[1]
)
def precompute_freqs_cis(self, dim: int, end: int = 16384, theta: float = 10000.0):
self.f_freqs_cis = precompute_freqs_cis_1d(dim, end, theta)
def forward(self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
):
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
context = self.text_embedding(context)
x, (f, ) = self.patchify(x)
freqs = torch.cat([
self.freqs[0][:f].view(f, -1).expand(f, -1),
self.freqs[1][:f].view(f, -1).expand(f, -1),
self.freqs[2][:f].view(f, -1).expand(f, -1),
], dim=-1).reshape(f, 1, -1).to(x.device)
for block in self.blocks:
x = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x, context, t_mod, freqs,
)
x = self.head(x, t)
x = self.unpatchify(x, (f, ))
return x
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
return rearrange(
x, 'b f (p c) -> b c (f p)',
f=grid_size[0],
p=self.patch_size[0]
)

View File

@@ -1,796 +0,0 @@
import math
from typing import List, Union
import numpy as np
import torch
from torch import nn
from torch.nn.utils import weight_norm
import torch.nn.functional as F
from einops import rearrange
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
# Scripting this brings model speed up 1.4x
@torch.jit.script
def snake(x, alpha):
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake1d(nn.Module):
def __init__(self, channels):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
def forward(self, x):
return snake(x, self.alpha)
class VectorQuantize(nn.Module):
"""
Implementation of VQ similar to Karpathy's repo:
https://github.com/karpathy/deep-vector-quantization
Additionally uses following tricks from Improved VQGAN
(https://arxiv.org/pdf/2110.04627.pdf):
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
for improved codebook usage
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
improves training stability
"""
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
super().__init__()
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
self.codebook = nn.Embedding(codebook_size, codebook_dim)
def forward(self, z):
"""Quantized the input tensor using a fixed codebook and returns
the corresponding codebook vectors
Parameters
----------
z : Tensor[B x D x T]
Returns
-------
Tensor[B x D x T]
Quantized continuous representation of input
Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
Tensor[1]
Codebook loss to update the codebook
Tensor[B x T]
Codebook indices (quantized discrete representation of input)
Tensor[B x D x T]
Projected latents (continuous representation of input before quantization)
"""
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
z_e = self.in_proj(z) # z_e : (B x D x T)
z_q, indices = self.decode_latents(z_e)
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
z_q = (
z_e + (z_q - z_e).detach()
) # noop in forward pass, straight-through gradient estimator in backward pass
z_q = self.out_proj(z_q)
return z_q, commitment_loss, codebook_loss, indices, z_e
def embed_code(self, embed_id):
return F.embedding(embed_id, self.codebook.weight)
def decode_code(self, embed_id):
return self.embed_code(embed_id).transpose(1, 2)
def decode_latents(self, latents):
encodings = rearrange(latents, "b d t -> (b t) d")
codebook = self.codebook.weight # codebook: (N x D)
# L2 normalize encodings and codebook (ViT-VQGAN)
encodings = F.normalize(encodings)
codebook = F.normalize(codebook)
# Compute euclidean distance with codebook
dist = (
encodings.pow(2).sum(1, keepdim=True)
- 2 * encodings @ codebook.t()
+ codebook.pow(2).sum(1, keepdim=True).t()
)
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
z_q = self.decode_code(indices)
return z_q, indices
class ResidualVectorQuantize(nn.Module):
"""
Introduced in SoundStream: An end2end neural audio codec
https://arxiv.org/abs/2107.03312
"""
def __init__(
self,
input_dim: int = 512,
n_codebooks: int = 9,
codebook_size: int = 1024,
codebook_dim: Union[int, list] = 8,
quantizer_dropout: float = 0.0,
):
super().__init__()
if isinstance(codebook_dim, int):
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
self.n_codebooks = n_codebooks
self.codebook_dim = codebook_dim
self.codebook_size = codebook_size
self.quantizers = nn.ModuleList(
[
VectorQuantize(input_dim, codebook_size, codebook_dim[i])
for i in range(n_codebooks)
]
)
self.quantizer_dropout = quantizer_dropout
def forward(self, z, n_quantizers: int = None):
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
the corresponding codebook vectors
Parameters
----------
z : Tensor[B x D x T]
n_quantizers : int, optional
No. of quantizers to use
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
Note: if `self.quantizer_dropout` is True, this argument is ignored
when in training mode, and a random number of quantizers is used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"vq/commitment_loss" : Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
"vq/codebook_loss" : Tensor[1]
Codebook loss to update the codebook
"""
z_q = 0
residual = z
commitment_loss = 0
codebook_loss = 0
codebook_indices = []
latents = []
if n_quantizers is None:
n_quantizers = self.n_codebooks
if self.training:
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
n_dropout = int(z.shape[0] * self.quantizer_dropout)
n_quantizers[:n_dropout] = dropout[:n_dropout]
n_quantizers = n_quantizers.to(z.device)
for i, quantizer in enumerate(self.quantizers):
if self.training is False and i >= n_quantizers:
break
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
residual
)
# Create mask to apply quantizer dropout
mask = (
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
)
z_q = z_q + z_q_i * mask[:, None, None]
residual = residual - z_q_i
# Sum losses
commitment_loss += (commitment_loss_i * mask).mean()
codebook_loss += (codebook_loss_i * mask).mean()
codebook_indices.append(indices_i)
latents.append(z_e_i)
codes = torch.stack(codebook_indices, dim=1)
latents = torch.cat(latents, dim=1)
return z_q, codes, latents, commitment_loss, codebook_loss
def from_codes(self, codes: torch.Tensor):
"""Given the quantized codes, reconstruct the continuous representation
Parameters
----------
codes : Tensor[B x N x T]
Quantized discrete representation of input
Returns
-------
Tensor[B x D x T]
Quantized continuous representation of input
"""
z_q = 0.0
z_p = []
n_codebooks = codes.shape[1]
for i in range(n_codebooks):
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
z_p.append(z_p_i)
z_q_i = self.quantizers[i].out_proj(z_p_i)
z_q = z_q + z_q_i
return z_q, torch.cat(z_p, dim=1), codes
def from_latents(self, latents: torch.Tensor):
"""Given the unquantized latents, reconstruct the
continuous representation after quantization.
Parameters
----------
latents : Tensor[B x N x T]
Continuous representation of input after projection
Returns
-------
Tensor[B x D x T]
Quantized representation of full-projected space
Tensor[B x D x T]
Quantized representation of latent space
"""
z_q = 0
z_p = []
codes = []
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
0
]
for i in range(n_codebooks):
j, k = dims[i], dims[i + 1]
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
z_p.append(z_p_i)
codes.append(codes_i)
z_q_i = self.quantizers[i].out_proj(z_p_i)
z_q = z_q + z_q_i
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.mean(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2],
)
else:
return 0.5 * torch.mean(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2],
)
def nll(self, sample, dims=[1, 2]):
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
return 0.5 * (
-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
class ResidualUnit(nn.Module):
def __init__(self, dim: int = 16, dilation: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
Snake1d(dim),
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
Snake1d(dim),
WNConv1d(dim, dim, kernel_size=1),
)
def forward(self, x):
y = self.block(x)
pad = (x.shape[-1] - y.shape[-1]) // 2
if pad > 0:
x = x[..., pad:-pad]
return x + y
class EncoderBlock(nn.Module):
def __init__(self, dim: int = 16, stride: int = 1):
super().__init__()
self.block = nn.Sequential(
ResidualUnit(dim // 2, dilation=1),
ResidualUnit(dim // 2, dilation=3),
ResidualUnit(dim // 2, dilation=9),
Snake1d(dim // 2),
WNConv1d(
dim // 2,
dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
),
)
def forward(self, x):
return self.block(x)
class Encoder(nn.Module):
def __init__(
self,
d_model: int = 64,
strides: list = [2, 4, 8, 8],
d_latent: int = 64,
):
super().__init__()
# Create first convolution
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride in strides:
d_model *= 2
self.block += [EncoderBlock(d_model, stride=stride)]
# Create last convolution
self.block += [
Snake1d(d_model),
WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
]
# Wrap black into nn.Sequential
self.block = nn.Sequential(*self.block)
self.enc_dim = d_model
def forward(self, x):
return self.block(x)
class DecoderBlock(nn.Module):
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
super().__init__()
self.block = nn.Sequential(
Snake1d(input_dim),
WNConvTranspose1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
output_padding=stride % 2,
),
ResidualUnit(output_dim, dilation=1),
ResidualUnit(output_dim, dilation=3),
ResidualUnit(output_dim, dilation=9),
)
def forward(self, x):
return self.block(x)
class Decoder(nn.Module):
def __init__(
self,
input_channel,
channels,
rates,
d_out: int = 1,
):
super().__init__()
# Add first conv layer
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
# Add upsampling + MRF blocks
for i, stride in enumerate(rates):
input_dim = channels // 2**i
output_dim = channels // 2 ** (i + 1)
layers += [DecoderBlock(input_dim, output_dim, stride)]
# Add final conv layer
layers += [
Snake1d(output_dim),
WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
nn.Tanh(),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class DacVAE(nn.Module):
def __init__(
self,
encoder_dim: int = 128,
encoder_rates: List[int] = [2, 3, 4, 5, 8],
latent_dim: int = 128,
decoder_dim: int = 2048,
decoder_rates: List[int] = [8, 5, 4, 3, 2],
n_codebooks: int = 9,
codebook_size: int = 1024,
codebook_dim: Union[int, list] = 8,
quantizer_dropout: bool = False,
sample_rate: int = 48000,
continuous: bool = True,
use_weight_norm: bool = False,
):
super().__init__()
self.encoder_dim = encoder_dim
self.encoder_rates = encoder_rates
self.decoder_dim = decoder_dim
self.decoder_rates = decoder_rates
self.sample_rate = sample_rate
self.continuous = continuous
self.use_weight_norm = use_weight_norm
if latent_dim is None:
latent_dim = encoder_dim * (2 ** len(encoder_rates))
self.latent_dim = latent_dim
self.hop_length = np.prod(encoder_rates)
self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
if not continuous:
self.n_codebooks = n_codebooks
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.quantizer = ResidualVectorQuantize(
input_dim=latent_dim,
n_codebooks=n_codebooks,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_dropout=quantizer_dropout,
)
else:
self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1)
self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1)
self.decoder = Decoder(
latent_dim,
decoder_dim,
decoder_rates,
)
self.sample_rate = sample_rate
self.apply(init_weights)
self.delay = self.get_delay()
if not self.use_weight_norm:
self.remove_weight_norm()
def get_delay(self):
# Any number works here, delay is invariant to input length
l_out = self.get_output_length(0)
L = l_out
layers = []
for layer in self.modules():
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
layers.append(layer)
for layer in reversed(layers):
d = layer.dilation[0]
k = layer.kernel_size[0]
s = layer.stride[0]
if isinstance(layer, nn.ConvTranspose1d):
L = ((L - d * (k - 1) - 1) / s) + 1
elif isinstance(layer, nn.Conv1d):
L = (L - 1) * s + d * (k - 1) + 1
L = math.ceil(L)
l_in = L
return (l_in - l_out) // 2
def get_output_length(self, input_length):
L = input_length
# Calculate output length
for layer in self.modules():
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
d = layer.dilation[0]
k = layer.kernel_size[0]
s = layer.stride[0]
if isinstance(layer, nn.Conv1d):
L = ((L - d * (k - 1) - 1) / s) + 1
elif isinstance(layer, nn.ConvTranspose1d):
L = (L - 1) * s + d * (k - 1) + 1
L = math.floor(L)
return L
@property
def dtype(self):
"""Get the dtype of the model parameters."""
# Return the dtype of the first parameter found
for param in self.parameters():
return param.dtype
return torch.float32 # fallback
@property
def device(self):
"""Get the device of the model parameters."""
# Return the device of the first parameter found
for param in self.parameters():
return param.device
return torch.device('cpu') # fallback
def preprocess(self, audio_data, sample_rate):
if sample_rate is None:
sample_rate = self.sample_rate
assert sample_rate == self.sample_rate
length = audio_data.shape[-1]
right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
audio_data = nn.functional.pad(audio_data, (0, right_pad))
return audio_data
def encode(
self,
audio_data: torch.Tensor,
n_quantizers: int = None,
):
"""Encode given audio data and return quantized latent codes
Parameters
----------
audio_data : Tensor[B x 1 x T]
Audio data to encode
n_quantizers : int, optional
Number of quantizers to use, by default None
If None, all quantizers are used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"vq/commitment_loss" : Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
"vq/codebook_loss" : Tensor[1]
Codebook loss to update the codebook
"length" : int
Number of samples in input audio
"""
z = self.encoder(audio_data) # [B x D x T]
if not self.continuous:
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
else:
z = self.quant_conv(z) # [B x 2D x T]
z = DiagonalGaussianDistribution(z)
codes, latents, commitment_loss, codebook_loss = None, None, 0, 0
return z, codes, latents, commitment_loss, codebook_loss
def decode(self, z: torch.Tensor):
"""Decode given latent codes and return audio data
Parameters
----------
z : Tensor[B x D x T]
Quantized continuous representation of input
length : int, optional
Number of samples in output audio, by default None
Returns
-------
dict
A dictionary with the following keys:
"audio" : Tensor[B x 1 x length]
Decoded audio data.
"""
if not self.continuous:
audio = self.decoder(z)
else:
z = self.post_quant_conv(z)
audio = self.decoder(z)
return audio
def forward(
self,
audio_data: torch.Tensor,
sample_rate: int = None,
n_quantizers: int = None,
):
"""Model forward pass
Parameters
----------
audio_data : Tensor[B x 1 x T]
Audio data to encode
sample_rate : int, optional
Sample rate of audio data in Hz, by default None
If None, defaults to `self.sample_rate`
n_quantizers : int, optional
Number of quantizers to use, by default None.
If None, all quantizers are used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"vq/commitment_loss" : Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
"vq/codebook_loss" : Tensor[1]
Codebook loss to update the codebook
"length" : int
Number of samples in input audio
"audio" : Tensor[B x 1 x length]
Decoded audio data.
"""
length = audio_data.shape[-1]
audio_data = self.preprocess(audio_data, sample_rate)
if not self.continuous:
z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
x = self.decode(z)
return {
"audio": x[..., :length],
"z": z,
"codes": codes,
"latents": latents,
"vq/commitment_loss": commitment_loss,
"vq/codebook_loss": codebook_loss,
}
else:
posterior, _, _, _, _ = self.encode(audio_data, n_quantizers)
z = posterior.sample()
x = self.decode(z)
kl_loss = posterior.kl()
kl_loss = kl_loss.mean()
return {
"audio": x[..., :length],
"z": z,
"kl_loss": kl_loss,
}
def remove_weight_norm(self):
"""
Remove weight_norm from all modules in the model.
This fuses the weight_g and weight_v parameters into a single weight parameter.
Should be called before inference for better performance.
Returns:
self: The model with weight_norm removed
"""
from torch.nn.utils import remove_weight_norm
num_removed = 0
for name, module in list(self.named_modules()):
if hasattr(module, "_forward_pre_hooks"):
for hook_id, hook in list(module._forward_pre_hooks.items()):
if "WeightNorm" in str(type(hook)):
try:
remove_weight_norm(module)
num_removed += 1
# print(f"Removed weight_norm from: {name}")
except ValueError as e:
print(f"Failed to remove weight_norm from {name}: {e}")
if num_removed > 0:
# print(f"Successfully removed weight_norm from {num_removed} modules")
self.use_weight_norm = False
else:
print("No weight_norm found in the model")
return self

View File

@@ -1,595 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Tuple, Optional
from einops import rearrange
from .wan_video_dit import AttentionModule, RMSNorm
from ..core import gradient_checkpoint_forward
class RotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, base: float, dim: int, device=None):
super().__init__()
self.base = base
self.dim = dim
self.attention_scaling = 1.0
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
@torch.compile(fullgraph=True)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class PerFrameAttentionPooling(nn.Module):
"""
Per-frame multi-head attention pooling.
Given a flattened token sequence [B, L, D] and grid size (T, H, W), perform a
single-query attention pooling over the H*W tokens for each time frame, producing
[B, T, D].
Inspired by SigLIP's Multihead Attention Pooling head (without MLP/residual stack).
"""
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
super().__init__()
assert dim % num_heads == 0, "dim must be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.probe = nn.Parameter(torch.randn(1, 1, dim))
nn.init.normal_(self.probe, std=0.02)
self.attention = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)
self.layernorm = nn.LayerNorm(dim, eps=eps)
def forward(self, x: torch.Tensor, grid_size: Tuple[int, int, int]) -> torch.Tensor:
"""
Args:
x: [B, L, D], where L = T*H*W
grid_size: (T, H, W)
Returns:
pooled: [B, T, D]
"""
B, L, D = x.shape
T, H, W = grid_size
assert D == self.dim, f"Channel dimension mismatch: D={D} vs dim={self.dim}"
assert L == T * H * W, f"Flattened length mismatch: L={L} vs T*H*W={T*H*W}"
S = H * W
# Re-arrange tokens grouped by frame.
x_bt_s_d = x.view(B, T, S, D).contiguous().view(B * T, S, D) # [B*T, S, D]
# A learnable probe as the query (one query per frame).
probe = self.probe.expand(B * T, -1, -1) # [B*T, 1, D]
# Attention pooling: query=probe, key/value=H*W tokens within the frame.
pooled_bt_1_d = self.attention(probe, x_bt_s_d, x_bt_s_d, need_weights=False)[0] # [B*T, 1, D]
pooled_bt_d = pooled_bt_1_d.squeeze(1) # [B*T, D]
# Restore to [B, T, D].
pooled = pooled_bt_d.view(B, T, D)
pooled = self.layernorm(pooled)
return pooled
class CrossModalInteractionController:
"""
Strategy class that controls interactions between two towers.
Manages the interaction mapping between visual DiT (e.g. 30 layers) and audio DiT (e.g. 30 layers).
"""
def __init__(self, visual_layers: int = 30, audio_layers: int = 30):
self.visual_layers = visual_layers
self.audio_layers = audio_layers
self.min_layers = min(visual_layers, audio_layers)
def get_interaction_layers(self, strategy: str = "shallow_focus") -> Dict[str, List[Tuple[int, int]]]:
"""
Get interaction layer mappings.
Args:
strategy: interaction strategy
- "shallow_focus": emphasize shallow layers to avoid deep-layer asymmetry
- "distributed": distributed interactions across the network
- "progressive": dense shallow interactions, sparse deeper interactions
- "custom": custom interaction layers
Returns:
A dict containing mappings for 'v2a' (visual -> audio) and 'a2v' (audio -> visual).
"""
if strategy == "shallow_focus":
# Emphasize the first ~1/3 layers to avoid deep-layer asymmetry.
num_interact = min(10, self.min_layers // 3)
interact_layers = list(range(0, num_interact))
elif strategy == "distributed":
# Distribute interactions across the network (every few layers).
step = 3
interact_layers = list(range(0, self.min_layers, step))
elif strategy == "progressive":
# Progressive: dense shallow interactions, sparse deeper interactions.
shallow = list(range(0, min(8, self.min_layers))) # Dense for the first 8 layers.
if self.min_layers > 8:
deep = list(range(8, self.min_layers, 3)) # Every 3 layers afterwards.
interact_layers = shallow + deep
else:
interact_layers = shallow
elif strategy == "custom":
# Custom strategy: adjust as needed.
interact_layers = [0, 2, 4, 6, 8, 12, 16, 20] # Explicit layer indices.
interact_layers = [i for i in interact_layers if i < self.min_layers]
elif strategy == "full":
interact_layers = list(range(0, self.min_layers))
else:
raise ValueError(f"Unknown interaction strategy: {strategy}")
# Build bidirectional mapping.
mapping = {
'v2a': [(i, i) for i in interact_layers], # visual layer i -> audio layer i
'a2v': [(i, i) for i in interact_layers] # audio layer i -> visual layer i
}
return mapping
def should_interact(self, layer_idx: int, direction: str, interaction_mapping: Dict) -> bool:
"""
Check whether a given layer should interact.
Args:
layer_idx: current layer index
direction: interaction direction ('v2a' or 'a2v')
interaction_mapping: interaction mapping table
Returns:
bool: whether to interact
"""
if direction not in interaction_mapping:
return False
return any(src == layer_idx for src, _ in interaction_mapping[direction])
class ConditionalCrossAttention(nn.Module):
def __init__(self, dim: int, kv_dim: int, num_heads: int, eps: float = 1e-6):
super().__init__()
self.q_dim = dim
self.kv_dim = kv_dim
self.num_heads = num_heads
self.head_dim = self.q_dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(kv_dim, dim)
self.v = nn.Linear(kv_dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads)
def forward(self, x: torch.Tensor, y: torch.Tensor, x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
ctx = y
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(ctx))
v = self.v(ctx)
if x_freqs is not None:
x_cos, x_sin = x_freqs
B, L, _ = q.shape
q_view = rearrange(q, 'b l (h d) -> b l h d', d=self.head_dim)
x_cos = x_cos.to(q_view.dtype).to(q_view.device)
x_sin = x_sin.to(q_view.dtype).to(q_view.device)
# Expect x_cos/x_sin shape: [B or 1, L, head_dim]
q_view, _ = apply_rotary_pos_emb(q_view, q_view, x_cos, x_sin, unsqueeze_dim=2)
q = rearrange(q_view, 'b l h d -> b l (h d)')
if y_freqs is not None:
y_cos, y_sin = y_freqs
Bc, Lc, _ = k.shape
k_view = rearrange(k, 'b l (h d) -> b l h d', d=self.head_dim)
y_cos = y_cos.to(k_view.dtype).to(k_view.device)
y_sin = y_sin.to(k_view.dtype).to(k_view.device)
# Expect y_cos/y_sin shape: [B or 1, L, head_dim]
_, k_view = apply_rotary_pos_emb(k_view, k_view, y_cos, y_sin, unsqueeze_dim=2)
k = rearrange(k_view, 'b l h d -> b l (h d)')
x = self.attn(q, k, v)
return self.o(x)
# from diffusers.models.attention import AdaLayerNorm
class AdaLayerNorm(nn.Module):
r"""
Norm layer modified to incorporate timestep embeddings.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
output_dim (`int`, *optional*):
norm_elementwise_affine (`bool`, defaults to `False):
norm_eps (`bool`, defaults to `False`):
chunk_dim (`int`, defaults to `0`):
"""
def __init__(
self,
embedding_dim: int,
num_embeddings: Optional[int] = None,
output_dim: Optional[int] = None,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-5,
chunk_dim: int = 0,
):
super().__init__()
self.chunk_dim = chunk_dim
output_dim = output_dim or embedding_dim * 2
if num_embeddings is not None:
self.emb = nn.Embedding(num_embeddings, embedding_dim)
else:
self.emb = None
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, output_dim)
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
def forward(
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
) -> torch.Tensor:
if self.emb is not None:
temb = self.emb(timestep)
temb = self.linear(self.silu(temb))
if self.chunk_dim == 2:
scale, shift = temb.chunk(2, dim=2)
# print(f"{x.shape = }, {scale.shape = }, {shift.shape = }")
elif self.chunk_dim == 1:
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
# other if-branch. This branch is specific to CogVideoX and OmniGen for now.
shift, scale = temb.chunk(2, dim=1)
shift = shift[:, None, :]
scale = scale[:, None, :]
else:
scale, shift = temb.chunk(2, dim=0)
x = self.norm(x) * (1 + scale) + shift
return x
class ConditionalCrossAttentionBlock(nn.Module):
"""
A thin wrapper around ConditionalCrossAttention.
Applies LayerNorm to the conditioning input `y` before cross-attention.
"""
def __init__(self, dim: int, kv_dim: int, num_heads: int, eps: float = 1e-6, pooled_adaln: bool = False):
super().__init__()
self.y_norm = nn.LayerNorm(kv_dim, eps=eps)
self.inner = ConditionalCrossAttention(dim=dim, kv_dim=kv_dim, num_heads=num_heads, eps=eps)
self.pooled_adaln = pooled_adaln
if pooled_adaln:
self.per_frame_pooling = PerFrameAttentionPooling(kv_dim, num_heads=num_heads, eps=eps)
self.adaln = AdaLayerNorm(kv_dim, output_dim=dim*2, chunk_dim=2)
def forward(
self,
x: torch.Tensor,
y: torch.Tensor,
x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
video_grid_size: Optional[Tuple[int, int, int]] = None,
) -> torch.Tensor:
if self.pooled_adaln:
assert video_grid_size is not None, "video_grid_size must not be None"
pooled_y = self.per_frame_pooling(y, video_grid_size)
# Interpolate pooled_y along its temporal dimension to match x's sequence length.
if pooled_y.shape[1] != x.shape[1]:
pooled_y = F.interpolate(
pooled_y.permute(0, 2, 1), # [B, C, T]
size=x.shape[1],
mode='linear',
align_corners=False,
).permute(0, 2, 1) # [B, T, C]
x = self.adaln(x, temb=pooled_y)
y = self.y_norm(y)
return self.inner(x=x, y=y, x_freqs=x_freqs, y_freqs=y_freqs)
class DualTowerConditionalBridge(nn.Module):
"""
Dual-tower conditional bridge.
"""
def __init__(self,
visual_layers: int = 40,
audio_layers: int = 30,
visual_hidden_dim: int = 5120, # visual DiT hidden state dimension
audio_hidden_dim: int = 1536, # audio DiT hidden state dimension
audio_fps: float = 50.0,
head_dim: int = 128, # attention head dimension
interaction_strategy: str = "full",
apply_cross_rope: bool = True, # whether to apply RoPE in cross-attention
apply_first_frame_bias_in_rope: bool = False, # whether to account for 1/video_fps bias for the first frame in RoPE alignment
trainable_condition_scale: bool = False,
pooled_adaln: bool = False,
):
super().__init__()
self.visual_hidden_dim = visual_hidden_dim
self.audio_hidden_dim = audio_hidden_dim
self.audio_fps = audio_fps
self.head_dim = head_dim
self.apply_cross_rope = apply_cross_rope
self.apply_first_frame_bias_in_rope = apply_first_frame_bias_in_rope
self.trainable_condition_scale = trainable_condition_scale
self.pooled_adaln = pooled_adaln
if self.trainable_condition_scale:
self.condition_scale = nn.Parameter(torch.tensor([1.0], dtype=torch.float32))
else:
self.condition_scale = 1.0
self.controller = CrossModalInteractionController(visual_layers, audio_layers)
self.interaction_mapping = self.controller.get_interaction_layers(interaction_strategy)
# Conditional cross-attention modules operating at the DiT hidden-state level.
self.audio_to_video_conditioners = nn.ModuleDict() # audio hidden states -> visual DiT conditioning
self.video_to_audio_conditioners = nn.ModuleDict() # visual hidden states -> audio DiT conditioning
# Build conditioners for layers that should interact.
# audio hidden states condition the visual DiT
self.rotary = RotaryEmbedding(base=10000.0, dim=head_dim)
for v_layer, _ in self.interaction_mapping['a2v']:
self.audio_to_video_conditioners[str(v_layer)] = ConditionalCrossAttentionBlock(
dim=visual_hidden_dim, # 3072 (visual DiT hidden states)
kv_dim=audio_hidden_dim, # 1536 (audio DiT hidden states)
num_heads=visual_hidden_dim // head_dim, # derive number of heads from hidden dim
pooled_adaln=False # a2v typically does not need pooled AdaLN
)
# visual hidden states condition the audio DiT
for a_layer, _ in self.interaction_mapping['v2a']:
self.video_to_audio_conditioners[str(a_layer)] = ConditionalCrossAttentionBlock(
dim=audio_hidden_dim, # 1536 (audio DiT hidden states)
kv_dim=visual_hidden_dim, # 3072 (visual DiT hidden states)
num_heads=audio_hidden_dim // head_dim, # safe head count derivation
pooled_adaln=self.pooled_adaln
)
@torch.no_grad()
def build_aligned_freqs(self,
video_fps: float,
grid_size: Tuple[int, int, int],
audio_steps: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
"""
Build aligned RoPE (cos, sin) pairs based on video fps, video grid size (f_v, h, w),
and audio sequence length `audio_steps` (with fixed audio fps = 44100/2048).
Returns:
visual_freqs: (cos_v, sin_v), shape [1, f_v*h*w, head_dim]
audio_freqs: (cos_a, sin_a), shape [1, audio_steps, head_dim]
"""
f_v, h, w = grid_size
L_v = f_v * h * w
L_a = int(audio_steps)
device = device or next(self.parameters()).device
dtype = dtype or torch.float32
# Audio positions: 0,1,2,...,L_a-1 (audio as reference).
audio_pos = torch.arange(L_a, device=device, dtype=torch.float32).unsqueeze(0)
# Video positions: align video frames to audio-step units.
# FIXME(dhyu): hard-coded VAE temporal stride = 4
if self.apply_first_frame_bias_in_rope:
# Account for the "first frame lasts 1/video_fps" bias.
video_effective_fps = float(video_fps) / 4.0
if f_v > 0:
t_starts = torch.zeros((f_v,), device=device, dtype=torch.float32)
if f_v > 1:
t_starts[1:] = (1.0 / float(video_fps)) + torch.arange(f_v - 1, device=device, dtype=torch.float32) * (1.0 / video_effective_fps)
else:
t_starts = torch.zeros((0,), device=device, dtype=torch.float32)
# Convert to audio-step units.
video_pos_per_frame = t_starts * float(self.audio_fps)
else:
# No first-frame bias: uniform alignment.
scale = float(self.audio_fps) / float(video_fps / 4.0)
video_pos_per_frame = torch.arange(f_v, device=device, dtype=torch.float32) * scale
# Flatten to f*h*w; tokens within the same frame share the same time position.
video_pos = video_pos_per_frame.repeat_interleave(h * w).unsqueeze(0)
# print(f"video fps: {video_fps}, audio fps: {self.audio_fps}, scale: {scale}")
# print(f"video pos: {video_pos.shape}, audio pos: {audio_pos.shape}")
# Build dummy x to produce cos/sin, dim=head_dim.
dummy_v = torch.zeros((1, L_v, self.head_dim), device=device, dtype=dtype)
dummy_a = torch.zeros((1, L_a, self.head_dim), device=device, dtype=dtype)
cos_v, sin_v = self.rotary(dummy_v, position_ids=video_pos)
cos_a, sin_a = self.rotary(dummy_a, position_ids=audio_pos)
return (cos_v, sin_v), (cos_a, sin_a)
def should_interact(self, layer_idx: int, direction: str) -> bool:
return self.controller.should_interact(layer_idx, direction, self.interaction_mapping)
def apply_conditional_control(
self,
layer_idx: int,
direction: str,
primary_hidden_states: torch.Tensor,
condition_hidden_states: torch.Tensor,
x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
condition_scale: Optional[float] = None,
video_grid_size: Optional[Tuple[int, int, int]] = None,
use_gradient_checkpointing: Optional[bool] = False,
use_gradient_checkpointing_offload: Optional[bool] = False,
) -> torch.Tensor:
"""
Apply conditional control (at the DiT hidden-state level).
Args:
layer_idx: current layer index
direction: conditioning direction
- 'a2v': audio hidden states -> visual DiT
- 'v2a': visual hidden states -> audio DiT
primary_hidden_states: primary DiT hidden states [B, L, hidden_dim]
condition_hidden_states: condition DiT hidden states [B, L, hidden_dim]
condition_scale: conditioning strength (similar to CFG scale)
Returns:
Conditioned primary DiT hidden states [B, L, hidden_dim]
"""
if not self.controller.should_interact(layer_idx, direction, self.interaction_mapping):
return primary_hidden_states
if direction == 'a2v':
# audio hidden states condition the visual DiT
conditioner = self.audio_to_video_conditioners[str(layer_idx)]
elif direction == 'v2a':
# visual hidden states condition the audio DiT
conditioner = self.video_to_audio_conditioners[str(layer_idx)]
else:
raise ValueError(f"Invalid direction: {direction}")
conditioned_features = gradient_checkpoint_forward(
conditioner,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x=primary_hidden_states,
y=condition_hidden_states,
x_freqs=x_freqs,
y_freqs=y_freqs,
video_grid_size=video_grid_size,
)
if self.trainable_condition_scale and condition_scale is not None:
print(
"[WARN] This model has a trainable condition_scale, but an external "
f"condition_scale={condition_scale} was provided. The trainable condition_scale "
"will be ignored in favor of the external value."
)
scale = condition_scale if condition_scale is not None else self.condition_scale
primary_hidden_states = primary_hidden_states + conditioned_features * scale
return primary_hidden_states
def forward(
self,
layer_idx: int,
visual_hidden_states: torch.Tensor,
audio_hidden_states: torch.Tensor,
*,
x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
a2v_condition_scale: Optional[float] = None,
v2a_condition_scale: Optional[float] = None,
condition_scale: Optional[float] = None,
video_grid_size: Optional[Tuple[int, int, int]] = None,
use_gradient_checkpointing: Optional[bool] = False,
use_gradient_checkpointing_offload: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply bidirectional conditional control to both visual/audio towers.
Args:
layer_idx: current layer index
visual_hidden_states: visual DiT hidden states
audio_hidden_states: audio DiT hidden states
x_freqs / y_freqs: cross-modal RoPE (cos, sin) pairs.
If provided, x_freqs is assumed to correspond to the primary tower and y_freqs
to the conditioning tower.
a2v_condition_scale: audio->visual conditioning strength (overrides global condition_scale)
v2a_condition_scale: visual->audio conditioning strength (overrides global condition_scale)
condition_scale: fallback conditioning strength when per-direction scale is None
video_grid_size: (F, H, W), used on the audio side when pooled_adaln is enabled
Returns:
(visual_hidden_states, audio_hidden_states), both conditioned in their respective directions.
"""
visual_conditioned = self.apply_conditional_control(
layer_idx=layer_idx,
direction="a2v",
primary_hidden_states=visual_hidden_states,
condition_hidden_states=audio_hidden_states,
x_freqs=x_freqs,
y_freqs=y_freqs,
condition_scale=a2v_condition_scale if a2v_condition_scale is not None else condition_scale,
video_grid_size=video_grid_size,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
audio_conditioned = self.apply_conditional_control(
layer_idx=layer_idx,
direction="v2a",
primary_hidden_states=audio_hidden_states,
condition_hidden_states=visual_hidden_states,
x_freqs=y_freqs,
y_freqs=x_freqs,
condition_scale=v2a_condition_scale if v2a_condition_scale is not None else condition_scale,
video_grid_size=video_grid_size,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
return visual_conditioned, audio_conditioned

View File

@@ -549,9 +549,6 @@ class QwenImageTransformerBlock(nn.Module):
class QwenImageDiT(torch.nn.Module):
_repeated_blocks = ["QwenImageTransformerBlock"]
def __init__(
self,
num_layers: int = 60,

View File

@@ -1,78 +0,0 @@
import torch
class SDTextEncoder(torch.nn.Module):
def __init__(
self,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
max_position_embeddings=77,
vocab_size=49408,
layer_norm_eps=1e-05,
hidden_act="quick_gelu",
initializer_factor=1.0,
initializer_range=0.02,
bos_token_id=0,
eos_token_id=2,
pad_token_id=1,
projection_dim=768,
):
super().__init__()
from transformers import CLIPConfig, CLIPTextModel
config = CLIPConfig(
text_config={
"hidden_size": hidden_size,
"intermediate_size": intermediate_size,
"num_hidden_layers": num_hidden_layers,
"num_attention_heads": num_attention_heads,
"max_position_embeddings": max_position_embeddings,
"vocab_size": vocab_size,
"layer_norm_eps": layer_norm_eps,
"hidden_act": hidden_act,
"initializer_factor": initializer_factor,
"initializer_range": initializer_range,
"bos_token_id": bos_token_id,
"eos_token_id": eos_token_id,
"pad_token_id": pad_token_id,
"projection_dim": projection_dim,
"dropout": 0.0,
},
vision_config={
"hidden_size": hidden_size,
"intermediate_size": intermediate_size,
"num_hidden_layers": num_hidden_layers,
"num_attention_heads": num_attention_heads,
"max_position_embeddings": max_position_embeddings,
"layer_norm_eps": layer_norm_eps,
"hidden_act": hidden_act,
"initializer_factor": initializer_factor,
"initializer_range": initializer_range,
"projection_dim": projection_dim,
},
projection_dim=projection_dim,
)
self.model = CLIPTextModel(config.text_config)
self.config = config
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
output_hidden_states=True,
**kwargs,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_hidden_states=output_hidden_states,
return_dict=True,
**kwargs,
)
if output_hidden_states:
return outputs.last_hidden_state, outputs.hidden_states
return outputs.last_hidden_state

View File

@@ -1,912 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional
# ===== Time Embedding =====
class Timesteps(nn.Module):
def __init__(self, num_channels, flip_sin_to_cos=True, freq_shift=0):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.freq_shift = freq_shift
def forward(self, timesteps):
half_dim = self.num_channels // 2
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / half_dim + self.freq_shift
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
sin_emb = torch.sin(emb)
cos_emb = torch.cos(emb)
if self.flip_sin_to_cos:
emb = torch.cat([cos_emb, sin_emb], dim=-1)
else:
emb = torch.cat([sin_emb, cos_emb], dim=-1)
return emb
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels, time_embed_dim, act_fn="silu", out_dim=None):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.act = nn.SiLU() if act_fn == "silu" else nn.GELU()
out_dim = out_dim if out_dim is not None else time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, out_dim)
def forward(self, sample):
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
# ===== ResNet Blocks =====
class ResnetBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
output_scale_factor=1.0,
use_in_shortcut=None,
):
super().__init__()
self.pre_norm = pre_norm
self.time_embedding_norm = time_embedding_norm
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps)
self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels))
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nn.SiLU()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
elif non_linearity == "gelu":
self.nonlinearity = nn.GELU()
elif non_linearity == "relu":
self.nonlinearity = nn.ReLU()
self.use_conv_shortcut = conv_shortcut
self.conv_shortcut = None
if conv_shortcut:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None
def forward(self, input_tensor, temb=None):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1)
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
# ===== Transformer Blocks =====
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, dropout=0.0):
super().__init__()
self.net = nn.ModuleList([
GEGLU(dim, dim * 4),
nn.Dropout(dropout),
nn.Linear(dim * 4, dim if dim_out is None else dim_out),
])
def forward(self, hidden_states):
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class Attention(nn.Module):
"""Attention block matching diffusers checkpoint key format.
Keys: to_q.weight, to_k.weight, to_v.weight, to_out.0.weight, to_out.0.bias
"""
def __init__(
self,
query_dim,
heads=8,
dim_head=64,
dropout=0.0,
bias=False,
upcast_attention=False,
cross_attention_dim=None,
):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.inner_dim = inner_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
self.to_out = nn.ModuleList([
nn.Linear(inner_dim, query_dim, bias=True),
nn.Dropout(dropout),
])
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
# Query
query = self.to_q(hidden_states)
batch_size, seq_len, _ = query.shape
# Key/Value
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
# Reshape for multi-head attention
head_dim = self.inner_dim // self.heads
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
# Scaled dot-product attention
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# Reshape back
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.inner_dim)
hidden_states = hidden_states.to(query.dtype)
# Output projection
hidden_states = self.to_out[0](hidden_states)
hidden_states = self.to_out[1](hidden_states)
return hidden_states
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
cross_attention_dim=None,
upcast_attention=False,
):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn1 = Attention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
bias=False,
upcast_attention=upcast_attention,
)
self.norm2 = nn.LayerNorm(dim)
self.attn2 = Attention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
bias=False,
upcast_attention=upcast_attention,
cross_attention_dim=cross_attention_dim,
)
self.norm3 = nn.LayerNorm(dim)
self.ff = FeedForward(dim, dropout=dropout)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
# Self-attention
attn_output = self.attn1(self.norm1(hidden_states))
hidden_states = attn_output + hidden_states
# Cross-attention
attn_output = self.attn2(self.norm2(hidden_states), encoder_hidden_states=encoder_hidden_states)
hidden_states = attn_output + hidden_states
# Feed-forward
ff_output = self.ff(self.norm3(hidden_states))
hidden_states = ff_output + hidden_states
return hidden_states
class Transformer2DModel(nn.Module):
"""2D Transformer block wrapper matching diffusers checkpoint structure.
Keys: norm.weight/bias, proj_in.weight/bias, transformer_blocks.X.*, proj_out.weight/bias
"""
def __init__(
self,
num_attention_heads=16,
attention_head_dim=64,
in_channels=320,
num_layers=1,
dropout=0.0,
norm_num_groups=32,
cross_attention_dim=768,
upcast_attention=False,
):
super().__init__()
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6)
self.proj_in = nn.Conv2d(in_channels, num_attention_heads * attention_head_dim, kernel_size=1, bias=True)
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(
dim=num_attention_heads * attention_head_dim,
n_heads=num_attention_heads,
d_head=attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
)
for _ in range(num_layers)
])
self.proj_out = nn.Conv2d(num_attention_heads * attention_head_dim, in_channels, kernel_size=1, bias=True)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch, channel, height, width = hidden_states.shape
residual = hidden_states
# Normalize and project to sequence
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel)
# Transformer blocks
for block in self.transformer_blocks:
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
# Project back to 2D
hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
# ===== Down/Up Blocks =====
class CrossAttnDownBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
temb_channels=1280,
dropout=0.0,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
cross_attention_dim=768,
attention_head_dim=1,
downsample=True,
):
super().__init__()
self.has_cross_attention = True
resnets = []
attentions = []
for i in range(num_layers):
in_channels_i = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels_i,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
Transformer2DModel(
num_attention_heads=attention_head_dim,
attention_head_dim=out_channels // attention_head_dim,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
dropout=dropout,
norm_num_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if downsample:
self.downsamplers = nn.ModuleList([
Downsample2D(out_channels, out_channels, padding=1)
])
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = []
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
output_states.append(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states.append(hidden_states)
return hidden_states, tuple(output_states)
class DownBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
temb_channels=1280,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
downsample=True,
):
super().__init__()
self.has_cross_attention = False
resnets = []
for i in range(num_layers):
in_channels_i = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels_i,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if downsample:
self.downsamplers = nn.ModuleList([
Downsample2D(out_channels, out_channels, padding=1)
])
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = []
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb)
output_states.append(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states.append(hidden_states)
return hidden_states, tuple(output_states)
class CrossAttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
prev_output_channel,
temb_channels=1280,
dropout=0.0,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
cross_attention_dim=768,
attention_head_dim=1,
upsample=True,
):
super().__init__()
self.has_cross_attention = True
resnets = []
attentions = []
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
Transformer2DModel(
num_attention_heads=attention_head_dim,
attention_head_dim=out_channels // attention_head_dim,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
dropout=dropout,
norm_num_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if upsample:
self.upsamplers = nn.ModuleList([
Upsample2D(out_channels, out_channels)
])
else:
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
for resnet, attn in zip(self.resnets, self.attentions):
# Pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
return hidden_states
class UpBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
prev_output_channel,
temb_channels=1280,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
upsample=True,
):
super().__init__()
self.has_cross_attention = False
resnets = []
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if upsample:
self.upsamplers = nn.ModuleList([
Upsample2D(out_channels, out_channels)
])
else:
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
for resnet in self.resnets:
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
return hidden_states
# ===== UNet Mid Block =====
class UNetMidBlock2DCrossAttn(nn.Module):
def __init__(
self,
in_channels,
temb_channels=1280,
dropout=0.0,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
cross_attention_dim=768,
attention_head_dim=1,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# There is always at least one resnet
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
]
attentions = []
for _ in range(num_layers):
attentions.append(
Transformer2DModel(
num_attention_heads=attention_head_dim,
attention_head_dim=in_channels // attention_head_dim,
in_channels=in_channels,
num_layers=transformer_layers_per_block,
dropout=dropout,
norm_num_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
)
)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = resnet(hidden_states, temb)
return hidden_states
# ===== Downsample / Upsample =====
class Downsample2D(nn.Module):
def __init__(self, in_channels, out_channels, padding=1):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=padding)
self.padding = padding
def forward(self, hidden_states):
if self.padding == 0:
hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
return self.conv(hidden_states)
class Upsample2D(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, hidden_states, upsample_size=None):
if upsample_size is not None:
hidden_states = F.interpolate(hidden_states, size=upsample_size, mode="nearest")
else:
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
return self.conv(hidden_states)
# ===== UNet2DConditionModel =====
class UNet2DConditionModel(nn.Module):
"""Stable Diffusion UNet with cross-attention conditioning.
state_dict keys match the diffusers UNet2DConditionModel checkpoint format.
"""
def __init__(
self,
sample_size=64,
in_channels=4,
out_channels=4,
down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
block_out_channels=(320, 640, 1280, 1280),
layers_per_block=2,
cross_attention_dim=768,
attention_head_dim=8,
norm_num_groups=32,
norm_eps=1e-5,
dropout=0.0,
act_fn="silu",
time_embedding_type="positional",
flip_sin_to_cos=True,
freq_shift=0,
time_embedding_dim=None,
resnet_time_scale_shift="default",
upcast_attention=False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.sample_size = sample_size
# Time embedding
timestep_embedding_dim = time_embedding_dim or block_out_channels[0]
self.time_proj = Timesteps(timestep_embedding_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift)
time_embed_dim = block_out_channels[0] * 4
self.time_embedding = TimestepEmbedding(timestep_embedding_dim, time_embed_dim)
# Input
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1)
# Down blocks
self.down_blocks = nn.ModuleList()
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
if "CrossAttn" in down_block_type:
down_block = CrossAttnDownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block,
transformer_layers_per_block=1,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
downsample=not is_final_block,
)
else:
down_block = DownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
downsample=not is_final_block,
)
self.down_blocks.append(down_block)
# Mid block
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
)
# Up blocks
self.up_blocks = nn.ModuleList()
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
# in_channels for up blocks: diffusers uses reversed_block_out_channels[min(i+1, len-1)]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
if "CrossAttn" in up_block_type:
up_block = CrossAttnUpBlock2D(
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block + 1,
transformer_layers_per_block=1,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
upsample=not is_final_block,
)
else:
up_block = UpBlock2D(
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block + 1,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
upsample=not is_final_block,
)
self.up_blocks.append(up_block)
# Output
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None, timestep_cond=None, added_cond_kwargs=None, return_dict=True):
# 1. Time embedding
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
t_emb = self.time_proj(timesteps)
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb)
# 2. Pre-process
sample = self.conv_in(sample)
# 3. Down
down_block_res_samples = (sample,)
for down_block in self.down_blocks:
sample, res_samples = down_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
)
down_block_res_samples += res_samples
# 4. Mid
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
# 5. Up
for up_block in self.up_blocks:
res_samples = down_block_res_samples[-len(up_block.resnets):]
down_block_res_samples = down_block_res_samples[:-len(up_block.resnets)]
upsample_size = down_block_res_samples[-1].shape[2:] if down_block_res_samples else None
sample = up_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
)
# 6. Post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if not return_dict:
return (sample,)
return sample

View File

@@ -1,642 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from typing import Optional
class DiagonalGaussianDistribution:
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
)
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
# randn_like doesn't accept generator on all torch versions
sample = torch.randn(self.mean.shape, generator=generator,
device=self.parameters.device, dtype=self.parameters.dtype)
return self.mean + self.std * sample
def kl(self, other: Optional["DiagonalGaussianDistribution"] = None) -> torch.Tensor:
if self.deterministic:
return torch.tensor([0.0])
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3],
)
def mode(self) -> torch.Tensor:
return self.mean
class ResnetBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
output_scale_factor=1.0,
use_in_shortcut=None,
):
super().__init__()
self.pre_norm = pre_norm
self.time_embedding_norm = time_embedding_norm
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps)
self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels))
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nn.SiLU()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
elif non_linearity == "gelu":
self.nonlinearity = nn.GELU()
elif non_linearity == "relu":
self.nonlinearity = nn.ReLU()
else:
raise ValueError(f"Unsupported non_linearity: {non_linearity}")
self.use_conv_shortcut = conv_shortcut
self.conv_shortcut = None
if conv_shortcut:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None
def forward(self, input_tensor, temb=None):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1)
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
class DownEncoderBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
output_scale_factor=1.0,
add_downsample=True,
downsample_padding=1,
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channels_i = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels_i,
out_channels=out_channels,
temb_channels=None,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList([
Downsample2D(out_channels, out_channels, padding=downsample_padding)
])
else:
self.downsamplers = None
def forward(self, hidden_states, *args, **kwargs):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
class UpDecoderBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
output_scale_factor=1.0,
add_upsample=True,
temb_channels=None,
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channels_i = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels_i,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList([
Upsample2D(out_channels, out_channels)
])
else:
self.upsamplers = None
def forward(self, hidden_states, temb=None):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class UNetMidBlock2D(nn.Module):
def __init__(
self,
in_channels,
temb_channels=None,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
add_attention=True,
attention_head_dim=1,
output_scale_factor=1.0,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.add_attention = add_attention
# there is always at least one resnet
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
attentions = []
if attention_head_dim is None:
attention_head_dim = in_channels
for _ in range(num_layers):
if self.add_attention:
attentions.append(
AttentionBlock(
in_channels,
num_groups=resnet_groups,
eps=resnet_eps,
)
)
else:
attentions.append(None)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
hidden_states = attn(hidden_states)
hidden_states = resnet(hidden_states, temb)
return hidden_states
class AttentionBlock(nn.Module):
"""Simple attention block for VAE mid block.
Mirrors diffusers Attention class with AttnProcessor2_0 for VAE use case.
Uses modern key names (to_q, to_k, to_v, to_out) matching in-memory diffusers structure.
Checkpoint uses deprecated keys (query, key, value, proj_attn) — mapped via converter.
"""
def __init__(self, channels, num_groups=32, eps=1e-6):
super().__init__()
self.channels = channels
self.eps = eps
self.heads = 1
self.rescale_output_factor = 1.0
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=eps, affine=True)
self.to_q = nn.Linear(channels, channels, bias=True)
self.to_k = nn.Linear(channels, channels, bias=True)
self.to_v = nn.Linear(channels, channels, bias=True)
self.to_out = nn.ModuleList([
nn.Linear(channels, channels, bias=True),
nn.Dropout(0.0),
])
def forward(self, hidden_states):
residual = hidden_states
# Group norm
hidden_states = self.group_norm(hidden_states)
# Flatten spatial dims: (B, C, H, W) -> (B, H*W, C)
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# QKV projection
query = self.to_q(hidden_states)
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
# Reshape for attention: (B, seq, dim) -> (B, heads, seq, head_dim)
inner_dim = key.shape[-1]
head_dim = inner_dim // self.heads
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
# Scaled dot-product attention
hidden_states = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# Reshape back: (B, heads, seq, head_dim) -> (B, seq, heads*head_dim)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# Output projection + dropout
hidden_states = self.to_out[0](hidden_states)
hidden_states = self.to_out[1](hidden_states)
# Reshape back to 4D and add residual
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
hidden_states = hidden_states + residual
# Rescale output factor
hidden_states = hidden_states / self.rescale_output_factor
return hidden_states
class Downsample2D(nn.Module):
"""Downsampling layer matching diffusers Downsample2D with use_conv=True.
Key names: conv.weight/bias.
When padding=0, applies explicit F.pad before conv to match dimension.
"""
def __init__(self, in_channels, out_channels, padding=1):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
self.padding = padding
def forward(self, hidden_states):
if self.padding == 0:
import torch.nn.functional as F
hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
return self.conv(hidden_states)
class Upsample2D(nn.Module):
"""Upsampling layer with key names matching diffusers checkpoint: conv.weight/bias."""
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, hidden_states):
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
return self.conv(hidden_states)
class Encoder(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
double_z=True,
mid_block_add_attention=True,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
self.down_blocks = nn.ModuleList([])
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = DownEncoderBlock2D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=self.layers_per_block,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
add_downsample=not is_final_block,
downsample_padding=0,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
add_attention=mid_block_add_attention,
)
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
def forward(self, sample):
sample = self.conv_in(sample)
for down_block in self.down_blocks:
sample = down_block(sample)
sample = self.mid_block(sample)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class Decoder(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
norm_type="group",
mid_block_add_attention=True,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
self.up_blocks = nn.ModuleList([])
temb_channels = in_channels if norm_type == "spatial" else None
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=temb_channels,
add_attention=mid_block_add_attention,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = UpDecoderBlock2D(
in_channels=prev_output_channel,
out_channels=output_channel,
num_layers=self.layers_per_block + 1,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
add_upsample=not is_final_block,
temb_channels=temb_channels,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
def forward(self, sample, latent_embeds=None):
sample = self.conv_in(sample)
sample = self.mid_block(sample, latent_embeds)
for up_block in self.up_blocks:
sample = up_block(sample, latent_embeds)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class StableDiffusionVAE(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"),
up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"),
block_out_channels=(128, 256, 512, 512),
layers_per_block=2,
act_fn="silu",
latent_channels=4,
norm_num_groups=32,
sample_size=512,
scaling_factor=0.18215,
shift_factor=None,
latents_mean=None,
latents_std=None,
force_upcast=True,
use_quant_conv=True,
use_post_quant_conv=True,
mid_block_add_attention=True,
):
super().__init__()
self.encoder = Encoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
double_z=True,
mid_block_add_attention=mid_block_add_attention,
)
self.decoder = Decoder(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
mid_block_add_attention=mid_block_add_attention,
)
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
self.latents_mean = latents_mean
self.latents_std = latents_std
self.scaling_factor = scaling_factor
self.shift_factor = shift_factor
self.sample_size = sample_size
self.force_upcast = force_upcast
def _encode(self, x):
h = self.encoder(x)
if self.quant_conv is not None:
h = self.quant_conv(h)
return h
def encode(self, x):
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
return posterior
def _decode(self, z):
if self.post_quant_conv is not None:
z = self.post_quant_conv(z)
return self.decoder(z)
def decode(self, z):
return self._decode(z)
def forward(self, sample, sample_posterior=True, return_dict=True, generator=None):
posterior = self.encode(sample)
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
# Scale latent
z = z * self.scaling_factor
decode = self.decode(z)
if return_dict:
return {"sample": decode, "posterior": posterior, "latent_sample": z}
return decode, posterior

View File

@@ -1,62 +0,0 @@
import torch
class SDXLTextEncoder2(torch.nn.Module):
def __init__(
self,
hidden_size=1280,
intermediate_size=5120,
num_hidden_layers=32,
num_attention_heads=20,
max_position_embeddings=77,
vocab_size=49408,
layer_norm_eps=1e-05,
hidden_act="gelu",
initializer_factor=1.0,
initializer_range=0.02,
bos_token_id=0,
eos_token_id=2,
pad_token_id=1,
projection_dim=1280,
):
super().__init__()
from transformers import CLIPTextConfig, CLIPTextModelWithProjection
config = CLIPTextConfig(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
max_position_embeddings=max_position_embeddings,
vocab_size=vocab_size,
layer_norm_eps=layer_norm_eps,
hidden_act=hidden_act,
initializer_factor=initializer_factor,
initializer_range=initializer_range,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
projection_dim=projection_dim,
)
self.model = CLIPTextModelWithProjection(config)
self.config = config
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
output_hidden_states=True,
**kwargs,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_hidden_states=output_hidden_states,
return_dict=True,
**kwargs,
)
if output_hidden_states:
return outputs.text_embeds, outputs.hidden_states
return outputs.text_embeds

View File

@@ -1,922 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional
# ===== Time Embedding =====
class Timesteps(nn.Module):
def __init__(self, num_channels, flip_sin_to_cos=True, freq_shift=0):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.freq_shift = freq_shift
def forward(self, timesteps):
half_dim = self.num_channels // 2
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / half_dim + self.freq_shift
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
sin_emb = torch.sin(emb)
cos_emb = torch.cos(emb)
if self.flip_sin_to_cos:
emb = torch.cat([cos_emb, sin_emb], dim=-1)
else:
emb = torch.cat([sin_emb, cos_emb], dim=-1)
return emb
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels, time_embed_dim, act_fn="silu", out_dim=None):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.act = nn.SiLU() if act_fn == "silu" else nn.GELU()
out_dim = out_dim if out_dim is not None else time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, out_dim)
def forward(self, sample):
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
# ===== ResNet Blocks =====
class ResnetBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
output_scale_factor=1.0,
use_in_shortcut=None,
):
super().__init__()
self.pre_norm = pre_norm
self.time_embedding_norm = time_embedding_norm
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps)
self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels))
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nn.SiLU()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
elif non_linearity == "gelu":
self.nonlinearity = nn.GELU()
elif non_linearity == "relu":
self.nonlinearity = nn.ReLU()
self.use_conv_shortcut = conv_shortcut
self.conv_shortcut = None
if conv_shortcut:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None
def forward(self, input_tensor, temb=None):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1)
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
# ===== Transformer Blocks =====
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, dropout=0.0):
super().__init__()
self.net = nn.ModuleList([
GEGLU(dim, dim * 4),
nn.Dropout(dropout),
nn.Linear(dim * 4, dim if dim_out is None else dim_out),
])
def forward(self, hidden_states):
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class Attention(nn.Module):
def __init__(
self,
query_dim,
heads=8,
dim_head=64,
dropout=0.0,
bias=False,
upcast_attention=False,
cross_attention_dim=None,
):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.inner_dim = inner_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
self.to_out = nn.ModuleList([
nn.Linear(inner_dim, query_dim, bias=True),
nn.Dropout(dropout),
])
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
query = self.to_q(hidden_states)
batch_size, seq_len, _ = query.shape
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
head_dim = self.inner_dim // self.heads
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.inner_dim)
hidden_states = hidden_states.to(query.dtype)
hidden_states = self.to_out[0](hidden_states)
hidden_states = self.to_out[1](hidden_states)
return hidden_states
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
cross_attention_dim=None,
upcast_attention=False,
):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn1 = Attention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
bias=False,
upcast_attention=upcast_attention,
)
self.norm2 = nn.LayerNorm(dim)
self.attn2 = Attention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
bias=False,
upcast_attention=upcast_attention,
cross_attention_dim=cross_attention_dim,
)
self.norm3 = nn.LayerNorm(dim)
self.ff = FeedForward(dim, dropout=dropout)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
attn_output = self.attn1(self.norm1(hidden_states))
hidden_states = attn_output + hidden_states
attn_output = self.attn2(self.norm2(hidden_states), encoder_hidden_states=encoder_hidden_states)
hidden_states = attn_output + hidden_states
ff_output = self.ff(self.norm3(hidden_states))
hidden_states = ff_output + hidden_states
return hidden_states
class Transformer2DModel(nn.Module):
def __init__(
self,
num_attention_heads=16,
attention_head_dim=64,
in_channels=320,
num_layers=1,
dropout=0.0,
norm_num_groups=32,
cross_attention_dim=768,
upcast_attention=False,
use_linear_projection=False,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.use_linear_projection = use_linear_projection
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6)
if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim, bias=True)
else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, bias=True)
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(
dim=inner_dim,
n_heads=num_attention_heads,
d_head=attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
)
for _ in range(num_layers)
])
if use_linear_projection:
self.proj_out = nn.Linear(inner_dim, in_channels, bias=True)
else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, bias=True)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch, channel, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if self.use_linear_projection:
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel)
hidden_states = self.proj_in(hidden_states)
else:
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel)
for block in self.transformer_blocks:
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
if self.use_linear_projection:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous()
else:
hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
# ===== Down/Up Blocks =====
class CrossAttnDownBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
temb_channels=1280,
dropout=0.0,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
cross_attention_dim=768,
attention_head_dim=1,
downsample=True,
use_linear_projection=False,
):
super().__init__()
self.has_cross_attention = True
resnets = []
attentions = []
for i in range(num_layers):
in_channels_i = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels_i,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
Transformer2DModel(
num_attention_heads=attention_head_dim,
attention_head_dim=out_channels // attention_head_dim,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
dropout=dropout,
norm_num_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
use_linear_projection=use_linear_projection,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if downsample:
self.downsamplers = nn.ModuleList([
Downsample2D(out_channels, out_channels, padding=1)
])
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = []
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
output_states.append(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states.append(hidden_states)
return hidden_states, tuple(output_states)
class DownBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
temb_channels=1280,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
downsample=True,
):
super().__init__()
self.has_cross_attention = False
resnets = []
for i in range(num_layers):
in_channels_i = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels_i,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if downsample:
self.downsamplers = nn.ModuleList([
Downsample2D(out_channels, out_channels, padding=1)
])
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = []
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb)
output_states.append(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states.append(hidden_states)
return hidden_states, tuple(output_states)
class CrossAttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
prev_output_channel,
temb_channels=1280,
dropout=0.0,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
cross_attention_dim=768,
attention_head_dim=1,
upsample=True,
use_linear_projection=False,
):
super().__init__()
self.has_cross_attention = True
resnets = []
attentions = []
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
Transformer2DModel(
num_attention_heads=attention_head_dim,
attention_head_dim=out_channels // attention_head_dim,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
dropout=dropout,
norm_num_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
use_linear_projection=use_linear_projection,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if upsample:
self.upsamplers = nn.ModuleList([
Upsample2D(out_channels, out_channels)
])
else:
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
for resnet, attn in zip(self.resnets, self.attentions):
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
return hidden_states
class UpBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
prev_output_channel,
temb_channels=1280,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
upsample=True,
):
super().__init__()
self.has_cross_attention = False
resnets = []
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if upsample:
self.upsamplers = nn.ModuleList([
Upsample2D(out_channels, out_channels)
])
else:
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
for resnet in self.resnets:
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
return hidden_states
# ===== UNet Mid Block =====
class UNetMidBlock2DCrossAttn(nn.Module):
def __init__(
self,
in_channels,
temb_channels=1280,
dropout=0.0,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
cross_attention_dim=768,
attention_head_dim=1,
use_linear_projection=False,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
]
attentions = []
for _ in range(num_layers):
attentions.append(
Transformer2DModel(
num_attention_heads=attention_head_dim,
attention_head_dim=in_channels // attention_head_dim,
in_channels=in_channels,
num_layers=transformer_layers_per_block,
dropout=dropout,
norm_num_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
use_linear_projection=use_linear_projection,
)
)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = resnet(hidden_states, temb)
return hidden_states
# ===== Downsample / Upsample =====
class Downsample2D(nn.Module):
def __init__(self, in_channels, out_channels, padding=1):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=padding)
self.padding = padding
def forward(self, hidden_states):
if self.padding == 0:
hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
return self.conv(hidden_states)
class Upsample2D(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, hidden_states, upsample_size=None):
if upsample_size is not None:
hidden_states = F.interpolate(hidden_states, size=upsample_size, mode="nearest")
else:
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
return self.conv(hidden_states)
# ===== SDXL UNet2DConditionModel =====
class SDXLUNet2DConditionModel(nn.Module):
def __init__(
self,
sample_size=128,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"),
block_out_channels=(320, 640, 1280),
layers_per_block=2,
cross_attention_dim=2048,
attention_head_dim=5,
transformer_layers_per_block=1,
norm_num_groups=32,
norm_eps=1e-5,
dropout=0.0,
act_fn="silu",
time_embedding_type="positional",
flip_sin_to_cos=True,
freq_shift=0,
time_embedding_dim=None,
resnet_time_scale_shift="default",
upcast_attention=False,
use_linear_projection=False,
addition_embed_type=None,
addition_time_embed_dim=None,
projection_class_embeddings_input_dim=None,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.sample_size = sample_size
self.addition_embed_type = addition_embed_type
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
timestep_embedding_dim = time_embedding_dim or block_out_channels[0]
self.time_proj = Timesteps(timestep_embedding_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift)
time_embed_dim = block_out_channels[0] * 4
self.time_embedding = TimestepEmbedding(timestep_embedding_dim, time_embed_dim)
if addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1)
self.down_blocks = nn.ModuleList()
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
if "CrossAttn" in down_block_type:
down_block = CrossAttnDownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block,
transformer_layers_per_block=transformer_layers_per_block[i],
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim[i],
downsample=not is_final_block,
use_linear_projection=use_linear_projection,
)
else:
down_block = DownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
downsample=not is_final_block,
)
self.down_blocks.append(down_block)
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=1,
transformer_layers_per_block=transformer_layers_per_block[-1],
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim[-1],
use_linear_projection=use_linear_projection,
)
self.up_blocks = nn.ModuleList()
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
if "CrossAttn" in up_block_type:
up_block = CrossAttnUpBlock2D(
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block + 1,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=reversed_attention_head_dim[i],
upsample=not is_final_block,
use_linear_projection=use_linear_projection,
)
else:
up_block = UpBlock2D(
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block + 1,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
upsample=not is_final_block,
)
self.up_blocks.append(up_block)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None, timestep_cond=None, added_cond_kwargs=None, return_dict=True):
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
t_emb = self.time_proj(timesteps)
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb)
if self.addition_embed_type == "text_time":
text_embeds = added_cond_kwargs.get("text_embeds")
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
emb = emb + aug_emb
sample = self.conv_in(sample)
down_block_res_samples = (sample,)
for down_block in self.down_blocks:
sample, res_samples = down_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
)
down_block_res_samples += res_samples
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
for up_block in self.up_blocks:
res_samples = down_block_res_samples[-len(up_block.resnets):]
down_block_res_samples = down_block_res_samples[:-len(up_block.resnets)]
upsample_size = down_block_res_samples[-1].shape[2:] if down_block_res_samples else None
sample = up_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if not return_dict:
return (sample,)
return sample

View File

@@ -6,7 +6,6 @@ from typing import Tuple, Optional
from einops import rearrange
from .wan_video_camera_controller import SimpleAdapter
from ..core.gradient import gradient_checkpoint_forward
from .wantodance import WanToDanceRotaryEmbedding, WanToDanceMusicEncoderLayer
try:
import flash_attn_interface
@@ -95,35 +94,23 @@ def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2))
freqs = freqs.to(torch.complex64) if freqs.device.type == "npu" else freqs
freqs = freqs.to(torch.complex64) if freqs.device == "npu" else freqs
x_out = torch.view_as_real(x_out * freqs).flatten(2)
return x_out.to(x.dtype)
def set_to_torch_norm(models):
for model in models:
for module in model.modules():
if isinstance(module, RMSNorm):
module.use_torch_norm = True
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.use_torch_norm = False
self.normalized_shape = (dim,)
def norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def forward(self, x):
dtype = x.dtype
if self.use_torch_norm:
return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
else:
return self.norm(x.float()).to(dtype) * self.weight
return self.norm(x.float()).to(dtype) * self.weight
class AttentionModule(nn.Module):
@@ -284,61 +271,7 @@ class Head(nn.Module):
return x
def wantodance_torch_dfs(model: nn.Module, parent_name='root'):
module_names, modules = [], []
current_name = parent_name if parent_name else 'root'
module_names.append(current_name)
modules.append(model)
for name, child in model.named_children():
if parent_name:
child_name = f'{parent_name}.{name}'
else:
child_name = name
child_modules, child_names = wantodance_torch_dfs(child, child_name)
module_names += child_names
modules += child_modules
return modules, module_names
class WanToDanceInjector(nn.Module):
def __init__(self, all_modules, all_modules_names, dim=2048, num_heads=32, inject_layer=[0, 27]):
super().__init__()
self.injected_block_id = {}
injector_id = 0
for mod_name, mod in zip(all_modules_names, all_modules):
if isinstance(mod, DiTBlock):
for inject_id in inject_layer:
if f'root.transformer_blocks.{inject_id}' == mod_name:
self.injected_block_id[inject_id] = injector_id
injector_id += 1
self.injector = nn.ModuleList(
[
CrossAttention(
dim=dim,
num_heads=num_heads,
)
for _ in range(injector_id)
]
)
self.injector_pre_norm_feat = nn.ModuleList(
[
nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6,)
for _ in range(injector_id)
]
)
self.injector_pre_norm_vec = nn.ModuleList(
[
nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6,)
for _ in range(injector_id)
]
)
class WanModel(torch.nn.Module):
_repeated_blocks = ["DiTBlock"]
def __init__(
self,
dim: int,
@@ -360,13 +293,6 @@ class WanModel(torch.nn.Module):
require_vae_embedding: bool = True,
require_clip_embedding: bool = True,
fuse_vae_embedding_in_latents: bool = False,
wantodance_enable_music_inject: bool = False,
wantodance_music_inject_layers = [0, 4, 8, 12, 16, 20, 24, 27],
wantodance_enable_refimage: bool = False,
wantodance_enable_refface: bool = False,
wantodance_enable_global: bool = False,
wantodance_enable_dynamicfps: bool = False,
wantodance_enable_unimodel: bool = False,
):
super().__init__()
self.dim = dim
@@ -399,12 +325,7 @@ class WanModel(torch.nn.Module):
])
self.head = Head(dim, out_dim, patch_size, eps)
head_dim = dim // num_heads
if wantodance_enable_dynamicfps or wantodance_enable_unimodel:
end = int(22350 / 8 + 0.5) # 149f * 30fps * 5s = 22350
self.freqs = precompute_freqs_cis_3d(head_dim, end=end)
else:
self.freqs = precompute_freqs_cis_3d(head_dim)
self.freqs = precompute_freqs_cis_3d(head_dim)
if has_image_input:
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
@@ -417,83 +338,8 @@ class WanModel(torch.nn.Module):
else:
self.control_adapter = None
self.prepare_wantodance(in_dim, dim, num_heads, has_image_pos_emb, out_dim, patch_size, eps,
wantodance_enable_music_inject, wantodance_music_inject_layers, wantodance_enable_refimage, wantodance_enable_refface,
wantodance_enable_global, wantodance_enable_dynamicfps, wantodance_enable_unimodel)
def prepare_wantodance(
self,
in_dim, dim, num_heads, has_image_pos_emb, out_dim, patch_size, eps,
wantodance_enable_music_inject: bool = False,
wantodance_music_inject_layers = [0, 4, 8, 12, 16, 20, 24, 27],
wantodance_enable_refimage: bool = False,
wantodance_enable_refface: bool = False,
wantodance_enable_global: bool = False,
wantodance_enable_dynamicfps: bool = False,
wantodance_enable_unimodel: bool = False,
):
if wantodance_enable_music_inject:
all_modules, all_modules_names = wantodance_torch_dfs(self.blocks, parent_name="root.transformer_blocks")
self.music_injector = WanToDanceInjector(all_modules, all_modules_names, dim=dim, num_heads=num_heads, inject_layer=wantodance_music_inject_layers)
if wantodance_enable_refimage:
self.img_emb_refimage = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
if wantodance_enable_refface:
self.img_emb_refface = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
if wantodance_enable_global or wantodance_enable_dynamicfps or wantodance_enable_unimodel:
music_feature_dim = 35
ff_size = 1024
dropout = 0.1
latent_dim = 256
nhead = 4
activation = F.gelu
rotary = WanToDanceRotaryEmbedding(dim=latent_dim)
self.music_projection = nn.Linear(music_feature_dim, latent_dim)
self.music_encoder = nn.Sequential()
for _ in range(2):
self.music_encoder.append(
WanToDanceMusicEncoderLayer(
d_model=latent_dim,
nhead=nhead,
dim_feedforward=ff_size,
dropout=dropout,
activation=activation,
batch_first=True,
rotary=rotary,
device='cuda',
)
)
if wantodance_enable_unimodel:
self.patch_embedding_global = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
if wantodance_enable_unimodel:
self.head_global = Head(dim, out_dim, patch_size, eps)
self.wantodance_enable_music_inject = wantodance_enable_music_inject
self.wantodance_enable_refimage = wantodance_enable_refimage
self.wantodance_enable_refface = wantodance_enable_refface
self.wantodance_enable_global = wantodance_enable_global
self.wantodance_enable_dynamicfps = wantodance_enable_dynamicfps
self.wantodance_enable_unimodel = wantodance_enable_unimodel
def wantodance_after_transformer_block(self, block_idx, hidden_states):
if self.wantodance_enable_music_inject:
if block_idx in self.music_injector.injected_block_id.keys():
audio_attn_id = self.music_injector.injected_block_id[block_idx]
audio_emb = self.merged_audio_emb # b f n c
num_frames = audio_emb.shape[1]
input_hidden_states = hidden_states.clone() # b (f h w) c
input_hidden_states = rearrange(input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames)
attn_hidden_states = self.music_injector.injector_pre_norm_feat[audio_attn_id](input_hidden_states)
audio_emb = rearrange(audio_emb, "b t c -> (b t) 1 c", t=num_frames)
attn_audio_emb = audio_emb
residual_out = self.music_injector.injector[audio_attn_id](attn_hidden_states, attn_audio_emb)
residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames)
hidden_states = hidden_states + residual_out
return hidden_states
def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None, enable_wantodance_global=False):
if enable_wantodance_global:
x = self.patch_embedding_global(x)
else:
x = self.patch_embedding(x)
def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None):
x = self.patch_embedding(x)
if self.control_adapter is not None and control_camera_latents_input is not None:
y_camera = self.control_adapter(control_camera_latents_input)
x = [u + v for u, v in zip(x, y_camera)]

View File

@@ -469,7 +469,7 @@ class Down_ResidualBlock(nn.Module):
def forward(self, x, feat_cache=None, feat_idx=[0]):
x_copy = x.clone()
for module in self.downsamples:
x, feat_cache, feat_idx = module(x, feat_cache, feat_idx)
x = module(x, feat_cache, feat_idx)
return x + self.avg_shortcut(x_copy), feat_cache, feat_idx
@@ -506,10 +506,10 @@ class Up_ResidualBlock(nn.Module):
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
x_main = x.clone()
for module in self.upsamples:
x_main, feat_cache, feat_idx = module(x_main, feat_cache, feat_idx)
x_main = module(x_main, feat_cache, feat_idx)
if self.avg_shortcut is not None:
x_shortcut = self.avg_shortcut(x, first_chunk)
return x_main + x_shortcut, feat_cache, feat_idx
return x_main + x_shortcut
else:
return x_main, feat_cache, feat_idx
@@ -1023,11 +1023,11 @@ class VideoVAE_(nn.Module):
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
out = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else:
out_, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
out_ = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2) # may add tensor offload
@@ -1247,22 +1247,6 @@ class WanVideoVAE(nn.Module):
return videos
def encode_framewise(self, videos, device):
hidden_states = []
for i in range(videos.shape[2]):
hidden_states.append(self.single_encode(videos[:, :, i:i+1], device))
hidden_states = torch.concat(hidden_states, dim=2)
return hidden_states
def decode_framewise(self, hidden_states, device):
video = []
for i in range(hidden_states.shape[2]):
video.append(self.single_decode(hidden_states[:, :, i:i+1], device))
video = torch.concat(video, dim=2)
return video
@staticmethod
def state_dict_converter():
return WanVideoVAEStateDictConverter()
@@ -1319,11 +1303,11 @@ class VideoVAE38_(VideoVAE_):
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, :1, :, :],
out = self.encoder(x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else:
out_, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = torch.cat([out, out_], 2)
@@ -1353,12 +1337,12 @@ class VideoVAE38_(VideoVAE_):
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
out = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
first_chunk=True)
else:
out_, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
out_ = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)

View File

@@ -1,209 +0,0 @@
from inspect import isfunction
from math import log, pi
import torch
from einops import rearrange, repeat
from torch import einsum, nn
from typing import Any, Callable, List, Optional, Union
from torch import Tensor
import torch.nn.functional as F
# helper functions
def exists(val):
return val is not None
def broadcat(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
), "invalid dimensions for broadcastable concatentation"
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
return torch.cat(tensors, dim=dim)
# rotary embedding helper functions
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
def apply_rotary_emb(freqs, t, start_index=0):
freqs = freqs.to(t)
rot_dim = freqs.shape[-1]
end_index = start_index + rot_dim
assert (
rot_dim <= t.shape[-1]
), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
t_left, t, t_right = (
t[..., :start_index],
t[..., start_index:end_index],
t[..., end_index:],
)
t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
return torch.cat((t_left, t, t_right), dim=-1)
# learned rotation helpers
def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
if exists(freq_ranges):
rotations = einsum("..., f -> ... f", rotations, freq_ranges)
rotations = rearrange(rotations, "... r f -> ... (r f)")
rotations = repeat(rotations, "... n -> ... (n r)", r=2)
return apply_rotary_emb(rotations, t, start_index=start_index)
# classes
class WanToDanceRotaryEmbedding(nn.Module):
def __init__(
self,
dim,
custom_freqs=None,
freqs_for="lang",
theta=10000,
max_freq=10,
num_freqs=1,
learned_freq=False,
):
super().__init__()
if exists(custom_freqs):
freqs = custom_freqs
elif freqs_for == "lang":
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
)
elif freqs_for == "pixel":
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
elif freqs_for == "constant":
freqs = torch.ones(num_freqs).float()
else:
raise ValueError(f"unknown modality {freqs_for}")
self.cache = dict()
if learned_freq:
self.freqs = nn.Parameter(freqs)
else:
self.register_buffer("freqs", freqs, persistent=False)
def rotate_queries_or_keys(self, t, seq_dim=-2):
device = t.device
seq_len = t.shape[seq_dim]
freqs = self.forward(
lambda: torch.arange(seq_len, device=device), cache_key=seq_len
)
return apply_rotary_emb(freqs, t)
def forward(self, t, cache_key=None):
if exists(cache_key) and cache_key in self.cache:
return self.cache[cache_key]
if isfunction(t):
t = t()
# freqs = self.freqs
freqs = self.freqs.to(t.device)
freqs = torch.einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
if exists(cache_key):
self.cache[cache_key] = freqs
return freqs
class WanToDanceMusicEncoderLayer(nn.Module):
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
layer_norm_eps: float = 1e-5,
batch_first: bool = False,
norm_first: bool = True,
device=None,
dtype=None,
rotary=None,
) -> None:
super().__init__()
self.self_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=batch_first, device=device, dtype=dtype
)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm_first = norm_first
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = activation
self.rotary = rotary
self.use_rotary = rotary is not None
# self-attention block
def _sa_block(
self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]
) -> Tensor:
qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
x = self.self_attn(
qk,
qk,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)[0]
return self.dropout1(x)
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
def forward(
self,
src: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
x = src
if self.norm_first:
self.norm1.to(device=x.device)
self.norm2.to(device=x.device)
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
x = x + self._ff_block(self.norm2(x))
else:
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
x = self.norm2(x + self._ff_block(x))
return x

View File

@@ -88,14 +88,6 @@ class Attention(torch.nn.Module):
self.norm_q = RMSNorm(head_dim, eps=1e-5)
self.norm_k = RMSNorm(head_dim, eps=1e-5)
# Apply RoPE
def apply_rotary_emb(self, x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast(get_device_type(), enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in) # todo
def forward(self, hidden_states, freqs_cis, attention_mask):
query = self.to_q(hidden_states)
key = self.to_k(hidden_states)
@@ -111,9 +103,17 @@ class Attention(torch.nn.Module):
if self.norm_k is not None:
key = self.norm_k(key)
# Apply RoPE
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast(get_device_type(), enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in) # todo
if freqs_cis is not None:
query = self.apply_rotary_emb(query, freqs_cis)
key = self.apply_rotary_emb(key, freqs_cis)
query = apply_rotary_emb(query, freqs_cis)
key = apply_rotary_emb(key, freqs_cis)
# Cast to correct dtype
dtype = query.dtype
@@ -326,7 +326,6 @@ class RopeEmbedder:
class ZImageDiT(nn.Module):
_supports_gradient_checkpointing = True
_no_split_modules = ["ZImageTransformerBlock"]
_repeated_blocks = ["ZImageTransformerBlock"]
def __init__(
self,

View File

@@ -6,36 +6,6 @@ class ZImageTextEncoder(torch.nn.Module):
def __init__(self, model_size="4B"):
super().__init__()
config_dict = {
"0.6B": Qwen3Config(**{
"architectures": [
"Qwen3ForCausalLM"
],
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 3072,
"max_position_embeddings": 40960,
"max_window_layers": 28,
"model_type": "qwen3",
"num_attention_heads": 16,
"num_hidden_layers": 28,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_scaling": None,
"rope_theta": 1000000,
"sliding_window": None,
"tie_word_embeddings": True,
"torch_dtype": "bfloat16",
"transformers_version": "4.51.0",
"use_cache": True,
"use_sliding_window": False,
"vocab_size": 151936
}),
"4B": Qwen3Config(**{
"architectures": [
"Qwen3ForCausalLM"

View File

@@ -1,264 +0,0 @@
import torch, math
from PIL import Image
from typing import Union
from tqdm import tqdm
from einops import rearrange
import numpy as np
from math import prod
from transformers import AutoTokenizer
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
from ..utils.lora.merge import merge_lora
from ..models.anima_dit import AnimaDiT
from ..models.z_image_text_encoder import ZImageTextEncoder
from ..models.wan_video_vae import WanVideoVAE
class AnimaImagePipeline(BasePipeline):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
)
self.scheduler = FlowMatchScheduler("Z-Image")
self.text_encoder: ZImageTextEncoder = None
self.dit: AnimaDiT = None
self.vae: WanVideoVAE = None
self.tokenizer: AutoTokenizer = None
self.tokenizer_t5xxl: AutoTokenizer = None
self.in_iteration_models = ("dit",)
self.units = [
AnimaUnit_ShapeChecker(),
AnimaUnit_NoiseInitializer(),
AnimaUnit_InputImageEmbedder(),
AnimaUnit_PromptEmbedder(),
]
self.model_fn = model_fn_anima
self.compilable_models = ["dit"]
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
tokenizer_t5xxl_config: ModelConfig = ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"),
vram_limit: float = None,
):
# Initialize pipeline
pipe = AnimaImagePipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
pipe.dit = model_pool.fetch_model("anima_dit")
pipe.vae = model_pool.fetch_model("wan_video_vae")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
if tokenizer_t5xxl_config is not None:
tokenizer_t5xxl_config.download_if_necessary()
pipe.tokenizer_t5xxl = AutoTokenizer.from_pretrained(tokenizer_t5xxl_config.path)
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 4.0,
# Image
input_image: Image.Image = None,
denoising_strength: float = 1.0,
# Shape
height: int = 1024,
width: int = 1024,
# Randomness
seed: int = None,
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 30,
sigma_shift: float = None,
# Progress bar
progress_bar_cmd = tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
# Parameters
inputs_posi = {
"prompt": prompt,
}
inputs_nega = {
"negative_prompt": negative_prompt,
}
inputs_shared = {
"cfg_scale": cfg_scale,
"input_image": input_image, "denoising_strength": denoising_strength,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"num_inference_steps": num_inference_steps,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
# Decode
self.load_models_to_device(['vae'])
image = self.vae.decode(inputs_shared["latents"].unsqueeze(2), device=self.device).squeeze(2)
image = self.vae_output_to_image(image)
self.load_models_to_device([])
return image
class AnimaUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width"),
output_params=("height", "width"),
)
def process(self, pipe: AnimaImagePipeline, height, width):
height, width = pipe.check_resize_height_width(height, width)
return {"height": height, "width": width}
class AnimaUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe: AnimaImagePipeline, height, width, seed, rand_device):
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
class AnimaUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",)
)
def process(self, pipe: AnimaImagePipeline, input_image, noise):
if input_image is None:
return {"latents": noise, "input_latents": None}
pipe.load_models_to_device(['vae'])
if isinstance(input_image, list):
input_latents = []
for image in input_image:
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
input_latents.append(pipe.vae.encode(image))
input_latents = torch.concat(input_latents, dim=0)
else:
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
input_latents = pipe.vae.encode(image.unsqueeze(2), device=pipe.device).squeeze(2)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents, "input_latents": input_latents}
class AnimaUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_emb",),
onload_model_names=("text_encoder",)
)
def encode_prompt(
self,
pipe: AnimaImagePipeline,
prompt,
device = None,
max_sequence_length: int = 512,
):
if isinstance(prompt, str):
prompt = [prompt]
text_inputs = pipe.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device).bool()
prompt_embeds = pipe.text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
).hidden_states[-1]
t5xxl_text_inputs = pipe.tokenizer_t5xxl(
prompt,
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
t5xxl_ids = t5xxl_text_inputs.input_ids.to(device)
return prompt_embeds.to(pipe.torch_dtype), t5xxl_ids
def process(self, pipe: AnimaImagePipeline, prompt):
pipe.load_models_to_device(self.onload_model_names)
prompt_embeds, t5xxl_ids = self.encode_prompt(pipe, prompt, pipe.device)
return {"prompt_emb": prompt_embeds, "t5xxl_ids": t5xxl_ids}
def model_fn_anima(
dit: AnimaDiT = None,
latents=None,
timestep=None,
prompt_emb=None,
t5xxl_ids=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs
):
latents = latents.unsqueeze(2)
timestep = timestep / 1000
model_output = dit(
x=latents,
timesteps=timestep,
context=prompt_emb,
t5xxl_ids=t5xxl_ids,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
model_output = model_output.squeeze(2)
return model_output

View File

@@ -1,266 +0,0 @@
"""
ERNIE-Image Text-to-Image Pipeline for DiffSynth-Studio.
Architecture: SharedAdaLN DiT + RoPE 3D + Joint Image-Text Attention.
"""
import torch
from typing import Union, Optional
from tqdm import tqdm
from transformers import AutoTokenizer
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from ..models.ernie_image_text_encoder import ErnieImageTextEncoder
from ..models.ernie_image_dit import ErnieImageDiT
from ..models.flux2_vae import Flux2VAE
class ErnieImagePipeline(BasePipeline):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
)
self.scheduler = FlowMatchScheduler("ERNIE-Image")
self.text_encoder: ErnieImageTextEncoder = None
self.dit: ErnieImageDiT = None
self.vae: Flux2VAE = None
self.tokenizer: AutoTokenizer = None
self.in_iteration_models = ("dit",)
self.units = [
ErnieImageUnit_ShapeChecker(),
ErnieImageUnit_PromptEmbedder(),
ErnieImageUnit_NoiseInitializer(),
ErnieImageUnit_InputImageEmbedder(),
]
self.model_fn = model_fn_ernie_image
self.compilable_models = ["dit"]
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
vram_limit: float = None,
):
pipe = ErnieImagePipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
pipe.text_encoder = model_pool.fetch_model("ernie_image_text_encoder")
pipe.dit = model_pool.fetch_model("ernie_image_dit")
pipe.vae = model_pool.fetch_model("flux2_vae")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 4.0,
# Shape
height: int = 1024,
width: int = 1024,
# Randomness
seed: int = None,
rand_device: str = "cuda",
# Steps
num_inference_steps: int = 50,
sigma_shift: float = 3.0,
# Progress bar
progress_bar_cmd=tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, shift=sigma_shift)
# Parameters
inputs_posi = {"prompt": prompt}
inputs_nega = {"negative_prompt": negative_prompt}
inputs_shared = {
"height": height, "width": width, "seed": seed,
"cfg_scale": cfg_scale, "num_inference_steps": num_inference_steps,
"rand_device": rand_device,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
# Decode
self.load_models_to_device(['vae'])
latents = inputs_shared["latents"]
image = self.vae.decode(latents)
image = self.vae_output_to_image(image)
self.load_models_to_device([])
return image
class ErnieImageUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width"),
output_params=("height", "width"),
)
def process(self, pipe: ErnieImagePipeline, height, width):
height, width = pipe.check_resize_height_width(height, width)
return {"height": height, "width": width}
class ErnieImageUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_embeds", "prompt_embeds_mask"),
onload_model_names=("text_encoder",)
)
def encode_prompt(self, pipe: ErnieImagePipeline, prompt):
if isinstance(prompt, str):
prompt = [prompt]
text_hiddens = []
text_lens_list = []
for p in prompt:
ids = pipe.tokenizer(
p,
add_special_tokens=True,
truncation=True,
padding=False,
)["input_ids"]
if len(ids) == 0:
if pipe.tokenizer.bos_token_id is not None:
ids = [pipe.tokenizer.bos_token_id]
else:
ids = [0]
input_ids = torch.tensor([ids], device=pipe.device)
outputs = pipe.text_encoder(
input_ids=input_ids,
)
# Text encoder returns tuple of (hidden_states_tuple,) where each layer's hidden state is included
all_hidden_states = outputs[0]
hidden = all_hidden_states[-2][0] # [T, H] - second to last layer
text_hiddens.append(hidden)
text_lens_list.append(hidden.shape[0])
# Pad to uniform length
if len(text_hiddens) == 0:
text_in_dim = pipe.text_encoder.config.hidden_size if hasattr(pipe.text_encoder, 'config') else 3072
return {
"prompt_embeds": torch.zeros((0, 0, text_in_dim), device=pipe.device, dtype=pipe.torch_dtype),
"prompt_embeds_mask": torch.zeros((0,), device=pipe.device, dtype=torch.long),
}
normalized = [th.to(pipe.device).to(pipe.torch_dtype) for th in text_hiddens]
text_lens = torch.tensor([t.shape[0] for t in normalized], device=pipe.device, dtype=torch.long)
Tmax = int(text_lens.max().item())
text_in_dim = normalized[0].shape[1]
text_bth = torch.zeros((len(normalized), Tmax, text_in_dim), device=pipe.device, dtype=pipe.torch_dtype)
for i, t in enumerate(normalized):
text_bth[i, :t.shape[0], :] = t
return {"prompt_embeds": text_bth, "prompt_embeds_mask": text_lens}
def process(self, pipe: ErnieImagePipeline, prompt):
pipe.load_models_to_device(self.onload_model_names)
if pipe.text_encoder is not None:
return self.encode_prompt(pipe, prompt)
return {}
class ErnieImageUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe: ErnieImagePipeline, height, width, seed, rand_device):
latent_h = height // pipe.height_division_factor
latent_w = width // pipe.width_division_factor
latent_channels = pipe.dit.in_channels
# Use pipeline device if rand_device is not specified
if rand_device is None:
rand_device = str(pipe.device)
noise = pipe.generate_noise(
(1, latent_channels, latent_h, latent_w),
seed=seed,
rand_device=rand_device,
rand_torch_dtype=pipe.torch_dtype,
)
return {"noise": noise}
class ErnieImageUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",)
)
def process(self, pipe: ErnieImagePipeline, input_image, noise):
if input_image is None:
# T2I path: use noise directly as initial latents
return {"latents": noise, "input_latents": None}
# I2I path: VAE encode input image
pipe.load_models_to_device(['vae'])
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
input_latents = pipe.vae.encode(image)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
# In inference mode, add noise to encoded latents
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents}
def model_fn_ernie_image(
dit: ErnieImageDiT,
latents=None,
timestep=None,
prompt_embeds=None,
prompt_embeds_mask=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
output = dit(
hidden_states=latents,
timestep=timestep,
text_bth=prompt_embeds,
text_lens=prompt_embeds_mask,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
return output

View File

@@ -42,7 +42,6 @@ class Flux2ImagePipeline(BasePipeline):
Flux2Unit_ImageIDs(),
]
self.model_fn = model_fn_flux2
self.compilable_models = ["dit"]
@staticmethod
@@ -91,7 +90,6 @@ class Flux2ImagePipeline(BasePipeline):
# Randomness
seed: int = None,
rand_device: str = "cpu",
initial_noise: torch.Tensor = None,
# Steps
num_inference_steps: int = 30,
# Progress bar
@@ -111,7 +109,7 @@ class Flux2ImagePipeline(BasePipeline):
"input_image": input_image, "denoising_strength": denoising_strength,
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device, "initial_noise": initial_noise,
"seed": seed, "rand_device": rand_device,
"num_inference_steps": num_inference_steps,
}
for unit in self.units:
@@ -431,15 +429,12 @@ class Flux2Unit_Qwen3PromptEmbedder(PipelineUnit):
class Flux2Unit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device", "initial_noise"),
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe: Flux2ImagePipeline, height, width, seed, rand_device, initial_noise):
if initial_noise is not None:
noise = initial_noise.clone()
else:
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
def process(self, pipe: Flux2ImagePipeline, height, width, seed, rand_device):
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
noise = noise.reshape(1, 128, height//16 * width//16).permute(0, 2, 1)
return {"noise": noise}

View File

@@ -103,7 +103,6 @@ class FluxImagePipeline(BasePipeline):
FluxImageUnit_LoRAEncode(),
]
self.model_fn = model_fn_flux_image
self.compilable_models = ["dit"]
self.lora_loader = FluxLoRALoader
def enable_lora_merger(self):

View File

@@ -1,282 +0,0 @@
import torch
from PIL import Image
from typing import Union, Optional
from tqdm import tqdm
from einops import rearrange
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from ..models.joyai_image_dit import JoyAIImageDiT
from ..models.joyai_image_text_encoder import JoyAIImageTextEncoder
from ..models.wan_video_vae import WanVideoVAE
class JoyAIImagePipeline(BasePipeline):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
)
self.scheduler = FlowMatchScheduler("Wan")
self.text_encoder: JoyAIImageTextEncoder = None
self.dit: JoyAIImageDiT = None
self.vae: WanVideoVAE = None
self.processor = None
self.in_iteration_models = ("dit",)
self.units = [
JoyAIImageUnit_ShapeChecker(),
JoyAIImageUnit_EditImageEmbedder(),
JoyAIImageUnit_PromptEmbedder(),
JoyAIImageUnit_NoiseInitializer(),
JoyAIImageUnit_InputImageEmbedder(),
]
self.model_fn = model_fn_joyai_image
self.compilable_models = ["dit"]
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
# Processor
processor_config: ModelConfig = None,
# Optional
vram_limit: float = None,
):
pipe = JoyAIImagePipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
pipe.text_encoder = model_pool.fetch_model("joyai_image_text_encoder")
pipe.dit = model_pool.fetch_model("joyai_image_dit")
pipe.vae = model_pool.fetch_model("wan_video_vae")
if processor_config is not None:
processor_config.download_if_necessary()
from transformers import AutoProcessor
pipe.processor = AutoProcessor.from_pretrained(processor_config.path)
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 5.0,
# Image
edit_image: Image.Image = None,
denoising_strength: float = 1.0,
# Shape
height: int = 1024,
width: int = 1024,
# Randomness
seed: int = None,
# Steps
max_sequence_length: int = 4096,
num_inference_steps: int = 30,
# Tiling
tiled: Optional[bool] = False,
tile_size: Optional[tuple[int, int]] = (30, 52),
tile_stride: Optional[tuple[int, int]] = (15, 26),
# Scheduler
shift: Optional[float] = 4.0,
# Progress bar
progress_bar_cmd=tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=shift)
# Parameters
inputs_posi = {"prompt": prompt}
inputs_nega = {"negative_prompt": negative_prompt}
inputs_shared = {
"cfg_scale": cfg_scale,
"edit_image": edit_image,
"denoising_strength": denoising_strength,
"height": height, "width": width,
"seed": seed, "max_sequence_length": max_sequence_length,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
}
# Unit chain
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
unit, self, inputs_shared, inputs_posi, inputs_nega
)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
# Decode
self.load_models_to_device(['vae'])
latents = rearrange(inputs_shared["latents"], "b n c f h w -> (b n) c f h w")
image = self.vae.decode(latents, device=self.device)[0]
image = self.vae_output_to_image(image, pattern="C 1 H W")
self.load_models_to_device([])
return image
class JoyAIImageUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width"),
output_params=("height", "width"),
)
def process(self, pipe: "JoyAIImagePipeline", height, width):
height, width = pipe.check_resize_height_width(height, width)
return {"height": height, "width": width}
class JoyAIImageUnit_PromptEmbedder(PipelineUnit):
prompt_template_encode = {
'image':
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
'multiple_images':
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n{}<|im_start|>assistant\n",
'video':
"<|im_start|>system\n \\nDescribe the video by detailing the following aspects:\n1. The main content and theme of the video.\n2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\n3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\n4. background environment, light, style and atmosphere.\n5. camera angles, movements, and transitions used in the video:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
}
prompt_template_encode_start_idx = {'image': 34, 'multiple_images': 34, 'video': 91}
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt", "positive": "positive"},
input_params_nega={"prompt": "negative_prompt", "positive": "positive"},
input_params=("edit_image", "max_sequence_length"),
output_params=("prompt_embeds", "prompt_embeds_mask"),
onload_model_names=("joyai_image_text_encoder",),
)
def process(self, pipe: "JoyAIImagePipeline", prompt, positive, edit_image, max_sequence_length):
pipe.load_models_to_device(self.onload_model_names)
has_image = edit_image is not None
if has_image:
prompt_embeds, prompt_embeds_mask = self._encode_with_image(pipe, prompt, edit_image, max_sequence_length)
else:
prompt_embeds, prompt_embeds_mask = self._encode_text_only(pipe, prompt, max_sequence_length)
return {"prompt_embeds": prompt_embeds, "prompt_embeds_mask": prompt_embeds_mask}
def _encode_with_image(self, pipe, prompt, edit_image, max_sequence_length):
template = self.prompt_template_encode['multiple_images']
drop_idx = self.prompt_template_encode_start_idx['multiple_images']
image_tokens = '<image>\n'
prompt = f"<|im_start|>user\n{image_tokens}{prompt}<|im_end|>\n"
prompt = prompt.replace('<image>\n', '<|vision_start|><|image_pad|><|vision_end|>')
prompt = template.format(prompt)
inputs = pipe.processor(text=[prompt], images=[edit_image], padding=True, return_tensors="pt").to(pipe.device)
last_hidden_states = pipe.text_encoder(**inputs)
prompt_embeds = last_hidden_states[:, drop_idx:]
prompt_embeds_mask = inputs['attention_mask'][:, drop_idx:]
if max_sequence_length is not None and prompt_embeds.shape[1] > max_sequence_length:
prompt_embeds = prompt_embeds[:, -max_sequence_length:, :]
prompt_embeds_mask = prompt_embeds_mask[:, -max_sequence_length:]
return prompt_embeds, prompt_embeds_mask
def _encode_text_only(self, pipe, prompt, max_sequence_length):
# TODO: may support for text-only encoding in the future.
raise NotImplementedError("Text-only encoding is not implemented yet. Please provide edit_image for now.")
return prompt_embeds, encoder_attention_mask
class JoyAIImageUnit_EditImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image", "tiled", "tile_size", "tile_stride", "height", "width"),
output_params=("ref_latents", "num_items", "is_multi_item"),
onload_model_names=("wan_video_vae",),
)
def process(self, pipe: "JoyAIImagePipeline", edit_image, tiled, tile_size, tile_stride, height, width):
if edit_image is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
# Resize edit image to match target dimensions (from ShapeChecker) to ensure ref_latents matches latents
edit_image = edit_image.resize((width, height), Image.LANCZOS)
images = [pipe.preprocess_image(edit_image).transpose(0, 1)]
latents = pipe.vae.encode(images, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
ref_vae = rearrange(latents, "(b n) c 1 h w -> b n c 1 h w", n=1).to(device=pipe.device, dtype=pipe.torch_dtype)
return {"ref_latents": ref_vae, "edit_image": edit_image}
class JoyAIImageUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("seed", "height", "width", "rand_device"),
output_params=("noise"),
)
def process(self, pipe: "JoyAIImagePipeline", seed, height, width, rand_device):
latent_h = height // pipe.vae.upsampling_factor
latent_w = width // pipe.vae.upsampling_factor
shape = (1, 1, pipe.vae.z_dim, 1, latent_h, latent_w)
noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
class JoyAIImageUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",),
)
def process(self, pipe: JoyAIImagePipeline, input_image, noise, tiled, tile_size, tile_stride):
if input_image is None:
return {"latents": noise}
pipe.load_models_to_device(self.onload_model_names)
if isinstance(input_image, Image.Image):
input_image = [input_image]
input_image = [pipe.preprocess_image(img).transpose(0, 1) for img in input_image]
latents = pipe.vae.encode(input_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
input_latents = rearrange(latents, "(b n) c 1 h w -> b n c 1 h w", n=(len(input_image)))
return {"latents": noise, "input_latents": input_latents}
def model_fn_joyai_image(
dit,
latents,
timestep,
prompt_embeds,
prompt_embeds_mask,
ref_latents=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
img = torch.cat([ref_latents, latents], dim=1) if ref_latents is not None else latents
img = dit(
hidden_states=img,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
img = img[:, -latents.size(1):]
return img

View File

@@ -1,731 +0,0 @@
import torch, types
import numpy as np
from PIL import Image
from einops import repeat
from typing import Optional, Union
from einops import rearrange
import numpy as np
from PIL import Image
from tqdm import tqdm
from typing import Optional
from transformers import AutoImageProcessor, Gemma3Processor
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer
from ..models.ltx2_dit import LTXModel
from ..models.ltx2_video_vae import LTX2VideoEncoder, LTX2VideoDecoder, VideoLatentPatchifier
from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier, AudioProcessor
from ..models.ltx2_upsampler import LTX2LatentUpsampler
from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS
from ..utils.data.media_io_ltx2 import ltx2_preprocess
from ..utils.data.audio import convert_to_stereo
class LTX2AudioVideoPipeline(BasePipeline):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device,
torch_dtype=torch_dtype,
height_division_factor=32,
width_division_factor=32,
time_division_factor=8,
time_division_remainder=1,
)
self.scheduler = FlowMatchScheduler("LTX-2")
self.text_encoder: LTX2TextEncoder = None
self.tokenizer: LTXVGemmaTokenizer = None
self.processor: Gemma3Processor = None
self.text_encoder_post_modules: LTX2TextEncoderPostModules = None
self.dit: LTXModel = None
self.video_vae_encoder: LTX2VideoEncoder = None
self.video_vae_decoder: LTX2VideoDecoder = None
self.audio_vae_encoder: LTX2AudioEncoder = None
self.audio_vae_decoder: LTX2AudioDecoder = None
self.audio_vocoder: LTX2Vocoder = None
self.upsampler: LTX2LatentUpsampler = None
self.video_patchifier: VideoLatentPatchifier = VideoLatentPatchifier(patch_size=1)
self.audio_patchifier: AudioPatchifier = AudioPatchifier(patch_size=1)
self.audio_processor: AudioProcessor = AudioProcessor()
self.in_iteration_models = ("dit",)
self.units = [
LTX2AudioVideoUnit_PipelineChecker(),
LTX2AudioVideoUnit_ShapeChecker(),
LTX2AudioVideoUnit_PromptEmbedder(),
LTX2AudioVideoUnit_NoiseInitializer(),
LTX2AudioVideoUnit_VideoRetakeEmbedder(),
LTX2AudioVideoUnit_AudioRetakeEmbedder(),
LTX2AudioVideoUnit_InputAudioEmbedder(),
LTX2AudioVideoUnit_InputVideoEmbedder(),
LTX2AudioVideoUnit_InputImagesEmbedder(),
LTX2AudioVideoUnit_InContextVideoEmbedder(),
]
self.stage2_units = [
LTX2AudioVideoUnit_SwitchStage2(),
LTX2AudioVideoUnit_NoiseInitializer(),
LTX2AudioVideoUnit_LatentsUpsampler(),
LTX2AudioVideoUnit_VideoRetakeEmbedder(),
LTX2AudioVideoUnit_AudioRetakeEmbedder(),
LTX2AudioVideoUnit_InputImagesEmbedder(),
LTX2AudioVideoUnit_SetScheduleStage2(),
]
self.model_fn = model_fn_ltx2
self.compilable_models = ["dit"]
self.default_negative_prompt = {
"LTX-2": (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
),
"LTX-2.3": (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
),
}
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
stage2_lora_config: Optional[ModelConfig] = None,
stage2_lora_strength: float = 0.8,
vram_limit: float = None,
):
# Initialize pipeline
pipe = LTX2AudioVideoPipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
pipe.text_encoder = model_pool.fetch_model("ltx2_text_encoder")
tokenizer_config.download_if_necessary()
pipe.tokenizer = LTXVGemmaTokenizer(tokenizer_path=tokenizer_config.path)
image_processor = AutoImageProcessor.from_pretrained(tokenizer_config.path, local_files_only=True)
pipe.processor = Gemma3Processor(image_processor=image_processor, tokenizer=pipe.tokenizer.tokenizer)
pipe.text_encoder_post_modules = model_pool.fetch_model("ltx2_text_encoder_post_modules")
pipe.dit = model_pool.fetch_model("ltx2_dit")
pipe.video_vae_encoder = model_pool.fetch_model("ltx2_video_vae_encoder")
pipe.video_vae_decoder = model_pool.fetch_model("ltx2_video_vae_decoder")
pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder")
pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder")
pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler")
pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
# Stage 2
if stage2_lora_config is not None:
pipe.stage2_lora_config = stage2_lora_config
pipe.stage2_lora_strength = stage2_lora_strength
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
def denoise_stage(self, inputs_shared, inputs_posi, inputs_nega, units, cfg_scale=1.0, progress_bar_cmd=tqdm, skip_stage=False):
if skip_stage:
return inputs_shared, inputs_posi, inputs_nega
for unit in units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video,
inpaint_mask=inputs_shared.get("denoise_mask_video", None), input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio,
inpaint_mask=inputs_shared.get("denoise_mask_audio", None), input_latents=inputs_shared.get("input_latents_audio", None), **inputs_shared)
return inputs_shared, inputs_posi, inputs_nega
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: Optional[str] = "",
denoising_strength: float = 1.0,
# Image-to-video
input_images: Optional[list[Image.Image]] = None,
input_images_indexes: Optional[list[int]] = [0],
input_images_strength: Optional[float] = 1.0,
# In-Context Video Control
in_context_videos: Optional[list[list[Image.Image]]] = None,
in_context_downsample_factor: Optional[int] = 2,
# Video-to-video
retake_video: Optional[list[Image.Image]] = None,
retake_video_regions: Optional[list[tuple[float, float]]] = None,
# Audio-to-video
retake_audio: Optional[torch.Tensor] = None,
audio_sample_rate: Optional[int] = 48000,
retake_audio_regions: Optional[list[tuple[float, float]]] = None,
# Randomness
seed: Optional[int] = None,
rand_device: Optional[str] = "cpu",
# Shape
height: Optional[int] = 512,
width: Optional[int] = 768,
num_frames: Optional[int] = 121,
frame_rate: Optional[int] = 24,
# Classifier-free guidance
cfg_scale: Optional[float] = 3.0,
# Scheduler
num_inference_steps: Optional[int] = 30,
# VAE tiling
tiled: Optional[bool] = True,
tile_size_in_pixels: Optional[int] = 512,
tile_overlap_in_pixels: Optional[int] = 128,
tile_size_in_frames: Optional[int] = 128,
tile_overlap_in_frames: Optional[int] = 24,
# Special Pipelines
use_two_stage_pipeline: Optional[bool] = False,
stage2_spatial_upsample_factor: Optional[int] = 2,
clear_lora_before_state_two: Optional[bool] = False,
use_distilled_pipeline: Optional[bool] = False,
# progress_bar
progress_bar_cmd=tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, special_case="ditilled_stage1" if use_distilled_pipeline else None)
# Inputs
inputs_posi = {
"prompt": prompt,
}
inputs_nega = {
"negative_prompt": negative_prompt,
}
inputs_shared = {
"input_images": input_images, "input_images_indexes": input_images_indexes, "input_images_strength": input_images_strength,
"retake_video": retake_video, "retake_video_regions": retake_video_regions,
"retake_audio": (retake_audio, audio_sample_rate) if retake_audio is not None else None, "retake_audio_regions": retake_audio_regions,
"in_context_videos": in_context_videos, "in_context_downsample_factor": in_context_downsample_factor,
"seed": seed, "rand_device": rand_device,
"height": height, "width": width, "num_frames": num_frames, "frame_rate": frame_rate,
"cfg_scale": cfg_scale,
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline, "clear_lora_before_state_two": clear_lora_before_state_two, "stage2_spatial_upsample_factor": stage2_spatial_upsample_factor,
"video_patchifier": self.video_patchifier, "audio_patchifier": self.audio_patchifier,
}
# Stage 1
inputs_shared, inputs_posi, inputs_nega = self.denoise_stage(inputs_shared, inputs_posi, inputs_nega, self.units, cfg_scale, progress_bar_cmd)
# Stage 2
inputs_shared, inputs_posi, inputs_nega = self.denoise_stage(inputs_shared, inputs_posi, inputs_nega, self.stage2_units, 1.0, progress_bar_cmd, not inputs_shared["use_two_stage_pipeline"])
# Decode
self.load_models_to_device(['video_vae_decoder'])
video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels, tile_overlap_in_pixels, tile_size_in_frames, tile_overlap_in_frames)
video = self.vae_output_to_video(video)
self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder'])
decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"])
decoded_audio = self.audio_vocoder(decoded_audio)
decoded_audio = self.output_audio_format_check(decoded_audio)
return video, decoded_audio
class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
input_params=("use_distilled_pipeline", "use_two_stage_pipeline"),
output_params=("use_two_stage_pipeline", "cfg_scale")
)
def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
if inputs_shared.get("use_distilled_pipeline", False):
inputs_shared["use_two_stage_pipeline"] = True
inputs_shared["cfg_scale"] = 1.0
print(f"Distilled pipeline requested, setting use_two_stage_pipeline to True, disable CFG by setting cfg_scale to 1.0.")
if inputs_shared.get("use_two_stage_pipeline", False):
# distill pipeline also uses two-stage, but it does not needs lora
if not inputs_shared.get("use_distilled_pipeline", False):
if not (hasattr(pipe, "stage2_lora_config") and pipe.stage2_lora_config is not None):
raise ValueError("Two-stage pipeline requested, but stage2_lora_config is not set in the pipeline.")
if not (hasattr(pipe, "upsampler") and pipe.upsampler is not None):
raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.")
return inputs_shared, inputs_posi, inputs_nega
class LTX2AudioVideoUnit_ShapeChecker(PipelineUnit):
"""
For two-stage pipelines, the resolution must be divisible by 64.
For one-stage pipelines, the resolution must be divisible by 32.
This unit set height and width to stage 1 resolution, and stage_2_width and stage_2_height.
"""
def __init__(self):
super().__init__(
input_params=("height", "width", "num_frames", "use_two_stage_pipeline", "stage2_spatial_upsample_factor"),
output_params=("height", "width", "num_frames", "stage_2_height", "stage_2_width"),
)
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, use_two_stage_pipeline=False, stage2_spatial_upsample_factor=2):
if use_two_stage_pipeline:
height, width = height // stage2_spatial_upsample_factor, width // stage2_spatial_upsample_factor
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
stage_2_height, stage_2_width = int(height * stage2_spatial_upsample_factor), int(width * stage2_spatial_upsample_factor)
else:
stage_2_height, stage_2_width = None, None
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
return {"height": height, "width": width, "num_frames": num_frames, "stage_2_height": stage_2_height, "stage_2_width": stage_2_width}
class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("video_context", "audio_context"),
onload_model_names=("text_encoder", "text_encoder_post_modules"),
)
def _preprocess_text(
self,
pipe,
text: str,
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
token_pairs = pipe.tokenizer.tokenize_with_weights(text)["gemma"]
input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.device)
attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.device)
outputs = pipe.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
return outputs.hidden_states, attention_mask
def encode_prompt(self, pipe, text, padding_side="left"):
hidden_states, attention_mask = self._preprocess_text(pipe, text)
video_encoding, audio_encoding, attention_mask = pipe.text_encoder_post_modules.process_hidden_states(
hidden_states, attention_mask, padding_side)
return video_encoding, audio_encoding, attention_mask
def process(self, pipe: LTX2AudioVideoPipeline, prompt: str):
pipe.load_models_to_device(self.onload_model_names)
video_context, audio_context, _ = self.encode_prompt(pipe, prompt)
return {"video_context": video_context, "audio_context": audio_context}
class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "num_frames", "seed", "rand_device", "frame_rate"),
output_params=("video_noise", "audio_noise", "video_positions", "audio_positions", "video_latent_shape", "audio_latent_shape")
)
def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=128)
video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device)
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float()
video_positions[:, 0, ...] = video_positions[:, 0, ...] / frame_rate
video_positions = video_positions.to(pipe.torch_dtype)
audio_latent_shape = AudioLatentShape.from_video_pixel_shape(video_pixel_shape)
audio_noise = pipe.generate_noise(audio_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
return {
"video_noise": video_noise,
"audio_noise": audio_noise,
"video_positions": video_positions,
"audio_positions": audio_positions,
"video_latent_shape": video_latent_shape,
"audio_latent_shape": audio_latent_shape
}
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
return self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_video", "video_noise", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels"),
output_params=("video_latents", "input_latents"),
onload_model_names=("video_vae_encoder")
)
def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, tiled, tile_size_in_pixels, tile_overlap_in_pixels):
if input_video is None or not pipe.scheduler.training:
return {"video_latents": video_noise}
else:
pipe.load_models_to_device(self.onload_model_names)
input_video = pipe.preprocess_video(input_video)
input_latents = pipe.video_vae_encoder.encode(input_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device)
return {"video_latents": input_latents, "input_latents": input_latents}
class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_audio", "audio_noise"),
output_params=("audio_latents", "audio_input_latents", "audio_positions", "audio_latent_shape"),
onload_model_names=("audio_vae_encoder",)
)
def process(self, pipe: LTX2AudioVideoPipeline, input_audio, audio_noise):
if input_audio is None or not pipe.scheduler.training:
return {"audio_latents": audio_noise}
else:
input_audio, sample_rate = input_audio
input_audio = convert_to_stereo(input_audio)
pipe.load_models_to_device(self.onload_model_names)
input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype)
audio_input_latents = pipe.audio_vae_encoder(input_audio)
audio_latent_shape = AudioLatentShape.from_torch_shape(audio_input_latents.shape)
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
return {"audio_latents": audio_input_latents, "audio_input_latents": audio_input_latents, "audio_positions": audio_positions, "audio_latent_shape": audio_latent_shape}
class LTX2AudioVideoUnit_VideoRetakeEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("retake_video", "height", "width", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "video_positions", "retake_video_regions"),
output_params=("input_latents_video", "denoise_mask_video"),
onload_model_names=("video_vae_encoder")
)
def process(self, pipe: LTX2AudioVideoPipeline, retake_video, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, video_positions, retake_video_regions=None):
if retake_video is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
resized_video = [frame.resize((width, height)) for frame in retake_video]
input_video = pipe.preprocess_video(resized_video)
input_latents_video = pipe.video_vae_encoder.encode(input_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device)
b, c, f, h, w = input_latents_video.shape
denoise_mask_video = torch.zeros((b, 1, f, h, w), device=input_latents_video.device, dtype=input_latents_video.dtype)
if retake_video_regions is not None and len(retake_video_regions) > 0:
for start_time, end_time in retake_video_regions:
t_start, t_end = video_positions[0, 0].unbind(dim=-1)
in_region = (t_end >= start_time) & (t_start <= end_time)
in_region = pipe.video_patchifier.unpatchify_video(in_region.unsqueeze(0).unsqueeze(-1), f, h, w)
denoise_mask_video = torch.where(in_region, torch.ones_like(denoise_mask_video), denoise_mask_video)
return {"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video}
class LTX2AudioVideoUnit_AudioRetakeEmbedder(PipelineUnit):
"""
Functionality of audio2video, audio retaking.
"""
def __init__(self):
super().__init__(
input_params=("retake_audio", "seed", "rand_device", "retake_audio_regions"),
output_params=("input_latents_audio", "audio_noise", "audio_positions", "audio_latent_shape", "denoise_mask_audio"),
onload_model_names=("audio_vae_encoder",)
)
def process(self, pipe: LTX2AudioVideoPipeline, retake_audio, seed, rand_device, retake_audio_regions=None):
if retake_audio is None:
return {}
else:
input_audio, sample_rate = retake_audio
input_audio = convert_to_stereo(input_audio)
pipe.load_models_to_device(self.onload_model_names)
input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype, device=pipe.device)
input_latents_audio = pipe.audio_vae_encoder(input_audio)
audio_latent_shape = AudioLatentShape.from_torch_shape(input_latents_audio.shape)
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
# Regenerate noise for the new shape if retake_audio is provided, to avoid shape mismatch.
audio_noise = pipe.generate_noise(input_latents_audio.shape, seed=seed, rand_device=rand_device)
b, c, t, f = input_latents_audio.shape
denoise_mask_audio = torch.zeros((b, 1, t, 1), device=input_latents_audio.device, dtype=input_latents_audio.dtype)
if retake_audio_regions is not None and len(retake_audio_regions) > 0:
for start_time, end_time in retake_audio_regions:
t_start, t_end = audio_positions[:, 0, :, 0], audio_positions[:, 0, :, 1]
in_region = (t_end >= start_time) & (t_start <= end_time)
in_region = pipe.audio_patchifier.unpatchify_audio(in_region.unsqueeze(-1), 1, 1)
denoise_mask_audio = torch.where(in_region, torch.ones_like(denoise_mask_audio), denoise_mask_audio)
return {
"input_latents_audio": input_latents_audio,
"denoise_mask_audio": denoise_mask_audio,
"audio_noise": audio_noise,
"audio_positions": audio_positions,
"audio_latent_shape": audio_latent_shape,
}
class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "frame_rate", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "input_latents_video", "denoise_mask_video"),
output_params=("denoise_mask_video", "input_latents_video", "ref_frames_latents", "ref_frames_positions"),
onload_model_names=("video_vae_encoder")
)
def get_image_latent(self, pipe, input_image, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels):
image = ltx2_preprocess(np.array(input_image.resize((width, height))))
image = torch.Tensor(np.array(image, dtype=np.float32)).to(dtype=pipe.torch_dtype, device=pipe.device)
image = image / 127.5 - 1.0
image = repeat(image, f"H W C -> B C F H W", B=1, F=1)
latents = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device)
return latents
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength=1.0, input_latents_video=None, denoise_mask_video=None):
b, _, f, h, w = latents.shape
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device) if denoise_mask_video is None else denoise_mask_video
input_latents_video = torch.zeros_like(latents) if input_latents_video is None else input_latents_video
for idx, input_latent in zip(input_indexes, input_latents):
idx = min(max(1 + (idx-1) // 8, 0), f - 1)
input_latent = input_latent.to(dtype=latents.dtype, device=latents.device)
input_latents_video[:, :, idx:idx + input_latent.shape[2], :, :] = input_latent
denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength
return input_latents_video, denoise_mask
def process(
self,
pipe: LTX2AudioVideoPipeline,
video_latents,
input_images,
height,
width,
frame_rate,
tiled,
tile_size_in_pixels,
tile_overlap_in_pixels,
input_images_indexes=[0],
input_images_strength=1.0,
input_latents_video=None,
denoise_mask_video=None,
):
if input_images is None or len(input_images) == 0:
return {}
else:
if len(input_images_indexes) != len(set(input_images_indexes)):
raise ValueError("Input images must have unique indexes.")
pipe.load_models_to_device(self.onload_model_names)
frame_conditions = {"input_latents_video": None, "denoise_mask_video": None, "ref_frames_latents": [], "ref_frames_positions": []}
for img, index in zip(input_images, input_images_indexes):
latents = self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels)
# first_frame by replacing latents
if index == 0:
input_latents_video, denoise_mask_video = self.apply_input_images_to_latents(
video_latents, [latents], [0], input_images_strength, input_latents_video, denoise_mask_video)
frame_conditions.update({"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video})
# other frames by adding reference latents
else:
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=VideoLatentShape.from_torch_shape(latents.shape), device=pipe.device)
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, False).float()
video_positions[:, 0, ...] = (video_positions[:, 0, ...] + index) / frame_rate
video_positions = video_positions.to(pipe.torch_dtype)
frame_conditions["ref_frames_latents"].append(latents)
frame_conditions["ref_frames_positions"].append(video_positions)
if len(frame_conditions["ref_frames_latents"]) == 0:
frame_conditions.update({"ref_frames_latents": None, "ref_frames_positions": None})
return frame_conditions
class LTX2AudioVideoUnit_InContextVideoEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("in_context_videos", "height", "width", "num_frames", "frame_rate", "in_context_downsample_factor", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels"),
output_params=("in_context_video_latents", "in_context_video_positions"),
onload_model_names=("video_vae_encoder")
)
def check_in_context_video(self, pipe, in_context_video, height, width, num_frames, in_context_downsample_factor):
if in_context_video is None or len(in_context_video) == 0:
raise ValueError("In-context video is None or empty.")
in_context_video = in_context_video[:num_frames]
expected_height = height // in_context_downsample_factor
expected_width = width // in_context_downsample_factor
current_h, current_w, current_f = in_context_video[0].size[1], in_context_video[0].size[0], len(in_context_video)
h, w, f = pipe.check_resize_height_width(expected_height, expected_width, current_f, verbose=0)
if current_h != h or current_w != w:
in_context_video = [img.resize((w, h)) for img in in_context_video]
if current_f != f:
# pad black frames at the end
in_context_video = in_context_video + [Image.new("RGB", (w, h), (0, 0, 0))] * (f - current_f)
return in_context_video
def process(self, pipe: LTX2AudioVideoPipeline, in_context_videos, height, width, num_frames, frame_rate, in_context_downsample_factor, tiled, tile_size_in_pixels, tile_overlap_in_pixels):
if in_context_videos is None or len(in_context_videos) == 0:
return {}
else:
pipe.load_models_to_device(self.onload_model_names)
latents, positions = [], []
for in_context_video in in_context_videos:
in_context_video = self.check_in_context_video(pipe, in_context_video, height, width, num_frames, in_context_downsample_factor)
in_context_video = pipe.preprocess_video(in_context_video)
in_context_latents = pipe.video_vae_encoder.encode(in_context_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device)
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=VideoLatentShape.from_torch_shape(in_context_latents.shape), device=pipe.device)
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float()
video_positions[:, 0, ...] = video_positions[:, 0, ...] / frame_rate
video_positions[:, 1, ...] *= in_context_downsample_factor # height axis
video_positions[:, 2, ...] *= in_context_downsample_factor # width axis
video_positions = video_positions.to(pipe.torch_dtype)
latents.append(in_context_latents)
positions.append(video_positions)
latents = torch.cat(latents, dim=1)
positions = torch.cat(positions, dim=1)
return {"in_context_video_latents": latents, "in_context_video_positions": positions}
class LTX2AudioVideoUnit_SwitchStage2(PipelineUnit):
"""
1. switch height and width to stage 2 resolution
2. clear in_context_video_latents and in_context_video_positions
3. switch stage 2 lora model
"""
def __init__(self):
super().__init__(
input_params=("stage_2_height", "stage_2_width", "clear_lora_before_state_two", "use_distilled_pipeline"),
output_params=("height", "width", "in_context_video_latents", "in_context_video_positions"),
)
def process(self, pipe: LTX2AudioVideoPipeline, stage_2_height, stage_2_width, clear_lora_before_state_two, use_distilled_pipeline):
stage2_params = {}
stage2_params.update({"height": stage_2_height, "width": stage_2_width})
stage2_params.update({"in_context_video_latents": None, "in_context_video_positions": None})
stage2_params.update({"input_latents_video": None, "denoise_mask_video": None})
if clear_lora_before_state_two:
pipe.clear_lora()
if not use_distilled_pipeline:
pipe.load_lora(pipe.dit, pipe.stage2_lora_config, alpha=pipe.stage2_lora_strength, state_dict=pipe.stage2_lora_config.state_dict)
return stage2_params
class LTX2AudioVideoUnit_SetScheduleStage2(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("video_latents", "video_noise", "audio_latents", "audio_noise"),
output_params=("video_latents", "audio_latents"),
)
def process(self, pipe: LTX2AudioVideoPipeline, video_latents, video_noise, audio_latents, audio_noise):
pipe.scheduler.set_timesteps(special_case="stage2")
video_latents = pipe.scheduler.add_noise(video_latents, video_noise, pipe.scheduler.timesteps[0])
audio_latents = pipe.scheduler.add_noise(audio_latents, audio_noise, pipe.scheduler.timesteps[0])
return {"video_latents": video_latents, "audio_latents": audio_latents}
class LTX2AudioVideoUnit_LatentsUpsampler(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("video_latents",),
output_params=("video_latents",),
onload_model_names=("upsampler",),
)
def process(self, pipe: LTX2AudioVideoPipeline, video_latents):
if video_latents is None or pipe.upsampler is None:
raise ValueError("No upsampler or no video latents before stage 2.")
else:
pipe.load_models_to_device(self.onload_model_names)
video_latents = pipe.video_vae_encoder.per_channel_statistics.un_normalize(video_latents)
video_latents = pipe.upsampler(video_latents)
video_latents = pipe.video_vae_encoder.per_channel_statistics.normalize(video_latents)
return {"video_latents": video_latents}
def model_fn_ltx2(
dit: LTXModel,
video_latents=None,
video_context=None,
video_positions=None,
video_patchifier=None,
audio_latents=None,
audio_context=None,
audio_positions=None,
audio_patchifier=None,
timestep=None,
# First Frame Conditioning
input_latents_video=None,
denoise_mask_video=None,
# Other Frames Conditioning
ref_frames_latents=None,
ref_frames_positions=None,
# In-Context Conditioning
in_context_video_latents=None,
in_context_video_positions=None,
# Audio Inputs
input_latents_audio=None,
denoise_mask_audio=None,
# Gradient Checkpointing
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
timestep = timestep.float() / 1000.
# patchify
b, c_v, f, h, w = video_latents.shape
video_latents = video_patchifier.patchify(video_latents)
seq_len_video = video_latents.shape[1]
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
# Frist frame conditioning by replacing the video latents
if input_latents_video is not None:
denoise_mask_video = video_patchifier.patchify(denoise_mask_video)
video_latents = video_latents * denoise_mask_video + video_patchifier.patchify(input_latents_video) * (1.0 - denoise_mask_video)
video_timesteps = denoise_mask_video * video_timesteps
# Reference conditioning by appending the reference video or frame latents
total_ref_latents = ref_frames_latents if ref_frames_latents is not None else []
total_ref_positions = ref_frames_positions if ref_frames_positions is not None else []
total_ref_latents += [in_context_video_latents] if in_context_video_latents is not None else []
total_ref_positions += [in_context_video_positions] if in_context_video_positions is not None else []
if len(total_ref_latents) > 0:
for ref_frames_latent, ref_frames_position in zip(total_ref_latents, total_ref_positions):
ref_frames_latent = video_patchifier.patchify(ref_frames_latent)
ref_frames_timestep = timestep.repeat(1, ref_frames_latent.shape[1], 1) * 0.
video_latents = torch.cat([video_latents, ref_frames_latent], dim=1)
video_positions = torch.cat([video_positions, ref_frames_position], dim=2)
video_timesteps = torch.cat([video_timesteps, ref_frames_timestep], dim=1)
if audio_latents is not None:
_, c_a, _, mel_bins = audio_latents.shape
audio_latents = audio_patchifier.patchify(audio_latents)
audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1)
else:
audio_timesteps = None
if input_latents_audio is not None:
denoise_mask_audio = audio_patchifier.patchify(denoise_mask_audio)
audio_latents = audio_latents * denoise_mask_audio + audio_patchifier.patchify(input_latents_audio) * (1.0 - denoise_mask_audio)
audio_timesteps = denoise_mask_audio * audio_timesteps
vx, ax = dit(
video_latents=video_latents,
video_positions=video_positions,
video_context=video_context,
video_timesteps=video_timesteps,
audio_latents=audio_latents,
audio_positions=audio_positions,
audio_context=audio_context,
audio_timesteps=audio_timesteps,
sigma=timestep,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
vx = vx[:, :seq_len_video, ...]
# unpatchify
vx = video_patchifier.unpatchify_video(vx, f, h, w)
ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins) if ax is not None else None
return vx, ax

View File

@@ -1,461 +0,0 @@
import sys
import torch, types
from PIL import Image
from typing import Optional, Union
from einops import rearrange
import numpy as np
from PIL import Image
from tqdm import tqdm
from typing import Optional
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from ..models.wan_video_dit import WanModel, sinusoidal_embedding_1d, set_to_torch_norm
from ..models.wan_video_text_encoder import WanTextEncoder, HuggingfaceTokenizer
from ..models.wan_video_vae import WanVideoVAE
from ..models.mova_audio_dit import MovaAudioDit
from ..models.mova_audio_vae import DacVAE
from ..models.mova_dual_tower_bridge import DualTowerConditionalBridge
from ..utils.data.audio import convert_to_mono, resample_waveform
class MovaAudioVideoPipeline(BasePipeline):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
)
self.scheduler = FlowMatchScheduler("Wan")
self.tokenizer: HuggingfaceTokenizer = None
self.text_encoder: WanTextEncoder = None
self.video_dit: WanModel = None # high noise model
self.video_dit2: WanModel = None # low noise model
self.audio_dit: MovaAudioDit = None
self.dual_tower_bridge: DualTowerConditionalBridge = None
self.video_vae: WanVideoVAE = None
self.audio_vae: DacVAE = None
self.in_iteration_models = ("video_dit", "audio_dit", "dual_tower_bridge")
self.in_iteration_models_2 = ("video_dit2", "audio_dit", "dual_tower_bridge")
self.units = [
MovaAudioVideoUnit_ShapeChecker(),
MovaAudioVideoUnit_NoiseInitializer(),
MovaAudioVideoUnit_InputVideoEmbedder(),
MovaAudioVideoUnit_InputAudioEmbedder(),
MovaAudioVideoUnit_PromptEmbedder(),
MovaAudioVideoUnit_ImageEmbedderVAE(),
MovaAudioVideoUnit_UnifiedSequenceParallel(),
]
self.model_fn = model_fn_mova_audio_video
self.compilable_models = ["video_dit", "video_dit2", "audio_dit"]
def enable_usp(self):
from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward
for block in self.video_dit.blocks + self.audio_dit.blocks + self.video_dit2.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.sp_size = get_sequence_parallel_world_size()
self.use_unified_sequence_parallel = True
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
use_usp: bool = False,
vram_limit: float = None,
):
if use_usp:
from ..utils.xfuser import initialize_usp
initialize_usp(device)
import torch.distributed as dist
from ..core.device.npu_compatible_device import get_device_name
if dist.is_available() and dist.is_initialized():
device = get_device_name()
# Initialize pipeline
pipe = MovaAudioVideoPipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
pipe.text_encoder = model_pool.fetch_model("wan_video_text_encoder")
dit = model_pool.fetch_model("wan_video_dit", index=2)
if isinstance(dit, list):
pipe.video_dit, pipe.video_dit2 = dit
else:
pipe.video_dit = dit
pipe.audio_dit = model_pool.fetch_model("mova_audio_dit")
pipe.dual_tower_bridge = model_pool.fetch_model("mova_dual_tower_bridge")
pipe.video_vae = model_pool.fetch_model("wan_video_vae")
pipe.audio_vae = model_pool.fetch_model("mova_audio_vae")
set_to_torch_norm([pipe.video_dit, pipe.audio_dit, pipe.dual_tower_bridge] + ([pipe.video_dit2] if pipe.video_dit2 is not None else []))
# Size division factor
if pipe.video_vae is not None:
pipe.height_division_factor = pipe.video_vae.upsampling_factor * 2
pipe.width_division_factor = pipe.video_vae.upsampling_factor * 2
# Initialize tokenizer and processor
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = HuggingfaceTokenizer(name=tokenizer_config.path, seq_len=512, clean='whitespace')
# Unified Sequence Parallel
if use_usp: pipe.enable_usp()
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: Optional[str] = "",
# Image-to-video
input_image: Optional[Image.Image] = None,
# First-last-frame-to-video
end_image: Optional[Image.Image] = None,
# Video-to-video
denoising_strength: Optional[float] = 1.0,
# Randomness
seed: Optional[int] = None,
rand_device: Optional[str] = "cpu",
# Shape
height: Optional[int] = 352,
width: Optional[int] = 640,
num_frames: Optional[int] = 81,
frame_rate: Optional[int] = 24,
# Classifier-free guidance
cfg_scale: Optional[float] = 5.0,
# Boundary
switch_DiT_boundary: Optional[float] = 0.9,
# Scheduler
num_inference_steps: Optional[int] = 50,
sigma_shift: Optional[float] = 5.0,
# VAE tiling
tiled: Optional[bool] = True,
tile_size: Optional[tuple[int, int]] = (30, 52),
tile_stride: Optional[tuple[int, int]] = (15, 26),
# progress_bar
progress_bar_cmd=tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
# Inputs
inputs_posi = {
"prompt": prompt,
}
inputs_nega = {
"negative_prompt": negative_prompt,
}
inputs_shared = {
"input_image": input_image,
"end_image": end_image,
"denoising_strength": denoising_strength,
"seed": seed, "rand_device": rand_device,
"height": height, "width": width, "num_frames": num_frames, "frame_rate": frame_rate,
"cfg_scale": cfg_scale,
"sigma_shift": sigma_shift,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
# Switch DiT if necessary
if timestep.item() < switch_DiT_boundary * 1000 and self.video_dit2 is not None and not models["video_dit"] is self.video_dit2:
self.load_models_to_device(self.in_iteration_models_2)
models["video_dit"] = self.video_dit2
# Timestep
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
# Scheduler
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video, **inputs_shared)
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared)
# Decode
self.load_models_to_device(['video_vae'])
video = self.video_vae.decode(inputs_shared["video_latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
video = self.vae_output_to_video(video)
self.load_models_to_device(["audio_vae"])
audio = self.audio_vae.decode(inputs_shared["audio_latents"])
audio = self.output_audio_format_check(audio)
self.load_models_to_device([])
return video, audio
class MovaAudioVideoUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "num_frames"),
output_params=("height", "width", "num_frames"),
)
def process(self, pipe: MovaAudioVideoPipeline, height, width, num_frames):
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
return {"height": height, "width": width, "num_frames": num_frames}
class MovaAudioVideoUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "num_frames", "seed", "rand_device", "frame_rate"),
output_params=("video_noise", "audio_noise")
)
def process(self, pipe: MovaAudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate):
length = (num_frames - 1) // 4 + 1
video_shape = (1, pipe.video_vae.model.z_dim, length, height // pipe.video_vae.upsampling_factor, width // pipe.video_vae.upsampling_factor)
video_noise = pipe.generate_noise(video_shape, seed=seed, rand_device=rand_device)
audio_num_samples = (int(pipe.audio_vae.sample_rate * num_frames / frame_rate) - 1) // int(pipe.audio_vae.hop_length) + 1
audio_shape = (1, pipe.audio_vae.latent_dim, audio_num_samples)
audio_noise = pipe.generate_noise(audio_shape, seed=seed, rand_device=rand_device)
return {"video_noise": video_noise, "audio_noise": audio_noise}
class MovaAudioVideoUnit_InputVideoEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_video", "video_noise", "tiled", "tile_size", "tile_stride"),
output_params=("video_latents", "input_latents"),
onload_model_names=("video_vae",)
)
def process(self, pipe: MovaAudioVideoPipeline, input_video, video_noise, tiled, tile_size, tile_stride):
if input_video is None or not pipe.scheduler.training:
return {"video_latents": video_noise}
else:
pipe.load_models_to_device(self.onload_model_names)
input_video = pipe.preprocess_video(input_video)
input_latents = pipe.video_vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
return {"input_latents": input_latents}
class MovaAudioVideoUnit_InputAudioEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_audio", "audio_noise"),
output_params=("audio_latents", "audio_input_latents"),
onload_model_names=("audio_vae",)
)
def process(self, pipe: MovaAudioVideoPipeline, input_audio, audio_noise):
if input_audio is None or not pipe.scheduler.training:
return {"audio_latents": audio_noise}
else:
pipe.load_models_to_device(self.onload_model_names)
input_audio, sample_rate = input_audio
input_audio = convert_to_mono(input_audio)
input_audio = resample_waveform(input_audio, sample_rate, pipe.audio_vae.sample_rate)
input_audio = pipe.audio_vae.preprocess(input_audio.unsqueeze(0), pipe.audio_vae.sample_rate)
z, _, _, _, _ = pipe.audio_vae.encode(input_audio)
return {"audio_input_latents": z.mode()}
class MovaAudioVideoUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("context",),
onload_model_names=("text_encoder",)
)
def encode_prompt(self, pipe: MovaAudioVideoPipeline, prompt):
ids, mask = pipe.tokenizer(
prompt,
padding="max_length",
max_length=512,
truncation=True,
add_special_tokens=True,
return_mask=True,
return_tensors="pt",
)
ids = ids.to(pipe.device)
mask = mask.to(pipe.device)
seq_lens = mask.gt(0).sum(dim=1).long()
prompt_emb = pipe.text_encoder(ids, mask)
for i, v in enumerate(seq_lens):
prompt_emb[:, v:] = 0
return prompt_emb
def process(self, pipe: MovaAudioVideoPipeline, prompt) -> dict:
pipe.load_models_to_device(self.onload_model_names)
prompt_emb = self.encode_prompt(pipe, prompt)
return {"context": prompt_emb}
class MovaAudioVideoUnit_ImageEmbedderVAE(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
output_params=("y",),
onload_model_names=("video_vae",)
)
def process(self, pipe: MovaAudioVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
if input_image is None or not pipe.video_dit.require_vae_embedding:
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
msk[:, 1:] = 0
if end_image is not None:
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
msk[:, -1:] = 1
else:
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
y = pipe.video_vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
y = torch.concat([msk, y])
y = y.unsqueeze(0)
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"y": y}
class MovaAudioVideoUnit_UnifiedSequenceParallel(PipelineUnit):
def __init__(self):
super().__init__(input_params=(), output_params=("use_unified_sequence_parallel",))
def process(self, pipe: MovaAudioVideoPipeline):
if hasattr(pipe, "use_unified_sequence_parallel") and pipe.use_unified_sequence_parallel:
return {"use_unified_sequence_parallel": True}
return {"use_unified_sequence_parallel": False}
def model_fn_mova_audio_video(
video_dit: WanModel,
audio_dit: MovaAudioDit,
dual_tower_bridge: DualTowerConditionalBridge,
video_latents: torch.Tensor = None,
audio_latents: torch.Tensor = None,
timestep: torch.Tensor = None,
context: torch.Tensor = None,
y: Optional[torch.Tensor] = None,
frame_rate: Optional[int] = 24,
use_unified_sequence_parallel: bool = False,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
):
video_x, audio_x = video_latents, audio_latents
# First-Last Frame
if y is not None:
video_x = torch.cat([video_x, y], dim=1)
# Timestep
video_t = video_dit.time_embedding(sinusoidal_embedding_1d(video_dit.freq_dim, timestep))
video_t_mod = video_dit.time_projection(video_t).unflatten(1, (6, video_dit.dim))
audio_t = audio_dit.time_embedding(sinusoidal_embedding_1d(audio_dit.freq_dim, timestep))
audio_t_mod = audio_dit.time_projection(audio_t).unflatten(1, (6, audio_dit.dim))
# Context
video_context = video_dit.text_embedding(context)
audio_context = audio_dit.text_embedding(context)
# Patchify
video_x = video_dit.patch_embedding(video_x)
f_v, h, w = video_x.shape[2:]
video_x = rearrange(video_x, 'b c f h w -> b (f h w) c').contiguous()
seq_len_video = video_x.shape[1]
audio_x = audio_dit.patch_embedding(audio_x)
f_a = audio_x.shape[2]
audio_x = rearrange(audio_x, 'b c f -> b f c').contiguous()
seq_len_audio = audio_x.shape[1]
# Freqs
video_freqs = torch.cat([
video_dit.freqs[0][:f_v].view(f_v, 1, 1, -1).expand(f_v, h, w, -1),
video_dit.freqs[1][:h].view(1, h, 1, -1).expand(f_v, h, w, -1),
video_dit.freqs[2][:w].view(1, 1, w, -1).expand(f_v, h, w, -1)
], dim=-1).reshape(f_v * h * w, 1, -1).to(video_x.device)
audio_freqs = torch.cat([
audio_dit.freqs[0][:f_a].view(f_a, -1).expand(f_a, -1),
audio_dit.freqs[1][:f_a].view(f_a, -1).expand(f_a, -1),
audio_dit.freqs[2][:f_a].view(f_a, -1).expand(f_a, -1),
], dim=-1).reshape(f_a, 1, -1).to(audio_x.device)
video_rope, audio_rope = dual_tower_bridge.build_aligned_freqs(
video_fps=frame_rate,
grid_size=(f_v, h, w),
audio_steps=audio_x.shape[1],
device=video_x.device,
dtype=video_x.dtype,
)
# usp func
if use_unified_sequence_parallel:
from ..utils.xfuser import get_current_chunk, gather_all_chunks
else:
get_current_chunk = lambda x, dim=1: x
gather_all_chunks = lambda x, seq_len, dim=1: x
# Forward blocks
for block_id in range(len(audio_dit.blocks)):
if dual_tower_bridge.should_interact(block_id, "a2v"):
video_x, audio_x = dual_tower_bridge(
block_id,
video_x,
audio_x,
x_freqs=video_rope,
y_freqs=audio_rope,
condition_scale=1.0,
video_grid_size=(f_v, h, w),
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
video_x = get_current_chunk(video_x, dim=1)
video_x = gradient_checkpoint_forward(
video_dit.blocks[block_id],
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
video_x, video_context, video_t_mod, video_freqs
)
video_x = gather_all_chunks(video_x, seq_len=seq_len_video, dim=1)
audio_x = get_current_chunk(audio_x, dim=1)
audio_x = gradient_checkpoint_forward(
audio_dit.blocks[block_id],
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
audio_x, audio_context, audio_t_mod, audio_freqs
)
audio_x = gather_all_chunks(audio_x, seq_len=seq_len_audio, dim=1)
video_x = get_current_chunk(video_x, dim=1)
for block_id in range(len(audio_dit.blocks), len(video_dit.blocks)):
video_x = gradient_checkpoint_forward(
video_dit.blocks[block_id],
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
video_x, video_context, video_t_mod, video_freqs
)
video_x = gather_all_chunks(video_x, seq_len=seq_len_video, dim=1)
# Head
video_x = video_dit.head(video_x, video_t)
video_x = video_dit.unpatchify(video_x, (f_v, h, w))
audio_x = audio_dit.head(audio_x, audio_t)
audio_x = audio_dit.unpatchify(audio_x, (f_a,))
return video_x, audio_x

View File

@@ -56,7 +56,6 @@ class QwenImagePipeline(BasePipeline):
QwenImageUnit_BlockwiseControlNet(),
]
self.model_fn = model_fn_qwen_image
self.compilable_models = ["dit"]
@staticmethod
@@ -683,16 +682,14 @@ class QwenImageUnit_Image2LoRADecode(PipelineUnit):
class QwenImageUnit_ContextImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("context_image", "height", "width", "tiled", "tile_size", "tile_stride", "layer_input_image"),
input_params=("context_image", "height", "width", "tiled", "tile_size", "tile_stride"),
output_params=("context_latents",),
onload_model_names=("vae",)
)
def process(self, pipe: QwenImagePipeline, context_image, height, width, tiled, tile_size, tile_stride, layer_input_image=None):
def process(self, pipe: QwenImagePipeline, context_image, height, width, tiled, tile_size, tile_stride):
if context_image is None:
return {}
if layer_input_image is not None:
context_image = context_image.convert("RGBA")
pipe.load_models_to_device(self.onload_model_names)
context_image = pipe.preprocess_image(context_image.resize((width, height))).to(device=pipe.device, dtype=pipe.torch_dtype)
context_latents = pipe.vae.encode(context_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)

View File

@@ -1,230 +0,0 @@
import torch
from PIL import Image
from tqdm import tqdm
from typing import Union
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion.ddim_scheduler import DDIMScheduler
from ..core import ModelConfig
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from transformers import AutoTokenizer, CLIPTextModel
from ..models.stable_diffusion_text_encoder import SDTextEncoder
from ..models.stable_diffusion_unet import UNet2DConditionModel
from ..models.stable_diffusion_vae import StableDiffusionVAE
class StableDiffusionPipeline(BasePipeline):
def __init__(self, device=get_device_type(), torch_dtype=torch.float16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=8, width_division_factor=8,
)
self.scheduler = DDIMScheduler()
self.text_encoder: SDTextEncoder = None
self.unet: UNet2DConditionModel = None
self.vae: StableDiffusionVAE = None
self.tokenizer: AutoTokenizer = None
self.in_iteration_models = ("unet",)
self.units = [
SDUnit_ShapeChecker(),
SDUnit_PromptEmbedder(),
SDUnit_NoiseInitializer(),
SDUnit_InputImageEmbedder(),
]
self.model_fn = model_fn_stable_diffusion
self.compilable_models = ["unet"]
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.float16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = None,
vram_limit: float = None,
):
pipe = StableDiffusionPipeline(device=device, torch_dtype=torch_dtype)
# Override vram_config to use the specified torch_dtype for all models
for mc in model_configs:
mc._vram_config_override = {
'onload_dtype': torch_dtype,
'computation_dtype': torch_dtype,
}
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
pipe.text_encoder = model_pool.fetch_model("stable_diffusion_text_encoder")
pipe.unet = model_pool.fetch_model("stable_diffusion_unet")
pipe.vae = model_pool.fetch_model("stable_diffusion_vae")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 7.5,
height: int = 512,
width: int = 512,
seed: int = None,
rand_device: str = "cpu",
num_inference_steps: int = 50,
eta: float = 0.0,
guidance_rescale: float = 0.0,
progress_bar_cmd=tqdm,
):
# 1. Scheduler
self.scheduler.set_timesteps(
num_inference_steps, eta=eta,
)
# 2. Three-dict input preparation
inputs_posi = {"prompt": prompt}
inputs_nega = {"negative_prompt": negative_prompt}
inputs_shared = {
"cfg_scale": cfg_scale,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"guidance_rescale": guidance_rescale,
}
# 3. Unit chain execution
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
unit, self, inputs_shared, inputs_posi, inputs_nega
)
# 4. Denoise loop
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(
self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared
)
# 5. VAE decode
self.load_models_to_device(['vae'])
latents = inputs_shared["latents"] / self.vae.scaling_factor
image = self.vae.decode(latents)
image = self.vae_output_to_image(image)
self.load_models_to_device([])
return image
class SDUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width"),
output_params=("height", "width"),
)
def process(self, pipe: StableDiffusionPipeline, height, width):
height, width = pipe.check_resize_height_width(height, width)
return {"height": height, "width": width}
class SDUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_embeds",),
onload_model_names=("text_encoder",)
)
def encode_prompt(
self,
pipe: StableDiffusionPipeline,
prompt: str,
device: torch.device,
) -> torch.Tensor:
text_inputs = pipe.tokenizer(
prompt,
padding="max_length",
max_length=pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_embeds = pipe.text_encoder(text_input_ids)
# TextEncoder returns (last_hidden_state, hidden_states) or just last_hidden_state.
# last_hidden_state is the post-final-layer-norm output, matching diffusers encode_prompt.
if isinstance(prompt_embeds, tuple):
prompt_embeds = prompt_embeds[0]
return prompt_embeds
def process(self, pipe: StableDiffusionPipeline, prompt):
pipe.load_models_to_device(self.onload_model_names)
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
return {"prompt_embeds": prompt_embeds}
class SDUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe: StableDiffusionPipeline, height, width, seed, rand_device):
noise = pipe.generate_noise(
(1, pipe.unet.in_channels, height // 8, width // 8),
seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype
)
return {"noise": noise}
class SDUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",),
)
def process(self, pipe: StableDiffusionPipeline, input_image, noise):
if input_image is None:
return {"latents": noise}
pipe.load_models_to_device(self.onload_model_names)
input_tensor = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(input_tensor).sample() * pipe.vae.scaling_factor
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
if pipe.scheduler.training:
return {"latents": latents, "input_latents": input_latents}
else:
return {"latents": latents}
def model_fn_stable_diffusion(
unet: UNet2DConditionModel,
latents=None,
timestep=None,
prompt_embeds=None,
cross_attention_kwargs=None,
timestep_cond=None,
added_cond_kwargs=None,
**kwargs,
):
# SD timestep is already in 0-999 range, no scaling needed
noise_pred = unet(
latents,
timestep,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
timestep_cond=timestep_cond,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
return noise_pred

View File

@@ -1,331 +0,0 @@
import torch
from PIL import Image
from tqdm import tqdm
from typing import Union
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion.ddim_scheduler import DDIMScheduler
from ..core import ModelConfig
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from transformers import AutoTokenizer, CLIPTextModel
from ..models.stable_diffusion_text_encoder import SDTextEncoder
from ..models.stable_diffusion_xl_unet import SDXLUNet2DConditionModel
from ..models.stable_diffusion_xl_text_encoder import SDXLTextEncoder2
from ..models.stable_diffusion_vae import StableDiffusionVAE
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""Rescale noise_cfg based on guidance_rescale to prevent overexposure.
Based on Section 3.4 from "Common Diffusion Noise Schedules and Sample Steps are Flawed"
https://huggingface.co/papers/2305.08891
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
class StableDiffusionXLPipeline(BasePipeline):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=8, width_division_factor=8,
)
self.scheduler = DDIMScheduler()
self.text_encoder: SDTextEncoder = None
self.text_encoder_2: SDXLTextEncoder2 = None
self.unet: SDXLUNet2DConditionModel = None
self.vae: StableDiffusionVAE = None
self.tokenizer: AutoTokenizer = None
self.tokenizer_2: AutoTokenizer = None
self.in_iteration_models = ("unet",)
self.units = [
SDXLUnit_ShapeChecker(),
SDXLUnit_PromptEmbedder(),
SDXLUnit_NoiseInitializer(),
SDXLUnit_InputImageEmbedder(),
SDXLUnit_AddTimeIdsComputer(),
]
self.model_fn = model_fn_stable_diffusion_xl
self.compilable_models = ["unet"]
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = None,
tokenizer_2_config: ModelConfig = None,
vram_limit: float = None,
):
pipe = StableDiffusionXLPipeline(device=device, torch_dtype=torch_dtype)
# Override vram_config to use the specified torch_dtype for all models
for mc in model_configs:
mc._vram_config_override = {
'onload_dtype': torch_dtype,
'computation_dtype': torch_dtype,
}
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
pipe.text_encoder = model_pool.fetch_model("stable_diffusion_text_encoder")
pipe.text_encoder_2 = model_pool.fetch_model("stable_diffusion_xl_text_encoder")
pipe.unet = model_pool.fetch_model("stable_diffusion_xl_unet")
pipe.vae = model_pool.fetch_model("stable_diffusion_xl_vae")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
if tokenizer_2_config is not None:
tokenizer_2_config.download_if_necessary()
pipe.tokenizer_2 = AutoTokenizer.from_pretrained(tokenizer_2_config.path)
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 5.0,
height: int = 1024,
width: int = 1024,
seed: int = None,
rand_device: str = "cpu",
num_inference_steps: int = 50,
guidance_rescale: float = 0.0,
progress_bar_cmd=tqdm,
):
# 1. Scheduler
self.scheduler.set_timesteps(num_inference_steps)
# 2. Three-dict input preparation
inputs_posi = {
"prompt": prompt,
}
inputs_nega = {
"prompt": negative_prompt,
}
inputs_shared = {
"cfg_scale": cfg_scale,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"guidance_rescale": guidance_rescale,
"crops_coords_top_left": (0, 0),
}
# 3. Unit chain execution
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
unit, self, inputs_shared, inputs_posi, inputs_nega
)
# 4. Denoise loop
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
# Apply guidance_rescale
if guidance_rescale > 0.0:
# cfg_guided_model_fn already applied CFG, now apply rescale
# We need the text-only prediction for rescale
noise_pred_text = self.model_fn(
self.unet,
inputs_shared["latents"],
timestep,
inputs_posi["prompt_embeds"],
pooled_prompt_embeds=inputs_posi["pooled_prompt_embeds"],
add_time_ids=inputs_posi["add_time_ids"],
)
noise_pred = rescale_noise_cfg(
noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
)
inputs_shared["latents"] = self.step(
self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared
)
# 6. VAE decode
self.load_models_to_device(['vae'])
latents = inputs_shared["latents"] / self.vae.scaling_factor
image = self.vae.decode(latents)
image = self.vae_output_to_image(image)
self.load_models_to_device([])
return image
class SDXLUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width"),
output_params=("height", "width"),
)
def process(self, pipe: StableDiffusionXLPipeline, height, width):
height, width = pipe.check_resize_height_width(height, width)
return {"height": height, "width": width}
class SDXLUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "prompt"},
output_params=("prompt_embeds", "pooled_prompt_embeds"),
onload_model_names=("text_encoder", "text_encoder_2")
)
def encode_prompt(
self,
pipe: StableDiffusionXLPipeline,
prompt: str,
device: torch.device,
) -> tuple:
"""Encode prompt using both text encoders (same prompt for both).
Returns (prompt_embeds, pooled_prompt_embeds):
- prompt_embeds: concat(encoder1_output, encoder2_output) -> (B, 77, 2048)
- pooled_prompt_embeds: encoder2 pooled output -> (B, 1280)
"""
# Text Encoder 1 (CLIP-L, 768-dim)
text_input_ids_1 = pipe.tokenizer(
prompt,
padding="max_length",
max_length=pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
).input_ids.to(device)
prompt_embeds_1 = pipe.text_encoder(text_input_ids_1)
if isinstance(prompt_embeds_1, tuple):
prompt_embeds_1 = prompt_embeds_1[0]
# Text Encoder 2 (CLIP-bigG, 1280-dim) — uses penultimate hidden states + pooled
text_input_ids_2 = pipe.tokenizer_2(
prompt,
padding="max_length",
max_length=pipe.tokenizer_2.model_max_length,
truncation=True,
return_tensors="pt",
).input_ids.to(device)
# SDXLTextEncoder2 forward returns (text_embeds/pooled, hidden_states_tuple)
pooled_prompt_embeds, hidden_states = pipe.text_encoder_2(text_input_ids_2, output_hidden_states=True)
# Use penultimate hidden state (same as diffusers: hidden_states[-2])
prompt_embeds_2 = hidden_states[-2]
# Concatenate both encoder outputs along feature dimension
prompt_embeds = torch.cat([prompt_embeds_1, prompt_embeds_2], dim=-1)
return prompt_embeds, pooled_prompt_embeds
def process(self, pipe: StableDiffusionXLPipeline, prompt):
pipe.load_models_to_device(self.onload_model_names)
prompt_embeds, pooled_prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds}
class SDXLUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe: StableDiffusionXLPipeline, height, width, seed, rand_device):
noise = pipe.generate_noise(
(1, pipe.unet.in_channels, height // 8, width // 8),
seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype
)
return {"noise": noise}
class SDXLUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",),
)
def process(self, pipe: StableDiffusionXLPipeline, input_image, noise):
if input_image is None:
return {"latents": noise}
pipe.load_models_to_device(self.onload_model_names)
input_tensor = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(input_tensor).sample() * pipe.vae.scaling_factor
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
if pipe.scheduler.training:
return {"latents": latents, "input_latents": input_latents}
else:
return {"latents": latents}
class SDXLUnit_AddTimeIdsComputer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width"),
output_params=("add_time_ids",),
)
def _get_add_time_ids(self, pipe, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim):
add_time_ids = list(original_size + crops_coords_top_left + target_size)
expected_add_embed_dim = pipe.unet.add_embedding.linear_1.in_features
addition_time_embed_dim = pipe.unet.add_time_proj.num_channels
passed_add_embed_dim = addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, "
f"but a vector of {passed_add_embed_dim} was created."
)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=pipe.device)
return add_time_ids
def process(self, pipe: StableDiffusionXLPipeline, height, width):
original_size = (height, width)
target_size = (height, width)
crops_coords_top_left = (0, 0)
text_encoder_projection_dim = pipe.text_encoder_2.config.projection_dim
add_time_ids = self._get_add_time_ids(
pipe, original_size, crops_coords_top_left, target_size,
dtype=pipe.torch_dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
return {"add_time_ids": add_time_ids}
def model_fn_stable_diffusion_xl(
unet: SDXLUNet2DConditionModel,
latents=None,
timestep=None,
prompt_embeds=None,
pooled_prompt_embeds=None,
add_time_ids=None,
cross_attention_kwargs=None,
timestep_cond=None,
**kwargs,
):
"""SDXL model forward with added_cond_kwargs for micro-conditioning."""
added_cond_kwargs = {
"text_embeds": pooled_prompt_embeds,
"time_ids": add_time_ids,
}
noise_pred = unet(
latents,
timestep,
encoder_hidden_states=prompt_embeds,
added_cond_kwargs=added_cond_kwargs,
cross_attention_kwargs=cross_attention_kwargs,
timestep_cond=timestep_cond,
return_dict=False,
)[0]
return noise_pred

View File

@@ -75,19 +75,15 @@ class WanVideoPipeline(BasePipeline):
WanVideoUnit_TeaCache(),
WanVideoUnit_CfgMerger(),
WanVideoUnit_LongCatVideo(),
WanVideoUnit_WanToDance_ProcessInputs(),
WanVideoUnit_WanToDance_RefImageEmbedder(),
WanVideoUnit_WanToDance_ImageKeyframesEmbedder(),
]
self.post_units = [
WanVideoPostUnit_S2V(),
]
self.model_fn = model_fn_wan_video
self.compilable_models = ["dit", "dit2"]
def enable_usp(self):
from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward, usp_dit_forward, usp_vace_forward
from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward, usp_dit_forward
for block in self.dit.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
@@ -96,14 +92,6 @@ class WanVideoPipeline(BasePipeline):
for block in self.dit2.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
if self.vace is not None:
for block in self.vace.vace_blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.vace.forward = types.MethodType(usp_vace_forward, self.vace)
if self.vace2 is not None:
for block in self.vace2.vace_blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.vace2.forward = types.MethodType(usp_vace_forward, self.vace2)
self.sp_size = get_sequence_parallel_world_size()
self.use_unified_sequence_parallel = True
@@ -256,13 +244,6 @@ class WanVideoPipeline(BasePipeline):
# Teacache
tea_cache_l1_thresh: Optional[float] = None,
tea_cache_model_id: Optional[str] = "",
# WanToDance
wantodance_music_path: Optional[str] = None,
wantodance_reference_image: Optional[Image.Image] = None,
wantodance_fps: Optional[float] = 30,
wantodance_keyframes: Optional[list[Image.Image]] = None,
wantodance_keyframes_mask: Optional[list[int]] = None,
framewise_decoding: bool = False,
# progress_bar
progress_bar_cmd=tqdm,
output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized",
@@ -299,9 +280,6 @@ class WanVideoPipeline(BasePipeline):
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video,
"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video,
"vap_video": vap_video,
"wantodance_music_path": wantodance_music_path, "wantodance_reference_image": wantodance_reference_image, "wantodance_fps": wantodance_fps,
"wantodance_keyframes": wantodance_keyframes, "wantodance_keyframes_mask": wantodance_keyframes_mask,
"framewise_decoding": framewise_decoding,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -347,10 +325,7 @@ class WanVideoPipeline(BasePipeline):
inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Decode
self.load_models_to_device(['vae'])
if framewise_decoding:
video = self.vae.decode_framewise(inputs_shared["latents"], device=self.device)
else:
video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if output_type == "quantized":
video = self.vae_output_to_video(video)
elif output_type == "floatpoint":
@@ -396,20 +371,17 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit):
class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image", "framewise_decoding"),
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image, framewise_decoding):
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image):
if input_video is None:
return {"latents": noise}
pipe.load_models_to_device(self.onload_model_names)
input_video = pipe.preprocess_video(input_video)
if framewise_decoding:
input_latents = pipe.vae.encode_framewise(input_video, device=pipe.device)
else:
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
if vace_reference_image is not None:
if not isinstance(vace_reference_image, list):
vace_reference_image = [vace_reference_image]
@@ -1046,111 +1018,6 @@ class WanVideoUnit_LongCatVideo(PipelineUnit):
return {"longcat_latents": longcat_latents}
class WanVideoUnit_WanToDance_ProcessInputs(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
)
def get_music_base_feature(self, music_path, fps=30):
import librosa
hop_length = 512
sr = fps * hop_length
data, sr = librosa.load(music_path, sr=sr)
sr = 22050
envelope = librosa.onset.onset_strength(y=data, sr=sr)
mfcc = librosa.feature.mfcc(y=data, sr=sr, n_mfcc=20).T
chroma = librosa.feature.chroma_cens(
y=data, sr=sr, hop_length=hop_length, n_chroma=12
).T
peak_idxs = librosa.onset.onset_detect(
onset_envelope=envelope.flatten(), sr=sr, hop_length=hop_length
)
peak_onehot = np.zeros_like(envelope, dtype=np.float32)
peak_onehot[peak_idxs] = 1.0
start_bpm = librosa.beat.tempo(y=librosa.load(music_path)[0])[0]
_, beat_idxs = librosa.beat.beat_track(
onset_envelope=envelope,
sr=sr,
hop_length=hop_length,
start_bpm=start_bpm,
tightness=100,
)
beat_onehot = np.zeros_like(envelope, dtype=np.float32)
beat_onehot[beat_idxs] = 1.0
audio_feature = np.concatenate(
[envelope[:, None], mfcc, chroma, peak_onehot[:, None], beat_onehot[:, None]],
axis=-1,
)
return torch.from_numpy(audio_feature)
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
if pipe.dit.wantodance_enable_global:
inputs_nega["skip_9th_layer"] = True
if inputs_shared.get("wantodance_music_path", None) is not None:
inputs_shared["music_feature"] = self.get_music_base_feature(inputs_shared["wantodance_music_path"]).to(dtype=pipe.torch_dtype, device=pipe.device)
return inputs_shared, inputs_posi, inputs_nega
class WanVideoUnit_WanToDance_RefImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("wantodance_reference_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
output_params=("wantodance_refimage_feature",),
onload_model_names=("image_encoder", "vae")
)
def process(self, pipe: WanVideoPipeline, wantodance_reference_image, num_frames, height, width, tiled, tile_size, tile_stride):
if wantodance_reference_image is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
if isinstance(wantodance_reference_image, list):
wantodance_reference_image = wantodance_reference_image[0]
image = pipe.preprocess_image(wantodance_reference_image.resize((width, height))).to(pipe.device) # B,C,H,W;B=1
refimage_feature = pipe.image_encoder.encode_image([image])
refimage_feature = refimage_feature.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"wantodance_refimage_feature": refimage_feature}
class WanVideoUnit_WanToDance_ImageKeyframesEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("wantodance_keyframes", "wantodance_keyframes_mask", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
output_params=("clip_feature", "y"),
onload_model_names=("image_encoder", "vae")
)
def process(self, pipe: WanVideoPipeline, wantodance_keyframes, wantodance_keyframes_mask, num_frames, height, width, tiled, tile_size, tile_stride):
if wantodance_keyframes is None:
return {}
wantodance_keyframes_mask = torch.tensor(wantodance_keyframes_mask)
pipe.load_models_to_device(self.onload_model_names)
images = []
for input_image in wantodance_keyframes:
input_image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
images.append(input_image)
clip_context = pipe.image_encoder.encode_image(images[:1]) # 取第一帧作为clip输入
msk = torch.zeros(1, num_frames, height//8, width//8, device=pipe.device)
msk[:, wantodance_keyframes_mask==1, :, :] = torch.ones(1, height//8, width//8, device=pipe.device) # set keyframes mask to 1
images = [image.transpose(0, 1) for image in images] # 3, num_frames, h, w
images = torch.concat(images, dim=1)
vae_input = images
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) # expand first frame mask, N to N + 3
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
y = torch.concat([msk, y])
y = y.unsqueeze(0)
clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"clip_feature": clip_context, "y": y}
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
self.num_inference_steps = num_inference_steps
@@ -1256,22 +1123,6 @@ class TemporalTiler_BCTHW:
return value
def wantodance_get_single_freqs(freqs, frame_num, fps):
total_frame = int(30.0 / (fps + 1e-6) * frame_num + 0.5)
interval_frame = 30.0 / (fps + 1e-6)
freqs_0 = freqs[:total_frame]
freqs_new = torch.zeros((frame_num, freqs_0.shape[1]), device=freqs_0.device, dtype=freqs_0.dtype)
freqs_new[0] = freqs_0[0]
freqs_new[-1] = freqs_0[total_frame - 1]
for i in range(1, frame_num-1):
pos = i * interval_frame
low_idx = int(pos)
high_idx = min(low_idx + 1, total_frame - 1)
weight_high = pos - low_idx
weight_low = 1.0 - weight_high
freqs_new[i] = freqs_0[low_idx] * weight_low + freqs_0[high_idx] * weight_high
return freqs_new
def model_fn_wan_video(
dit: WanModel,
@@ -1307,10 +1158,6 @@ def model_fn_wan_video(
use_gradient_checkpointing_offload: bool = False,
control_camera_latents_input = None,
fuse_vae_embedding_in_latents: bool = False,
wantodance_refimage_feature = None,
wantodance_fps: float = 30.0,
music_feature = None,
skip_9th_layer: bool = False,
**kwargs,
):
if sliding_window_size is not None and sliding_window_stride is not None:
@@ -1408,10 +1255,7 @@ def model_fn_wan_video(
context = torch.cat([clip_embdding, context], dim=1)
# Camera control
if hasattr(dit, "wantodance_enable_global") and dit.wantodance_enable_global and int(wantodance_fps + 0.5) != 30:
x = dit.patchify(x, control_camera_latents_input, enable_wantodance_global=True)
else:
x = dit.patchify(x, control_camera_latents_input)
x = dit.patchify(x, control_camera_latents_input)
# Animate
if pose_latents is not None and face_pixel_values is not None:
@@ -1459,61 +1303,14 @@ def model_fn_wan_video(
tea_cache_update = tea_cache.check(dit, x, t_mod)
else:
tea_cache_update = False
# WanToDance
if hasattr(dit, "wantodance_enable_global") and dit.wantodance_enable_global:
if wantodance_refimage_feature is not None:
refimage_feature_embedding = dit.img_emb_refimage(wantodance_refimage_feature)
context = torch.cat([refimage_feature_embedding, context], dim=1)
if (dit.wantodance_enable_dynamicfps or dit.wantodance_enable_unimodel) and int(wantodance_fps + 0.5) != 30:
freqs_0 = wantodance_get_single_freqs(dit.freqs[0], f, wantodance_fps)
freqs = torch.cat([
freqs_0.view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
if dit.wantodance_enable_global or dit.wantodance_enable_dynamicfps or dit.wantodance_enable_unimodel:
if use_unified_sequence_parallel:
length = int(float(music_feature.shape[0]) / get_sequence_parallel_world_size()) * get_sequence_parallel_world_size()
music_feature = music_feature[:length]
music_feature = torch.chunk(music_feature, get_sequence_parallel_world_size(), dim=0)[get_sequence_parallel_rank()]
if not dit.training:
dit.music_encoder.to(x.device, dtype=x.dtype) # only evaluation
music_feature = music_feature.to(x.device, dtype=x.dtype)
music_feature = dit.music_projection(music_feature)
music_feature = dit.music_encoder(music_feature)
if music_feature.dim() == 2:
music_feature = music_feature.unsqueeze(0)
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
music_feature = get_sp_group().all_gather(music_feature, dim=1)
music_feature = music_feature.unsqueeze(1) # [1, 1, 149, 4800]
N = 149
M = 4800
music_feature = torch.nn.functional.interpolate(music_feature, size=(N, M), mode='bilinear')
music_feature = music_feature.squeeze(1) # shape: [1, 149, 4800]
if music_feature is not None:
if music_feature.dim() == 2:
music_feature = music_feature.unsqueeze(0)
music_feature = music_feature.to(x.device, dtype=x.dtype)
interp_mode = 'bilinear'
if interp_mode == 'bilinear':
frame_num = latents.shape[2] if len(latents.shape) == 5 else latents.shape[1] # 21
context_shape_end = context.shape[2] ## 14B 5120
music_feature = music_feature.unsqueeze(1) # shape: [1, 1, 149, 4800]
if use_unified_sequence_parallel:
N = int(float(frame_num * 8) / get_sequence_parallel_world_size()) * get_sequence_parallel_world_size()
else:
N = frame_num * 8
music_feature = torch.nn.functional.interpolate(music_feature, size=(N, context_shape_end), mode='bilinear')
music_feature = music_feature.squeeze(1) # shape: [1, N, context_shape_end]
if use_unified_sequence_parallel:
dit.merged_audio_emb = torch.chunk(music_feature, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
else:
dit.merged_audio_emb = music_feature
else:
dit.merged_audio_emb = music_feature
if vace_context is not None:
vace_hints = vace(
x, vace_context, context, t_mod, freqs,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload
)
# blocks
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
@@ -1521,13 +1318,6 @@ def model_fn_wan_video(
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
x = chunks[get_sequence_parallel_rank()]
if vace_context is not None:
vace_hints = vace(
x, vace_context, context, t_mod, freqs,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload
)
if tea_cache_update:
x = tea_cache.update(x)
else:
@@ -1536,12 +1326,8 @@ def model_fn_wan_video(
return vap(block, *inputs)
return custom_forward
# Block
for block_id, block in enumerate(dit.blocks):
if skip_9th_layer:
# This is only used in WanToDance
if block_id == 9:
continue
# Block
if vap is not None and block_id in vap.mot_layers_mapping:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
@@ -1570,23 +1356,18 @@ def model_fn_wan_video(
# VACE
if vace_context is not None and block_id in vace.vace_layers_mapping:
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0)
x = x + current_vace_hint * vace_scale
# Animate
if pose_latents is not None and face_pixel_values is not None:
x = animate_adapter.after_transformer_block(block_id, x, motion_vec)
# WanToDance
if hasattr(dit, "wantodance_enable_music_inject") and dit.wantodance_enable_music_inject:
x = dit.wantodance_after_transformer_block(block_id, x)
if tea_cache is not None:
tea_cache.store(x)
if hasattr(dit, "wantodance_enable_unimodel") and dit.wantodance_enable_unimodel and int(wantodance_fps + 0.5) != 30:
x = dit.head_global(x, t)
else:
x = dit.head(x, t)
x = dit.head(x, t)
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)

View File

@@ -1,4 +1,4 @@
import torch, math, warnings
import torch, math
from PIL import Image
from typing import Union
from tqdm import tqdm
@@ -6,7 +6,7 @@ from einops import rearrange
import numpy as np
from typing import Union, List, Optional, Tuple, Iterable, Dict
from ..core.device.npu_compatible_device import get_device_type, IS_NPU_AVAILABLE
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..core.data.operators import ImageCropAndResize
@@ -54,7 +54,6 @@ class ZImagePipeline(BasePipeline):
ZImageUnit_PAIControlNet(),
]
self.model_fn = model_fn_z_image
self.compilable_models = ["dit"]
@staticmethod
@@ -64,7 +63,6 @@ class ZImagePipeline(BasePipeline):
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
vram_limit: float = None,
enable_npu_patch: bool = True,
):
# Initialize pipeline
pipe = ZImagePipeline(device=device, torch_dtype=torch_dtype)
@@ -86,8 +84,6 @@ class ZImagePipeline(BasePipeline):
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
# NPU patch
apply_npu_patch(enable_npu_patch)
return pipe
@@ -300,7 +296,7 @@ class ZImageUnit_PromptEmbedder(PipelineUnit):
def process(self, pipe: ZImagePipeline, prompt, edit_image):
pipe.load_models_to_device(self.onload_model_names)
if hasattr(pipe, "dit") and pipe.dit is not None and pipe.dit.siglip_embedder is not None:
if hasattr(pipe, "dit") and pipe.dit.siglip_embedder is not None:
# Z-Image-Turbo and Z-Image-Omni-Base use different prompt encoding methods.
# We determine which encoding method to use based on the model architecture.
# If you are using two-stage split training,
@@ -671,19 +667,3 @@ def model_fn_z_image_turbo(
x = rearrange(x, "C B H W -> B C H W")
x = -x
return x
def apply_npu_patch(enable_npu_patch: bool=True):
if IS_NPU_AVAILABLE and enable_npu_patch:
from ..models.general_modules import RMSNorm
from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm
from ..models.z_image_dit import Attention
from ..core.npu_patch.npu_fused_operator import (
rms_norm_forward_npu,
rms_norm_forward_transformers_npu,
rotary_emb_Zimage_npu
)
warnings.warn("Replacing RMSNorm and Rope with NPU fusion operators to improve the performance of the model on NPU.Set enable_npu_patch=False to disable this feature.")
RMSNorm.forward = rms_norm_forward_npu
Qwen3RMSNorm.forward = rms_norm_forward_transformers_npu
Attention.apply_rotary_emb = rotary_emb_Zimage_npu

View File

@@ -116,7 +116,7 @@ class VideoData:
if self.height is not None and self.width is not None:
return self.height, self.width
else:
width, height = self.__getitem__(0).size
height, width, _ = self.__getitem__(0).shape
return height, width
def __getitem__(self, item):

View File

@@ -1,108 +0,0 @@
import torch
import torchaudio
def convert_to_mono(audio_tensor: torch.Tensor) -> torch.Tensor:
"""
Convert audio to mono by averaging channels.
Supports [C, T] or [B, C, T]. Output shape: [1, T] or [B, 1, T].
"""
return audio_tensor.mean(dim=-2, keepdim=True)
def convert_to_stereo(audio_tensor: torch.Tensor) -> torch.Tensor:
"""
Convert audio to stereo.
Supports [C, T] or [B, C, T]. Duplicate mono, keep stereo.
"""
if audio_tensor.size(-2) == 1:
return audio_tensor.repeat(1, 2, 1) if audio_tensor.dim() == 3 else audio_tensor.repeat(2, 1)
return audio_tensor
def resample_waveform(waveform: torch.Tensor, source_rate: int, target_rate: int) -> torch.Tensor:
"""Resample waveform to target sample rate if needed."""
if source_rate == target_rate:
return waveform
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
return resampled.to(dtype=waveform.dtype)
def read_audio_with_torchcodec(
path: str,
start_time: float = 0,
duration: float | None = None,
) -> tuple[torch.Tensor, int]:
"""
Read audio from file natively using torchcodec, with optional start time and duration.
Args:
path (str): The file path to the audio file.
start_time (float, optional): The start time in seconds to read from. Defaults to 0.
duration (float | None, optional): The duration in seconds to read. If None, reads until the end. Defaults to None.
Returns:
tuple[torch.Tensor, int]: A tuple containing the audio tensor and the sample rate.
The audio tensor shape is [C, T] where C is the number of channels and T is the number of audio frames.
"""
from torchcodec.decoders import AudioDecoder
decoder = AudioDecoder(path)
stop_seconds = None if duration is None else start_time + duration
waveform = decoder.get_samples_played_in_range(start_seconds=start_time, stop_seconds=stop_seconds).data
return waveform, decoder.metadata.sample_rate
def read_audio(
path: str,
start_time: float = 0,
duration: float | None = None,
resample: bool = False,
resample_rate: int = 48000,
backend: str = "torchcodec",
) -> tuple[torch.Tensor, int]:
"""
Read audio from file, with optional start time, duration, and resampling.
Args:
path (str): The file path to the audio file.
start_time (float, optional): The start time in seconds to read from. Defaults to 0.
duration (float | None, optional): The duration in seconds to read. If None, reads until the end. Defaults to None.
resample (bool, optional): Whether to resample the audio to a different sample rate. Defaults to False.
resample_rate (int, optional): The target sample rate for resampling if resample is True. Defaults to 48000.
backend (str, optional): The audio backend to use for reading. Defaults to "torchcodec".
Returns:
tuple[torch.Tensor, int]: A tuple containing the audio tensor and the sample rate.
The audio tensor shape is [C, T] where C is the number of channels and T is the number of audio frames.
"""
if backend == "torchcodec":
waveform, sample_rate = read_audio_with_torchcodec(path, start_time, duration)
else:
raise ValueError(f"Unsupported audio backend: {backend}")
if resample:
waveform = resample_waveform(waveform, sample_rate, resample_rate)
sample_rate = resample_rate
return waveform, sample_rate
def save_audio(waveform: torch.Tensor, sample_rate: int, save_path: str, backend: str = "torchcodec"):
"""
Save audio tensor to file.
Args:
waveform (torch.Tensor): The audio tensor to save. Shape can be [C, T] or [B, C, T].
sample_rate (int): The sample rate of the audio.
save_path (str): The file path to save the audio to.
backend (str, optional): The audio backend to use for saving. Defaults to "torchcodec".
"""
if waveform.dim() == 3:
waveform = waveform[0]
if backend == "torchcodec":
from torchcodec.encoders import AudioEncoder
encoder = AudioEncoder(waveform, sample_rate=sample_rate)
encoder.to_file(dest=save_path)
else:
raise ValueError(f"Unsupported audio backend: {backend}")

View File

@@ -1,134 +0,0 @@
import av
from fractions import Fraction
import torch
from PIL import Image
from tqdm import tqdm
from .audio import convert_to_stereo
def _resample_audio(
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
) -> None:
cc = audio_stream.codec_context
# Use the encoder's format/layout/rate as the *target*
target_format = cc.format or "fltp" # AAC → usually fltp
target_layout = cc.layout or "stereo"
target_rate = cc.sample_rate or frame_in.sample_rate
audio_resampler = av.audio.resampler.AudioResampler(
format=target_format,
layout=target_layout,
rate=target_rate,
)
audio_next_pts = 0
for rframe in audio_resampler.resample(frame_in):
if rframe.pts is None:
rframe.pts = audio_next_pts
audio_next_pts += rframe.samples
rframe.sample_rate = frame_in.sample_rate
container.mux(audio_stream.encode(rframe))
# flush audio encoder
for packet in audio_stream.encode():
container.mux(packet)
def _write_audio(
container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int
) -> None:
if samples.ndim == 1:
samples = samples.unsqueeze(0)
samples = convert_to_stereo(samples)
assert samples.ndim == 2 and samples.shape[0] == 2, "audio samples must be [C, S] or [S], C must be 1 or 2"
samples = samples.T
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
if samples.dtype != torch.int16:
samples = torch.clip(samples, -1.0, 1.0)
samples = (samples * 32767.0).to(torch.int16)
frame_in = av.AudioFrame.from_ndarray(
samples.contiguous().reshape(1, -1).cpu().numpy(),
format="s16",
layout="stereo",
)
frame_in.sample_rate = audio_sample_rate
_resample_audio(container, audio_stream, frame_in)
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
"""
Prepare the audio stream for writing.
"""
audio_stream = container.add_stream("aac")
supported_sample_rates = audio_stream.codec_context.codec.audio_rates
if supported_sample_rates:
best_rate = min(supported_sample_rates, key=lambda x: abs(x - audio_sample_rate))
if best_rate != audio_sample_rate:
print(f"Using closest supported audio sample rate: {best_rate}")
else:
best_rate = audio_sample_rate
audio_stream.codec_context.sample_rate = best_rate
audio_stream.codec_context.layout = "stereo"
audio_stream.codec_context.time_base = Fraction(1, best_rate)
return audio_stream
def write_video_audio(
video: list[Image.Image],
audio: torch.Tensor | None,
output_path: str,
fps: int = 24,
audio_sample_rate: int | None = None,
) -> None:
"""
Writes a sequence of images and an audio tensor to a video file.
This function utilizes PyAV (or a similar multimedia library) to encode a list of PIL images into a video stream
and multiplex a PyTorch tensor as the audio stream into the output container.
Args:
video (list[Image.Image]): A list of PIL Image objects representing the video frames.
The length of this list determines the total duration of the video based on the FPS.
audio (torch.Tensor | None): The audio data as a PyTorch tensor.
The shape is typically (channels, samples). If no audio is required, pass None.
channels can be 1 or 2. 1 for mono, 2 for stereo.
output_path (str): The file path (including extension) where the output video will be saved.
fps (int, optional): The frame rate (frames per second) for the video. Defaults to 24.
audio_sample_rate (int | None, optional): The sample rate (e.g., 44100, 48000) for the audio.
If the audio tensor is provided and this is None, the function attempts to infer the rate
based on the audio tensor's length and the video duration.
Raises:
ValueError: If an audio tensor is provided but the sample rate cannot be determined.
"""
duration = len(video) / fps
if audio_sample_rate is None:
audio_sample_rate = int(audio.shape[-1] / duration)
width, height = video[0].size
container = av.open(output_path, mode="w")
stream = container.add_stream("libx264", rate=int(fps))
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
if audio is not None:
if audio_sample_rate is None:
raise ValueError("audio_sample_rate is required when audio is provided")
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
for frame in tqdm(video, total=len(video)):
frame = av.VideoFrame.from_image(frame)
for packet in stream.encode(frame):
container.mux(packet)
# Flush encoder
for packet in stream.encode():
container.mux(packet)
if audio is not None:
_write_audio(container, audio_stream, audio, audio_sample_rate)
container.close()

View File

@@ -1,43 +0,0 @@
import av
import numpy as np
from io import BytesIO
from .audio_video import write_video_audio as write_video_audio_ltx2
def encode_single_frame(output_file: str, image_array: np.ndarray, crf: float) -> None:
container = av.open(output_file, "w", format="mp4")
try:
stream = container.add_stream("libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"})
# Round to nearest multiple of 2 for compatibility with video codecs
height = image_array.shape[0] // 2 * 2
width = image_array.shape[1] // 2 * 2
image_array = image_array[:height, :width]
stream.height = height
stream.width = width
av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(format="yuv420p")
container.mux(stream.encode(av_frame))
container.mux(stream.encode())
finally:
container.close()
def decode_single_frame(video_file: str) -> np.array:
container = av.open(video_file)
try:
stream = next(s for s in container.streams if s.type == "video")
frame = next(container.decode(stream))
finally:
container.close()
return frame.to_ndarray(format="rgb24")
def ltx2_preprocess(image: np.array, crf: float = 33) -> np.array:
if crf == 0:
return image
with BytesIO() as output_file:
encode_single_frame(output_file, image, crf)
video_bytes = output_file.getvalue()
with BytesIO(video_bytes) as video_file:
image_array = decode_single_frame(video_file)
return image_array

View File

@@ -1,4 +1,4 @@
import torch, warnings
import torch
class GeneralLoRALoader:
@@ -26,11 +26,7 @@ class GeneralLoRALoader:
keys.pop(0)
keys.pop(-1)
target_name = ".".join(keys)
# Alpha: Deprecated but retained for compatibility.
key_alpha = key.replace(lora_B_key + ".weight", "alpha").replace(lora_B_key + ".default.weight", "alpha")
if key_alpha == key or key_alpha not in lora_state_dict:
key_alpha = None
lora_name_dict[target_name] = (key, key.replace(lora_B_key, lora_A_key), key_alpha)
lora_name_dict[target_name] = (key, key.replace(lora_B_key, lora_A_key))
return lora_name_dict
@@ -40,10 +36,6 @@ class GeneralLoRALoader:
for name in name_dict:
weight_up = state_dict[name_dict[name][0]]
weight_down = state_dict[name_dict[name][1]]
if name_dict[name][2] is not None:
warnings.warn("Alpha detected in the LoRA file. This may be a LoRA model not trained by DiffSynth-Studio. To ensure compatibility, the LoRA weights will be converted to weight * alpha / rank.")
alpha = state_dict[name_dict[name][2]] / weight_down.shape[0]
weight_down = weight_down * alpha
state_dict_[name + f".lora_B{suffix}"] = weight_up
state_dict_[name + f".lora_A{suffix}"] = weight_down
return state_dict_

View File

@@ -1 +0,0 @@
Please see `docs/en/Research_Tutorial/inference_time_scaling.md` or `docs/zh/Research_Tutorial/inference_time_scaling.md` for more details.

View File

@@ -1 +0,0 @@
from .ses import ses_search

View File

@@ -1,117 +0,0 @@
import torch
import pywt
import numpy as np
from tqdm import tqdm
def split_dwt(z_tensor_cpu, wavelet_name, dwt_level):
all_clow_np = []
all_chigh_list = []
z_tensor_cpu = z_tensor_cpu.cpu().float()
for i in range(z_tensor_cpu.shape[0]):
z_numpy_ch = z_tensor_cpu[i].numpy()
coeffs_ch = pywt.wavedec2(z_numpy_ch, wavelet_name, level=dwt_level, mode='symmetric', axes=(-2, -1))
clow_np = coeffs_ch[0]
chigh_list = coeffs_ch[1:]
all_clow_np.append(clow_np)
all_chigh_list.append(chigh_list)
all_clow_tensor = torch.from_numpy(np.stack(all_clow_np, axis=0))
return all_clow_tensor, all_chigh_list
def reconstruct_dwt(c_low_tensor_cpu, c_high_coeffs, wavelet_name, original_shape):
H_high, W_high = original_shape
c_low_tensor_cpu = c_low_tensor_cpu.cpu().float()
clow_np = c_low_tensor_cpu.numpy()
if clow_np.ndim == 4 and clow_np.shape[0] == 1:
clow_np = clow_np[0]
coeffs_combined = [clow_np] + c_high_coeffs
z_recon_np = pywt.waverec2(coeffs_combined, wavelet_name, mode='symmetric', axes=(-2, -1))
if z_recon_np.shape[-2] != H_high or z_recon_np.shape[-1] != W_high:
z_recon_np = z_recon_np[..., :H_high, :W_high]
z_recon_tensor = torch.from_numpy(z_recon_np)
if z_recon_tensor.ndim == 3:
z_recon_tensor = z_recon_tensor.unsqueeze(0)
return z_recon_tensor
def ses_search(
base_latents,
objective_reward_fn,
total_eval_budget=30,
popsize=10,
k_elites=5,
wavelet_name="db1",
dwt_level=4,
):
latent_h, latent_w = base_latents.shape[-2], base_latents.shape[-1]
c_low_init, c_high_fixed_batch = split_dwt(base_latents, wavelet_name, dwt_level)
c_high_fixed = c_high_fixed_batch[0]
c_low_shape = c_low_init.shape[1:]
mu = torch.zeros_like(c_low_init.view(-1).cpu())
sigma_sq = torch.ones_like(mu) * 1.0
best_overall = {"fitness": -float('inf'), "score": -float('inf'), "c_low": c_low_init[0]}
eval_count = 0
elite_db = []
n_generations = (total_eval_budget // popsize) + 5
pbar = tqdm(total=total_eval_budget, desc="[SES] Searching", unit="img")
for gen in range(n_generations):
if eval_count >= total_eval_budget: break
std = torch.sqrt(torch.clamp(sigma_sq, min=1e-9))
z_noise = torch.randn(popsize, mu.shape[0])
samples_flat = mu + z_noise * std
samples_reshaped = samples_flat.view(popsize, *c_low_shape)
batch_results = []
for i in range(popsize):
if eval_count >= total_eval_budget: break
c_low_sample = samples_reshaped[i].unsqueeze(0)
z_recon = reconstruct_dwt(c_low_sample, c_high_fixed, wavelet_name, (latent_h, latent_w))
z_recon = z_recon.to(base_latents.device, dtype=base_latents.dtype)
# img = pipeline_callback(z_recon)
# score = scorer.get_score(img, prompt)
score = objective_reward_fn(z_recon)
res = {
"score": score,
"c_low": c_low_sample.cpu()
}
batch_results.append(res)
if score > best_overall['score']:
best_overall = res
eval_count += 1
pbar.update(1)
if not batch_results: break
elite_db.extend(batch_results)
elite_db.sort(key=lambda x: x['score'], reverse=True)
elite_db = elite_db[:k_elites]
elites_flat = torch.stack([x['c_low'].view(-1) for x in elite_db])
mu_new = torch.mean(elites_flat, dim=0)
if len(elite_db) > 1:
sigma_sq_new = torch.var(elites_flat, dim=0, unbiased=True) + 1e-7
else:
sigma_sq_new = sigma_sq
mu = mu_new
sigma_sq = sigma_sq_new
pbar.close()
best_c_low = best_overall['c_low']
final_latents = reconstruct_dwt(best_c_low, c_high_fixed, wavelet_name, (latent_h, latent_w))
return final_latents.to(base_latents.device, dtype=base_latents.dtype)

View File

@@ -1,6 +0,0 @@
def AnimaDiTStateDictConverter(state_dict):
new_state_dict = {}
for key in state_dict:
value = state_dict[key]
new_state_dict[key.replace("net.", "")] = value
return new_state_dict

View File

@@ -1,21 +0,0 @@
def ErnieImageTextEncoderStateDictConverter(state_dict):
"""
Maps checkpoint keys from multimodal Mistral3Model format
to text-only Ministral3Model format.
Checkpoint keys (Mistral3Model):
language_model.model.layers.0.input_layernorm.weight
language_model.model.norm.weight
Model keys (ErnieImageTextEncoder → self.model = Ministral3Model):
model.layers.0.input_layernorm.weight
model.norm.weight
Mapping: language_model. → model.
"""
new_state_dict = {}
for key in state_dict:
if key.startswith("language_model.model."):
new_key = key.replace("language_model.model.", "model.", 1)
new_state_dict[new_key] = state_dict[key]
return new_state_dict

View File

@@ -1,20 +0,0 @@
def JoyAIImageTextEncoderStateDictConverter(state_dict):
"""Convert HuggingFace Qwen3VL checkpoint keys to DiffSynth wrapper keys.
Mapping (checkpoint -> wrapper):
- lm_head.weight -> model.lm_head.weight
- model.language_model.* -> model.model.language_model.*
- model.visual.* -> model.model.visual.*
"""
state_dict_ = {}
for key in state_dict:
if key == "lm_head.weight":
new_key = "model.lm_head.weight"
elif key.startswith("model.language_model."):
new_key = "model.model." + key[len("model."):]
elif key.startswith("model.visual."):
new_key = "model.model." + key[len("model."):]
else:
new_key = key
state_dict_[new_key] = state_dict[key]
return state_dict_

View File

@@ -1,32 +0,0 @@
def LTX2AudioEncoderStateDictConverter(state_dict):
# Not used
state_dict_ = {}
for name in state_dict:
if name.startswith("audio_vae.encoder."):
new_name = name.replace("audio_vae.encoder.", "")
state_dict_[new_name] = state_dict[name]
elif name.startswith("audio_vae.per_channel_statistics."):
new_name = name.replace("audio_vae.per_channel_statistics.", "per_channel_statistics.")
state_dict_[new_name] = state_dict[name]
return state_dict_
def LTX2AudioDecoderStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith("audio_vae.decoder."):
new_name = name.replace("audio_vae.decoder.", "")
state_dict_[new_name] = state_dict[name]
elif name.startswith("audio_vae.per_channel_statistics."):
new_name = name.replace("audio_vae.per_channel_statistics.", "per_channel_statistics.")
state_dict_[new_name] = state_dict[name]
return state_dict_
def LTX2VocoderStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith("vocoder."):
new_name = name[len("vocoder."):]
state_dict_[new_name] = state_dict[name]
return state_dict_

View File

@@ -1,9 +0,0 @@
def LTXModelStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith("model.diffusion_model."):
new_name = name.replace("model.diffusion_model.", "")
if new_name.startswith("audio_embeddings_connector.") or new_name.startswith("video_embeddings_connector."):
continue
state_dict_[new_name] = state_dict[name]
return state_dict_

View File

@@ -1,31 +0,0 @@
def LTX2TextEncoderStateDictConverter(state_dict):
state_dict_ = {}
for key in state_dict:
if key.startswith("language_model.model."):
new_key = key.replace("language_model.model.", "model.language_model.")
elif key.startswith("vision_tower."):
new_key = key.replace("vision_tower.", "model.vision_tower.")
elif key.startswith("multi_modal_projector."):
new_key = key.replace("multi_modal_projector.", "model.multi_modal_projector.")
elif key.startswith("language_model.lm_head."):
new_key = key.replace("language_model.lm_head.", "lm_head.")
else:
continue
state_dict_[new_key] = state_dict[key]
state_dict_["lm_head.weight"] = state_dict_.get("model.language_model.embed_tokens.weight")
return state_dict_
def LTX2TextEncoderPostModulesStateDictConverter(state_dict):
state_dict_ = {}
for key in state_dict:
if key.startswith("text_embedding_projection."):
new_key = key.replace("text_embedding_projection.", "feature_extractor_linear.")
elif key.startswith("model.diffusion_model.video_embeddings_connector."):
new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "embeddings_connector.")
elif key.startswith("model.diffusion_model.audio_embeddings_connector."):
new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "audio_embeddings_connector.")
else:
continue
state_dict_[new_key] = state_dict[key]
return state_dict_

View File

@@ -1,24 +0,0 @@
def LTX2VideoEncoderStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith("vae.encoder."):
new_name = name.replace("vae.encoder.", "")
state_dict_[new_name] = state_dict[name]
elif name.startswith("vae.per_channel_statistics."):
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
if new_name not in ["per_channel_statistics.channel", "per_channel_statistics.mean-of-stds", "per_channel_statistics.mean-of-stds_over_std-of-means"]:
state_dict_[new_name] = state_dict[name]
return state_dict_
def LTX2VideoDecoderStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith("vae.decoder."):
new_name = name.replace("vae.decoder.", "")
state_dict_[new_name] = state_dict[name]
elif name.startswith("vae.per_channel_statistics."):
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
if new_name not in ["per_channel_statistics.channel", "per_channel_statistics.mean-of-stds", "per_channel_statistics.mean-of-stds_over_std-of-means"]:
state_dict_[new_name] = state_dict[name]
return state_dict_

View File

@@ -1,7 +0,0 @@
def SDTextEncoderStateDictConverter(state_dict):
new_state_dict = {}
for key in state_dict:
if key.startswith("text_model.") and "position_ids" not in key:
new_key = "model." + key
new_state_dict[new_key] = state_dict[key]
return new_state_dict

View File

@@ -1,18 +0,0 @@
def SDVAEStateDictConverter(state_dict):
new_state_dict = {}
for key in state_dict:
if ".query." in key:
new_key = key.replace(".query.", ".to_q.")
new_state_dict[new_key] = state_dict[key]
elif ".key." in key:
new_key = key.replace(".key.", ".to_k.")
new_state_dict[new_key] = state_dict[key]
elif ".value." in key:
new_key = key.replace(".value.", ".to_v.")
new_state_dict[new_key] = state_dict[key]
elif ".proj_attn." in key:
new_key = key.replace(".proj_attn.", ".to_out.0.")
new_state_dict[new_key] = state_dict[key]
else:
new_state_dict[key] = state_dict[key]
return new_state_dict

View File

@@ -1,13 +0,0 @@
import torch
def SDXLTextEncoder2StateDictConverter(state_dict):
new_state_dict = {}
for key in state_dict:
if key == "text_projection.weight":
val = state_dict[key]
new_state_dict["model.text_projection.weight"] = val.float() if val.dtype == torch.float16 else val
elif key.startswith("text_model.") and "position_ids" not in key:
new_key = "model." + key
val = state_dict[key]
new_state_dict[new_key] = val.float() if val.dtype == torch.float16 else val
return new_state_dict

View File

@@ -1,3 +0,0 @@
def ZImageDiTStateDictConverter(state_dict):
state_dict_ = {name.replace("model.diffusion_model.", ""): state_dict[name] for name in state_dict}
return state_dict_

View File

@@ -1 +1 @@
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, usp_vace_forward, get_sequence_parallel_world_size, initialize_usp, get_current_chunk, gather_all_chunks
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, get_sequence_parallel_world_size, initialize_usp

View File

@@ -1,13 +1,10 @@
import torch
from typing import Optional
from einops import rearrange
from yunchang.kernels import AttnType
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from ... import IS_NPU_AVAILABLE
from ...core.device import parse_nccl_backend, parse_device_type
from ...core.gradient import gradient_checkpoint_forward
@@ -34,16 +31,13 @@ def sinusoidal_embedding_1d(dim, position):
def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len
original_tensor_device = original_tensor.device
if original_tensor.device == "npu":
original_tensor = original_tensor.cpu()
padding_tensor = torch.ones(
pad_size,
s1,
s2,
dtype=original_tensor.dtype,
device=original_tensor.device)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0).to(device=original_tensor_device)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
return padded_tensor
def rope_apply(x, freqs, num_heads):
@@ -57,7 +51,7 @@ def rope_apply(x, freqs, num_heads):
sp_rank = get_sequence_parallel_rank()
freqs = pad_freqs(freqs, s_per_rank * sp_size)
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device.type == "npu" else freqs_rank
freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device == "npu" else freqs_rank
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
return x_out.to(x.dtype)
@@ -117,39 +111,6 @@ def usp_dit_forward(self,
return x
def usp_vace_forward(
self, x, vace_context, context, t_mod, freqs,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
):
# Compute full sequence length from the sharded x
full_seq_len = x.shape[1] * get_sequence_parallel_world_size()
# Embed vace_context via patch embedding
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = torch.cat([
torch.cat([u, u.new_zeros(1, full_seq_len - u.size(1), u.size(2))],
dim=1) for u in c
])
# Chunk VACE context along sequence dim BEFORE processing through blocks
c = torch.chunk(c, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
# Process through vace_blocks (self_attn already monkey-patched to usp_attn_forward)
for block in self.vace_blocks:
c = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
c, x, context, t_mod, freqs
)
# Hints are already sharded per-rank
hints = torch.unbind(c)[:-1]
return hints
def usp_attn_forward(self, x, freqs):
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))
@@ -161,12 +122,7 @@ def usp_attn_forward(self, x, freqs):
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
attn_type = AttnType.FA
ring_impl_type = "basic"
if IS_NPU_AVAILABLE:
attn_type = AttnType.NPU
ring_impl_type = "basic_npu"
x = xFuserLongContextAttention(attn_type=attn_type, ring_impl_type=ring_impl_type)(
x = xFuserLongContextAttention()(
None,
query=q,
key=k,
@@ -176,31 +132,4 @@ def usp_attn_forward(self, x, freqs):
del q, k, v
getattr(torch, parse_device_type(x.device)).empty_cache()
return self.o(x)
def get_current_chunk(x, dim=1):
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=dim)
ndims = len(chunks[0].shape)
pad_list = [0] * (2 * ndims)
pad_end_index = 2 * (ndims - 1 - dim) + 1
max_size = chunks[0].size(dim)
chunks = [
torch.nn.functional.pad(
chunk,
tuple(pad_list[:pad_end_index] + [max_size - chunk.size(dim)] + pad_list[pad_end_index+1:]),
value=0
)
for chunk in chunks
]
x = chunks[get_sequence_parallel_rank()]
return x
def gather_all_chunks(x, seq_len=None, dim=1):
x = get_sp_group().all_gather(x, dim=dim)
if seq_len is not None:
slices = [slice(None)] * x.ndim
slices[dim] = slice(0, seq_len)
x = x[tuple(slices)]
return x
return self.o(x)

View File

@@ -1,5 +0,0 @@
# Make sure to modify __release_datetime__ to release time when making official release.
__version__ = '2.0.0'
# default release datetime for branches under active development is set
# to be a time far-far-away-into-the-future
__release_datetime__ = '2099-10-13 08:56:12'

View File

@@ -1,28 +0,0 @@
# .readthedocs.yaml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
# Required
version: 2
# Set the OS, Python version and other tools you might need
build:
os: ubuntu-22.04
tools:
python: "3.10"
# Build documentation in the "docs/" directory with Sphinx
sphinx:
configuration: docs/en/conf.py
# Optionally build your docs in additional formats such as PDF and ePub
# formats:
# - pdf
# - epub
# Optional but recommended, declare the Python requirements required
# to build your documentation
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
python:
install:
- requirements: docs/requirements.txt

View File

@@ -1,6 +1,6 @@
# `diffsynth.core.attention`: Attention Mechanism Implementation
`diffsynth.core.attention` provides routing mechanisms for attention mechanism implementations, automatically selecting efficient attention implementations based on available packages in the `Python` environment and [environment variables](../../Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation).
`diffsynth.core.attention` provides routing mechanisms for attention mechanism implementations, automatically selecting efficient attention implementations based on available packages in the `Python` environment and [environment variables](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation).
## Attention Mechanism
@@ -46,7 +46,7 @@ Note that the dimension of the Attention Score in the attention mechanism ( $\te
* xFormers: [GitHub](https://github.com/facebookresearch/xformers), [Documentation](https://facebookresearch.github.io/xformers/components/ops.html#module-xformers.ops)
* PyTorch: [GitHub](https://github.com/pytorch/pytorch), [Documentation](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
To call attention implementations other than `PyTorch`, please follow the instructions on their GitHub pages to install the corresponding packages. `DiffSynth-Studio` will automatically route to the corresponding implementation based on available packages in the Python environment, or can be controlled through [environment variables](../../Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation).
To call attention implementations other than `PyTorch`, please follow the instructions on their GitHub pages to install the corresponding packages. `DiffSynth-Studio` will automatically route to the corresponding implementation based on available packages in the Python environment, or can be controlled through [environment variables](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation).
```python
from diffsynth.core.attention import attention_forward

View File

@@ -8,9 +8,9 @@ This document introduces the model download and loading functionalities in `diff
### Downloading and Loading Models from Remote Sources
Taking the model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) as an example, after filling in `model_id` and `origin_file_pattern` in `ModelConfig`, the model can be automatically downloaded. By default, it downloads to the `./models` path, which can be modified through the [environment variable DIFFSYNTH_MODEL_BASE_PATH](../../Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path).
Taking the model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) as an example, after filling in `model_id` and `origin_file_pattern` in `ModelConfig`, the model can be automatically downloaded. By default, it downloads to the `./models` path, which can be modified through the [environment variable DIFFSYNTH_MODEL_BASE_PATH](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path).
By default, even if the model has already been downloaded, the program will still query the remote for any missing files. To completely disable remote requests, set the [environment variable DIFFSYNTH_SKIP_DOWNLOAD](../../Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) to `True`.
By default, even if the model has already been downloaded, the program will still query the remote for any missing files. To completely disable remote requests, set the [environment variable DIFFSYNTH_SKIP_DOWNLOAD](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) to `True`.
```python
from diffsynth.core import ModelConfig
@@ -51,7 +51,7 @@ config = ModelConfig(path=[
### VRAM Management Configuration
`ModelConfig` also contains VRAM management configuration information. See [VRAM Management](../../Pipeline_Usage/VRAM_management.md#more-usage-methods) for details.
`ModelConfig` also contains VRAM management configuration information. See [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md#more-usage-methods) for details.
## Model File Loading
@@ -103,11 +103,11 @@ print(hash_model_file([
The model hash value is only related to the keys and tensor shapes in the state dict of the model file, and is unrelated to the numerical values of the model parameters, file saving time, and other information. When calculating the model hash value of `.safetensors` format files, `hash_model_file` is almost instantly completed without reading the model parameters. However, when calculating the model hash value of `.bin`, `.pth`, `.ckpt`, and other binary files, all model parameters need to be read, so **we do not recommend developers to continue using these formats of files.**
By [writing model Config](../../Developer_Guide/Integrating_Your_Model.md#step-3-writing-model-config) and filling in model hash value and other information into `diffsynth/configs/model_configs.py`, developers can let `DiffSynth-Studio` automatically identify the model type and load it.
By [writing model Config](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-3-writing-model-config) and filling in model hash value and other information into `diffsynth/configs/model_configs.py`, developers can let `DiffSynth-Studio` automatically identify the model type and load it.
## Model Loading
`load_model` is the external entry for loading models in `diffsynth.core.loader`. It will call [skip_model_initialization](../../API_Reference/core/vram.md#skipping-model-parameter-initialization) to skip model parameter initialization. If [Disk Offload](../../Pipeline_Usage/VRAM_management.md#disk-offload) is enabled, it calls [DiskMap](../../API_Reference/core/vram.md#state-dict-disk-mapping) for lazy loading. If Disk Offload is not enabled, it calls [load_state_dict](#model-file-loading) to load model parameters. If necessary, it will also call [state dict converter](../../Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) for model format conversion. Finally, it calls `model.eval()` to switch to inference mode.
`load_model` is the external entry for loading models in `diffsynth.core.loader`. It will call [skip_model_initialization](/docs/en/API_Reference/core/vram.md#skipping-model-parameter-initialization) to skip model parameter initialization. If [Disk Offload](/docs/en/Pipeline_Usage/VRAM_management.md#disk-offload) is enabled, it calls [DiskMap](/docs/en/API_Reference/core/vram.md#state-dict-disk-mapping) for lazy loading. If Disk Offload is not enabled, it calls [load_state_dict](#model-file-loading) to load model parameters. If necessary, it will also call [state dict converter](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) for model format conversion. Finally, it calls `model.eval()` to switch to inference mode.
Here is a usage example with Disk Offload enabled:

View File

@@ -31,7 +31,7 @@ state_dict = load_state_dict(path, device="cpu")
model.load_state_dict(state_dict, assign=True)
```
In `DiffSynth-Studio`, all pretrained models follow this loading logic. After developers [integrate models](../../Developer_Guide/Integrating_Your_Model.md), they can directly load models quickly using this approach.
In `DiffSynth-Studio`, all pretrained models follow this loading logic. After developers [integrate models](/docs/en/Developer_Guide/Integrating_Your_Model.md), they can directly load models quickly using this approach.
## State Dict Disk Mapping
@@ -57,10 +57,10 @@ state_dict = DiskMap(path, device="cpu") # Fast
print(state_dict["img_in.weight"])
```
`DiskMap` is the basic component of Disk Offload in `DiffSynth-Studio`. After developers [configure fine-grained VRAM management schemes](../../Developer_Guide/Enabling_VRAM_management.md), they can directly enable Disk Offload.
`DiskMap` is the basic component of Disk Offload in `DiffSynth-Studio`. After developers [configure fine-grained VRAM management schemes](/docs/en/Developer_Guide/Enabling_VRAM_management.md), they can directly enable Disk Offload.
`DiskMap` is a functionality implemented using the characteristics of `.safetensors` files. Therefore, when using `.bin`, `.pth`, `.ckpt`, and other binary files, model parameters are fully loaded, which causes Disk Offload to not support these formats of files. **We do not recommend developers to continue using these formats of files.**
## Replacable Modules for VRAM Management
When `DiffSynth-Studio`'s VRAM management is enabled, the modules inside the model will be replaced with replacable modules in `diffsynth.core.vram.layers`. For usage, see [Fine-grained VRAM Management Scheme](../../Developer_Guide/Enabling_VRAM_management.md#writing-fine-grained-vram-management-schemes).
When `DiffSynth-Studio`'s VRAM management is enabled, the modules inside the model will be replaced with replacable modules in `diffsynth.core.vram.layers`. For usage, see [Fine-grained VRAM Management Scheme](/docs/en/Developer_Guide/Enabling_VRAM_management.md#writing-fine-grained-vram-management-schemes).

View File

@@ -1,6 +1,6 @@
# Building a Pipeline
After [integrating the required models for the Pipeline](../Developer_Guide/Integrating_Your_Model.md), you also need to build a `Pipeline` for model inference. This document provides a standardized process for building a `Pipeline`. Developers can also refer to existing `Pipeline` implementations for construction.
After [integrating the required models for the Pipeline](/docs/en/Developer_Guide/Integrating_Your_Model.md), you also need to build a `Pipeline` for model inference. This document provides a standardized process for building a `Pipeline`. Developers can also refer to existing `Pipeline` implementations for construction.
The `Pipeline` implementation is located in `diffsynth/pipelines`. Each `Pipeline` contains the following essential key components:
@@ -79,7 +79,7 @@ This includes the following parts:
return pipe
```
Developers need to implement the logic for fetching models. The corresponding model names are the `"model_name"` in the [model Config filled in during model integration](../Developer_Guide/Integrating_Your_Model.md#step-3-writing-model-config).
Developers need to implement the logic for fetching models. The corresponding model names are the `"model_name"` in the [model Config filled in during model integration](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-3-writing-model-config).
Some models also need to load `tokenizer`. Extra `tokenizer_config` parameters can be added to `from_pretrained` as needed, and this part can be implemented after fetching the models.

View File

@@ -1,6 +1,6 @@
# Fine-Grained VRAM Management Scheme
This document introduces how to write reasonable fine-grained VRAM management schemes for models, and how to use the VRAM management functions in `DiffSynth-Studio` for other external code libraries. Before reading this document, please read the document [VRAM Management](../Pipeline_Usage/VRAM_management.md).
This document introduces how to write reasonable fine-grained VRAM management schemes for models, and how to use the VRAM management functions in `DiffSynth-Studio` for other external code libraries. Before reading this document, please read the document [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md).
## How Much VRAM Does a 20B Model Need?
@@ -124,7 +124,7 @@ module_map={
}
```
In addition, `vram_config` and `vram_limit` are also required, which have been introduced in [VRAM Management](../Pipeline_Usage/VRAM_management.md#more-usage-methods).
In addition, `vram_config` and `vram_limit` are also required, which have been introduced in [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md#more-usage-methods).
Call `enable_vram_management` to enable VRAM management. Note that the `device` when loading the model is `cpu`, consistent with `offload_device`:
@@ -171,7 +171,7 @@ The above code only requires 2G VRAM to run the `forward` of a 20B model.
## Disk Offload
[Disk Offload](../Pipeline_Usage/VRAM_management.md#disk-offload) is a special VRAM management scheme that needs to be enabled during the model loading process, not after the model is loaded. Usually, when the above code can run smoothly, Disk Offload can be directly enabled:
[Disk Offload](/docs/en/Pipeline_Usage/VRAM_management.md#disk-offload) is a special VRAM management scheme that needs to be enabled during the model loading process, not after the model is loaded. Usually, when the above code can run smoothly, Disk Offload can be directly enabled:
```python
from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
@@ -212,7 +212,7 @@ with torch.no_grad():
output = model(**inputs)
```
Disk Offload is an extremely special VRAM management scheme. It only supports `.safetensors` format files, not binary files such as `.bin`, `.pth`, `.ckpt`, and does not support [state dict converter](../Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape.
Disk Offload is an extremely special VRAM management scheme. It only supports `.safetensors` format files, not binary files such as `.bin`, `.pth`, `.ckpt`, and does not support [state dict converter](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape.
If there are situations where Disk Offload cannot run normally but non-Disk Offload can run normally, please submit an issue to us on GitHub.
@@ -227,7 +227,7 @@ To make it easier for users to use the VRAM management function, we write the fi
}
```# Fine-Grained VRAM Management Scheme
This document introduces how to write reasonable fine-grained VRAM management schemes for models, and how to use the VRAM management functions in `DiffSynth-Studio` for other external code libraries. Before reading this document, please read the document [VRAM Management](../Pipeline_Usage/VRAM_management.md).
This document introduces how to write reasonable fine-grained VRAM management schemes for models, and how to use the VRAM management functions in `DiffSynth-Studio` for other external code libraries. Before reading this document, please read the document [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md).
## How Much VRAM Does a 20B Model Need?
@@ -351,7 +351,7 @@ module_map={
}
```
In addition, `vram_config` and `vram_limit` are also required, which have been introduced in [VRAM Management](../Pipeline_Usage/VRAM_management.md#more-usage-methods).
In addition, `vram_config` and `vram_limit` are also required, which have been introduced in [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md#more-usage-methods).
Call `enable_vram_management` to enable VRAM management. Note that the `device` when loading the model is `cpu`, consistent with `offload_device`:
@@ -398,7 +398,7 @@ The above code only requires 2G VRAM to run the `forward` of a 20B model.
## Disk Offload
[Disk Offload](../Pipeline_Usage/VRAM_management.md#disk-offload) is a special VRAM management scheme that needs to be enabled during the model loading process, not after the model is loaded. Usually, when the above code can run smoothly, Disk Offload can be directly enabled:
[Disk Offload](/docs/en/Pipeline_Usage/VRAM_management.md#disk-offload) is a special VRAM management scheme that needs to be enabled during the model loading process, not after the model is loaded. Usually, when the above code can run smoothly, Disk Offload can be directly enabled:
```python
from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
@@ -439,7 +439,7 @@ with torch.no_grad():
output = model(**inputs)
```
Disk Offload is an extremely special VRAM management scheme. It only supports `.safetensors` format files, not binary files such as `.bin`, `.pth`, `.ckpt`, and does not support [state dict converter](../Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape.
Disk Offload is an extremely special VRAM management scheme. It only supports `.safetensors` format files, not binary files such as `.bin`, `.pth`, `.ckpt`, and does not support [state dict converter](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape.
If there are situations where Disk Offload cannot run normally but non-Disk Offload can run normally, please submit an issue to us on GitHub.

View File

@@ -183,4 +183,4 @@ Loaded model: {
## Step 5: Writing Model VRAM Management Scheme
`DiffSynth-Studio` supports complex VRAM management. See [Enabling VRAM Management](../Developer_Guide/Enabling_VRAM_management.md) for details.
`DiffSynth-Studio` supports complex VRAM management. See [Enabling VRAM Management](/docs/en/Developer_Guide/Enabling_VRAM_management.md) for details.

View File

@@ -1,6 +1,6 @@
# Integrating Model Training
After [integrating models](../Developer_Guide/Integrating_Your_Model.md) and [implementing Pipeline](../Developer_Guide/Building_a_Pipeline.md), the next step is to integrate model training functionality.
After [integrating models](/docs/en/Developer_Guide/Integrating_Your_Model.md) and [implementing Pipeline](/docs/en/Developer_Guide/Building_a_Pipeline.md), the next step is to integrate model training functionality.
## Training-Inference Consistent Pipeline Modification

View File

@@ -1,20 +0,0 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

View File

@@ -1,139 +0,0 @@
# Anima
Anima is an image generation model trained and open-sourced by CircleStone Labs and Comfy Org.
## Installation
Before using this project for model inference and training, please install DiffSynth-Studio first.
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
For more installation information, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md).
## Quick Start
The following code demonstrates how to quickly load the [circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima) model for inference. VRAM management is enabled by default, allowing the framework to automatically control model parameter loading based on available VRAM. Minimum 8GB VRAM required.
```python
from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": "disk",
"onload_device": "disk",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AnimaImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors", **vram_config),
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors", **vram_config),
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait."
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
image = pipe(prompt, seed=0, num_inference_steps=50)
image.save("image.jpg")
```
## Model Overview
|Model ID|Inference|Low VRAM Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|
|-|-|-|-|-|-|-|
|[circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference_low_vram/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/full/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_full/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/lora/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_lora/anima-preview.py)|
Special training scripts:
* Differential LoRA Training: [doc](../Training/Differential_LoRA.md)
* FP8 Precision Training: [doc](../Training/FP8_Precision.md)
* Two-Stage Split Training: [doc](../Training/Split_Training.md)
* End-to-End Direct Distillation: [doc](../Training/Direct_Distill.md)
## Model Inference
Models are loaded through `AnimaImagePipeline.from_pretrained`, see [Model Inference](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
Input parameters for `AnimaImagePipeline` inference include:
* `prompt`: Text description of the desired image content.
* `negative_prompt`: Content to exclude from the generated image (default: `""`).
* `cfg_scale`: Classifier-free guidance parameter (default: 4.0).
* `input_image`: Input image for image-to-image generation (default: `None`).
* `denoising_strength`: Controls similarity to input image (default: 1.0).
* `height`: Image height (must be multiple of 16, default: 1024).
* `width`: Image width (must be multiple of 16, default: 1024).
* `seed`: Random seed (default: `None`).
* `rand_device`: Device for random noise generation (default: `"cpu"`).
* `num_inference_steps`: Inference steps (default: 30).
* `sigma_shift`: Scheduler sigma offset (default: `None`).
* `progress_bar_cmd`: Progress bar implementation (default: `tqdm.tqdm`).
For VRAM constraints, enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). Recommended low-VRAM configurations are provided in the "Model Overview" table above.
## Model Training
Anima models are trained through [`examples/anima/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/train.py) with parameters including:
* General Training Parameters
* Dataset Configuration
* `--dataset_base_path`: Dataset root directory.
* `--dataset_metadata_path`: Metadata file path.
* `--dataset_repeat`: Dataset repetition per epoch.
* `--dataset_num_workers`: Dataloader worker count.
* `--data_file_keys`: Metadata fields to load (comma-separated).
* Model Loading
* `--model_paths`: Model paths (JSON format).
* `--model_id_with_origin_paths`: Model IDs with origin paths (e.g., `"anima-team/anima-1B:text_encoder/*.safetensors"`).
* `--extra_inputs`: Additional pipeline inputs (e.g., `controlnet_inputs` for ControlNet).
* `--fp8_models`: FP8-formatted models (same format as `--model_paths`).
* Training Configuration
* `--learning_rate`: Learning rate.
* `--num_epochs`: Training epochs.
* `--trainable_models`: Trainable components (e.g., `dit`, `vae`, `text_encoder`).
* `--find_unused_parameters`: Handle unused parameters in DDP training.
* `--weight_decay`: Weight decay value.
* `--task`: Training task (default: `sft`).
* Output Configuration
* `--output_path`: Model output directory.
* `--remove_prefix_in_ckpt`: Remove state dict prefixes.
* `--save_steps`: Model saving interval.
* LoRA Configuration
* `--lora_base_model`: Target model for LoRA.
* `--lora_target_modules`: Target modules for LoRA.
* `--lora_rank`: LoRA rank.
* `--lora_checkpoint`: LoRA checkpoint path.
* `--preset_lora_path`: Preloaded LoRA checkpoint path.
* `--preset_lora_model`: Model to merge LoRA with (e.g., `dit`).
* Gradient Configuration
* `--use_gradient_checkpointing`: Enable gradient checkpointing.
* `--use_gradient_checkpointing_offload`: Offload checkpointing to CPU.
* `--gradient_accumulation_steps`: Gradient accumulation steps.
* Image Resolution
* `--height`: Image height (empty for dynamic resolution).
* `--width`: Image width (empty for dynamic resolution).
* `--max_pixels`: Maximum pixel area for dynamic resolution.
* Anima-Specific Parameters
* `--tokenizer_path`: Tokenizer path for text-to-image models.
* `--tokenizer_t5xxl_path`: T5-XXL tokenizer path.
We provide a sample image dataset for testing:
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
```
For training script details, refer to [Model Training](../Pipeline_Usage/Model_Training.md). For advanced training techniques, see [Training Framework Documentation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/).

View File

@@ -1,134 +0,0 @@
# ERNIE-Image
ERNIE-Image is a powerful image generation model with 8B parameters developed by Baidu, featuring a compact and efficient architecture with strong instruction-following capability. Based on an 8B DiT backbone, it delivers performance comparable to larger (20B+) models in certain scenarios while maintaining parameter efficiency. It offers reliable performance in instruction understanding and execution, text generation (English/Chinese/Japanese), and overall stability.
## Installation
Before performing model inference and training, please install DiffSynth-Studio first.
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md).
## Quick Start
Running the following code will load the [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 3G VRAM.
```python
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = ErnieImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device='cuda',
model_configs=[
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="一只黑白相间的中华田园犬",
negative_prompt="",
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
cfg_scale=4.0,
)
image.save("output.jpg")
```
## Model Overview
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|-|-|-|-|-|-|-|
|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
## Model Inference
The model is loaded via `ErnieImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
The input parameters for `ErnieImagePipeline` inference include:
* `prompt`: The prompt describing the content to appear in the image.
* `negative_prompt`: The negative prompt describing what should not appear in the image, default value is `""`.
* `cfg_scale`: Classifier-free guidance parameter, default value is 4.0.
* `height`: Image height, must be a multiple of 16, default value is 1024.
* `width`: Image width, must be a multiple of 16, default value is 1024.
* `seed`: Random seed. Default is `None`, meaning completely random.
* `rand_device`: The computing device for generating random Gaussian noise matrices, default is `"cuda"`. When set to `cuda`, different GPUs will produce different results.
* `num_inference_steps`: Number of inference steps, default value is 50.
If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low-VRAM configurations for each model in the "Model Overview" table above.
## Model Training
ERNIE-Image series models are trained uniformly via [`examples/ernie_image/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/train.py). The script parameters include:
* General Training Parameters
* Dataset Configuration
* `--dataset_base_path`: Root directory of the dataset.
* `--dataset_metadata_path`: Path to the dataset metadata file.
* `--dataset_repeat`: Number of dataset repeats per epoch.
* `--dataset_num_workers`: Number of processes per DataLoader.
* `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`.
* Model Loading Configuration
* `--model_paths`: Paths to load models from, in JSON format.
* `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `"PaddlePaddle/ERNIE-Image:transformer/diffusion_pytorch_model*.safetensors"`, separated by commas.
* `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`.
* `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients.
* Basic Training Configuration
* `--learning_rate`: Learning rate.
* `--num_epochs`: Number of epochs.
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
* `--find_unused_parameters`: Whether unused parameters exist in DDP training.
* `--weight_decay`: Weight decay magnitude.
* `--task`: Training task, defaults to `sft`.
* Output Configuration
* `--output_path`: Path to save the model.
* `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict.
* `--save_steps`: Interval in training steps to save the model.
* LoRA Configuration
* `--lora_base_model`: Which model to add LoRA to.
* `--lora_target_modules`: Which layers to add LoRA to.
* `--lora_rank`: Rank of LoRA.
* `--lora_checkpoint`: Path to LoRA checkpoint.
* `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training.
* `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`.
* Gradient Configuration
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
* Resolution Configuration
* `--height`: Height of the image. Leave empty to enable dynamic resolution.
* `--width`: Width of the image. Leave empty to enable dynamic resolution.
* `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution.
* ERNIE-Image Specific Parameters
* `--tokenizer_path`: Path to the tokenizer, leave empty to auto-download from remote.
We provide an example image dataset for testing, which can be downloaded with the following command:
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
```
We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).

View File

@@ -14,7 +14,7 @@ cd DiffSynth-Studio
pip install -e .
```
For more information about installation, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md).
For more information about installation, please refer to [Install Dependencies](/docs/en/Pipeline_Usage/Setup.md).
## Quick Start
@@ -81,31 +81,31 @@ graph LR;
| Model ID | Extra Parameters | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
| - | - | - | - | - | - | - | - |
| [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev.py) |
| [black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Krea-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py) |
| [black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) | `kontext_images` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py) |
| [alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py) |
| [InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py) |
| [jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py) |
| [InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter) | `ipadapter_images`, `ipadapter_scale` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py) |
| [ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou) | `infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py) |
| [DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) | `eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-EliGen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py) |
| [DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev) | `lora_encoder_inputs`, `lora_encoder_scale` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py) | - | - |
| [DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py) | - | - | - | - | - |
| [stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit) | `step1x_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Step1X-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Step1X-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Step1X-Edit.py) |
| [ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview) | `flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLEX.2-preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLEX.2-preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLEX.2-preview.py) |
| [DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2) | `nexus_gen_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Nexus-Gen-Editing.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Nexus-Gen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Nexus-Gen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Nexus-Gen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Nexus-Gen.py) |
| [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) | | [code](/examples/flux/model_inference/FLUX.1-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py) |
| [black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) | | [code](/examples/flux/model_inference/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py) |
| [black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) | `kontext_images` | [code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py) |
| [alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py) |
| [InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py) |
| [jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py) |
| [InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter) | `ipadapter_images`, `ipadapter_scale` | [code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py) |
| [ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou) | `infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py) |
| [DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) | `eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint` | [code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py) | - | - | [code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py) |
| [DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev) | `lora_encoder_inputs`, `lora_encoder_scale` | [code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py) | - | - |
| [DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev) | | [code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py) | - | - | - | - | - |
| [stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit) | `step1x_reference_image` | [code](/examples/flux/model_inference/Step1X-Edit.py) | [code](/examples/flux/model_inference_low_vram/Step1X-Edit.py) | [code](/examples/flux/model_training/full/Step1X-Edit.sh) | [code](/examples/flux/model_training/validate_full/Step1X-Edit.py) | [code](/examples/flux/model_training/lora/Step1X-Edit.sh) | [code](/examples/flux/model_training/validate_lora/Step1X-Edit.py) |
| [ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview) | `flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop` | [code](/examples/flux/model_inference/FLEX.2-preview.py) | [code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py) | [code](/examples/flux/model_training/full/FLEX.2-preview.sh) | [code](/examples/flux/model_training/validate_full/FLEX.2-preview.py) | [code](/examples/flux/model_training/lora/FLEX.2-preview.sh) | [code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py) |
| [DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2) | `nexus_gen_reference_image` | [code](/examples/flux/model_inference/Nexus-Gen-Editing.py) | [code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py) | [code](/examples/flux/model_training/full/Nexus-Gen.sh) | [code](/examples/flux/model_training/validate_full/Nexus-Gen.py) | [code](/examples/flux/model_training/lora/Nexus-Gen.sh) | [code](/examples/flux/model_training/validate_lora/Nexus-Gen.py) |
Special Training Scripts:
* Differential LoRA Training: [doc](../Training/Differential_LoRA.md)
* FP8 Precision Training: [doc](../Training/FP8_Precision.md)
* Two-stage Split Training: [doc](../Training/Split_Training.md)
* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md)
* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md), [code](/examples/flux/model_training/special/differential_training/)
* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/flux/model_training/special/fp8_training/)
* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/flux/model_training/special/split_training/)
* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh)
## Model Inference
Models are loaded via `FluxImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models).
Models are loaded via `FluxImagePipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models).
Input parameters for `FluxImagePipeline` inference include:
@@ -143,11 +143,11 @@ Input parameters for `FluxImagePipeline` inference include:
* `flex_control_stop`: Flex model control stop timestep.
* `nexus_gen_reference_image`: Nexus-Gen model reference image.
If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
## Model Training
FLUX series models are uniformly trained through [`examples/flux/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/train.py), and the script parameters include:
FLUX series models are uniformly trained through [`examples/flux/model_training/train.py`](/examples/flux/model_training/train.py), and the script parameters include:
* General Training Parameters
* Dataset Basic Configuration
@@ -195,7 +195,7 @@ FLUX series models are uniformly trained through [`examples/flux/model_training/
We have built a sample image dataset for your testing. You can download this dataset with the following command:
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
```
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/).

View File

@@ -21,7 +21,7 @@ cd DiffSynth-Studio
pip install -e .
```
For more information about installation, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md).
For more information about installation, please refer to [Install Dependencies](/docs/en/Pipeline_Usage/Setup.md).
## Quick Start
@@ -61,22 +61,22 @@ image.save("image.jpg")
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
| - | - | - | - | - | - | - |
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
Special Training Scripts:
* Differential LoRA Training: [doc](../Training/Differential_LoRA.md)
* FP8 Precision Training: [doc](../Training/FP8_Precision.md)
* Two-stage Split Training: [doc](../Training/Split_Training.md)
* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md)
* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md)
* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md)
* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md)
* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md)
## Model Inference
Models are loaded via `Flux2ImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models).
Models are loaded via `Flux2ImagePipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models).
Input parameters for `Flux2ImagePipeline` inference include:
@@ -95,11 +95,11 @@ Input parameters for `Flux2ImagePipeline` inference include:
* `tile_stride`: Tile stride during VAE encoding/decoding stages, default is 64, only effective when `tiled=True`, must be less than or equal to `tile_size`.
* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be disabled by setting to `lambda x:x`.
If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
## Model Training
FLUX.2 series models are uniformly trained through [`examples/flux2/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/train.py), and the script parameters include:
FLUX.2 series models are uniformly trained through [`examples/flux2/model_training/train.py`](/examples/flux2/model_training/train.py), and the script parameters include:
* General Training Parameters
* Dataset Basic Configuration
@@ -145,7 +145,7 @@ FLUX.2 series models are uniformly trained through [`examples/flux2/model_traini
We have built a sample image dataset for your testing. You can download this dataset with the following command:
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
```
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/).

View File

@@ -1,154 +0,0 @@
# JoyAI-Image
JoyAI-Image is a unified multi-modal foundation model open-sourced by JD.com, supporting image understanding, text-to-image generation, and instruction-guided image editing.
## Installation
Before performing model inference and training, please install DiffSynth-Studio first.
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md).
## Quick Start
Running the following code will load the [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 4GB VRAM.
```python
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
import torch
from PIL import Image
from modelscope import dataset_snapshot_download
# Download dataset
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
)
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = JoyAIImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
],
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
# Use first sample from dataset
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
prompt = "将裙子改为粉色"
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
output = pipe(
prompt=prompt,
edit_image=edit_image,
height=1024,
width=1024,
seed=0,
num_inference_steps=30,
cfg_scale=5.0,
)
output.save("output_joyai_edit_low_vram.png")
```
## Model Overview
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|-|-|-|-|-|-|-|
|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)|
## Model Inference
The model is loaded via `JoyAIImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
The input parameters for `JoyAIImagePipeline` inference include:
* `prompt`: Text prompt describing the desired image editing effect.
* `negative_prompt`: Negative prompt specifying what should not appear in the result, defaults to empty string.
* `cfg_scale`: Classifier-free guidance scale factor, defaults to 5.0. Higher values make the output more closely follow the prompt.
* `edit_image`: Image to be edited.
* `denoising_strength`: Denoising strength controlling how much the input image is repainted, defaults to 1.0.
* `height`: Height of the output image, defaults to 1024. Must be divisible by 16.
* `width`: Width of the output image, defaults to 1024. Must be divisible by 16.
* `seed`: Random seed for reproducibility. Set to `None` for random seed.
* `max_sequence_length`: Maximum sequence length for the text encoder, defaults to 4096.
* `num_inference_steps`: Number of inference steps, defaults to 30. More steps typically yield better quality.
* `tiled`: Whether to enable tiling for reduced VRAM usage, defaults to False.
* `tile_size`: Tile size, defaults to (30, 52).
* `tile_stride`: Tile stride, defaults to (15, 26).
* `shift`: Shift parameter for the scheduler, controlling the Flow Match scheduling curve, defaults to 4.0.
* `progress_bar_cmd`: Progress bar display mode, defaults to tqdm.
## Model Training
Models in the joyai_image series are trained uniformly via `examples/joyai_image/model_training/train.py`. The script parameters include:
* General Training Parameters
* Dataset Configuration
* `--dataset_base_path`: Root directory of the dataset.
* `--dataset_metadata_path`: Path to the dataset metadata file.
* `--dataset_repeat`: Number of dataset repeats per epoch.
* `--dataset_num_workers`: Number of processes per DataLoader.
* `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`.
* Model Loading Configuration
* `--model_paths`: Paths to load models from, in JSON format.
* `--model_id_with_origin_paths`: Model IDs with original paths, separated by commas.
* `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`.
* `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients.
* Basic Training Configuration
* `--learning_rate`: Learning rate.
* `--num_epochs`: Number of epochs.
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
* `--find_unused_parameters`: Whether unused parameters exist in DDP training.
* `--weight_decay`: Weight decay magnitude.
* `--task`: Training task, defaults to `sft`.
* Output Configuration
* `--output_path`: Path to save the model.
* `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict.
* `--save_steps`: Interval in training steps to save the model.
* LoRA Configuration
* `--lora_base_model`: Which model to add LoRA to.
* `--lora_target_modules`: Which layers to add LoRA to.
* `--lora_rank`: Rank of LoRA.
* `--lora_checkpoint`: Path to LoRA checkpoint.
* `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training.
* `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`.
* Gradient Configuration
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
* Resolution Configuration
* `--height`: Height of the image/video. Leave empty to enable dynamic resolution.
* `--width`: Width of the image/video. Leave empty to enable dynamic resolution.
* `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution.
* `--num_frames`: Number of frames for video (video generation models only).
* JoyAI-Image Specific Parameters
* `--processor_path`: Path to the processor for processing text and image encoder inputs.
* `--initialize_model_on_cpu`: Whether to initialize models on CPU. By default, models are initialized on the accelerator device.
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
```
We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).

View File

@@ -1,171 +0,0 @@
# LTX-2
LTX-2 is a series of audio-video generation models developed by Lightricks.
## Installation
Before using this project for model inference and training, please install DiffSynth-Studio first.
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
For more information about installation, please refer to [Installation Dependencies](../Pipeline_Usage/Setup.md).
## Quick Start
Run the following code to quickly load the [Lightricks/LTX-2.3](https://www.modelscope.cn/models/Lightricks/LTX-2.3) model and perform inference. VRAM management has been enabled, and the framework will automatically control model parameter loading based on remaining VRAM. It can run with a minimum of 8GB VRAM.
```python
import torch
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = LTX2AudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-spatial-upscaler-x2-1.0.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-distilled-lora-384.safetensors"),
)
prompt = "Two cute orange cats, wearing boxing gloves, stand in a boxing ring and fight each other. They are punching each other fast and yelling: 'I will win!'"
negative_prompt = pipe.default_negative_prompt["LTX-2.3"]
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
seed=43,
height=1024, width=1536, num_frames=121,
tiled=True, use_two_stage_pipeline=True,
)
write_video_audio_ltx2(video=video, audio=audio, output_path='video.mp4', fps=24, audio_sample_rate=pipe.audio_vocoder.output_sampling_rate)
```
## Model Overview
|Model ID|Additional Parameters|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
|-|-|-|-|-|-|-|-|
|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2.3-I2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2.3-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-I2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-I2AV.py)|
|[Lightricks/LTX-2.3: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2.3: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-DistilledPipeline.py)|-|-|-|-|
|[Lightricks/LTX-2.3: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py)|
|[Lightricks/LTX-2.3: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2.3: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-DistilledPipeline.py)|-|-|-|-|
|[Lightricks/LTX-2.3: A2V](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`retake_audio`,`audio_sample_rate`,`retake_audio_regions`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-A2V-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-A2V-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2.3: Retake](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`retake_video`,`retake_video_regions`,`retake_audio`,`audio_sample_rate`,`retake_audio_regions`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage-Retake.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage-Retake.py)|-|-|-|-|
|[Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py)|
|[Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control](https://www.modelscope.cn/models/Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Motion-Track-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Motion-Track-Control.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py)|
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)|
|[Lightricks/LTX-2-19b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|
|[Lightricks/LTX-2-19b-IC-LoRA-Detailer](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Detailer)|`in_context_videos`,`in_context_downsample_factor`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Detailer.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Static](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Static)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py)|-|-|-|-|
## Model Inference
Models are loaded through `LTX2AudioVideoPipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
Input parameters for `LTX2AudioVideoPipeline` inference include:
* `prompt`: Prompt describing the content appearing in the video.
* `negative_prompt`: Negative prompt describing content that should not appear in the video, default value is `""`.
* `cfg_scale`: Classifier-free guidance parameter, default value is 3.0.
* `input_images`: List of input images for image-to-video generation.
* `input_images_indexes`: Frame index list of input images in the video.
* `input_images_strength`: Strength of input images, default value is 1.0.
* `denoising_strength`: Denoising strength, range is 01, default value is 1.0.
* `seed`: Random seed. Default is `None`, which means completely random.
* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different results will be generated on different GPUs.
* `height`: Video height, must be a multiple of 32 (single-stage) or 64 (two-stage).
* `width`: Video width, must be a multiple of 32 (single-stage) or 64 (two-stage).
* `num_frames`: Number of video frames, default value is 121, must be a multiple of 8 + 1.
* `num_inference_steps`: Number of inference steps, default value is 40.
* `tiled`: Whether to enable VAE tiling inference, default is `True`. When set to `True`, it can significantly reduce VRAM usage during VAE encoding/decoding stages, with slight errors and minor inference time extension.
* `tile_size_in_pixels`: Pixel tiling size during VAE encoding/decoding stages, default is 512.
* `tile_overlap_in_pixels`: Pixel tiling overlap size during VAE encoding/decoding stages, default is 128.
* `tile_size_in_frames`: Frame tiling size during VAE encoding/decoding stages, default is 128.
* `tile_overlap_in_frames`: Frame tiling overlap size during VAE encoding/decoding stages, default is 24.
* `use_two_stage_pipeline`: Whether to use two-stage pipeline, default is `False`.
* `use_distilled_pipeline`: Whether to use distilled pipeline, default is `False`.
* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be set to `lambda x:x` to hide the progress bar.
If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the previous "Supported Inference Scripts" section.
## Model Training
LTX-2 series models are uniformly trained through [`examples/ltx2/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/train.py), and the script parameters include:
* General Training Parameters
* Dataset Basic Configuration
* `--dataset_base_path`: Root directory of the dataset.
* `--dataset_metadata_path`: Metadata file path of the dataset.
* `--dataset_repeat`: Number of times the dataset is repeated in each epoch.
* `--dataset_num_workers`: Number of processes for each DataLoader.
* `--data_file_keys`: Field names to be loaded from metadata, usually image or video file paths, separated by `,`.
* Model Loading Configuration
* `--model_paths`: Paths of models to be loaded. JSON format.
* `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `"Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors"`. Separated by commas.
* `--extra_inputs`: Extra input parameters required by the model Pipeline, e.g., extra parameters when training image editing models, separated by `,`.
* `--fp8_models`: Models loaded in FP8 format, consistent with `--model_paths` or `--model_id_with_origin_paths` format. Currently only supports models whose parameters are not updated by gradients (no gradient backpropagation, or gradients only update their LoRA).
* Training Basic Configuration
* `--learning_rate`: Learning rate.
* `--num_epochs`: Number of epochs.
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
* `--find_unused_parameters`: Whether there are unused parameters in DDP training. Some models contain redundant parameters that do not participate in gradient calculation, and this setting needs to be enabled to avoid errors in multi-GPU training.
* `--weight_decay`: Weight decay size, see [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html).
* `--task`: Training task, default is `sft`. Some models support more training modes, please refer to the documentation of each specific model.
* Output Configuration
* `--output_path`: Model saving path.
* `--remove_prefix_in_ckpt`: Remove prefix in the state dict of the model file.
* `--save_steps`: Interval of training steps to save the model. If this parameter is left blank, the model is saved once per epoch.
* LoRA Configuration
* `--lora_base_model`: Which model to add LoRA to.
* `--lora_target_modules`: Which layers to add LoRA to.
* `--lora_rank`: Rank of LoRA.
* `--lora_checkpoint`: Path of the LoRA checkpoint. If this path is provided, LoRA will be loaded from this checkpoint.
* `--preset_lora_path`: Preset LoRA checkpoint path. If this path is provided, this LoRA will be loaded in the form of being merged into the base model. This parameter is used for LoRA differential training.
* `--preset_lora_model`: Model that the preset LoRA is merged into, e.g., `dit`.
* Gradient Configuration
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to memory.
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
* Video Width/Height Configuration
* `--height`: Height of the video. Leave `height` and `width` blank to enable dynamic resolution.
* `--width`: Width of the video. Leave `height` and `width` blank to enable dynamic resolution.
* `--max_pixels`: Maximum pixel area of video frames. When dynamic resolution is enabled, video frames with resolution larger than this value will be downscaled, and video frames with resolution smaller than this value will remain unchanged.
* `--num_frames`: Number of frames in the video.
* LTX-2 Series Specific Parameters
* `--tokenizer_path`: Path of the tokenizer, applicable to text-to-video models, leave blank to automatically download from remote.
* `--frame_rate`: frame rate of the training videos.
We have built a sample video dataset for your testing. You can download this dataset with the following command:
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
```
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).

View File

@@ -2,7 +2,7 @@
## Qwen-Image
Documentation: [./Qwen-Image.md](../Model_Details/Qwen-Image.md)
Documentation: [./Qwen-Image.md](/docs/en/Model_Details/Qwen-Image.md)
<details>
@@ -69,23 +69,23 @@ graph LR;
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
| - | - | - | - | - | - | - |
| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) |
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) |
| [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py) |
| [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py) |
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py) |
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py) |
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) |
| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) |
| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - |
| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](/examples/qwen_image/model_inference/Qwen-Image.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) |
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) |
| [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) | [code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py) |
| [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) | [code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py) |
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py) |
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py) |
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) |
| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) |
| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - |
## FLUX Series
Documentation: [./FLUX.md](../Model_Details/FLUX.md)
Documentation: [./FLUX.md](/docs/en/Model_Details/FLUX.md)
<details>
@@ -149,24 +149,24 @@ graph LR;
| Model ID | Extra Parameters | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
| - | - | - | - | - | - | - | - |
| [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev.py) |
| [black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Krea-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py) |
| [black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) | `kontext_images` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py) |
| [alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py) |
| [InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py) |
| [jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py) |
| [InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter) | `ipadapter_images`, `ipadapter_scale` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py) |
| [ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou) | `infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py) |
| [DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) | `eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-EliGen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py) |
| [DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev) | `lora_encoder_inputs`, `lora_encoder_scale` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py) | - | - |
| [DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py) | - | - | - | - | - |
| [stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit) | `step1x_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Step1X-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Step1X-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Step1X-Edit.py) |
| [ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview) | `flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLEX.2-preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLEX.2-preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLEX.2-preview.py) |
| [DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2) | `nexus_gen_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Nexus-Gen-Editing.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Nexus-Gen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Nexus-Gen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Nexus-Gen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Nexus-Gen.py) |
| [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) | | [code](/examples/flux/model_inference/FLUX.1-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py) |
| [black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) | | [code](/examples/flux/model_inference/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py) |
| [black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) | `kontext_images` | [code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py) |
| [alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py) |
| [InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py) |
| [jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py) |
| [InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter) | `ipadapter_images`, `ipadapter_scale` | [code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py) |
| [ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou) | `infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py) |
| [DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) | `eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint` | [code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py) | - | - | [code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py) |
| [DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev) | `lora_encoder_inputs`, `lora_encoder_scale` | [code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py) | - | - |
| [DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev) | | [code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py) | - | - | - | - | - |
| [stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit) | `step1x_reference_image` | [code](/examples/flux/model_inference/Step1X-Edit.py) | [code](/examples/flux/model_inference_low_vram/Step1X-Edit.py) | [code](/examples/flux/model_training/full/Step1X-Edit.sh) | [code](/examples/flux/model_training/validate_full/Step1X-Edit.py) | [code](/examples/flux/model_training/lora/Step1X-Edit.sh) | [code](/examples/flux/model_training/validate_lora/Step1X-Edit.py) |
| [ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview) | `flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop` | [code](/examples/flux/model_inference/FLEX.2-preview.py) | [code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py) | [code](/examples/flux/model_training/full/FLEX.2-preview.sh) | [code](/examples/flux/model_training/validate_full/FLEX.2-preview.py) | [code](/examples/flux/model_training/lora/FLEX.2-preview.sh) | [code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py) |
| [DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2) | `nexus_gen_reference_image` | [code](/examples/flux/model_inference/Nexus-Gen-Editing.py) | [code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py) | [code](/examples/flux/model_training/full/Nexus-Gen.sh) | [code](/examples/flux/model_training/validate_full/Nexus-Gen.py) | [code](/examples/flux/model_training/lora/Nexus-Gen.sh) | [code](/examples/flux/model_training/validate_lora/Nexus-Gen.py) |
## Wan Series
Documentation: [./Wan.md](../Model_Details/Wan.md)
Documentation: [./Wan.md](/docs/en/Model_Details/Wan.md)
<details>
@@ -254,38 +254,38 @@ graph LR;
| Model ID | Extra Parameters | Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
| - | - | - | - | - | - | - |
| [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py) |
| [Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py) |
| [Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py) |
| [Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py) |
| [Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py) |
| [iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py) |
| [Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py) |
| [Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py) |
| [PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py) |
| [PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control) | `control_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py) |
| [PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py) |
| [PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control) | `control_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py) |
| [PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control) | `control_video`, `reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py) |
| [PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control) | `control_video`, `reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py) |
| [PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py) |
| [PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py) |
| [PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) |
| [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py) |
| [DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1) | `motion_bucket_id` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py) |
| [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/krea-realtime-video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/krea-realtime-video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/krea-realtime-video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/krea-realtime-video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py) |
| [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) | `longcat_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/LongCat-Video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/LongCat-Video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/LongCat-Video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/LongCat-Video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/LongCat-Video.py) |
| [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) | `vap_video`, `vap_prompt` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py) |
| [Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py) |
| [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py) |
| [Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py) |
| [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B) | `input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py) |
| [Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B) | `input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py) |
| [PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py) |
| [PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py) |
| [PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control) | `control_video`, `reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py) |
| [PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py) |
| [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | | [code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py) |
| [Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | | [code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py) |
| [Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py) |
| [Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py) |
| [Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py) |
| [iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py) |
| [Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py) |
| [Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py) |
| [PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py) |
| [PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control) | `control_video` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py) |
| [PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py) |
| [PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control) | `control_video` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py) |
| [PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py) |
| [PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py) |
| [PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py) |
| [PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py) |
| [PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) |
| [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py) |
| [DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1) | `motion_bucket_id` | [code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py) |
| [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) | | [code](/examples/wanvideo/model_inference/krea-realtime-video.py) | [code](/examples/wanvideo/model_training/full/krea-realtime-video.sh) | [code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py) | [code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh) | [code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py) |
| [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) | `longcat_video` | [code](/examples/wanvideo/model_inference/LongCat-Video.py) | [code](/examples/wanvideo/model_training/full/LongCat-Video.sh) | [code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py) | [code](/examples/wanvideo/model_training/lora/LongCat-Video.sh) | [code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py) |
| [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) | `vap_video`, `vap_prompt` | [code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py) | [code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py) | [code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py) |
| [Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B) | | [code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py) |
| [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py) |
| [Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py) |
| [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B) | `input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video` | [code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py) |
| [Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B) | `input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video` | [code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py) |
| [PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py) |
| [PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py) |
| [PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py) |
| [PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py) |
* FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/fp8_training/)
* Two-stage Split Training: [doc](../Training/Split_Training.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/split_training/)
* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/direct_distill/)
* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/wanvideo/model_training/special/fp8_training/)
* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/wanvideo/model_training/special/split_training/)
* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/wanvideo/model_training/special/direct_distill/)

View File

@@ -14,7 +14,7 @@ cd DiffSynth-Studio
pip install -e .
```
For more information about installation, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md).
For more information about installation, please refer to [Install Dependencies](/docs/en/Pipeline_Usage/Setup.md).
## Quick Start
@@ -80,44 +80,35 @@ graph LR;
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
| - | - | - | - | - | - | - |
| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) |
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[FireRedTeam/FireRed-Image-Edit-1.0](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.0)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/FireRed-Image-Edit-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.0.py)|
|[FireRedTeam/FireRed-Image-Edit-1.1](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/FireRed-Image-Edit-1.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.1.py)|
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control-V2)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Layered-Control-V2.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control-V2.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control-V2.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control-V2.py)|
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) |
| [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py) |
| [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py) |
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py) |
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py) |
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) |
| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) |
| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - |
|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](/examples/qwen_image/model_inference/Qwen-Image.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) |
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) |
| [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) | [code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py) |
| [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) | [code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py) |
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py) |
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py) |
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) |
| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) |
| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - |
|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
Special Training Scripts:
* Differential LoRA Training: [doc](../Training/Differential_LoRA.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/qwen_image/model_training/special/differential_training/)
* FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/qwen_image/model_training/special/fp8_training/)
* Two-stage Split Training: [doc](../Training/Split_Training.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/qwen_image/model_training/special/split_training/)
* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md), [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)
DeepSpeed ZeRO Stage 3 Training: The Qwen-Image series models support DeepSpeed ZeRO Stage 3 training, which partitions the model across multiple GPUs. Taking full parameter training of the Qwen-Image model as an example, the following modifications are required:
* `--config_file examples/qwen_image/model_training/full/accelerate_config_zero3.yaml`
* `--initialize_model_on_cpu`
* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md), [code](/examples/qwen_image/model_training/special/differential_training/)
* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/qwen_image/model_training/special/fp8_training/)
* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/qwen_image/model_training/special/split_training/)
* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)
## Model Inference
Models are loaded via `QwenImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models).
Models are loaded via `QwenImagePipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models).
Input parameters for `QwenImagePipeline` inference include:
@@ -148,11 +139,11 @@ Input parameters for `QwenImagePipeline` inference include:
* `tile_stride`: Tile stride during VAE encoding/decoding stages, default is 64, only effective when `tiled=True`, must be less than or equal to `tile_size`.
* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be disabled by setting to `lambda x:x`.
If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
## Model Training
Qwen-Image series models are uniformly trained through [`examples/qwen_image/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/train.py), and the script parameters include:
Qwen-Image series models are uniformly trained through [`examples/qwen_image/model_training/train.py`](/examples/qwen_image/model_training/train.py), and the script parameters include:
* General Training Parameters
* Dataset Basic Configuration
@@ -199,7 +190,7 @@ Qwen-Image series models are uniformly trained through [`examples/qwen_image/mod
We have built a sample image dataset for your testing. You can download this dataset with the following command:
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
```
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/).

Some files were not shown because too many files have changed in this diff Show More