mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
Compare commits
101 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3da625432e | ||
|
|
002e3cdb74 | ||
|
|
29bf66cdc9 | ||
|
|
a80fb84220 | ||
|
|
394db06d86 | ||
|
|
1186379139 | ||
|
|
f2e3427566 | ||
|
|
c53c813c12 | ||
|
|
b0680ef711 | ||
|
|
f5a3201d42 | ||
|
|
95cfb77881 | ||
|
|
5c89a15b9a | ||
|
|
9d09e0431c | ||
|
|
a604d76339 | ||
|
|
36c203da57 | ||
|
|
079e51c9f3 | ||
|
|
8f18e24597 | ||
|
|
45d973e87d | ||
|
|
c80fec2a56 | ||
|
|
b5d04ceb30 | ||
|
|
960d8c62c0 | ||
|
|
f77b6357c5 | ||
|
|
166e6d2d38 | ||
|
|
5e7e3db0af | ||
|
|
ae8cb139e8 | ||
|
|
e2a3a987da | ||
|
|
f7b9ae7d57 | ||
|
|
5d198287f0 | ||
|
|
5bccd60c80 | ||
|
|
078fc551d9 | ||
|
|
52ba5d414e | ||
|
|
ba0626e38f | ||
|
|
4ec4d9c20a | ||
|
|
7a80f10fa4 | ||
|
|
3bd5188b3e | ||
|
|
7650e9381e | ||
|
|
8c9ddc9274 | ||
|
|
681df93a85 | ||
|
|
4741542523 | ||
|
|
c927062546 | ||
|
|
f3ebd6f714 | ||
|
|
959471f083 | ||
|
|
d9228074bd | ||
|
|
b272253956 | ||
|
|
7bc5611fb8 | ||
|
|
f7d23c6551 | ||
|
|
13eff18e7d | ||
|
|
a38954b72c | ||
|
|
d40efe897f | ||
|
|
c9c2561791 | ||
|
|
0139b042e0 | ||
|
|
ed9e4374af | ||
|
|
2a0eb9c383 | ||
|
|
73b13f4c86 | ||
|
|
75ebd797da | ||
|
|
31ba103d8e | ||
|
|
c5aaa1da41 | ||
|
|
6bcb99fd2e | ||
|
|
ab8f455c46 | ||
|
|
add6f88324 | ||
|
|
430b495100 | ||
|
|
62ba8a3f2e | ||
|
|
237d178733 | ||
|
|
b3ef224042 | ||
|
|
f43b18ec21 | ||
|
|
6d671db5d2 | ||
|
|
07f5d88ac9 | ||
|
|
880231b4be | ||
|
|
b3f6c3275f | ||
|
|
29cd5c7612 | ||
|
|
ff4be1c7c7 | ||
|
|
6b0fb1601f | ||
|
|
4b400c07eb | ||
|
|
6a6ae6d791 | ||
|
|
1a380a6b62 | ||
|
|
5ca74923e8 | ||
|
|
8b9a094c1b | ||
|
|
5996c2b068 | ||
|
|
8fc7e005a6 | ||
|
|
a18966c300 | ||
|
|
a87910bc65 | ||
|
|
f48662e863 | ||
|
|
8d8bfc7f54 | ||
|
|
8e15dcd289 | ||
|
|
586ac9d8a6 | ||
|
|
625b5ff16d | ||
|
|
ee73a29885 | ||
|
|
288bbc7128 | ||
|
|
5002ac74dc | ||
|
|
863a6ba597 | ||
|
|
b08bc1470d | ||
|
|
96143aa26b | ||
|
|
71cea4371c | ||
|
|
fc11fd4297 | ||
|
|
bd3c5822a1 | ||
|
|
96fb0f3afe | ||
|
|
b68663426f | ||
|
|
0e6976a0ae | ||
|
|
94b57e9677 | ||
|
|
3fb037d33a | ||
|
|
6383ec358c |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -2,6 +2,7 @@
|
|||||||
/models
|
/models
|
||||||
/scripts
|
/scripts
|
||||||
/diffusers
|
/diffusers
|
||||||
|
/.vscode
|
||||||
*.pkl
|
*.pkl
|
||||||
*.safetensors
|
*.safetensors
|
||||||
*.pth
|
*.pth
|
||||||
|
|||||||
424
README.md
424
README.md
@@ -7,11 +7,14 @@
|
|||||||
[](https://github.com/modelscope/DiffSynth-Studio/issues)
|
[](https://github.com/modelscope/DiffSynth-Studio/issues)
|
||||||
[](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
|
[](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
|
||||||
[](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
|
[](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
|
||||||
|
[](https://discord.gg/Mm9suEeUDc)
|
||||||
|
|
||||||
[切换到中文版](./README_zh.md)
|
[切换到中文版](./README_zh.md)
|
||||||
|
|
||||||
## Introduction
|
## Introduction
|
||||||
|
|
||||||
|
> DiffSynth-Studio Documentation: [中文版](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/)、[English version](https://diffsynth-studio-doc.readthedocs.io/en/latest/)
|
||||||
|
|
||||||
Welcome to the magical world of Diffusion models! DiffSynth-Studio is an open-source Diffusion model engine developed and maintained by the [ModelScope Community](https://www.modelscope.cn/). We hope to foster technological innovation through framework construction, aggregate the power of the open-source community, and explore the boundaries of generative model technology!
|
Welcome to the magical world of Diffusion models! DiffSynth-Studio is an open-source Diffusion model engine developed and maintained by the [ModelScope Community](https://www.modelscope.cn/). We hope to foster technological innovation through framework construction, aggregate the power of the open-source community, and explore the boundaries of generative model technology!
|
||||||
|
|
||||||
DiffSynth currently includes two open-source projects:
|
DiffSynth currently includes two open-source projects:
|
||||||
@@ -23,16 +26,32 @@ DiffSynth currently includes two open-source projects:
|
|||||||
* ModelScope AIGC Zone (for Chinese users): https://modelscope.cn/aigc/home
|
* ModelScope AIGC Zone (for Chinese users): https://modelscope.cn/aigc/home
|
||||||
* ModelScope Civision (for global users): https://modelscope.ai/civision/home
|
* ModelScope Civision (for global users): https://modelscope.ai/civision/home
|
||||||
|
|
||||||
> DiffSynth-Studio Documentation: [中文版](/docs/zh/README.md)、[English version](/docs/en/README.md)
|
|
||||||
|
|
||||||
We believe that a well-developed open-source code framework can lower the threshold for technical exploration. We have achieved many [interesting technologies](#innovative-achievements) based on this codebase. Perhaps you also have many wild ideas, and with DiffSynth-Studio, you can quickly realize these ideas. For this reason, we have prepared detailed documentation for developers. We hope that through these documents, developers can understand the principles of Diffusion models, and we look forward to expanding the boundaries of technology together with you.
|
We believe that a well-developed open-source code framework can lower the threshold for technical exploration. We have achieved many [interesting technologies](#innovative-achievements) based on this codebase. Perhaps you also have many wild ideas, and with DiffSynth-Studio, you can quickly realize these ideas. For this reason, we have prepared detailed documentation for developers. We hope that through these documents, developers can understand the principles of Diffusion models, and we look forward to expanding the boundaries of technology together with you.
|
||||||
|
|
||||||
## Update History
|
## Update History
|
||||||
|
|
||||||
> DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update.
|
> DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update.
|
||||||
|
|
||||||
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
|
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
|
||||||
- **February 10, 2026** Added inference support for the LTX-2 audio-video generation model. See the documentation for details. Support for model training will be implemented in the future.
|
|
||||||
|
- **April 23, 2026** ACE-Step open-sourced, welcome a new member to the audio model family! Support includes text-to-music generation, low VRAM inference, and LoRA training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/ACE-Step.md) and [example code](/examples/ace_step/).
|
||||||
|
|
||||||
|
- **April 14, 2026** JoyAI-Image open-sourced, welcome a new member to the image editing model family! Support includes instruction-guided image editing, low VRAM inference, and training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/JoyAI-Image.md) and [example code](/examples/joyai_image/).
|
||||||
|
|
||||||
|
- **March 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available.
|
||||||
|
|
||||||
|
- **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/en/Model_Details/LTX-2.md) and [code](/examples/ltx2/).
|
||||||
|
|
||||||
|
- **March 3, 2026**: We released the [DiffSynth-Studio/Qwen-Image-Layered-Control-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control-V2) model, which is an updated version of Qwen-Image-Layered-Control. In addition to the originally supported text-guided functionality, it adds brush-controlled layer separation capabilities.
|
||||||
|
|
||||||
|
- **March 2, 2026** Added support for [Anima](https://modelscope.cn/models/circlestone-labs/Anima). For details, please refer to the [documentation](docs/en/Model_Details/Anima.md). This is an interesting anime-style image generation model. We look forward to its future updates.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>More</summary>
|
||||||
|
|
||||||
|
- **February 26, 2026** Added full and lora training support for the LTX-2 audio-video generation model. See the [documentation](/docs/en/Model_Details/LTX-2.md) for details.
|
||||||
|
|
||||||
|
- **February 10, 2026** Added inference support for the LTX-2 audio-video generation model. See the [documentation](/docs/en/Model_Details/LTX-2.md) for details. Support for model training will be implemented in the future.
|
||||||
|
|
||||||
- **February 2, 2026** The first document of the Research Tutorial series is now available, guiding you through training a small 0.1B text-to-image model from scratch. For details, see the [documentation](/docs/en/Research_Tutorial/train_from_scratch.md) and [model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel). We hope DiffSynth-Studio can evolve into a more powerful training framework for Diffusion models.
|
- **February 2, 2026** The first document of the Research Tutorial series is now available, guiding you through training a small 0.1B text-to-image model from scratch. For details, see the [documentation](/docs/en/Research_Tutorial/train_from_scratch.md) and [model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel). We hope DiffSynth-Studio can evolve into a more powerful training framework for Diffusion models.
|
||||||
|
|
||||||
@@ -57,9 +76,6 @@ We believe that a well-developed open-source code framework can lower the thresh
|
|||||||
- [Differential LoRA Training](/docs/zh/Training/Differential_LoRA.md): This is a training technique we used in [ArtAug](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), now available for LoRA training of any model.
|
- [Differential LoRA Training](/docs/zh/Training/Differential_LoRA.md): This is a training technique we used in [ArtAug](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), now available for LoRA training of any model.
|
||||||
- [FP8 Training](/docs/zh/Training/FP8_Precision.md): FP8 can be applied to any non-training model during training, i.e., models with gradients turned off or gradients that only affect LoRA weights.
|
- [FP8 Training](/docs/zh/Training/FP8_Precision.md): FP8 can be applied to any non-training model during training, i.e., models with gradients turned off or gradients that only affect LoRA weights.
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary>More</summary>
|
|
||||||
|
|
||||||
- **November 4, 2025** Supported the [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) model, which is trained based on Wan 2.1 and supports generating corresponding actions based on reference videos.
|
- **November 4, 2025** Supported the [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) model, which is trained based on Wan 2.1 and supports generating corresponding actions based on reference videos.
|
||||||
|
|
||||||
- **October 30, 2025** Supported the [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) model, which supports text-to-video, image-to-video, and video continuation. This model uses the Wan framework for inference and training in this project.
|
- **October 30, 2025** Supported the [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) model, which supports text-to-video, image-to-video, and video continuation. This model uses the Wan framework for inference and training in this project.
|
||||||
@@ -341,6 +357,60 @@ Example code for FLUX.2 is available at: [/examples/flux2/](/examples/flux2/)
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
#### Anima: [/docs/en/Model_Details/Anima.md](/docs/en/Model_Details/Anima.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Quick Start</summary>
|
||||||
|
|
||||||
|
Run the following code to quickly load the [circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 8GB VRAM.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": "disk",
|
||||||
|
"offload_device": "disk",
|
||||||
|
"onload_dtype": "disk",
|
||||||
|
"onload_device": "disk",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = AnimaImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
|
||||||
|
tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
prompt = "Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait."
|
||||||
|
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||||
|
image = pipe(prompt, seed=0, num_inference_steps=50)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Examples</summary>
|
||||||
|
|
||||||
|
Example code for Anima is located at: [/examples/anima/](/examples/anima/)
|
||||||
|
|
||||||
|
| Model ID | Inference | Low VRAM Inference | Full Training | Validation after Full Training | LoRA Training | Validation after LoRA Training |
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima)|[code](/examples/anima/model_inference/anima-preview.py)|[code](/examples/anima/model_inference_low_vram/anima-preview.py)|[code](/examples/anima/model_training/full/anima-preview.sh)|[code](/examples/anima/model_training/validate_full/anima-preview.py)|[code](/examples/anima/model_training/lora/anima-preview.sh)|[code](/examples/anima/model_training/validate_lora/anima-preview.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
#### Qwen-Image: [/docs/en/Model_Details/Qwen-Image.md](/docs/en/Model_Details/Qwen-Image.md)
|
#### Qwen-Image: [/docs/en/Model_Details/Qwen-Image.md](/docs/en/Model_Details/Qwen-Image.md)
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
@@ -420,9 +490,12 @@ Example code for Qwen-Image is available at: [/examples/qwen_image/](/examples/q
|
|||||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||||
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
||||||
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
||||||
|
|[FireRedTeam/FireRed-Image-Edit-1.0](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.0)|[code](/examples/qwen_image/model_inference/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.0.sh)|[code](/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.0.sh)|[code](/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.0.py)|
|
||||||
|
|[FireRedTeam/FireRed-Image-Edit-1.1](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.1)|[code](/examples/qwen_image/model_inference/FireRed-Image-Edit-1.1.py)|[code](/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.1.py)|[code](/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.1.sh)|[code](/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.1.py)|[code](/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.1.sh)|[code](/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.1.py)|
|
||||||
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|
||||||
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-Layered-Control-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control-V2.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control-V2.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
||||||
@@ -529,6 +602,143 @@ Example code for FLUX.1 is available at: [/examples/flux/](/examples/flux/)
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
#### ERNIE-Image: [/docs/en/Model_Details/ERNIE-Image.md](/docs/en/Model_Details/ERNIE-Image.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Quick Start</summary>
|
||||||
|
|
||||||
|
Running the following code will quickly load the [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 3GB VRAM.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = ErnieImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device='cuda',
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="一只黑白相间的中华田园犬",
|
||||||
|
negative_prompt="",
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
cfg_scale=4.0,
|
||||||
|
)
|
||||||
|
image.save("output.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Examples</summary>
|
||||||
|
|
||||||
|
Example code for ERNIE-Image is available at: [/examples/ernie_image/](/examples/ernie_image/)
|
||||||
|
|
||||||
|
| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
|
||||||
|
|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### JoyAI-Image: [/docs/en/Model_Details/JoyAI-Image.md](/docs/en/Model_Details/JoyAI-Image.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Quick Start</summary>
|
||||||
|
|
||||||
|
Running the following code will quickly load the [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 4GB VRAM.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
|
||||||
|
# Download dataset
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||||
|
local_dir="data/diffsynth_example_dataset",
|
||||||
|
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
|
||||||
|
)
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
|
||||||
|
pipe = JoyAIImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
|
||||||
|
],
|
||||||
|
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use first sample from dataset
|
||||||
|
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
|
||||||
|
prompt = "将裙子改为粉色"
|
||||||
|
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
|
||||||
|
|
||||||
|
output = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
edit_image=edit_image,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=0,
|
||||||
|
num_inference_steps=30,
|
||||||
|
cfg_scale=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
output.save("output_joyai_edit_low_vram.png")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Examples</summary>
|
||||||
|
|
||||||
|
Example code for JoyAI-Image is available at: [/examples/joyai_image/](/examples/joyai_image/)
|
||||||
|
|
||||||
|
| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
### Video Synthesis
|
### Video Synthesis
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
||||||
@@ -556,12 +766,26 @@ vram_config = {
|
|||||||
"computation_dtype": torch.bfloat16,
|
"computation_dtype": torch.bfloat16,
|
||||||
"computation_device": "cuda",
|
"computation_device": "cuda",
|
||||||
}
|
}
|
||||||
|
"""
|
||||||
|
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
|
||||||
|
Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
|
||||||
|
For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
|
||||||
|
and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
|
||||||
|
We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
|
||||||
|
and avoid redundant memory usage when users only want to use part of the model.
|
||||||
|
"""
|
||||||
|
# use the repackaged modelconfig from "DiffSynth-Studio/LTX-2-Repackage" to avoid redundant model loading
|
||||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
@@ -569,6 +793,20 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# use the following modelconfig if you want to initialize model from offical checkpoints from "Lightricks/LTX-2"
|
||||||
|
# pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
# torch_dtype=torch.bfloat16,
|
||||||
|
# device="cuda",
|
||||||
|
# model_configs=[
|
||||||
|
# ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||||
|
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
|
# ],
|
||||||
|
# tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
# stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||||
|
# vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
# )
|
||||||
|
|
||||||
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
|
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
|
||||||
negative_prompt = (
|
negative_prompt = (
|
||||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||||
@@ -613,7 +851,19 @@ Example code for LTX-2 is available at: [/examples/ltx2/](/examples/ltx2/)
|
|||||||
|
|
||||||
| Model ID | Extra Args | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
| Model ID | Extra Args | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||||
|-|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|-|
|
||||||
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-|
|
|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2.3-I2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2.3-I2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2.3-I2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-I2AV.py)|
|
||||||
|
|[Lightricks/LTX-2.3: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-TwoStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2.3: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-DistilledPipeline.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2.3: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py)|
|
||||||
|
|[Lightricks/LTX-2.3: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2.3: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-DistilledPipeline.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2.3: A2V](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`retake_audio`,`audio_sample_rate`,`retake_audio_regions`|[code](/examples/ltx2/model_inference/LTX-2.3-A2V-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-A2V-TwoStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2.3: Retake](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`retake_video`,`retake_video_regions`,`retake_audio`,`audio_sample_rate`,`retake_audio_regions`|[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage-Retake.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage-Retake.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py)|
|
||||||
|
|[Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control](https://www.modelscope.cn/models/Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Motion-Track-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Motion-Track-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py)|
|
||||||
|
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)|
|
||||||
|
|[Lightricks/LTX-2-19b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|
|
||||||
|
|[Lightricks/LTX-2-19b-IC-LoRA-Detailer](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Detailer)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Detailer.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|
|
||||||
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|
||||||
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|
||||||
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|
||||||
@@ -728,39 +978,123 @@ graph LR;
|
|||||||
|
|
||||||
Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
|
Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
|
||||||
|
|
||||||
| Model ID | Extra Args | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
| Model ID | Extra Inputs | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
||||||
|
|-|-|-|-|-|-|-|-|
|
||||||
|
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||||
|
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||||
|
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||||
|
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|
||||||
|
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|
||||||
|
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|
||||||
|
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
||||||
|
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|
||||||
|
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
||||||
|
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
|
||||||
|
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
|
||||||
|
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B_multi_clips.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
|
||||||
|
|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|
||||||
|
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
||||||
|
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
||||||
|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
||||||
|
|[openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference_low_vram/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-360P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-360P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py)|
|
||||||
|
|[openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference_low_vram/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-720P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-720P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py)|
|
||||||
|
|[Wan-AI/WanToDance-14B (global model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py)|
|
||||||
|
|[Wan-AI/WanToDance-14B (local model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-local.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### Audio Synthesis
|
||||||
|
|
||||||
|
#### ACE-Step: [/docs/en/Model_Details/ACE-Step.md](/docs/en/Model_Details/ACE-Step.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Quick Start</summary>
|
||||||
|
|
||||||
|
Running the following code will quickly load the [ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 3GB VRAM.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.audio import save_audio
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
pipe = AceStepPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||||
|
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||||
|
audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
lyrics=lyrics,
|
||||||
|
duration=160,
|
||||||
|
bpm=100,
|
||||||
|
keyscale="B minor",
|
||||||
|
timesignature="4",
|
||||||
|
vocal_language="zh",
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Examples</summary>
|
||||||
|
|
||||||
|
Example code for ACE-Step is available at: [/examples/ace_step/](/examples/ace_step/)
|
||||||
|
|
||||||
|
| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
|[ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5)|[code](/examples/ace_step/model_inference/Ace-Step1.5.py)|[code](/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py)|[code](/examples/ace_step/model_training/full/Ace-Step1.5.sh)|[code](/examples/ace_step/model_training/validate_full/Ace-Step1.5.py)|[code](/examples/ace_step/model_training/lora/Ace-Step1.5.sh)|[code](/examples/ace_step/model_training/validate_lora/Ace-Step1.5.py)|
|
||||||
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
|[ACE-Step/acestep-v15-turbo-shift1](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift1)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-shift1.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-shift1.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift1.py)|
|
||||||
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
|[ACE-Step/acestep-v15-turbo-shift3](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift3)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-shift3.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-shift3.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift3.py)|
|
||||||
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|
|[ACE-Step/acestep-v15-turbo-continuous](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-continuous)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-continuous.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-continuous.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-continuous.py)|
|
||||||
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|
|[ACE-Step/acestep-v15-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base.py)|[code](/examples/ace_step/model_training/full/acestep-v15-base.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-base.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-base.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-base.py)|
|
||||||
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|
|[ACE-Step/acestep-v15-base: CoverTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py)|—|—|—|—|
|
||||||
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
|[ACE-Step/acestep-v15-base: RepaintTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py)|—|—|—|—|
|
||||||
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
|[ACE-Step/acestep-v15-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-sft)|[code](/examples/ace_step/model_inference/acestep-v15-sft.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py)|[code](/examples/ace_step/model_training/full/acestep-v15-sft.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-sft.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-sft.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-sft.py)|
|
||||||
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|
|[ACE-Step/acestep-v15-xl-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-base)|[code](/examples/ace_step/model_inference/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-base.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-base.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-base.py)|
|
||||||
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|
|[ACE-Step/acestep-v15-xl-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-sft)|[code](/examples/ace_step/model_inference/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-sft.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-sft.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-sft.py)|
|
||||||
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|
|[ACE-Step/acestep-v15-xl-turbo](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-turbo)|[code](/examples/ace_step/model_inference/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-turbo.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-turbo.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-turbo.py)|
|
||||||
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|
|
||||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|
|
||||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|
|
||||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|
|
||||||
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|
|
||||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|
|
||||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|
|
||||||
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
|
||||||
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
|
|
||||||
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](/examples/wanvideo/model_inference/LongCat-Video.py)|[code](/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
|
|
||||||
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
|
|
||||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
|
||||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
|
||||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
|
||||||
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
|
||||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
|
|
||||||
|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|
|
||||||
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
|
||||||
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
|
||||||
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@@ -774,7 +1108,7 @@ DiffSynth-Studio is not just an engineered model framework, but also an incubato
|
|||||||
|
|
||||||
- Paper: [Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation
|
- Paper: [Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation
|
||||||
](https://arxiv.org/abs/2602.03208)
|
](https://arxiv.org/abs/2602.03208)
|
||||||
- Sample Code: coming soon
|
- Sample Code: [/docs/en/Research_Tutorial/inference_time_scaling.md](/docs/en/Research_Tutorial/inference_time_scaling.md)
|
||||||
|
|
||||||
|FLUX.1-dev|FLUX.1-dev + SES|Qwen-Image|Qwen-Image + SES|
|
|FLUX.1-dev|FLUX.1-dev + SES|Qwen-Image|Qwen-Image + SES|
|
||||||
|-|-|-|-|
|
|-|-|-|-|
|
||||||
@@ -918,3 +1252,9 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-47
|
|||||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
## Contact Us
|
||||||
|
|
||||||
|
|Discord:https://discord.gg/Mm9suEeUDc|
|
||||||
|
|-|
|
||||||
|
|<img width="160" height="160" alt="Image" src="https://github.com/user-attachments/assets/29bdc97b-e35d-4fea-88d6-32e35182e458" />|
|
||||||
|
|||||||
422
README_zh.md
422
README_zh.md
@@ -7,11 +7,14 @@
|
|||||||
[](https://github.com/modelscope/DiffSynth-Studio/issues)
|
[](https://github.com/modelscope/DiffSynth-Studio/issues)
|
||||||
[](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
|
[](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
|
||||||
[](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
|
[](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
|
||||||
|
[](https://discord.gg/Mm9suEeUDc)
|
||||||
|
|
||||||
[Switch to English](./README.md)
|
[Switch to English](./README.md)
|
||||||
|
|
||||||
## 简介
|
## 简介
|
||||||
|
|
||||||
|
> DiffSynth-Studio 文档:[中文版](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/)、[English version](https://diffsynth-studio-doc.readthedocs.io/en/latest/)
|
||||||
|
|
||||||
欢迎来到 Diffusion 模型的魔法世界!DiffSynth-Studio 是由[魔搭社区](https://www.modelscope.cn/)团队开发和维护的开源 Diffusion 模型引擎。我们期望以框架建设孵化技术创新,凝聚开源社区的力量,探索生成式模型技术的边界!
|
欢迎来到 Diffusion 模型的魔法世界!DiffSynth-Studio 是由[魔搭社区](https://www.modelscope.cn/)团队开发和维护的开源 Diffusion 模型引擎。我们期望以框架建设孵化技术创新,凝聚开源社区的力量,探索生成式模型技术的边界!
|
||||||
|
|
||||||
DiffSynth 目前包括两个开源项目:
|
DiffSynth 目前包括两个开源项目:
|
||||||
@@ -23,15 +26,31 @@ DiffSynth 目前包括两个开源项目:
|
|||||||
* 魔搭社区 AIGC 专区 (面向中国用户): https://modelscope.cn/aigc/home
|
* 魔搭社区 AIGC 专区 (面向中国用户): https://modelscope.cn/aigc/home
|
||||||
* ModelScope Civision (for global users): https://modelscope.ai/civision/home
|
* ModelScope Civision (for global users): https://modelscope.ai/civision/home
|
||||||
|
|
||||||
> DiffSynth-Studio 文档:[中文版](/docs/zh/README.md)、[English version](/docs/en/README.md)
|
|
||||||
|
|
||||||
我们相信,一个完善的开源代码框架能够降低技术探索的门槛,我们基于这个代码库搞出了不少[有意思的技术](#创新成果)。或许你也有许多天马行空的构想,借助 DiffSynth-Studio,你可以快速实现这些想法。为此,我们为开发者准备了详细的文档,我们希望通过这些文档,帮助开发者理解 Diffusion 模型的原理,更期待与你一同拓展技术的边界。
|
我们相信,一个完善的开源代码框架能够降低技术探索的门槛,我们基于这个代码库搞出了不少[有意思的技术](#创新成果)。或许你也有许多天马行空的构想,借助 DiffSynth-Studio,你可以快速实现这些想法。为此,我们为开发者准备了详细的文档,我们希望通过这些文档,帮助开发者理解 Diffusion 模型的原理,更期待与你一同拓展技术的边界。
|
||||||
|
|
||||||
## 更新历史
|
## 更新历史
|
||||||
|
|
||||||
> DiffSynth-Studio 经历了大版本更新,部分旧功能已停止维护,如需使用旧版功能,请切换到大版本更新前的[最后一个历史版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3)。
|
> DiffSynth-Studio 经历了大版本更新,部分旧功能已停止维护,如需使用旧版功能,请切换到大版本更新前的[最后一个历史版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3)。
|
||||||
|
|
||||||
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
|
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 和 [mi804](https://github.com/mi804) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
|
||||||
|
|
||||||
|
- **2026年4月23日** ACE-Step 开源,欢迎加入音频生成模型家族!支持文生音乐推理、低显存推理和 LoRA 训练能力。详情请参考[文档](/docs/zh/Model_Details/ACE-Step.md)和[示例代码](/examples/ace_step/)。
|
||||||
|
|
||||||
|
- **2026年4月14日** JoyAI-Image 开源,欢迎加入图像编辑模型家族!支持指令引导的图像编辑推理、低显存推理和训练能力。详情请参考[文档](/docs/zh/Model_Details/JoyAI-Image.md)和[示例代码](/examples/joyai_image/)。
|
||||||
|
|
||||||
|
- **2026年3月19日** 新增对 [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) 和 [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/Wan.md)和[示例代码](/examples/mova/)现已可用。
|
||||||
|
|
||||||
|
- **2026年3月12日** 我们新增了 [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) 音视频生成模型的支持,模型支持的功能包括文生音视频、图生音视频、IC-LoRA控制、音频生视频、音视频局部Inpainting,框架支持完整的推理和训练功能。详细信息请参考 [文档](/docs/zh/Model_Details/LTX-2.md) 和 [示例代码](/examples/ltx2/)。
|
||||||
|
|
||||||
|
- **2026年3月3日** 我们发布了 [DiffSynth-Studio/Qwen-Image-Layered-Control-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control-V2) 模型,这是 Qwen-Image-Layered-Control 的更新版本。除了原本就支持的文本引导功能,新增了画笔控制的图层拆分能力。
|
||||||
|
|
||||||
|
- **2026年3月2日** 新增对[Anima](https://modelscope.cn/models/circlestone-labs/Anima)的支持,详见[文档](docs/zh/Model_Details/Anima.md)。这是一个有趣的动漫风格图像生成模型,我们期待其后续的模型更新。
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>更多</summary>
|
||||||
|
|
||||||
|
- **2026年2月26日** 新增对[LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2)音视频生成模型全量微调与LoRA训练支持,详见[文档](docs/zh/Model_Details/LTX-2.md)。
|
||||||
|
|
||||||
- **2026年2月10日** 新增对[LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2)音视频生成模型的推理支持,详见[文档](docs/zh/Model_Details/LTX-2.md),后续将推进模型训练的支持。
|
- **2026年2月10日** 新增对[LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2)音视频生成模型的推理支持,详见[文档](docs/zh/Model_Details/LTX-2.md),后续将推进模型训练的支持。
|
||||||
|
|
||||||
- **2026年2月2日** Research Tutorial 的第一篇文档上线,带你从零开始训练一个 0.1B 的小型文生图模型,详见[文档](/docs/zh/Research_Tutorial/train_from_scratch.md)、[模型](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel),我们希望 DiffSynth-Studio 能够成为一个更强大的 Diffusion 模型训练框架。
|
- **2026年2月2日** Research Tutorial 的第一篇文档上线,带你从零开始训练一个 0.1B 的小型文生图模型,详见[文档](/docs/zh/Research_Tutorial/train_from_scratch.md)、[模型](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel),我们希望 DiffSynth-Studio 能够成为一个更强大的 Diffusion 模型训练框架。
|
||||||
@@ -57,9 +76,6 @@ DiffSynth 目前包括两个开源项目:
|
|||||||
- [差分 LoRA 训练](/docs/zh/Training/Differential_LoRA.md):这是我们曾在 [ArtAug](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) 中使用的训练技术,目前已可用于任意模型的 LoRA 训练。
|
- [差分 LoRA 训练](/docs/zh/Training/Differential_LoRA.md):这是我们曾在 [ArtAug](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) 中使用的训练技术,目前已可用于任意模型的 LoRA 训练。
|
||||||
- [FP8 训练](/docs/zh/Training/FP8_Precision.md):FP8 在训练中支持应用到任意非训练模型,即梯度关闭或者梯度仅影响 LoRA 权重的模型。
|
- [FP8 训练](/docs/zh/Training/FP8_Precision.md):FP8 在训练中支持应用到任意非训练模型,即梯度关闭或者梯度仅影响 LoRA 权重的模型。
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary>更多</summary>
|
|
||||||
|
|
||||||
- **2025年11月4日** 支持了 [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) 模型,该模型基于 Wan 2.1 训练,支持根据参考视频生成相应的动作。
|
- **2025年11月4日** 支持了 [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) 模型,该模型基于 Wan 2.1 训练,支持根据参考视频生成相应的动作。
|
||||||
|
|
||||||
- **2025年10月30日** 支持了 [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) 模型,该模型支持文生视频、图生视频、视频续写。这个模型在本项目中沿用 Wan 的框架进行推理和训练。
|
- **2025年10月30日** 支持了 [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) 模型,该模型支持文生视频、图生视频、视频续写。这个模型在本项目中沿用 Wan 的框架进行推理和训练。
|
||||||
@@ -341,6 +357,60 @@ FLUX.2 的示例代码位于:[/examples/flux2/](/examples/flux2/)
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
#### Anima: [/docs/zh/Model_Details/Anima.md](/docs/zh/Model_Details/Anima.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": "disk",
|
||||||
|
"offload_device": "disk",
|
||||||
|
"onload_dtype": "disk",
|
||||||
|
"onload_device": "disk",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = AnimaImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
|
||||||
|
tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
prompt = "Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait."
|
||||||
|
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||||
|
image = pipe(prompt, seed=0, num_inference_steps=50)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
Anima 的示例代码位于:[/examples/anima/](/examples/anima/)
|
||||||
|
|
||||||
|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima)|[code](/examples/anima/model_inference/anima-preview.py)|[code](/examples/anima/model_inference_low_vram/anima-preview.py)|[code](/examples/anima/model_training/full/anima-preview.sh)|[code](/examples/anima/model_training/validate_full/anima-preview.py)|[code](/examples/anima/model_training/lora/anima-preview.sh)|[code](/examples/anima/model_training/validate_lora/anima-preview.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
#### Qwen-Image: [/docs/zh/Model_Details/Qwen-Image.md](/docs/zh/Model_Details/Qwen-Image.md)
|
#### Qwen-Image: [/docs/zh/Model_Details/Qwen-Image.md](/docs/zh/Model_Details/Qwen-Image.md)
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
@@ -420,9 +490,12 @@ Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/
|
|||||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||||
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
||||||
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
||||||
|
|[FireRedTeam/FireRed-Image-Edit-1.0](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.0)|[code](/examples/qwen_image/model_inference/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.0.sh)|[code](/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.0.sh)|[code](/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.0.py)|
|
||||||
|
|[FireRedTeam/FireRed-Image-Edit-1.1](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.1)|[code](/examples/qwen_image/model_inference/FireRed-Image-Edit-1.1.py)|[code](/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.1.py)|[code](/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.1.sh)|[code](/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.1.py)|[code](/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.1.sh)|[code](/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.1.py)|
|
||||||
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|
||||||
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-Layered-Control-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control-V2.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control-V2.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
||||||
@@ -529,6 +602,143 @@ FLUX.1 的示例代码位于:[/examples/flux/](/examples/flux/)
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
#### ERNIE-Image: [/docs/zh/Model_Details/ERNIE-Image.md](/docs/zh/Model_Details/ERNIE-Image.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = ErnieImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device='cuda',
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="一只黑白相间的中华田园犬",
|
||||||
|
negative_prompt="",
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
cfg_scale=4.0,
|
||||||
|
)
|
||||||
|
image.save("output.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
ERNIE-Image 的示例代码位于:[/examples/ernie_image/](/examples/ernie_image/)
|
||||||
|
|
||||||
|
| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 |
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
|
||||||
|
|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### JoyAI-Image: [/docs/zh/Model_Details/JoyAI-Image.md](/docs/zh/Model_Details/JoyAI-Image.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 4G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
|
||||||
|
# Download dataset
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||||
|
local_dir="data/diffsynth_example_dataset",
|
||||||
|
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
|
||||||
|
)
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
|
||||||
|
pipe = JoyAIImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
|
||||||
|
],
|
||||||
|
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use first sample from dataset
|
||||||
|
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
|
||||||
|
prompt = "将裙子改为粉色"
|
||||||
|
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
|
||||||
|
|
||||||
|
output = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
edit_image=edit_image,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=0,
|
||||||
|
num_inference_steps=30,
|
||||||
|
cfg_scale=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
output.save("output_joyai_edit_low_vram.png")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
JoyAI-Image 的示例代码位于:[/examples/joyai_image/](/examples/joyai_image/)
|
||||||
|
|
||||||
|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
### 视频生成模型
|
### 视频生成模型
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
||||||
@@ -556,12 +766,26 @@ vram_config = {
|
|||||||
"computation_dtype": torch.bfloat16,
|
"computation_dtype": torch.bfloat16,
|
||||||
"computation_device": "cuda",
|
"computation_device": "cuda",
|
||||||
}
|
}
|
||||||
|
"""
|
||||||
|
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
|
||||||
|
Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
|
||||||
|
For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
|
||||||
|
and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
|
||||||
|
We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
|
||||||
|
and avoid redundant memory usage when users only want to use part of the model.
|
||||||
|
"""
|
||||||
|
# use the repackaged modelconfig from "DiffSynth-Studio/LTX-2-Repackage" to avoid redundant model loading
|
||||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
@@ -569,6 +793,20 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# use the following modelconfig if you want to initialize model from offical checkpoints from "Lightricks/LTX-2"
|
||||||
|
# pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
# torch_dtype=torch.bfloat16,
|
||||||
|
# device="cuda",
|
||||||
|
# model_configs=[
|
||||||
|
# ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||||
|
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
|
# ],
|
||||||
|
# tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
# stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||||
|
# vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
# )
|
||||||
|
|
||||||
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
|
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
|
||||||
negative_prompt = (
|
negative_prompt = (
|
||||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||||
@@ -613,7 +851,19 @@ LTX-2 的示例代码位于:[/examples/ltx2/](/examples/ltx2/)
|
|||||||
|
|
||||||
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|-|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|-|
|
||||||
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-|
|
|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2.3-I2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2.3-I2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2.3-I2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-I2AV.py)|
|
||||||
|
|[Lightricks/LTX-2.3: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-TwoStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2.3: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-DistilledPipeline.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2.3: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py)|
|
||||||
|
|[Lightricks/LTX-2.3: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2.3: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-DistilledPipeline.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2.3: A2V](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`retake_audio`,`audio_sample_rate`,`retake_audio_regions`|[code](/examples/ltx2/model_inference/LTX-2.3-A2V-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-A2V-TwoStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2.3: Retake](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`retake_video`,`retake_video_regions`,`retake_audio`,`audio_sample_rate`,`retake_audio_regions`|[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage-Retake.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage-Retake.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py)|
|
||||||
|
|[Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control](https://www.modelscope.cn/models/Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Motion-Track-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Motion-Track-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py)|
|
||||||
|
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)|
|
||||||
|
|[Lightricks/LTX-2-19b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|
|
||||||
|
|[Lightricks/LTX-2-19b-IC-LoRA-Detailer](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Detailer)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Detailer.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|
|
||||||
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|
||||||
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|
||||||
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|
||||||
@@ -728,39 +978,123 @@ graph LR;
|
|||||||
|
|
||||||
Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/)
|
Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/)
|
||||||
|
|
||||||
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|-|
|
||||||
|
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||||
|
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||||
|
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||||
|
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|
||||||
|
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|
||||||
|
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|
||||||
|
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
||||||
|
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|
||||||
|
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
||||||
|
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
|
||||||
|
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
|
||||||
|
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B_multi_clips.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
|
||||||
|
|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|
||||||
|
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
||||||
|
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
||||||
|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
||||||
|
|[openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference_low_vram/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-360P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-360P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py)|
|
||||||
|
|[openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference_low_vram/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-720P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-720P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py)|
|
||||||
|
|[Wan-AI/WanToDance-14B (global model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py)|
|
||||||
|
|[Wan-AI/WanToDance-14B (local model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-local.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 音频生成模型
|
||||||
|
|
||||||
|
#### ACE-Step: [/docs/zh/Model_Details/ACE-Step.md](/docs/zh/Model_Details/ACE-Step.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.audio import save_audio
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
pipe = AceStepPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||||
|
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||||
|
audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
lyrics=lyrics,
|
||||||
|
duration=160,
|
||||||
|
bpm=100,
|
||||||
|
keyscale="B minor",
|
||||||
|
timesignature="4",
|
||||||
|
vocal_language="zh",
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
ACE-Step 的示例代码位于:[/examples/ace_step/](/examples/ace_step/)
|
||||||
|
|
||||||
|
| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 |
|
||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
|[ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5)|[code](/examples/ace_step/model_inference/Ace-Step1.5.py)|[code](/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py)|[code](/examples/ace_step/model_training/full/Ace-Step1.5.sh)|[code](/examples/ace_step/model_training/validate_full/Ace-Step1.5.py)|[code](/examples/ace_step/model_training/lora/Ace-Step1.5.sh)|[code](/examples/ace_step/model_training/validate_lora/Ace-Step1.5.py)|
|
||||||
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
|[ACE-Step/acestep-v15-turbo-shift1](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift1)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-shift1.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-shift1.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift1.py)|
|
||||||
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
|[ACE-Step/acestep-v15-turbo-shift3](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift3)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-shift3.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-shift3.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift3.py)|
|
||||||
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|
|[ACE-Step/acestep-v15-turbo-continuous](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-continuous)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-continuous.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-continuous.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-continuous.py)|
|
||||||
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|
|[ACE-Step/acestep-v15-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base.py)|[code](/examples/ace_step/model_training/full/acestep-v15-base.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-base.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-base.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-base.py)|
|
||||||
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|
|[ACE-Step/acestep-v15-base: CoverTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py)|—|—|—|—|
|
||||||
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
|[ACE-Step/acestep-v15-base: RepaintTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py)|—|—|—|—|
|
||||||
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
|[ACE-Step/acestep-v15-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-sft)|[code](/examples/ace_step/model_inference/acestep-v15-sft.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py)|[code](/examples/ace_step/model_training/full/acestep-v15-sft.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-sft.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-sft.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-sft.py)|
|
||||||
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|
|[ACE-Step/acestep-v15-xl-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-base)|[code](/examples/ace_step/model_inference/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-base.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-base.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-base.py)|
|
||||||
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|
|[ACE-Step/acestep-v15-xl-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-sft)|[code](/examples/ace_step/model_inference/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-sft.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-sft.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-sft.py)|
|
||||||
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|
|[ACE-Step/acestep-v15-xl-turbo](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-turbo)|[code](/examples/ace_step/model_inference/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-turbo.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-turbo.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-turbo.py)|
|
||||||
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|
|
||||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|
|
||||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|
|
||||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|
|
||||||
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|
|
||||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|
|
||||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|
|
||||||
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
|
||||||
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
|
|
||||||
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](/examples/wanvideo/model_inference/LongCat-Video.py)|[code](/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
|
|
||||||
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
|
|
||||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
|
||||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
|
||||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
|
||||||
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
|
||||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
|
|
||||||
|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|
|
||||||
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
|
||||||
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
|
||||||
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@@ -774,7 +1108,7 @@ DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果
|
|||||||
|
|
||||||
- 论文:[Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation
|
- 论文:[Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation
|
||||||
](https://arxiv.org/abs/2602.03208)
|
](https://arxiv.org/abs/2602.03208)
|
||||||
- 代码样例:coming soon
|
- 代码样例:[/docs/en/Research_Tutorial/inference_time_scaling.md](/docs/en/Research_Tutorial/inference_time_scaling.md)
|
||||||
|
|
||||||
|FLUX.1-dev|FLUX.1-dev + SES|Qwen-Image|Qwen-Image + SES|
|
|FLUX.1-dev|FLUX.1-dev + SES|Qwen-Image|Qwen-Image + SES|
|
||||||
|-|-|-|-|
|
|-|-|-|-|
|
||||||
@@ -920,3 +1254,9 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-47
|
|||||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
## 联系我们
|
||||||
|
|
||||||
|
|Discord:https://discord.gg/Mm9suEeUDc|
|
||||||
|
|-|
|
||||||
|
|<img width="160" height="160" alt="Image" src="https://github.com/user-attachments/assets/29bdc97b-e35d-4fea-88d6-32e35182e458" />|
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
from .model_configs import MODEL_CONFIGS
|
from .model_configs import MODEL_CONFIGS
|
||||||
from .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS
|
from .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS, VERSION_CHECKER_MAPS
|
||||||
|
|||||||
@@ -307,6 +307,13 @@ wan_series = [
|
|||||||
"model_class": "diffsynth.models.wav2vec.WanS2VAudioEncoder",
|
"model_class": "diffsynth.models.wav2vec.WanS2VAudioEncoder",
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="global_model.safetensors")
|
||||||
|
"model_hash": "eb18873fc0ba77b541eb7b62dbcd2059",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'wantodance_enable_music_inject': True, 'wantodance_music_inject_layers': [0, 4, 8, 12, 16, 20, 24, 27], 'wantodance_enable_refimage': True, 'has_ref_conv': True, 'wantodance_enable_refface': False, 'wantodance_enable_global': True, 'wantodance_enable_dynamicfps': True, 'wantodance_enable_unimodel': True}
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
flux_series = [
|
flux_series = [
|
||||||
@@ -534,6 +541,22 @@ flux2_series = [
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
ernie_image_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "584c13713849f1af4e03d5f1858b8b7b",
|
||||||
|
"model_name": "ernie_image_dit",
|
||||||
|
"model_class": "diffsynth.models.ernie_image_dit.ErnieImageDiT",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors")
|
||||||
|
"model_hash": "404ed9f40796a38dd34c1620f1920207",
|
||||||
|
"model_name": "ernie_image_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.ernie_image_text_encoder.ErnieImageTextEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ernie_image_text_encoder.ErnieImageTextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
z_image_series = [
|
z_image_series = [
|
||||||
{
|
{
|
||||||
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors")
|
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors")
|
||||||
@@ -597,8 +620,22 @@ z_image_series = [
|
|||||||
"extra_kwargs": {"model_size": "0.6B"},
|
"extra_kwargs": {"model_size": "0.6B"},
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
# To ensure compatibility with the `model.diffusion_model` prefix introduced by other frameworks.
|
||||||
|
"model_hash": "8cf241a0d32f93d5de368502a086852f",
|
||||||
|
"model_name": "z_image_dit",
|
||||||
|
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_dit.ZImageDiTStateDictConverter",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
"""
|
||||||
|
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
|
||||||
|
Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
|
||||||
|
For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
|
||||||
|
and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
|
||||||
|
We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
|
||||||
|
and avoid redundant memory usage when users only want to use part of the model.
|
||||||
|
"""
|
||||||
ltx2_series = [
|
ltx2_series = [
|
||||||
{
|
{
|
||||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
@@ -607,6 +644,13 @@ ltx2_series = [
|
|||||||
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors")
|
||||||
|
"model_hash": "c567aaa37d5ed7454c73aa6024458661",
|
||||||
|
"model_name": "ltx2_dit",
|
||||||
|
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
@@ -614,6 +658,13 @@ ltx2_series = [
|
|||||||
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors")
|
||||||
|
"model_hash": "7f7e904a53260ec0351b05f32153754b",
|
||||||
|
"model_name": "ltx2_video_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
@@ -621,6 +672,13 @@ ltx2_series = [
|
|||||||
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors")
|
||||||
|
"model_hash": "dc6029ca2825147872b45e35a2dc3a97",
|
||||||
|
"model_name": "ltx2_video_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
@@ -628,6 +686,13 @@ ltx2_series = [
|
|||||||
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors")
|
||||||
|
"model_hash": "7d7823dde8f1ea0b50fb07ac329dd4cb",
|
||||||
|
"model_name": "ltx2_audio_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
@@ -635,16 +700,37 @@ ltx2_series = [
|
|||||||
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
||||||
},
|
},
|
||||||
# { # not used currently
|
{
|
||||||
# # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors")
|
||||||
# "model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
"model_hash": "f471360f6b24bef702ab73133d9f8bb9",
|
||||||
# "model_name": "ltx2_audio_vae_encoder",
|
"model_name": "ltx2_audio_vocoder",
|
||||||
# "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
|
||||||
# "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
||||||
# },
|
},
|
||||||
{
|
{
|
||||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_audio_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_encoder.safetensors")
|
||||||
|
"model_hash": "29338f3b95e7e312a3460a482e4f4554",
|
||||||
|
"model_name": "ltx2_audio_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_text_encoder_post_modules",
|
||||||
|
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors")
|
||||||
|
"model_hash": "981629689c8be92a712ab3c5eb4fc3f6",
|
||||||
"model_name": "ltx2_text_encoder_post_modules",
|
"model_name": "ltx2_text_encoder_post_modules",
|
||||||
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
||||||
@@ -662,5 +748,254 @@ ltx2_series = [
|
|||||||
"model_name": "ltx2_latent_upsampler",
|
"model_name": "ltx2_latent_upsampler",
|
||||||
"model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
|
"model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
|
||||||
|
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
|
||||||
|
"model_name": "ltx2_dit",
|
||||||
|
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
||||||
|
"extra_kwargs": {"apply_gated_attention": True, "cross_attention_adaln": True, "caption_channels": None},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
|
||||||
|
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
|
||||||
|
"model_name": "ltx2_video_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
||||||
|
"extra_kwargs": {"encoder_version": "ltx-2.3"},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
|
||||||
|
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
|
||||||
|
"model_name": "ltx2_video_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
||||||
|
"extra_kwargs": {"decoder_version": "ltx-2.3"},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
|
||||||
|
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
|
||||||
|
"model_name": "ltx2_audio_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
|
||||||
|
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
|
||||||
|
"model_name": "ltx2_audio_vocoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2VocoderWithBWE",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
|
||||||
|
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
|
||||||
|
"model_name": "ltx2_audio_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
|
||||||
|
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
|
||||||
|
"model_name": "ltx2_text_encoder_post_modules",
|
||||||
|
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
||||||
|
"extra_kwargs": {"separated_audio_video": True, "embedding_dim_gemma": 3840, "num_layers_gemma": 49, "video_attention_heads": 32, "video_attention_head_dim": 128, "audio_attention_heads": 32, "audio_attention_head_dim": 64, "num_connector_layers": 8, "apply_gated_attention": True},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
|
||||||
|
"model_hash": "aed408774d694a2452f69936c32febb5",
|
||||||
|
"model_name": "ltx2_latent_upsampler",
|
||||||
|
"model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
|
||||||
|
"extra_kwargs": {"rational_resampler": False},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="transformer.safetensors")
|
||||||
|
"model_hash": "1c55afad76ed33c112a2978550b524d1",
|
||||||
|
"model_name": "ltx2_dit",
|
||||||
|
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
||||||
|
"extra_kwargs": {"apply_gated_attention": True, "cross_attention_adaln": True, "caption_channels": None},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="video_vae_encoder.safetensors")
|
||||||
|
"model_hash": "eecdc07c2ec30863b8a2b8b2134036cf",
|
||||||
|
"model_name": "ltx2_video_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
||||||
|
"extra_kwargs": {"encoder_version": "ltx-2.3"},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="video_vae_decoder.safetensors")
|
||||||
|
"model_hash": "deda2f542e17ee25bc8c38fd605316ea",
|
||||||
|
"model_name": "ltx2_video_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
||||||
|
"extra_kwargs": {"decoder_version": "ltx-2.3"},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors")
|
||||||
|
"model_hash": "7d7823dde8f1ea0b50fb07ac329dd4cb",
|
||||||
|
"model_name": "ltx2_audio_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vae_encoder.safetensors")
|
||||||
|
"model_hash": "29338f3b95e7e312a3460a482e4f4554",
|
||||||
|
"model_name": "ltx2_audio_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors")
|
||||||
|
"model_hash": "cd436c99e69ec5c80f050f0944f02a15",
|
||||||
|
"model_name": "ltx2_audio_vocoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2VocoderWithBWE",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors")
|
||||||
|
"model_hash": "05da2aab1c4b061f72c426311c165a43",
|
||||||
|
"model_name": "ltx2_text_encoder_post_modules",
|
||||||
|
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
||||||
|
"extra_kwargs": {"separated_audio_video": True, "embedding_dim_gemma": 3840, "num_layers_gemma": 49, "video_attention_heads": 32, "video_attention_head_dim": 128, "audio_attention_heads": 32, "audio_attention_head_dim": 64, "num_connector_layers": 8, "apply_gated_attention": True},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series
|
anima_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors")
|
||||||
|
"model_hash": "a9995952c2d8e63cf82e115005eb61b9",
|
||||||
|
"model_name": "z_image_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
||||||
|
"extra_kwargs": {"model_size": "0.6B"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors")
|
||||||
|
"model_hash": "417673936471e79e31ed4d186d7a3f4a",
|
||||||
|
"model_name": "anima_dit",
|
||||||
|
"model_class": "diffsynth.models.anima_dit.AnimaDiT",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.anima_dit.AnimaDiTStateDictConverter",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
mova_series = [
|
||||||
|
# Example: ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors")
|
||||||
|
{
|
||||||
|
"model_hash": "8c57e12790e2c45a64817e0ce28cde2f",
|
||||||
|
"model_name": "mova_audio_dit",
|
||||||
|
"model_class": "diffsynth.models.mova_audio_dit.MovaAudioDit",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1], 'in_dim': 128, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 128, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
|
||||||
|
},
|
||||||
|
# Example: ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors")
|
||||||
|
{
|
||||||
|
"model_hash": "418517fb2b4e919d2cac8f314fcf82ac",
|
||||||
|
"model_name": "mova_audio_vae",
|
||||||
|
"model_class": "diffsynth.models.mova_audio_vae.DacVAE",
|
||||||
|
},
|
||||||
|
# Example: ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors")
|
||||||
|
{
|
||||||
|
"model_hash": "d1139dbbc8b4ab53cf4b4243d57bbceb",
|
||||||
|
"model_name": "mova_dual_tower_bridge",
|
||||||
|
"model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
joyai_image_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth")
|
||||||
|
"model_hash": "56592ddfd7d0249d3aa527d24161a863",
|
||||||
|
"model_name": "joyai_image_dit",
|
||||||
|
"model_class": "diffsynth.models.joyai_image_dit.JoyAIImageDiT",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model-*.safetensors")
|
||||||
|
"model_hash": "2d11bf14bba8b4e87477c8199a895403",
|
||||||
|
"model_name": "joyai_image_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.joyai_image_text_encoder.JoyAIImageTextEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.joyai_image_text_encoder.JoyAIImageTextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
ace_step_series = [
|
||||||
|
# === Standard DiT variants (24 layers, hidden_size=2048) ===
|
||||||
|
# Covers: turbo, turbo-shift1, turbo-shift3, turbo-continuous, base, sft
|
||||||
|
# All share identical state_dict structure → same hash
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors")
|
||||||
|
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
|
||||||
|
"model_name": "ace_step_dit",
|
||||||
|
"model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.AceStepDiTModelStateDictConverter",
|
||||||
|
},
|
||||||
|
# === XL DiT variants (32 layers, hidden_size=2560) ===
|
||||||
|
# Covers: xl-base, xl-sft, xl-turbo
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors")
|
||||||
|
"model_hash": "3a28a410c2246f125153ef792d8bc828",
|
||||||
|
"model_name": "ace_step_dit",
|
||||||
|
"model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.AceStepDiTModelStateDictConverter",
|
||||||
|
"extra_kwargs": {
|
||||||
|
"hidden_size": 2560,
|
||||||
|
"intermediate_size": 9728,
|
||||||
|
"num_hidden_layers": 32,
|
||||||
|
"num_attention_heads": 32,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"head_dim": 128,
|
||||||
|
"encoder_hidden_size": 2048,
|
||||||
|
"layer_types": ["sliding_attention", "full_attention"] * 16,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
# === Conditioner (shared by all DiT variants, same architecture) ===
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors")
|
||||||
|
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
|
||||||
|
"model_name": "ace_step_conditioner",
|
||||||
|
"model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.AceStepConditionEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
# === XL Conditioner (same architecture, but checkpoint includes XL decoder → different file hash) ===
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors")
|
||||||
|
"model_hash": "3a28a410c2246f125153ef792d8bc828",
|
||||||
|
"model_name": "ace_step_conditioner",
|
||||||
|
"model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.AceStepConditionEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
# === Qwen3-Embedding (text encoder) ===
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors")
|
||||||
|
"model_hash": "3509bea17b0e8cffc3dd4a15cc7899d0",
|
||||||
|
"model_name": "ace_step_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.ace_step_text_encoder.AceStepTextEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_text_encoder.AceStepTextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
# === VAE (AutoencoderOobleck CNN) ===
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "51420834e54474986a7f4be0e4d6f687",
|
||||||
|
"model_name": "ace_step_vae",
|
||||||
|
"model_class": "diffsynth.models.ace_step_vae.AceStepVAE",
|
||||||
|
},
|
||||||
|
# === Tokenizer (VAE latent discretization: tokenizer + detokenizer) ===
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors")
|
||||||
|
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
|
||||||
|
"model_name": "ace_step_tokenizer",
|
||||||
|
"model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.AceStepTokenizerStateDictConverter",
|
||||||
|
},
|
||||||
|
# === XL Tokenizer (XL models share same tokenizer architecture) ===
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors")
|
||||||
|
"model_hash": "3a28a410c2246f125153ef792d8bc828",
|
||||||
|
"model_name": "ace_step_tokenizer",
|
||||||
|
"model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.AceStepTokenizerStateDictConverter",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
MODEL_CONFIGS = (
|
||||||
|
qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series
|
||||||
|
+ z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series
|
||||||
|
)
|
||||||
|
|||||||
@@ -243,4 +243,108 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
|||||||
"transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
"transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
"transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
"transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
},
|
},
|
||||||
|
"diffsynth.models.anima_dit.AnimaDiT": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.mova_audio_dit.MovaAudioDit": {
|
||||||
|
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
|
||||||
|
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.mova_audio_vae.DacVAE": {
|
||||||
|
"diffsynth.models.mova_audio_vae.Snake1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ernie_image_dit.ErnieImageDiT": {
|
||||||
|
"diffsynth.models.ernie_image_dit.ErnieImageRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ernie_image_text_encoder.ErnieImageTextEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.ministral3.modeling_ministral3.Ministral3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.joyai_image_dit.Transformer3DModel": {
|
||||||
|
"diffsynth.models.joyai_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.joyai_image_dit.ModulateWan": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.joyai_image_text_encoder.JoyAIImageTextEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLVisionModel": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
# ACE-Step module maps
|
||||||
|
"diffsynth.models.ace_step_dit.AceStepDiTModel": {
|
||||||
|
"diffsynth.models.ace_step_dit.AceStepDiTLayer": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ace_step_conditioner.AceStepConditionEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ace_step_text_encoder.AceStepTextEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ace_step_vae.AceStepVAE": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.ace_step_vae.Snake1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ace_step_tokenizer.AceStepTokenizer": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"vector_quantize_pytorch.ResidualFSQ": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def QwenImageTextEncoder_Module_Map_Updater():
|
||||||
|
current = VRAM_MANAGEMENT_MODULE_MAPS["diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder"]
|
||||||
|
from packaging import version
|
||||||
|
import transformers
|
||||||
|
if version.parse(transformers.__version__) >= version.parse("5.2.0"):
|
||||||
|
# The Qwen2RMSNorm in transformers 5.2.0+ has been renamed to Qwen2_5_VLRMSNorm, so we need to update the module map accordingly
|
||||||
|
current.pop("transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm", None)
|
||||||
|
current["transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRMSNorm"] = "diffsynth.core.vram.layers.AutoWrappedModule"
|
||||||
|
return current
|
||||||
|
|
||||||
|
VERSION_CHECKER_MAPS = {
|
||||||
|
"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": QwenImageTextEncoder_Module_Map_Updater,
|
||||||
}
|
}
|
||||||
@@ -1,6 +1,9 @@
|
|||||||
|
import math, warnings
|
||||||
import torch, torchvision, imageio, os
|
import torch, torchvision, imageio, os
|
||||||
import imageio.v3 as iio
|
import imageio.v3 as iio
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
import torchaudio
|
||||||
|
from diffsynth.utils.data.audio import read_audio
|
||||||
|
|
||||||
|
|
||||||
class DataProcessingPipeline:
|
class DataProcessingPipeline:
|
||||||
@@ -105,27 +108,59 @@ class ToList(DataProcessingOperator):
|
|||||||
return [data]
|
return [data]
|
||||||
|
|
||||||
|
|
||||||
class LoadVideo(DataProcessingOperator):
|
class FrameSamplerByRateMixin:
|
||||||
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
|
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_rate=24, fix_frame_rate=False):
|
||||||
self.num_frames = num_frames
|
self.num_frames = num_frames
|
||||||
self.time_division_factor = time_division_factor
|
self.time_division_factor = time_division_factor
|
||||||
self.time_division_remainder = time_division_remainder
|
self.time_division_remainder = time_division_remainder
|
||||||
# frame_processor is build in the video loader for high efficiency.
|
self.frame_rate = frame_rate
|
||||||
self.frame_processor = frame_processor
|
self.fix_frame_rate = fix_frame_rate
|
||||||
|
|
||||||
|
def get_reader(self, data: str):
|
||||||
|
return imageio.get_reader(data)
|
||||||
|
|
||||||
|
def get_available_num_frames(self, reader):
|
||||||
|
if not self.fix_frame_rate:
|
||||||
|
return reader.count_frames()
|
||||||
|
meta_data = reader.get_meta_data()
|
||||||
|
total_original_frames = int(reader.count_frames())
|
||||||
|
duration = meta_data["duration"] if "duration" in meta_data else total_original_frames / meta_data['fps']
|
||||||
|
total_available_frames = math.floor(duration * self.frame_rate)
|
||||||
|
return int(total_available_frames)
|
||||||
|
|
||||||
def get_num_frames(self, reader):
|
def get_num_frames(self, reader):
|
||||||
num_frames = self.num_frames
|
num_frames = self.num_frames
|
||||||
if int(reader.count_frames()) < num_frames:
|
total_frames = self.get_available_num_frames(reader)
|
||||||
num_frames = int(reader.count_frames())
|
if int(total_frames) < num_frames:
|
||||||
|
num_frames = total_frames
|
||||||
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
||||||
num_frames -= 1
|
num_frames -= 1
|
||||||
return num_frames
|
return num_frames
|
||||||
|
|
||||||
|
def map_single_frame_id(self, new_sequence_id: int, raw_frame_rate: float, total_raw_frames: int) -> int:
|
||||||
|
if not self.fix_frame_rate:
|
||||||
|
return new_sequence_id
|
||||||
|
target_time_in_seconds = new_sequence_id / self.frame_rate
|
||||||
|
raw_frame_index_float = target_time_in_seconds * raw_frame_rate
|
||||||
|
frame_id = int(round(raw_frame_index_float))
|
||||||
|
frame_id = min(frame_id, total_raw_frames - 1)
|
||||||
|
return frame_id
|
||||||
|
|
||||||
|
|
||||||
|
class LoadVideo(DataProcessingOperator, FrameSamplerByRateMixin):
|
||||||
|
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x, frame_rate=24, fix_frame_rate=False):
|
||||||
|
FrameSamplerByRateMixin.__init__(self, num_frames, time_division_factor, time_division_remainder, frame_rate, fix_frame_rate)
|
||||||
|
# frame_processor is build in the video loader for high efficiency.
|
||||||
|
self.frame_processor = frame_processor
|
||||||
|
|
||||||
def __call__(self, data: str):
|
def __call__(self, data: str):
|
||||||
reader = imageio.get_reader(data)
|
reader = self.get_reader(data)
|
||||||
|
raw_frame_rate = reader.get_meta_data()['fps']
|
||||||
num_frames = self.get_num_frames(reader)
|
num_frames = self.get_num_frames(reader)
|
||||||
|
total_raw_frames = reader.count_frames()
|
||||||
frames = []
|
frames = []
|
||||||
for frame_id in range(num_frames):
|
for frame_id in range(num_frames):
|
||||||
|
frame_id = self.map_single_frame_id(frame_id, raw_frame_rate, total_raw_frames)
|
||||||
frame = reader.get_data(frame_id)
|
frame = reader.get_data(frame_id)
|
||||||
frame = Image.fromarray(frame)
|
frame = Image.fromarray(frame)
|
||||||
frame = self.frame_processor(frame)
|
frame = self.frame_processor(frame)
|
||||||
@@ -218,3 +253,51 @@ class LoadAudio(DataProcessingOperator):
|
|||||||
import librosa
|
import librosa
|
||||||
input_audio, sample_rate = librosa.load(data, sr=self.sr)
|
input_audio, sample_rate = librosa.load(data, sr=self.sr)
|
||||||
return input_audio
|
return input_audio
|
||||||
|
|
||||||
|
|
||||||
|
class LoadAudioWithTorchaudio(DataProcessingOperator, FrameSamplerByRateMixin):
|
||||||
|
|
||||||
|
def __init__(self, num_frames=121, time_division_factor=8, time_division_remainder=1, frame_rate=24, fix_frame_rate=True):
|
||||||
|
FrameSamplerByRateMixin.__init__(self, num_frames, time_division_factor, time_division_remainder, frame_rate, fix_frame_rate)
|
||||||
|
|
||||||
|
def __call__(self, data: str):
|
||||||
|
try:
|
||||||
|
reader = self.get_reader(data)
|
||||||
|
num_frames = self.get_num_frames(reader)
|
||||||
|
duration = num_frames / self.frame_rate
|
||||||
|
waveform, sample_rate = torchaudio.load(data)
|
||||||
|
target_samples = int(duration * sample_rate)
|
||||||
|
current_samples = waveform.shape[-1]
|
||||||
|
if current_samples > target_samples:
|
||||||
|
waveform = waveform[..., :target_samples]
|
||||||
|
elif current_samples < target_samples:
|
||||||
|
padding = target_samples - current_samples
|
||||||
|
waveform = torch.nn.functional.pad(waveform, (0, padding))
|
||||||
|
return waveform, sample_rate
|
||||||
|
except:
|
||||||
|
warnings.warn(f"Cannot load audio in {data}. The audio will be `None`.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class LoadPureAudioWithTorchaudio(DataProcessingOperator):
|
||||||
|
|
||||||
|
def __init__(self, target_sample_rate=None, target_duration=None):
|
||||||
|
self.target_sample_rate = target_sample_rate
|
||||||
|
self.target_duration = target_duration
|
||||||
|
self.resample = True if target_sample_rate is not None else False
|
||||||
|
|
||||||
|
def __call__(self, data: str):
|
||||||
|
try:
|
||||||
|
waveform, sample_rate = read_audio(data, resample=self.resample, resample_rate=self.target_sample_rate)
|
||||||
|
if self.target_duration is not None:
|
||||||
|
target_samples = int(self.target_duration * sample_rate)
|
||||||
|
current_samples = waveform.shape[-1]
|
||||||
|
if current_samples > target_samples:
|
||||||
|
waveform = waveform[..., :target_samples]
|
||||||
|
elif current_samples < target_samples:
|
||||||
|
padding = target_samples - current_samples
|
||||||
|
waveform = torch.nn.functional.pad(waveform, (0, padding))
|
||||||
|
return waveform, sample_rate
|
||||||
|
except Exception as e:
|
||||||
|
warnings.warn(f"Cannot load audio in '{data}' due to '{e}'. The audio will be `None`.")
|
||||||
|
return None
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ class UnifiedDataset(torch.utils.data.Dataset):
|
|||||||
max_pixels=1920*1080, height=None, width=None,
|
max_pixels=1920*1080, height=None, width=None,
|
||||||
height_division_factor=16, width_division_factor=16,
|
height_division_factor=16, width_division_factor=16,
|
||||||
num_frames=81, time_division_factor=4, time_division_remainder=1,
|
num_frames=81, time_division_factor=4, time_division_remainder=1,
|
||||||
|
frame_rate=24, fix_frame_rate=False,
|
||||||
):
|
):
|
||||||
return RouteByType(operator_map=[
|
return RouteByType(operator_map=[
|
||||||
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
|
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
|
||||||
@@ -53,6 +54,7 @@ class UnifiedDataset(torch.utils.data.Dataset):
|
|||||||
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
|
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
|
||||||
num_frames, time_division_factor, time_division_remainder,
|
num_frames, time_division_factor, time_division_remainder,
|
||||||
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
||||||
|
frame_rate=frame_rate, fix_frame_rate=fix_frame_rate,
|
||||||
)),
|
)),
|
||||||
])),
|
])),
|
||||||
])
|
])
|
||||||
|
|||||||
@@ -1,12 +1,32 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import deepspeed
|
||||||
|
_HAS_DEEPSPEED = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
_HAS_DEEPSPEED = False
|
||||||
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs, **kwargs):
|
def custom_forward(*inputs, **kwargs):
|
||||||
return module(*inputs, **kwargs)
|
return module(*inputs, **kwargs)
|
||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
|
|
||||||
|
def create_custom_forward_use_reentrant(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
|
||||||
|
def judge_args_requires_grad(*args):
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, torch.Tensor) and arg.requires_grad:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def gradient_checkpoint_forward(
|
def gradient_checkpoint_forward(
|
||||||
model,
|
model,
|
||||||
use_gradient_checkpointing,
|
use_gradient_checkpointing,
|
||||||
@@ -14,6 +34,17 @@ def gradient_checkpoint_forward(
|
|||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if use_gradient_checkpointing and _HAS_DEEPSPEED and deepspeed.checkpointing.is_configured():
|
||||||
|
all_args = args + tuple(kwargs.values())
|
||||||
|
if not judge_args_requires_grad(*all_args):
|
||||||
|
# get the first grad_enabled tensor from un_checkpointed forward
|
||||||
|
model_output = model(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
model_output = deepspeed.checkpointing.checkpoint(
|
||||||
|
create_custom_forward_use_reentrant(model),
|
||||||
|
*all_args,
|
||||||
|
)
|
||||||
|
return model_output
|
||||||
if use_gradient_checkpointing_offload:
|
if use_gradient_checkpointing_offload:
|
||||||
with torch.autograd.graph.save_on_cpu():
|
with torch.autograd.graph.save_on_cpu():
|
||||||
model_output = torch.utils.checkpoint.checkpoint(
|
model_output = torch.utils.checkpoint.checkpoint(
|
||||||
|
|||||||
@@ -417,7 +417,7 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
|||||||
def lora_forward(self, x, out):
|
def lora_forward(self, x, out):
|
||||||
if self.lora_merger is None:
|
if self.lora_merger is None:
|
||||||
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
||||||
out = out + x @ lora_A.T @ lora_B.T
|
out = out + x @ lora_A.T.to(device=x.device, dtype=x.dtype) @ lora_B.T.to(device=x.device, dtype=x.dtype)
|
||||||
else:
|
else:
|
||||||
lora_output = []
|
lora_output = []
|
||||||
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
||||||
|
|||||||
@@ -94,19 +94,22 @@ class BasePipeline(torch.nn.Module):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def check_resize_height_width(self, height, width, num_frames=None):
|
def check_resize_height_width(self, height, width, num_frames=None, verbose=1):
|
||||||
# Shape check
|
# Shape check
|
||||||
if height % self.height_division_factor != 0:
|
if height % self.height_division_factor != 0:
|
||||||
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
||||||
|
if verbose > 0:
|
||||||
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
|
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
|
||||||
if width % self.width_division_factor != 0:
|
if width % self.width_division_factor != 0:
|
||||||
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
||||||
|
if verbose > 0:
|
||||||
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
|
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
|
||||||
if num_frames is None:
|
if num_frames is None:
|
||||||
return height, width
|
return height, width
|
||||||
else:
|
else:
|
||||||
if num_frames % self.time_division_factor != self.time_division_remainder:
|
if num_frames % self.time_division_factor != self.time_division_remainder:
|
||||||
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
|
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
|
||||||
|
if verbose > 0:
|
||||||
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
|
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
|
||||||
return height, width, num_frames
|
return height, width, num_frames
|
||||||
|
|
||||||
@@ -144,6 +147,12 @@ class BasePipeline(torch.nn.Module):
|
|||||||
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
|
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
|
||||||
return video
|
return video
|
||||||
|
|
||||||
|
def output_audio_format_check(self, audio_output):
|
||||||
|
# output standard foramt: [C, T], output dtype: float()
|
||||||
|
# remove batch dim
|
||||||
|
if audio_output.ndim == 3:
|
||||||
|
audio_output = audio_output.squeeze(0)
|
||||||
|
return audio_output.float().cpu()
|
||||||
|
|
||||||
def load_models_to_device(self, model_names):
|
def load_models_to_device(self, model_names):
|
||||||
if self.vram_management_enabled:
|
if self.vram_management_enabled:
|
||||||
@@ -330,6 +339,38 @@ class BasePipeline(torch.nn.Module):
|
|||||||
noise_pred = noise_pred_posi
|
noise_pred = noise_pred_posi
|
||||||
return noise_pred
|
return noise_pred
|
||||||
|
|
||||||
|
def compile_pipeline(self, mode: str = "default", dynamic: bool = True, fullgraph: bool = False, compile_models: list = None, **kwargs):
|
||||||
|
"""
|
||||||
|
compile the pipeline with torch.compile. The models that will be compiled are determined by the `compilable_models` attribute of the pipeline.
|
||||||
|
If a model has `_repeated_blocks` attribute, we will compile these blocks with regional compilation. Otherwise, we will compile the whole model.
|
||||||
|
See https://docs.pytorch.org/docs/stable/generated/torch.compile.html#torch.compile for details about compilation arguments.
|
||||||
|
Args:
|
||||||
|
mode: The compilation mode, which will be passed to `torch.compile`, options are "default", "reduce-overhead", "max-autotune" and "max-autotune-no-cudagraphs. Default to "default".
|
||||||
|
dynamic: Whether to enable dynamic graph compilation to support dynamic input shapes, which will be passed to `torch.compile`. Default to True (recommended).
|
||||||
|
fullgraph: Whether to use full graph compilation, which will be passed to `torch.compile`. Default to False (recommended).
|
||||||
|
compile_models: The list of model names to be compiled. If None, we will compile the models in `pipeline.compilable_models`. Default to None.
|
||||||
|
**kwargs: Other arguments for `torch.compile`.
|
||||||
|
"""
|
||||||
|
compile_models = compile_models or getattr(self, "compilable_models", [])
|
||||||
|
if len(compile_models) == 0:
|
||||||
|
print("No compilable models in the pipeline. Skip compilation.")
|
||||||
|
return
|
||||||
|
for name in compile_models:
|
||||||
|
model = getattr(self, name, None)
|
||||||
|
if model is None:
|
||||||
|
print(f"Model '{name}' not found in the pipeline.")
|
||||||
|
continue
|
||||||
|
repeated_blocks = getattr(model, "_repeated_blocks", None)
|
||||||
|
# regional compilation for repeated blocks.
|
||||||
|
if repeated_blocks is not None:
|
||||||
|
for submod in model.modules():
|
||||||
|
if submod.__class__.__name__ in repeated_blocks:
|
||||||
|
submod.compile(mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs)
|
||||||
|
# compile the whole model.
|
||||||
|
else:
|
||||||
|
model.compile(mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs)
|
||||||
|
print(f"{name} is compiled with mode={mode}, dynamic={dynamic}, fullgraph={fullgraph}.")
|
||||||
|
|
||||||
|
|
||||||
class PipelineUnitGraph:
|
class PipelineUnitGraph:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing_extensions import Literal
|
|||||||
|
|
||||||
class FlowMatchScheduler():
|
class FlowMatchScheduler():
|
||||||
|
|
||||||
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning"] = "FLUX.1"):
|
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning", "ERNIE-Image", "ACE-Step"] = "FLUX.1"):
|
||||||
self.set_timesteps_fn = {
|
self.set_timesteps_fn = {
|
||||||
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
|
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
|
||||||
"Wan": FlowMatchScheduler.set_timesteps_wan,
|
"Wan": FlowMatchScheduler.set_timesteps_wan,
|
||||||
@@ -13,6 +13,8 @@ class FlowMatchScheduler():
|
|||||||
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
|
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
|
||||||
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
|
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
|
||||||
"Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
|
"Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
|
||||||
|
"ERNIE-Image": FlowMatchScheduler.set_timesteps_ernie_image,
|
||||||
|
"ACE-Step": FlowMatchScheduler.set_timesteps_ace_step,
|
||||||
}.get(template, FlowMatchScheduler.set_timesteps_flux)
|
}.get(template, FlowMatchScheduler.set_timesteps_flux)
|
||||||
self.num_train_timesteps = 1000
|
self.num_train_timesteps = 1000
|
||||||
|
|
||||||
@@ -129,6 +131,38 @@ class FlowMatchScheduler():
|
|||||||
timesteps = sigmas * num_train_timesteps
|
timesteps = sigmas * num_train_timesteps
|
||||||
return sigmas, timesteps
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_ernie_image(num_inference_steps=50, denoising_strength=1.0, shift=3.0):
|
||||||
|
sigma_min = 0.0
|
||||||
|
sigma_max = 1.0
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||||
|
if shift is not None and shift != 1.0:
|
||||||
|
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_ace_step(num_inference_steps=8, denoising_strength=1.0, shift=3.0):
|
||||||
|
"""ACE-Step Flow Matching scheduler.
|
||||||
|
|
||||||
|
Timesteps range from 1.0 to 0.0 (not multiplied by 1000).
|
||||||
|
Shift transformation: t = shift * t / (1 + (shift - 1) * t)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_inference_steps: Number of diffusion steps.
|
||||||
|
denoising_strength: Denoising strength (1.0 = full denoising).
|
||||||
|
shift: Timestep shift parameter (default 3.0 for turbo).
|
||||||
|
"""
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
sigma_start = denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, 0.0, num_inference_steps + 1)[:-1]
|
||||||
|
if shift is not None and shift != 1.0:
|
||||||
|
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
|
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
|
||||||
sigma_min = 0.0
|
sigma_min = 0.0
|
||||||
@@ -146,6 +180,18 @@ class FlowMatchScheduler():
|
|||||||
timesteps[timestep_id] = timestep
|
timesteps[timestep_id] = timestep
|
||||||
return sigmas, timesteps
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_joyai_image(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
||||||
|
sigma_min = 0.0
|
||||||
|
sigma_max = 1.0
|
||||||
|
shift = 4.0 if shift is None else shift
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||||
|
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None):
|
def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None):
|
||||||
num_train_timesteps = 1000
|
num_train_timesteps = 1000
|
||||||
|
|||||||
@@ -28,6 +28,36 @@ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def FlowMatchSFTAudioVideoLoss(pipe: BasePipeline, **inputs):
|
||||||
|
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
|
||||||
|
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
|
||||||
|
|
||||||
|
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
||||||
|
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
|
||||||
|
# video
|
||||||
|
noise = torch.randn_like(inputs["input_latents"])
|
||||||
|
inputs["video_latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||||
|
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||||
|
|
||||||
|
# audio
|
||||||
|
if inputs.get("audio_input_latents") is not None:
|
||||||
|
audio_noise = torch.randn_like(inputs["audio_input_latents"])
|
||||||
|
inputs["audio_latents"] = pipe.scheduler.add_noise(inputs["audio_input_latents"], audio_noise, timestep)
|
||||||
|
training_target_audio = pipe.scheduler.training_target(inputs["audio_input_latents"], audio_noise, timestep)
|
||||||
|
|
||||||
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||||
|
noise_pred, noise_pred_audio = pipe.model_fn(**models, **inputs, timestep=timestep)
|
||||||
|
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||||
|
loss = loss * pipe.scheduler.training_weight(timestep)
|
||||||
|
if inputs.get("audio_input_latents") is not None:
|
||||||
|
loss_audio = torch.nn.functional.mse_loss(noise_pred_audio.float(), training_target_audio.float())
|
||||||
|
loss_audio = loss_audio * pipe.scheduler.training_weight(timestep)
|
||||||
|
loss = loss + loss_audio
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def DirectDistillLoss(pipe: BasePipeline, **inputs):
|
def DirectDistillLoss(pipe: BasePipeline, **inputs):
|
||||||
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
|
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
|
||||||
pipe.scheduler.training = True
|
pipe.scheduler.training = True
|
||||||
@@ -91,7 +121,9 @@ class TrajectoryImitationLoss(torch.nn.Module):
|
|||||||
progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs())
|
progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs())
|
||||||
latents_ = trajectory_teacher[progress_id_teacher]
|
latents_ = trajectory_teacher[progress_id_teacher]
|
||||||
|
|
||||||
target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma)
|
denom = sigma_ - sigma
|
||||||
|
denom = torch.sign(denom) * torch.clamp(denom.abs(), min=1e-6)
|
||||||
|
target = (latents_ - inputs_shared["latents"]) / denom
|
||||||
loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep)
|
loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|||||||
@@ -29,19 +29,19 @@ def launch_training_task(
|
|||||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
|
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||||
model.to(device=accelerator.device)
|
model.to(device=accelerator.device)
|
||||||
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
||||||
|
initialize_deepspeed_gradient_checkpointing(accelerator)
|
||||||
for epoch_id in range(num_epochs):
|
for epoch_id in range(num_epochs):
|
||||||
for data in tqdm(dataloader):
|
for data in tqdm(dataloader):
|
||||||
with accelerator.accumulate(model):
|
with accelerator.accumulate(model):
|
||||||
optimizer.zero_grad()
|
|
||||||
if dataset.load_from_cache:
|
if dataset.load_from_cache:
|
||||||
loss = model({}, inputs=data)
|
loss = model({}, inputs=data)
|
||||||
else:
|
else:
|
||||||
loss = model(data)
|
loss = model(data)
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
|
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
|
||||||
if save_steps is None:
|
if save_steps is None:
|
||||||
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||||
model_logger.on_training_end(accelerator, model, save_steps)
|
model_logger.on_training_end(accelerator, model, save_steps)
|
||||||
@@ -70,3 +70,19 @@ def launch_data_process_task(
|
|||||||
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
|
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
|
||||||
data = model(data)
|
data = model(data)
|
||||||
torch.save(data, save_path)
|
torch.save(data, save_path)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_deepspeed_gradient_checkpointing(accelerator: Accelerator):
|
||||||
|
if getattr(accelerator.state, "deepspeed_plugin", None) is not None:
|
||||||
|
ds_config = accelerator.state.deepspeed_plugin.deepspeed_config
|
||||||
|
if "activation_checkpointing" in ds_config:
|
||||||
|
import deepspeed
|
||||||
|
act_config = ds_config["activation_checkpointing"]
|
||||||
|
deepspeed.checkpointing.configure(
|
||||||
|
mpu_=None,
|
||||||
|
partition_activations=act_config.get("partition_activations", False),
|
||||||
|
checkpoint_in_cpu=act_config.get("cpu_checkpointing", False),
|
||||||
|
contiguous_checkpointing=act_config.get("contiguous_memory_optimization", False)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("Do not find activation_checkpointing config in deepspeed config, skip initializing deepspeed gradient checkpointing.")
|
||||||
|
|||||||
@@ -1,9 +1,32 @@
|
|||||||
import torch, json, os
|
import torch, json, os, inspect
|
||||||
from ..core import ModelConfig, load_state_dict
|
from ..core import ModelConfig, load_state_dict
|
||||||
from ..utils.controlnet import ControlNetInput
|
from ..utils.controlnet import ControlNetInput
|
||||||
|
from .base_pipeline import PipelineUnit
|
||||||
from peft import LoraConfig, inject_adapter_in_model
|
from peft import LoraConfig, inject_adapter_in_model
|
||||||
|
|
||||||
|
|
||||||
|
class GeneralUnit_RemoveCache(PipelineUnit):
|
||||||
|
def __init__(self, required_params=tuple(), force_remove_params_shared=tuple(), force_remove_params_posi=tuple(), force_remove_params_nega=tuple()):
|
||||||
|
super().__init__(take_over=True)
|
||||||
|
self.required_params = required_params
|
||||||
|
self.force_remove_params_shared = force_remove_params_shared
|
||||||
|
self.force_remove_params_posi = force_remove_params_posi
|
||||||
|
self.force_remove_params_nega = force_remove_params_nega
|
||||||
|
|
||||||
|
def process_params(self, inputs, required_params, force_remove_params):
|
||||||
|
inputs_ = {}
|
||||||
|
for name, param in inputs.items():
|
||||||
|
if name in required_params and name not in force_remove_params:
|
||||||
|
inputs_[name] = param
|
||||||
|
return inputs_
|
||||||
|
|
||||||
|
def process(self, pipe, inputs_shared, inputs_posi, inputs_nega):
|
||||||
|
inputs_shared = self.process_params(inputs_shared, self.required_params, self.force_remove_params_shared)
|
||||||
|
inputs_posi = self.process_params(inputs_posi, self.required_params, self.force_remove_params_posi)
|
||||||
|
inputs_nega = self.process_params(inputs_nega, self.required_params, self.force_remove_params_nega)
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
|
|
||||||
class DiffusionTrainingModule(torch.nn.Module):
|
class DiffusionTrainingModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -231,14 +254,30 @@ class DiffusionTrainingModule(torch.nn.Module):
|
|||||||
setattr(pipe, lora_base_model, model)
|
setattr(pipe, lora_base_model, model)
|
||||||
|
|
||||||
|
|
||||||
def split_pipeline_units(self, task, pipe, trainable_models=None, lora_base_model=None):
|
def split_pipeline_units(
|
||||||
|
self, task, pipe,
|
||||||
|
trainable_models=None, lora_base_model=None,
|
||||||
|
# TODO: set `remove_unnecessary_params` to `True` by default
|
||||||
|
remove_unnecessary_params=False,
|
||||||
|
# TODO: move `loss_required_params` to `loss.py`
|
||||||
|
loss_required_params=("input_latents", "max_timestep_boundary", "min_timestep_boundary", "first_frame_latents", "video_latents", "audio_input_latents", "num_inference_steps"),
|
||||||
|
force_remove_params_shared=tuple(),
|
||||||
|
force_remove_params_posi=tuple(),
|
||||||
|
force_remove_params_nega=tuple(),
|
||||||
|
):
|
||||||
models_require_backward = []
|
models_require_backward = []
|
||||||
if trainable_models is not None:
|
if trainable_models is not None:
|
||||||
models_require_backward += trainable_models.split(",")
|
models_require_backward += trainable_models.split(",")
|
||||||
if lora_base_model is not None:
|
if lora_base_model is not None:
|
||||||
models_require_backward += [lora_base_model]
|
models_require_backward += [lora_base_model]
|
||||||
if task.endswith(":data_process"):
|
if task.endswith(":data_process"):
|
||||||
_, pipe.units = pipe.split_pipeline_units(models_require_backward)
|
other_units, pipe.units = pipe.split_pipeline_units(models_require_backward)
|
||||||
|
if remove_unnecessary_params:
|
||||||
|
required_params = list(loss_required_params) + [i for i in inspect.signature(self.pipe.model_fn).parameters]
|
||||||
|
for unit in other_units:
|
||||||
|
required_params.extend(unit.fetch_input_params())
|
||||||
|
required_params = sorted(list(set(required_params)))
|
||||||
|
pipe.units.append(GeneralUnit_RemoveCache(required_params, force_remove_params_shared, force_remove_params_posi, force_remove_params_nega))
|
||||||
elif task.endswith(":train"):
|
elif task.endswith(":train"):
|
||||||
pipe.units, _ = pipe.split_pipeline_units(models_require_backward)
|
pipe.units, _ = pipe.split_pipeline_units(models_require_backward)
|
||||||
return pipe
|
return pipe
|
||||||
|
|||||||
695
diffsynth/models/ace_step_conditioner.py
Normal file
695
diffsynth/models/ace_step_conditioner.py
Normal file
@@ -0,0 +1,695 @@
|
|||||||
|
# Copyright 2025 The ACESTEO Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from ..core.attention import attention_forward
|
||||||
|
from ..core.gradient import gradient_checkpoint_forward
|
||||||
|
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from transformers.modeling_outputs import BaseModelOutput
|
||||||
|
from transformers.processing_utils import Unpack
|
||||||
|
from transformers.utils import can_return_tuple, logging
|
||||||
|
from transformers.models.qwen3.modeling_qwen3 import (
|
||||||
|
Qwen3MLP,
|
||||||
|
Qwen3RMSNorm,
|
||||||
|
Qwen3RotaryEmbedding,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_4d_mask(
|
||||||
|
seq_len: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
is_sliding_window: bool = False,
|
||||||
|
is_causal: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
indices = torch.arange(seq_len, device=device)
|
||||||
|
diff = indices.unsqueeze(1) - indices.unsqueeze(0)
|
||||||
|
valid_mask = torch.ones((seq_len, seq_len), device=device, dtype=torch.bool)
|
||||||
|
if is_causal:
|
||||||
|
valid_mask = valid_mask & (diff >= 0)
|
||||||
|
if is_sliding_window and sliding_window is not None:
|
||||||
|
if is_causal:
|
||||||
|
valid_mask = valid_mask & (diff <= sliding_window)
|
||||||
|
else:
|
||||||
|
valid_mask = valid_mask & (torch.abs(diff) <= sliding_window)
|
||||||
|
valid_mask = valid_mask.unsqueeze(0).unsqueeze(0)
|
||||||
|
if attention_mask is not None:
|
||||||
|
padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool)
|
||||||
|
valid_mask = valid_mask & padding_mask_4d
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device)
|
||||||
|
mask_tensor.masked_fill_(valid_mask, 0.0)
|
||||||
|
return mask_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def pack_sequences(hidden1: torch.Tensor, hidden2: torch.Tensor, mask1: torch.Tensor, mask2: torch.Tensor):
|
||||||
|
hidden_cat = torch.cat([hidden1, hidden2], dim=1)
|
||||||
|
mask_cat = torch.cat([mask1, mask2], dim=1)
|
||||||
|
B, L, D = hidden_cat.shape
|
||||||
|
sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True)
|
||||||
|
hidden_left = torch.gather(hidden_cat, 1, sort_idx.unsqueeze(-1).expand(B, L, D))
|
||||||
|
lengths = mask_cat.sum(dim=1)
|
||||||
|
new_mask = (torch.arange(L, dtype=torch.long, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1))
|
||||||
|
return hidden_left, new_mask
|
||||||
|
|
||||||
|
|
||||||
|
class Lambda(nn.Module):
|
||||||
|
def __init__(self, func):
|
||||||
|
super().__init__()
|
||||||
|
self.func = func
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.func(x)
|
||||||
|
|
||||||
|
|
||||||
|
class AceStepAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
num_key_value_heads: int,
|
||||||
|
rms_norm_eps: float,
|
||||||
|
attention_bias: bool,
|
||||||
|
attention_dropout: float,
|
||||||
|
layer_types: list,
|
||||||
|
head_dim: Optional[int] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
layer_idx: int = 0,
|
||||||
|
is_cross_attention: bool = False,
|
||||||
|
is_causal: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||||
|
self.num_key_value_groups = num_attention_heads // num_key_value_heads
|
||||||
|
self.scaling = self.head_dim ** -0.5
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
if is_cross_attention:
|
||||||
|
is_causal = False
|
||||||
|
self.is_causal = is_causal
|
||||||
|
self.is_cross_attention = is_cross_attention
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(hidden_size, num_attention_heads * self.head_dim, bias=attention_bias)
|
||||||
|
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
|
||||||
|
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
|
||||||
|
self.o_proj = nn.Linear(num_attention_heads * self.head_dim, hidden_size, bias=attention_bias)
|
||||||
|
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||||
|
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||||
|
self.attention_type = layer_types[layer_idx]
|
||||||
|
self.sliding_window = sliding_window if layer_types[layer_idx] == "sliding_attention" else None
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor],
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
position_embeddings: tuple[torch.Tensor, torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||||
|
input_shape = hidden_states.shape[:-1]
|
||||||
|
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||||
|
|
||||||
|
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||||
|
|
||||||
|
is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None
|
||||||
|
|
||||||
|
if is_cross_attention:
|
||||||
|
encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim)
|
||||||
|
if past_key_value is not None:
|
||||||
|
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||||
|
curr_past_key_value = past_key_value.cross_attention_cache
|
||||||
|
if not is_updated:
|
||||||
|
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
|
||||||
|
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
|
||||||
|
key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx)
|
||||||
|
past_key_value.is_updated[self.layer_idx] = True
|
||||||
|
else:
|
||||||
|
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||||
|
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||||
|
else:
|
||||||
|
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
|
||||||
|
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
|
||||||
|
|
||||||
|
else:
|
||||||
|
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||||
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||||
|
if position_embeddings is not None:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
if self.num_key_value_groups > 1:
|
||||||
|
key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
|
||||||
|
value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
|
||||||
|
|
||||||
|
attn_output = attention_forward(
|
||||||
|
query_states, key_states, value_states,
|
||||||
|
q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d",
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
)
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).flatten(2, 3).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
class AceStepEncoderLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
num_key_value_heads: int,
|
||||||
|
rms_norm_eps: float,
|
||||||
|
attention_bias: bool,
|
||||||
|
attention_dropout: float,
|
||||||
|
layer_types: list,
|
||||||
|
head_dim: Optional[int] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
layer_idx: int = 0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
self.self_attn = AceStepAttention(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
rms_norm_eps=rms_norm_eps,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
attention_dropout=attention_dropout,
|
||||||
|
layer_types=layer_types,
|
||||||
|
head_dim=head_dim,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
is_cross_attention=False,
|
||||||
|
is_causal=False,
|
||||||
|
)
|
||||||
|
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||||
|
|
||||||
|
mlp_config = type('Config', (), {
|
||||||
|
'hidden_size': hidden_size,
|
||||||
|
'intermediate_size': intermediate_size,
|
||||||
|
'hidden_act': 'silu',
|
||||||
|
})()
|
||||||
|
self.mlp = Qwen3MLP(mlp_config)
|
||||||
|
self.attention_type = layer_types[layer_idx]
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
hidden_states, self_attn_weights = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=False,
|
||||||
|
past_key_value=None,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (self_attn_weights,)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class AceStepLyricEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int = 2048,
|
||||||
|
intermediate_size: int = 6144,
|
||||||
|
num_hidden_layers: int = 24,
|
||||||
|
num_attention_heads: int = 16,
|
||||||
|
num_key_value_heads: int = 8,
|
||||||
|
rms_norm_eps: float = 1e-6,
|
||||||
|
attention_bias: bool = False,
|
||||||
|
attention_dropout: float = 0.0,
|
||||||
|
layer_types: Optional[list] = None,
|
||||||
|
head_dim: Optional[int] = None,
|
||||||
|
sliding_window: Optional[int] = 128,
|
||||||
|
use_sliding_window: bool = True,
|
||||||
|
use_cache: bool = True,
|
||||||
|
rope_theta: float = 1000000,
|
||||||
|
max_position_embeddings: int = 32768,
|
||||||
|
initializer_range: float = 0.02,
|
||||||
|
text_hidden_dim: int = 1024,
|
||||||
|
num_lyric_encoder_hidden_layers: int = 8,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_lyric_encoder_hidden_layers = num_lyric_encoder_hidden_layers
|
||||||
|
self.text_hidden_dim = text_hidden_dim
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * (num_hidden_layers // 2))
|
||||||
|
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
self.use_sliding_window = use_sliding_window
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
|
||||||
|
|
||||||
|
self.embed_tokens = nn.Linear(text_hidden_dim, hidden_size)
|
||||||
|
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||||
|
rope_config = type('RopeConfig', (), {
|
||||||
|
'hidden_size': hidden_size,
|
||||||
|
'num_attention_heads': num_attention_heads,
|
||||||
|
'num_key_value_heads': num_key_value_heads,
|
||||||
|
'head_dim': head_dim,
|
||||||
|
'max_position_embeddings': max_position_embeddings,
|
||||||
|
'rope_theta': rope_theta,
|
||||||
|
'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
|
||||||
|
'rms_norm_eps': rms_norm_eps,
|
||||||
|
'attention_bias': attention_bias,
|
||||||
|
'attention_dropout': attention_dropout,
|
||||||
|
'hidden_act': 'silu',
|
||||||
|
'intermediate_size': intermediate_size,
|
||||||
|
'layer_types': self.layer_types,
|
||||||
|
'sliding_window': sliding_window,
|
||||||
|
'_attn_implementation': self._attn_implementation,
|
||||||
|
})()
|
||||||
|
self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
AceStepEncoderLayer(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
rms_norm_eps=rms_norm_eps,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
attention_dropout=attention_dropout,
|
||||||
|
layer_types=self.layer_types,
|
||||||
|
head_dim=head_dim,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
)
|
||||||
|
for layer_idx in range(num_lyric_encoder_hidden_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
@can_return_tuple
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||||
|
) -> BaseModelOutput:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else False
|
||||||
|
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
|
||||||
|
|
||||||
|
assert input_ids is None, "Only `inputs_embeds` is supported for the lyric encoder."
|
||||||
|
assert attention_mask is not None, "Attention mask must be provided for the lyric encoder."
|
||||||
|
assert inputs_embeds is not None, "Inputs embeddings must be provided for the lyric encoder."
|
||||||
|
|
||||||
|
inputs_embeds = self.embed_tokens(inputs_embeds)
|
||||||
|
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
seq_len = inputs_embeds.shape[1]
|
||||||
|
dtype = inputs_embeds.dtype
|
||||||
|
device = inputs_embeds.device
|
||||||
|
|
||||||
|
full_attn_mask = create_4d_mask(
|
||||||
|
seq_len=seq_len, dtype=dtype, device=device,
|
||||||
|
attention_mask=attention_mask, sliding_window=None,
|
||||||
|
is_sliding_window=False, is_causal=False
|
||||||
|
)
|
||||||
|
sliding_attn_mask = None
|
||||||
|
if self.use_sliding_window:
|
||||||
|
sliding_attn_mask = create_4d_mask(
|
||||||
|
seq_len=seq_len, dtype=dtype, device=device,
|
||||||
|
attention_mask=attention_mask, sliding_window=self.sliding_window,
|
||||||
|
is_sliding_window=True, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self_attn_mask_mapping = {
|
||||||
|
"full_attention": full_attn_mask,
|
||||||
|
"sliding_attention": sliding_attn_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
|
||||||
|
for layer_module in self.layers[: self.num_lyric_encoder_hidden_layers]:
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
layer_outputs = layer_module(
|
||||||
|
hidden_states, position_embeddings,
|
||||||
|
self_attn_mask_mapping[layer_module.attention_type],
|
||||||
|
position_ids, output_attentions,
|
||||||
|
**flash_attn_kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
return BaseModelOutput(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AceStepTimbreEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int = 2048,
|
||||||
|
intermediate_size: int = 6144,
|
||||||
|
num_hidden_layers: int = 24,
|
||||||
|
num_attention_heads: int = 16,
|
||||||
|
num_key_value_heads: int = 8,
|
||||||
|
rms_norm_eps: float = 1e-6,
|
||||||
|
attention_bias: bool = False,
|
||||||
|
attention_dropout: float = 0.0,
|
||||||
|
layer_types: Optional[list] = None,
|
||||||
|
head_dim: Optional[int] = None,
|
||||||
|
sliding_window: Optional[int] = 128,
|
||||||
|
use_sliding_window: bool = True,
|
||||||
|
use_cache: bool = True,
|
||||||
|
rope_theta: float = 1000000,
|
||||||
|
max_position_embeddings: int = 32768,
|
||||||
|
initializer_range: float = 0.02,
|
||||||
|
timbre_hidden_dim: int = 64,
|
||||||
|
num_timbre_encoder_hidden_layers: int = 4,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * (num_hidden_layers // 2))
|
||||||
|
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
self.use_sliding_window = use_sliding_window
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.timbre_hidden_dim = timbre_hidden_dim
|
||||||
|
self.num_timbre_encoder_hidden_layers = num_timbre_encoder_hidden_layers
|
||||||
|
self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
|
||||||
|
|
||||||
|
self.embed_tokens = nn.Linear(timbre_hidden_dim, hidden_size)
|
||||||
|
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||||
|
rope_config = type('RopeConfig', (), {
|
||||||
|
'hidden_size': hidden_size,
|
||||||
|
'num_attention_heads': num_attention_heads,
|
||||||
|
'num_key_value_heads': num_key_value_heads,
|
||||||
|
'head_dim': head_dim,
|
||||||
|
'max_position_embeddings': max_position_embeddings,
|
||||||
|
'rope_theta': rope_theta,
|
||||||
|
'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
|
||||||
|
'rms_norm_eps': rms_norm_eps,
|
||||||
|
'attention_bias': attention_bias,
|
||||||
|
'attention_dropout': attention_dropout,
|
||||||
|
'hidden_act': 'silu',
|
||||||
|
'intermediate_size': intermediate_size,
|
||||||
|
'layer_types': self.layer_types,
|
||||||
|
'sliding_window': sliding_window,
|
||||||
|
'_attn_implementation': self._attn_implementation,
|
||||||
|
})()
|
||||||
|
self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
self.special_token = nn.Parameter(torch.randn(1, 1, hidden_size))
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
AceStepEncoderLayer(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
rms_norm_eps=rms_norm_eps,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
attention_dropout=attention_dropout,
|
||||||
|
layer_types=self.layer_types,
|
||||||
|
head_dim=head_dim,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
)
|
||||||
|
for layer_idx in range(num_timbre_encoder_hidden_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def unpack_timbre_embeddings(self, timbre_embs_packed, refer_audio_order_mask):
|
||||||
|
N, d = timbre_embs_packed.shape
|
||||||
|
device = timbre_embs_packed.device
|
||||||
|
dtype = timbre_embs_packed.dtype
|
||||||
|
B = int(refer_audio_order_mask.max().item() + 1)
|
||||||
|
counts = torch.bincount(refer_audio_order_mask, minlength=B)
|
||||||
|
max_count = counts.max().item()
|
||||||
|
sorted_indices = torch.argsort(refer_audio_order_mask * N + torch.arange(N, device=device), stable=True)
|
||||||
|
sorted_batch_ids = refer_audio_order_mask[sorted_indices]
|
||||||
|
positions = torch.arange(N, device=device)
|
||||||
|
batch_starts = torch.cat([torch.tensor([0], device=device), torch.cumsum(counts, dim=0)[:-1]])
|
||||||
|
positions_in_sorted = positions - batch_starts[sorted_batch_ids]
|
||||||
|
inverse_indices = torch.empty_like(sorted_indices)
|
||||||
|
inverse_indices[sorted_indices] = torch.arange(N, device=device)
|
||||||
|
positions_in_batch = positions_in_sorted[inverse_indices]
|
||||||
|
indices_2d = refer_audio_order_mask * max_count + positions_in_batch
|
||||||
|
one_hot = F.one_hot(indices_2d, num_classes=B * max_count).to(dtype)
|
||||||
|
timbre_embs_flat = one_hot.t() @ timbre_embs_packed
|
||||||
|
timbre_embs_unpack = timbre_embs_flat.reshape(B, max_count, d)
|
||||||
|
mask_flat = (one_hot.sum(dim=0) > 0).long()
|
||||||
|
new_mask = mask_flat.reshape(B, max_count)
|
||||||
|
return timbre_embs_unpack, new_mask
|
||||||
|
|
||||||
|
@can_return_tuple
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
refer_audio_acoustic_hidden_states_packed: Optional[torch.FloatTensor] = None,
|
||||||
|
refer_audio_order_mask: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||||
|
) -> BaseModelOutput:
|
||||||
|
inputs_embeds = refer_audio_acoustic_hidden_states_packed
|
||||||
|
inputs_embeds = self.embed_tokens(inputs_embeds)
|
||||||
|
seq_len = inputs_embeds.shape[1]
|
||||||
|
cache_position = torch.arange(0, seq_len, device=inputs_embeds.device)
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
dtype = inputs_embeds.dtype
|
||||||
|
device = inputs_embeds.device
|
||||||
|
|
||||||
|
full_attn_mask = create_4d_mask(
|
||||||
|
seq_len=seq_len, dtype=dtype, device=device,
|
||||||
|
attention_mask=attention_mask, sliding_window=None,
|
||||||
|
is_sliding_window=False, is_causal=False
|
||||||
|
)
|
||||||
|
sliding_attn_mask = None
|
||||||
|
if self.use_sliding_window:
|
||||||
|
sliding_attn_mask = create_4d_mask(
|
||||||
|
seq_len=seq_len, dtype=dtype, device=device,
|
||||||
|
attention_mask=attention_mask, sliding_window=self.sliding_window,
|
||||||
|
is_sliding_window=True, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self_attn_mask_mapping = {
|
||||||
|
"full_attention": full_attn_mask,
|
||||||
|
"sliding_attention": sliding_attn_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
|
for layer_module in self.layers[: self.num_timbre_encoder_hidden_layers]:
|
||||||
|
layer_outputs = layer_module(
|
||||||
|
hidden_states, position_embeddings,
|
||||||
|
self_attn_mask_mapping[layer_module.attention_type],
|
||||||
|
position_ids,
|
||||||
|
**flash_attn_kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
hidden_states = hidden_states[:, 0, :]
|
||||||
|
# For packed input: reshape [1, T, D] -> [T, D] for unpacking
|
||||||
|
timbre_embs_unpack, timbre_embs_mask = self.unpack_timbre_embeddings(hidden_states, refer_audio_order_mask)
|
||||||
|
return timbre_embs_unpack, timbre_embs_mask
|
||||||
|
|
||||||
|
|
||||||
|
class AceStepConditionEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int = 2048,
|
||||||
|
intermediate_size: int = 6144,
|
||||||
|
num_hidden_layers: int = 24,
|
||||||
|
num_attention_heads: int = 16,
|
||||||
|
num_key_value_heads: int = 8,
|
||||||
|
rms_norm_eps: float = 1e-6,
|
||||||
|
attention_bias: bool = False,
|
||||||
|
attention_dropout: float = 0.0,
|
||||||
|
layer_types: Optional[list] = None,
|
||||||
|
head_dim: Optional[int] = None,
|
||||||
|
sliding_window: Optional[int] = 128,
|
||||||
|
use_sliding_window: bool = True,
|
||||||
|
use_cache: bool = True,
|
||||||
|
rope_theta: float = 1000000,
|
||||||
|
max_position_embeddings: int = 32768,
|
||||||
|
initializer_range: float = 0.02,
|
||||||
|
text_hidden_dim: int = 1024,
|
||||||
|
timbre_hidden_dim: int = 64,
|
||||||
|
num_lyric_encoder_hidden_layers: int = 8,
|
||||||
|
num_timbre_encoder_hidden_layers: int = 4,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * (num_hidden_layers // 2))
|
||||||
|
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
self.use_sliding_window = use_sliding_window
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.text_hidden_dim = text_hidden_dim
|
||||||
|
self.timbre_hidden_dim = timbre_hidden_dim
|
||||||
|
self.num_lyric_encoder_hidden_layers = num_lyric_encoder_hidden_layers
|
||||||
|
self.num_timbre_encoder_hidden_layers = num_timbre_encoder_hidden_layers
|
||||||
|
self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
|
||||||
|
|
||||||
|
self.text_projector = nn.Linear(text_hidden_dim, hidden_size, bias=False)
|
||||||
|
self.null_condition_emb = nn.Parameter(torch.randn(1, 1, hidden_size))
|
||||||
|
self.lyric_encoder = AceStepLyricEncoder(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
rms_norm_eps=rms_norm_eps,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
attention_dropout=attention_dropout,
|
||||||
|
layer_types=layer_types,
|
||||||
|
head_dim=head_dim,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
use_sliding_window=use_sliding_window,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
initializer_range=initializer_range,
|
||||||
|
text_hidden_dim=text_hidden_dim,
|
||||||
|
num_lyric_encoder_hidden_layers=num_lyric_encoder_hidden_layers,
|
||||||
|
)
|
||||||
|
self.timbre_encoder = AceStepTimbreEncoder(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
rms_norm_eps=rms_norm_eps,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
attention_dropout=attention_dropout,
|
||||||
|
layer_types=layer_types,
|
||||||
|
head_dim=head_dim,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
use_sliding_window=use_sliding_window,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
initializer_range=initializer_range,
|
||||||
|
timbre_hidden_dim=timbre_hidden_dim,
|
||||||
|
num_timbre_encoder_hidden_layers=num_timbre_encoder_hidden_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
text_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
text_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
lyric_hidden_states: Optional[torch.LongTensor] = None,
|
||||||
|
lyric_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
reference_latents: Optional[torch.Tensor] = None,
|
||||||
|
refer_audio_order_mask: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
text_hidden_states = self.text_projector(text_hidden_states)
|
||||||
|
lyric_encoder_outputs = self.lyric_encoder(
|
||||||
|
inputs_embeds=lyric_hidden_states,
|
||||||
|
attention_mask=lyric_attention_mask,
|
||||||
|
)
|
||||||
|
lyric_hidden_states = lyric_encoder_outputs.last_hidden_state
|
||||||
|
timbre_embs_unpack, timbre_embs_mask = self.timbre_encoder(reference_latents, refer_audio_order_mask)
|
||||||
|
encoder_hidden_states, encoder_attention_mask = pack_sequences(
|
||||||
|
lyric_hidden_states, timbre_embs_unpack, lyric_attention_mask, timbre_embs_mask
|
||||||
|
)
|
||||||
|
encoder_hidden_states, encoder_attention_mask = pack_sequences(
|
||||||
|
encoder_hidden_states, text_hidden_states, encoder_attention_mask, text_attention_mask
|
||||||
|
)
|
||||||
|
return encoder_hidden_states, encoder_attention_mask
|
||||||
901
diffsynth/models/ace_step_dit.py
Normal file
901
diffsynth/models/ace_step_dit.py
Normal file
@@ -0,0 +1,901 @@
|
|||||||
|
# Copyright 2025 The ACESTEO Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from ..core.attention.attention import attention_forward
|
||||||
|
from ..core import gradient_checkpoint_forward
|
||||||
|
|
||||||
|
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
||||||
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from transformers.modeling_outputs import BaseModelOutput
|
||||||
|
from transformers.processing_utils import Unpack
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
from transformers.models.qwen3.modeling_qwen3 import (
|
||||||
|
Qwen3MLP,
|
||||||
|
Qwen3RMSNorm,
|
||||||
|
Qwen3RotaryEmbedding,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_4d_mask(
|
||||||
|
seq_len: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None, # [Batch, Seq_Len]
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
is_sliding_window: bool = False,
|
||||||
|
is_causal: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
General 4D Attention Mask generator compatible with CPU/Mac/SDPA and Eager mode.
|
||||||
|
Supports use cases:
|
||||||
|
1. Causal Full: is_causal=True, is_sliding_window=False (standard GPT)
|
||||||
|
2. Causal Sliding: is_causal=True, is_sliding_window=True (Mistral/Qwen local window)
|
||||||
|
3. Bidirectional Full: is_causal=False, is_sliding_window=False (BERT/Encoder)
|
||||||
|
4. Bidirectional Sliding: is_causal=False, is_sliding_window=True (Longformer local)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[Batch, 1, Seq_Len, Seq_Len] additive mask (0.0 for keep, -inf for mask)
|
||||||
|
"""
|
||||||
|
# ------------------------------------------------------
|
||||||
|
# 1. Construct basic geometry mask [Seq_Len, Seq_Len]
|
||||||
|
# ------------------------------------------------------
|
||||||
|
|
||||||
|
# Build index matrices
|
||||||
|
# i (Query): [0, 1, ..., L-1]
|
||||||
|
# j (Key): [0, 1, ..., L-1]
|
||||||
|
indices = torch.arange(seq_len, device=device)
|
||||||
|
# diff = i - j
|
||||||
|
diff = indices.unsqueeze(1) - indices.unsqueeze(0)
|
||||||
|
|
||||||
|
# Initialize all True (all positions visible)
|
||||||
|
valid_mask = torch.ones((seq_len, seq_len), device=device, dtype=torch.bool)
|
||||||
|
|
||||||
|
# (A) Handle causality (Causal)
|
||||||
|
if is_causal:
|
||||||
|
# i >= j => diff >= 0
|
||||||
|
valid_mask = valid_mask & (diff >= 0)
|
||||||
|
|
||||||
|
# (B) Handle sliding window
|
||||||
|
if is_sliding_window and sliding_window is not None:
|
||||||
|
if is_causal:
|
||||||
|
# Causal sliding: only attend to past window steps
|
||||||
|
# i - j <= window => diff <= window
|
||||||
|
# (diff >= 0 already handled above)
|
||||||
|
valid_mask = valid_mask & (diff <= sliding_window)
|
||||||
|
else:
|
||||||
|
# Bidirectional sliding: attend past and future window steps
|
||||||
|
# |i - j| <= window => abs(diff) <= sliding_window
|
||||||
|
valid_mask = valid_mask & (torch.abs(diff) <= sliding_window)
|
||||||
|
|
||||||
|
# Expand dimensions to [1, 1, Seq_Len, Seq_Len] for broadcasting
|
||||||
|
valid_mask = valid_mask.unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
# ------------------------------------------------------
|
||||||
|
# 2. Apply padding mask (Key Masking)
|
||||||
|
# ------------------------------------------------------
|
||||||
|
if attention_mask is not None:
|
||||||
|
# attention_mask shape: [Batch, Seq_Len] (1=valid, 0=padding)
|
||||||
|
# We want to mask out invalid keys (columns)
|
||||||
|
# Expand shape: [Batch, 1, 1, Seq_Len]
|
||||||
|
padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool)
|
||||||
|
|
||||||
|
# Broadcasting: Geometry Mask [1, 1, L, L] & Padding Mask [B, 1, 1, L]
|
||||||
|
# Result shape: [B, 1, L, L]
|
||||||
|
valid_mask = valid_mask & padding_mask_4d
|
||||||
|
|
||||||
|
# ------------------------------------------------------
|
||||||
|
# 3. Convert to additive mask
|
||||||
|
# ------------------------------------------------------
|
||||||
|
# Get the minimal value for current dtype
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
|
||||||
|
# Create result tensor filled with -inf by default
|
||||||
|
mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# Set valid positions to 0.0
|
||||||
|
mask_tensor.masked_fill_(valid_mask, 0.0)
|
||||||
|
|
||||||
|
return mask_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def pack_sequences(hidden1: torch.Tensor, hidden2: torch.Tensor, mask1: torch.Tensor, mask2: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Pack two sequences by concatenating and sorting them based on mask values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden1: First hidden states tensor of shape [B, L1, D]
|
||||||
|
hidden2: Second hidden states tensor of shape [B, L2, D]
|
||||||
|
mask1: First mask tensor of shape [B, L1]
|
||||||
|
mask2: Second mask tensor of shape [B, L2]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (packed_hidden_states, new_mask) where:
|
||||||
|
- packed_hidden_states: Packed hidden states with valid tokens (mask=1) first, shape [B, L1+L2, D]
|
||||||
|
- new_mask: New mask tensor indicating valid positions, shape [B, L1+L2]
|
||||||
|
"""
|
||||||
|
# Step 1: Concatenate hidden states and masks along sequence dimension
|
||||||
|
hidden_cat = torch.cat([hidden1, hidden2], dim=1) # [B, L, D]
|
||||||
|
mask_cat = torch.cat([mask1, mask2], dim=1) # [B, L]
|
||||||
|
|
||||||
|
B, L, D = hidden_cat.shape
|
||||||
|
|
||||||
|
# Step 2: Sort indices so that mask values of 1 come before 0
|
||||||
|
sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True) # [B, L]
|
||||||
|
|
||||||
|
# Step 3: Reorder hidden states using sorted indices
|
||||||
|
hidden_left = torch.gather(hidden_cat, 1, sort_idx.unsqueeze(-1).expand(B, L, D))
|
||||||
|
|
||||||
|
# Step 4: Create new mask based on valid sequence lengths
|
||||||
|
lengths = mask_cat.sum(dim=1) # [B]
|
||||||
|
new_mask = (torch.arange(L, dtype=torch.long, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1))
|
||||||
|
|
||||||
|
return hidden_left, new_mask
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedding(nn.Module):
|
||||||
|
"""
|
||||||
|
Timestep embedding module for diffusion models.
|
||||||
|
|
||||||
|
Converts timestep values into high-dimensional embeddings using sinusoidal
|
||||||
|
positional encoding, followed by MLP layers. Used for conditioning diffusion
|
||||||
|
models on timestep information.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
time_embed_dim: int,
|
||||||
|
scale: float = 1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
|
||||||
|
self.act1 = nn.SiLU()
|
||||||
|
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, bias=True)
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
self.act2 = nn.SiLU()
|
||||||
|
self.time_proj = nn.Linear(time_embed_dim, time_embed_dim * 6)
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def timestep_embedding(self, t, dim, max_period=10000):
|
||||||
|
"""
|
||||||
|
Create sinusoidal timestep embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t: A 1-D tensor of N indices, one per batch element. These may be fractional.
|
||||||
|
dim: The dimension of the output embeddings.
|
||||||
|
max_period: Controls the minimum frequency of the embeddings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An (N, D) tensor of positional embeddings.
|
||||||
|
"""
|
||||||
|
t = t * self.scale
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(
|
||||||
|
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||||
|
).to(device=t.device)
|
||||||
|
args = t[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def forward(self, t):
|
||||||
|
t_freq = self.timestep_embedding(t, self.in_channels)
|
||||||
|
temb = self.linear_1(t_freq.to(t.dtype))
|
||||||
|
temb = self.act1(temb)
|
||||||
|
temb = self.linear_2(temb)
|
||||||
|
timestep_proj = self.time_proj(self.act2(temb)).unflatten(1, (6, -1))
|
||||||
|
return temb, timestep_proj
|
||||||
|
|
||||||
|
|
||||||
|
class AceStepAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
Multi-headed attention module for AceStep model.
|
||||||
|
|
||||||
|
Implements the attention mechanism from 'Attention Is All You Need' paper,
|
||||||
|
with support for both self-attention and cross-attention modes. Uses RMSNorm
|
||||||
|
for query and key normalization, and supports sliding window attention for
|
||||||
|
efficient long-sequence processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
num_key_value_heads: int,
|
||||||
|
rms_norm_eps: float,
|
||||||
|
attention_bias: bool,
|
||||||
|
attention_dropout: float,
|
||||||
|
layer_types: list,
|
||||||
|
head_dim: Optional[int] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
layer_idx: int = 0,
|
||||||
|
is_cross_attention: bool = False,
|
||||||
|
is_causal: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||||
|
self.num_key_value_groups = num_attention_heads // num_key_value_heads
|
||||||
|
self.scaling = self.head_dim ** -0.5
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
if is_cross_attention:
|
||||||
|
is_causal = False
|
||||||
|
self.is_causal = is_causal
|
||||||
|
self.is_cross_attention = is_cross_attention
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(hidden_size, num_attention_heads * self.head_dim, bias=attention_bias)
|
||||||
|
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
|
||||||
|
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
|
||||||
|
self.o_proj = nn.Linear(num_attention_heads * self.head_dim, hidden_size, bias=attention_bias)
|
||||||
|
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||||
|
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||||
|
self.attention_type = layer_types[layer_idx]
|
||||||
|
self.sliding_window = sliding_window if layer_types[layer_idx] == "sliding_attention" else None
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor],
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
position_embeddings: tuple[torch.Tensor, torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||||
|
input_shape = hidden_states.shape[:-1]
|
||||||
|
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||||
|
|
||||||
|
# Project and normalize query states
|
||||||
|
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||||
|
|
||||||
|
# Determine if this is cross-attention (requires encoder_hidden_states)
|
||||||
|
is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None
|
||||||
|
|
||||||
|
# Cross-attention path: attend to encoder hidden states
|
||||||
|
if is_cross_attention:
|
||||||
|
encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim)
|
||||||
|
if past_key_value is not None:
|
||||||
|
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||||
|
# After the first generated token, we can reuse all key/value states from cache
|
||||||
|
curr_past_key_value = past_key_value.cross_attention_cache
|
||||||
|
|
||||||
|
# Conditions for calculating key and value states
|
||||||
|
if not is_updated:
|
||||||
|
# Compute and cache K/V for the first time
|
||||||
|
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
|
||||||
|
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
|
||||||
|
# Update cache: save all key/value states to cache for fast auto-regressive generation
|
||||||
|
key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx)
|
||||||
|
# Set flag that this layer's cross-attention cache is updated
|
||||||
|
past_key_value.is_updated[self.layer_idx] = True
|
||||||
|
else:
|
||||||
|
# Reuse cached key/value states for subsequent tokens
|
||||||
|
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||||
|
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||||
|
else:
|
||||||
|
# No cache used, compute K/V directly
|
||||||
|
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
|
||||||
|
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
|
||||||
|
|
||||||
|
# Self-attention path: attend to the same sequence
|
||||||
|
else:
|
||||||
|
# Project and normalize key/value states for self-attention
|
||||||
|
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||||
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||||
|
# Apply rotary position embeddings (RoPE) if provided
|
||||||
|
if position_embeddings is not None:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
# Update cache for auto-regressive generation
|
||||||
|
if past_key_value is not None:
|
||||||
|
# Sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
# GGA expansion: if num_key_value_heads < num_attention_heads
|
||||||
|
if self.num_key_value_groups > 1:
|
||||||
|
key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
|
||||||
|
value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
|
||||||
|
|
||||||
|
# Use DiffSynth unified attention
|
||||||
|
# Tensors are already in (batch, heads, seq, dim) format -> "b n s d"
|
||||||
|
attn_output = attention_forward(
|
||||||
|
query_states, key_states, value_states,
|
||||||
|
q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d",
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_weights = None # attention_forward doesn't return weights
|
||||||
|
|
||||||
|
# Flatten and project output: (B, n_heads, seq, dim) -> (B, seq, n_heads*dim)
|
||||||
|
attn_output = attn_output.transpose(1, 2).flatten(2, 3).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
class AceStepEncoderLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
Encoder layer for AceStep model.
|
||||||
|
|
||||||
|
Consists of self-attention and MLP (feed-forward) sub-layers with residual connections.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
num_key_value_heads: int,
|
||||||
|
intermediate_size: int = 6144,
|
||||||
|
rms_norm_eps: float = 1e-6,
|
||||||
|
attention_bias: bool = False,
|
||||||
|
attention_dropout: float = 0.0,
|
||||||
|
layer_types: list = None,
|
||||||
|
head_dim: Optional[int] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
layer_idx: int = 0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
self.self_attn = AceStepAttention(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
rms_norm_eps=rms_norm_eps,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
attention_dropout=attention_dropout,
|
||||||
|
layer_types=layer_types,
|
||||||
|
head_dim=head_dim,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
is_cross_attention=False,
|
||||||
|
is_causal=False,
|
||||||
|
)
|
||||||
|
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||||
|
|
||||||
|
# MLP (feed-forward) sub-layer
|
||||||
|
self.mlp = Qwen3MLP(
|
||||||
|
config=type('Config', (), {
|
||||||
|
'hidden_size': hidden_size,
|
||||||
|
'intermediate_size': intermediate_size,
|
||||||
|
'hidden_act': 'silu',
|
||||||
|
})()
|
||||||
|
)
|
||||||
|
self.attention_type = layer_types[layer_idx]
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[
|
||||||
|
torch.FloatTensor,
|
||||||
|
Optional[tuple[torch.FloatTensor, torch.FloatTensor]],
|
||||||
|
]:
|
||||||
|
# Self-attention with residual connection
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
hidden_states, self_attn_weights = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
# Encoders don't use cache
|
||||||
|
use_cache=False,
|
||||||
|
past_key_value=None,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# MLP with residual connection
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (self_attn_weights,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class AceStepDiTLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
DiT (Diffusion Transformer) layer for AceStep model.
|
||||||
|
|
||||||
|
Implements a transformer layer with three main components:
|
||||||
|
1. Self-attention with adaptive layer norm (AdaLN)
|
||||||
|
2. Cross-attention (optional) for conditioning on encoder outputs
|
||||||
|
3. Feed-forward MLP with adaptive layer norm
|
||||||
|
|
||||||
|
Uses scale-shift modulation from timestep embeddings for adaptive normalization.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
num_key_value_heads: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
rms_norm_eps: float,
|
||||||
|
attention_bias: bool,
|
||||||
|
attention_dropout: float,
|
||||||
|
layer_types: list,
|
||||||
|
head_dim: Optional[int] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
layer_idx: int = 0,
|
||||||
|
use_cross_attention: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.self_attn_norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||||
|
self.self_attn = AceStepAttention(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
rms_norm_eps=rms_norm_eps,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
attention_dropout=attention_dropout,
|
||||||
|
layer_types=layer_types,
|
||||||
|
head_dim=head_dim,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.use_cross_attention = use_cross_attention
|
||||||
|
if self.use_cross_attention:
|
||||||
|
self.cross_attn_norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||||
|
self.cross_attn = AceStepAttention(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
rms_norm_eps=rms_norm_eps,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
attention_dropout=attention_dropout,
|
||||||
|
layer_types=layer_types,
|
||||||
|
head_dim=head_dim,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
is_cross_attention=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mlp_norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||||
|
self.mlp = Qwen3MLP(
|
||||||
|
config=type('Config', (), {
|
||||||
|
'hidden_size': hidden_size,
|
||||||
|
'intermediate_size': intermediate_size,
|
||||||
|
'hidden_act': 'silu',
|
||||||
|
})()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, hidden_size) / hidden_size**0.5)
|
||||||
|
self.attention_type = layer_types[layer_idx]
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||||
|
temb: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[EncoderDecoderCache] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# Extract scale-shift parameters for adaptive layer norm from timestep embeddings
|
||||||
|
# 6 values: (shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa)
|
||||||
|
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||||
|
self.scale_shift_table.to(temb.device) + temb
|
||||||
|
).chunk(6, dim=1)
|
||||||
|
|
||||||
|
# Step 1: Self-attention with adaptive layer norm (AdaLN)
|
||||||
|
# Apply adaptive normalization: norm(x) * (1 + scale) + shift
|
||||||
|
norm_hidden_states = (self.self_attn_norm(hidden_states) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
|
||||||
|
attn_output, self_attn_weights = self.self_attn(
|
||||||
|
hidden_states=norm_hidden_states,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=False,
|
||||||
|
past_key_value=None,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
# Apply gated residual connection: x = x + attn_output * gate
|
||||||
|
hidden_states = (hidden_states + attn_output * gate_msa).type_as(hidden_states)
|
||||||
|
|
||||||
|
# Step 2: Cross-attention (if enabled) for conditioning on encoder outputs
|
||||||
|
if self.use_cross_attention:
|
||||||
|
norm_hidden_states = self.cross_attn_norm(hidden_states).type_as(hidden_states)
|
||||||
|
attn_output, cross_attn_weights = self.cross_attn(
|
||||||
|
hidden_states=norm_hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
attention_mask=encoder_attention_mask,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
# Standard residual connection for cross-attention
|
||||||
|
hidden_states = hidden_states + attn_output
|
||||||
|
|
||||||
|
# Step 3: Feed-forward (MLP) with adaptive layer norm
|
||||||
|
# Apply adaptive normalization for MLP: norm(x) * (1 + scale) + shift
|
||||||
|
norm_hidden_states = (self.mlp_norm(hidden_states) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states)
|
||||||
|
ff_output = self.mlp(norm_hidden_states)
|
||||||
|
# Apply gated residual connection: x = x + mlp_output * gate
|
||||||
|
hidden_states = (hidden_states + ff_output * c_gate_msa).type_as(hidden_states)
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (self_attn_weights, cross_attn_weights)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Lambda(nn.Module):
|
||||||
|
"""
|
||||||
|
Wrapper module for arbitrary lambda functions.
|
||||||
|
|
||||||
|
Allows using lambda functions in nn.Sequential by wrapping them in a Module.
|
||||||
|
Useful for simple transformations like transpose operations.
|
||||||
|
"""
|
||||||
|
def __init__(self, func):
|
||||||
|
super().__init__()
|
||||||
|
self.func = func
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.func(x)
|
||||||
|
|
||||||
|
|
||||||
|
class AceStepDiTModel(nn.Module):
|
||||||
|
"""
|
||||||
|
DiT (Diffusion Transformer) model for AceStep.
|
||||||
|
|
||||||
|
Main diffusion model that generates audio latents conditioned on text, lyrics,
|
||||||
|
and timbre. Uses patch-based processing with transformer layers, timestep
|
||||||
|
conditioning, and cross-attention to encoder outputs.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int = 2048,
|
||||||
|
intermediate_size: int = 6144,
|
||||||
|
num_hidden_layers: int = 24,
|
||||||
|
num_attention_heads: int = 16,
|
||||||
|
num_key_value_heads: int = 8,
|
||||||
|
rms_norm_eps: float = 1e-6,
|
||||||
|
attention_bias: bool = False,
|
||||||
|
attention_dropout: float = 0.0,
|
||||||
|
layer_types: Optional[list] = None,
|
||||||
|
head_dim: Optional[int] = None,
|
||||||
|
sliding_window: Optional[int] = 128,
|
||||||
|
use_sliding_window: bool = True,
|
||||||
|
use_cache: bool = True,
|
||||||
|
rope_theta: float = 1000000,
|
||||||
|
max_position_embeddings: int = 32768,
|
||||||
|
initializer_range: float = 0.02,
|
||||||
|
patch_size: int = 2,
|
||||||
|
in_channels: int = 192,
|
||||||
|
audio_acoustic_hidden_dim: int = 64,
|
||||||
|
encoder_hidden_size: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * (num_hidden_layers // 2))
|
||||||
|
self.use_sliding_window = use_sliding_window
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
self.use_cache = use_cache
|
||||||
|
encoder_hidden_size = encoder_hidden_size or hidden_size
|
||||||
|
|
||||||
|
# Rotary position embeddings for transformer layers
|
||||||
|
rope_config = type('RopeConfig', (), {
|
||||||
|
'hidden_size': hidden_size,
|
||||||
|
'num_attention_heads': num_attention_heads,
|
||||||
|
'num_key_value_heads': num_key_value_heads,
|
||||||
|
'head_dim': head_dim,
|
||||||
|
'max_position_embeddings': max_position_embeddings,
|
||||||
|
'rope_theta': rope_theta,
|
||||||
|
'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
|
||||||
|
'rms_norm_eps': rms_norm_eps,
|
||||||
|
'attention_bias': attention_bias,
|
||||||
|
'attention_dropout': attention_dropout,
|
||||||
|
'hidden_act': 'silu',
|
||||||
|
'intermediate_size': intermediate_size,
|
||||||
|
'layer_types': self.layer_types,
|
||||||
|
'sliding_window': sliding_window,
|
||||||
|
})()
|
||||||
|
self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
|
||||||
|
|
||||||
|
# Stack of DiT transformer layers
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
AceStepDiTLayer(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
rms_norm_eps=rms_norm_eps,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
attention_dropout=attention_dropout,
|
||||||
|
layer_types=self.layer_types,
|
||||||
|
head_dim=head_dim,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
)
|
||||||
|
for layer_idx in range(num_hidden_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
# Input projection: patch embedding using 1D convolution
|
||||||
|
self.proj_in = nn.Sequential(
|
||||||
|
Lambda(lambda x: x.transpose(1, 2)),
|
||||||
|
nn.Conv1d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=hidden_size,
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size,
|
||||||
|
padding=0,
|
||||||
|
),
|
||||||
|
Lambda(lambda x: x.transpose(1, 2)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Timestep embeddings for diffusion conditioning
|
||||||
|
self.time_embed = TimestepEmbedding(in_channels=256, time_embed_dim=hidden_size)
|
||||||
|
self.time_embed_r = TimestepEmbedding(in_channels=256, time_embed_dim=hidden_size)
|
||||||
|
|
||||||
|
# Project encoder hidden states to model dimension
|
||||||
|
self.condition_embedder = nn.Linear(encoder_hidden_size, hidden_size, bias=True)
|
||||||
|
|
||||||
|
# Output normalization and projection
|
||||||
|
self.norm_out = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||||
|
self.proj_out = nn.Sequential(
|
||||||
|
Lambda(lambda x: x.transpose(1, 2)),
|
||||||
|
nn.ConvTranspose1d(
|
||||||
|
in_channels=hidden_size,
|
||||||
|
out_channels=audio_acoustic_hidden_dim,
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size,
|
||||||
|
padding=0,
|
||||||
|
),
|
||||||
|
Lambda(lambda x: x.transpose(1, 2)),
|
||||||
|
)
|
||||||
|
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, hidden_size) / hidden_size**0.5)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
timestep: torch.Tensor,
|
||||||
|
timestep_r: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
encoder_attention_mask: torch.Tensor,
|
||||||
|
context_latents: torch.Tensor,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
past_key_values: Optional[EncoderDecoderCache] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
return_hidden_states: int = None,
|
||||||
|
custom_layers_config: Optional[dict] = None,
|
||||||
|
enable_early_exit: bool = False,
|
||||||
|
use_gradient_checkpointing: bool = False,
|
||||||
|
use_gradient_checkpointing_offload: bool = False,
|
||||||
|
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||||
|
):
|
||||||
|
|
||||||
|
use_cache = use_cache if use_cache is not None else self.use_cache
|
||||||
|
|
||||||
|
# Disable cache during training or when gradient checkpointing is enabled
|
||||||
|
if self.gradient_checkpointing and self.training and use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
if self.training:
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
# Initialize cache if needed (only during inference for auto-regressive generation)
|
||||||
|
if not self.training and use_cache and past_key_values is None:
|
||||||
|
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
|
||||||
|
|
||||||
|
# Compute timestep embeddings for diffusion conditioning
|
||||||
|
# Two embeddings: one for timestep t, one for timestep difference (t - r)
|
||||||
|
temb_t, timestep_proj_t = self.time_embed(timestep)
|
||||||
|
temb_r, timestep_proj_r = self.time_embed_r(timestep - timestep_r)
|
||||||
|
# Combine embeddings
|
||||||
|
temb = temb_t + temb_r
|
||||||
|
timestep_proj = timestep_proj_t + timestep_proj_r
|
||||||
|
|
||||||
|
# Concatenate context latents (source latents + chunk masks) with hidden states
|
||||||
|
hidden_states = torch.cat([context_latents, hidden_states], dim=-1)
|
||||||
|
# Record original sequence length for later restoration after padding
|
||||||
|
original_seq_len = hidden_states.shape[1]
|
||||||
|
# Apply padding if sequence length is not divisible by patch_size
|
||||||
|
# This ensures proper patch extraction
|
||||||
|
pad_length = 0
|
||||||
|
if hidden_states.shape[1] % self.patch_size != 0:
|
||||||
|
pad_length = self.patch_size - (hidden_states.shape[1] % self.patch_size)
|
||||||
|
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_length), mode='constant', value=0)
|
||||||
|
|
||||||
|
# Project input to patches and project encoder states
|
||||||
|
hidden_states = self.proj_in(hidden_states)
|
||||||
|
encoder_hidden_states = self.condition_embedder(encoder_hidden_states)
|
||||||
|
|
||||||
|
# Cache positions
|
||||||
|
if cache_position is None:
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Position IDs
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
seq_len = hidden_states.shape[1]
|
||||||
|
encoder_seq_len = encoder_hidden_states.shape[1]
|
||||||
|
dtype = hidden_states.dtype
|
||||||
|
device = hidden_states.device
|
||||||
|
|
||||||
|
# Initialize Mask variables
|
||||||
|
full_attn_mask = None
|
||||||
|
sliding_attn_mask = None
|
||||||
|
encoder_attn_mask = None
|
||||||
|
decoder_attn_mask = None
|
||||||
|
# Target library discards the passed-in attention_mask for 4D mask
|
||||||
|
# construction (line 1384: attention_mask = None)
|
||||||
|
attention_mask = None
|
||||||
|
|
||||||
|
# 1. Full Attention (Bidirectional, Global)
|
||||||
|
full_attn_mask = create_4d_mask(
|
||||||
|
seq_len=seq_len,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
sliding_window=None,
|
||||||
|
is_sliding_window=False,
|
||||||
|
is_causal=False
|
||||||
|
)
|
||||||
|
max_len = max(seq_len, encoder_seq_len)
|
||||||
|
|
||||||
|
encoder_attn_mask = create_4d_mask(
|
||||||
|
seq_len=max_len,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
sliding_window=None,
|
||||||
|
is_sliding_window=False,
|
||||||
|
is_causal=False
|
||||||
|
)
|
||||||
|
encoder_attn_mask = encoder_attn_mask[:, :, :seq_len, :encoder_seq_len]
|
||||||
|
|
||||||
|
# 2. Sliding Attention (Bidirectional, Local)
|
||||||
|
if self.use_sliding_window:
|
||||||
|
sliding_attn_mask = create_4d_mask(
|
||||||
|
seq_len=seq_len,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
sliding_window=self.sliding_window,
|
||||||
|
is_sliding_window=True,
|
||||||
|
is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build mask mapping
|
||||||
|
self_attn_mask_mapping = {
|
||||||
|
"full_attention": full_attn_mask,
|
||||||
|
"sliding_attention": sliding_attn_mask,
|
||||||
|
"encoder_attention_mask": encoder_attn_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create position embeddings to be shared across all decoder layers
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
all_cross_attentions = () if output_attentions else None
|
||||||
|
|
||||||
|
# Handle early exit for custom layer configurations
|
||||||
|
max_needed_layer = float('inf')
|
||||||
|
if custom_layers_config is not None and enable_early_exit:
|
||||||
|
max_needed_layer = max(custom_layers_config.keys())
|
||||||
|
output_attentions = True
|
||||||
|
if all_cross_attentions is None:
|
||||||
|
all_cross_attentions = ()
|
||||||
|
|
||||||
|
# Process through transformer layers
|
||||||
|
for index_block, layer_module in enumerate(self.layers):
|
||||||
|
# Early exit optimization
|
||||||
|
if index_block > max_needed_layer:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Prepare layer arguments
|
||||||
|
layer_args = (
|
||||||
|
hidden_states,
|
||||||
|
position_embeddings,
|
||||||
|
timestep_proj,
|
||||||
|
self_attn_mask_mapping[layer_module.attention_type],
|
||||||
|
position_ids,
|
||||||
|
past_key_values,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
cache_position,
|
||||||
|
encoder_hidden_states,
|
||||||
|
self_attn_mask_mapping["encoder_attention_mask"],
|
||||||
|
)
|
||||||
|
layer_kwargs = flash_attn_kwargs
|
||||||
|
|
||||||
|
# Use gradient checkpointing if enabled
|
||||||
|
layer_outputs = gradient_checkpoint_forward(
|
||||||
|
layer_module,
|
||||||
|
use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload,
|
||||||
|
*layer_args,
|
||||||
|
**layer_kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if output_attentions and self.layers[index_block].use_cross_attention:
|
||||||
|
# layer_outputs structure: (hidden_states, self_attn_weights, cross_attn_weights)
|
||||||
|
if len(layer_outputs) >= 3:
|
||||||
|
all_cross_attentions += (layer_outputs[2],)
|
||||||
|
|
||||||
|
if return_hidden_states:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
# Extract scale-shift parameters for adaptive output normalization
|
||||||
|
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||||
|
shift = shift.to(hidden_states.device)
|
||||||
|
scale = scale.to(hidden_states.device)
|
||||||
|
|
||||||
|
# Apply adaptive layer norm: norm(x) * (1 + scale) + shift
|
||||||
|
hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).type_as(hidden_states)
|
||||||
|
# Project output: de-patchify back to original sequence format
|
||||||
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
|
||||||
|
# Crop back to original sequence length to ensure exact length match (remove padding)
|
||||||
|
hidden_states = hidden_states[:, :original_seq_len, :]
|
||||||
|
|
||||||
|
outputs = (hidden_states, past_key_values)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (all_cross_attentions,)
|
||||||
|
return outputs
|
||||||
53
diffsynth/models/ace_step_text_encoder.py
Normal file
53
diffsynth/models/ace_step_text_encoder.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class AceStepTextEncoder(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
from transformers import Qwen3Config, Qwen3Model
|
||||||
|
|
||||||
|
config = Qwen3Config(
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
bos_token_id=151643,
|
||||||
|
dtype="bfloat16",
|
||||||
|
eos_token_id=151643,
|
||||||
|
head_dim=128,
|
||||||
|
hidden_act="silu",
|
||||||
|
hidden_size=1024,
|
||||||
|
initializer_range=0.02,
|
||||||
|
intermediate_size=3072,
|
||||||
|
layer_types=["full_attention"] * 28,
|
||||||
|
max_position_embeddings=32768,
|
||||||
|
max_window_layers=28,
|
||||||
|
model_type="qwen3",
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_hidden_layers=28,
|
||||||
|
num_key_value_heads=8,
|
||||||
|
pad_token_id=151643,
|
||||||
|
rms_norm_eps=1e-06,
|
||||||
|
rope_scaling=None,
|
||||||
|
rope_theta=1000000,
|
||||||
|
sliding_window=None,
|
||||||
|
tie_word_embeddings=True,
|
||||||
|
use_cache=True,
|
||||||
|
use_sliding_window=False,
|
||||||
|
vocab_size=151669,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = Qwen3Model(config)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
):
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
return outputs.last_hidden_state
|
||||||
732
diffsynth/models/ace_step_tokenizer.py
Normal file
732
diffsynth/models/ace_step_tokenizer.py
Normal file
@@ -0,0 +1,732 @@
|
|||||||
|
# Copyright 2025 The ACESTEO Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""ACE-Step Audio Tokenizer — VAE latent discretization pathway.
|
||||||
|
|
||||||
|
Contains:
|
||||||
|
- AceStepAudioTokenizer: continuous VAE latent → discrete FSQ tokens
|
||||||
|
- AudioTokenDetokenizer: discrete tokens → continuous VAE-latent-shaped features
|
||||||
|
|
||||||
|
Only used in cover song mode (is_covers=True). Bypassed in text-to-music.
|
||||||
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from ..core.attention import attention_forward
|
||||||
|
from ..core.gradient import gradient_checkpoint_forward
|
||||||
|
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from transformers.modeling_outputs import BaseModelOutput
|
||||||
|
from transformers.processing_utils import Unpack
|
||||||
|
from transformers.utils import can_return_tuple, logging
|
||||||
|
from transformers.models.qwen3.modeling_qwen3 import (
|
||||||
|
Qwen3MLP,
|
||||||
|
Qwen3RMSNorm,
|
||||||
|
Qwen3RotaryEmbedding,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
from vector_quantize_pytorch import ResidualFSQ
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_4d_mask(
|
||||||
|
seq_len: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
is_sliding_window: bool = False,
|
||||||
|
is_causal: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
indices = torch.arange(seq_len, device=device)
|
||||||
|
diff = indices.unsqueeze(1) - indices.unsqueeze(0)
|
||||||
|
valid_mask = torch.ones((seq_len, seq_len), device=device, dtype=torch.bool)
|
||||||
|
if is_causal:
|
||||||
|
valid_mask = valid_mask & (diff >= 0)
|
||||||
|
if is_sliding_window and sliding_window is not None:
|
||||||
|
if is_causal:
|
||||||
|
valid_mask = valid_mask & (diff <= sliding_window)
|
||||||
|
else:
|
||||||
|
valid_mask = valid_mask & (torch.abs(diff) <= sliding_window)
|
||||||
|
valid_mask = valid_mask.unsqueeze(0).unsqueeze(0)
|
||||||
|
if attention_mask is not None:
|
||||||
|
padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool)
|
||||||
|
valid_mask = valid_mask & padding_mask_4d
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device)
|
||||||
|
mask_tensor.masked_fill_(valid_mask, 0.0)
|
||||||
|
return mask_tensor
|
||||||
|
|
||||||
|
|
||||||
|
class Lambda(nn.Module):
|
||||||
|
def __init__(self, func):
|
||||||
|
super().__init__()
|
||||||
|
self.func = func
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.func(x)
|
||||||
|
|
||||||
|
|
||||||
|
class AceStepAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
num_key_value_heads: int,
|
||||||
|
rms_norm_eps: float,
|
||||||
|
attention_bias: bool,
|
||||||
|
attention_dropout: float,
|
||||||
|
layer_types: list,
|
||||||
|
head_dim: Optional[int] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
layer_idx: int = 0,
|
||||||
|
is_cross_attention: bool = False,
|
||||||
|
is_causal: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||||
|
self.num_key_value_groups = num_attention_heads // num_key_value_heads
|
||||||
|
self.scaling = self.head_dim ** -0.5
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
if is_cross_attention:
|
||||||
|
is_causal = False
|
||||||
|
self.is_causal = is_causal
|
||||||
|
self.is_cross_attention = is_cross_attention
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(hidden_size, num_attention_heads * self.head_dim, bias=attention_bias)
|
||||||
|
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
|
||||||
|
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
|
||||||
|
self.o_proj = nn.Linear(num_attention_heads * self.head_dim, hidden_size, bias=attention_bias)
|
||||||
|
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||||
|
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||||
|
self.attention_type = layer_types[layer_idx]
|
||||||
|
self.sliding_window = sliding_window if layer_types[layer_idx] == "sliding_attention" else None
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor],
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
position_embeddings: tuple[torch.Tensor, torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
input_shape = hidden_states.shape[:-1]
|
||||||
|
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||||
|
|
||||||
|
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||||
|
|
||||||
|
is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None
|
||||||
|
|
||||||
|
if is_cross_attention:
|
||||||
|
encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim)
|
||||||
|
if past_key_value is not None:
|
||||||
|
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||||
|
curr_past_key_value = past_key_value.cross_attention_cache
|
||||||
|
if not is_updated:
|
||||||
|
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
|
||||||
|
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
|
||||||
|
key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx)
|
||||||
|
past_key_value.is_updated[self.layer_idx] = True
|
||||||
|
else:
|
||||||
|
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||||
|
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||||
|
else:
|
||||||
|
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
|
||||||
|
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
|
||||||
|
|
||||||
|
else:
|
||||||
|
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||||
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||||
|
if position_embeddings is not None:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
if self.num_key_value_groups > 1:
|
||||||
|
key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
|
||||||
|
value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
|
||||||
|
|
||||||
|
attn_output = attention_forward(
|
||||||
|
query_states, key_states, value_states,
|
||||||
|
q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d",
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
)
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).flatten(2, 3).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
class AceStepEncoderLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
num_key_value_heads: int,
|
||||||
|
rms_norm_eps: float,
|
||||||
|
attention_bias: bool,
|
||||||
|
attention_dropout: float,
|
||||||
|
layer_types: list,
|
||||||
|
head_dim: Optional[int] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
layer_idx: int = 0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
self.self_attn = AceStepAttention(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
rms_norm_eps=rms_norm_eps,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
attention_dropout=attention_dropout,
|
||||||
|
layer_types=layer_types,
|
||||||
|
head_dim=head_dim,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
is_cross_attention=False,
|
||||||
|
is_causal=False,
|
||||||
|
)
|
||||||
|
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||||
|
|
||||||
|
mlp_config = type('Config', (), {
|
||||||
|
'hidden_size': hidden_size,
|
||||||
|
'intermediate_size': intermediate_size,
|
||||||
|
'hidden_act': 'silu',
|
||||||
|
})()
|
||||||
|
self.mlp = Qwen3MLP(mlp_config)
|
||||||
|
self.attention_type = layer_types[layer_idx]
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
hidden_states, self_attn_weights = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=False,
|
||||||
|
past_key_value=None,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (self_attn_weights,)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionPooler(nn.Module):
|
||||||
|
"""Pools every pool_window_size frames into 1 representation via transformer + CLS token."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int = 2048,
|
||||||
|
intermediate_size: int = 6144,
|
||||||
|
num_attention_heads: int = 16,
|
||||||
|
num_key_value_heads: int = 8,
|
||||||
|
rms_norm_eps: float = 1e-6,
|
||||||
|
attention_bias: bool = False,
|
||||||
|
attention_dropout: float = 0.0,
|
||||||
|
layer_types: Optional[list] = None,
|
||||||
|
head_dim: Optional[int] = None,
|
||||||
|
sliding_window: Optional[int] = 128,
|
||||||
|
use_sliding_window: bool = True,
|
||||||
|
rope_theta: float = 1000000,
|
||||||
|
max_position_embeddings: int = 32768,
|
||||||
|
initializer_range: float = 0.02,
|
||||||
|
num_attention_pooler_hidden_layers: int = 2,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
# Default matches target library config (24 alternating entries).
|
||||||
|
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * 12)
|
||||||
|
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
self.use_sliding_window = use_sliding_window
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.num_attention_pooler_hidden_layers = num_attention_pooler_hidden_layers
|
||||||
|
self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
|
||||||
|
|
||||||
|
self.embed_tokens = nn.Linear(hidden_size, hidden_size)
|
||||||
|
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||||
|
# Slice layer_types to our own layer count
|
||||||
|
pooler_layer_types = self.layer_types[:num_attention_pooler_hidden_layers]
|
||||||
|
rope_config = type('RopeConfig', (), {
|
||||||
|
'hidden_size': hidden_size,
|
||||||
|
'num_attention_heads': num_attention_heads,
|
||||||
|
'num_key_value_heads': num_key_value_heads,
|
||||||
|
'head_dim': head_dim,
|
||||||
|
'max_position_embeddings': max_position_embeddings,
|
||||||
|
'rope_theta': rope_theta,
|
||||||
|
'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
|
||||||
|
'rms_norm_eps': rms_norm_eps,
|
||||||
|
'attention_bias': attention_bias,
|
||||||
|
'attention_dropout': attention_dropout,
|
||||||
|
'hidden_act': 'silu',
|
||||||
|
'intermediate_size': intermediate_size,
|
||||||
|
'layer_types': pooler_layer_types,
|
||||||
|
'sliding_window': sliding_window,
|
||||||
|
'_attn_implementation': self._attn_implementation,
|
||||||
|
})()
|
||||||
|
self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
self.special_token = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
AceStepEncoderLayer(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
rms_norm_eps=rms_norm_eps,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
attention_dropout=attention_dropout,
|
||||||
|
layer_types=pooler_layer_types,
|
||||||
|
head_dim=head_dim,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
)
|
||||||
|
for layer_idx in range(num_attention_pooler_hidden_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
@can_return_tuple
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
B, T, P, D = x.shape
|
||||||
|
x = self.embed_tokens(x)
|
||||||
|
special_tokens = self.special_token.expand(B, T, 1, -1).to(x.device)
|
||||||
|
x = torch.cat([special_tokens, x], dim=2)
|
||||||
|
x = rearrange(x, "b t p c -> (b t) p c")
|
||||||
|
|
||||||
|
cache_position = torch.arange(0, x.shape[1], device=x.device)
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
hidden_states = x
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
|
seq_len = x.shape[1]
|
||||||
|
dtype = x.dtype
|
||||||
|
device = x.device
|
||||||
|
|
||||||
|
full_attn_mask = create_4d_mask(
|
||||||
|
seq_len=seq_len, dtype=dtype, device=device,
|
||||||
|
attention_mask=attention_mask, sliding_window=None,
|
||||||
|
is_sliding_window=False, is_causal=False
|
||||||
|
)
|
||||||
|
sliding_attn_mask = None
|
||||||
|
if self.use_sliding_window:
|
||||||
|
sliding_attn_mask = create_4d_mask(
|
||||||
|
seq_len=seq_len, dtype=dtype, device=device,
|
||||||
|
attention_mask=attention_mask, sliding_window=self.sliding_window,
|
||||||
|
is_sliding_window=True, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self_attn_mask_mapping = {
|
||||||
|
"full_attention": full_attn_mask,
|
||||||
|
"sliding_attention": sliding_attn_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
for layer_module in self.layers:
|
||||||
|
layer_outputs = layer_module(
|
||||||
|
hidden_states, position_embeddings,
|
||||||
|
attention_mask=self_attn_mask_mapping[layer_module.attention_type],
|
||||||
|
**flash_attn_kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
cls_output = hidden_states[:, 0, :]
|
||||||
|
return rearrange(cls_output, "(b t) c -> b t c", b=B)
|
||||||
|
|
||||||
|
|
||||||
|
class AceStepAudioTokenizer(nn.Module):
|
||||||
|
"""Converts continuous acoustic features (VAE latents) into discrete quantized tokens.
|
||||||
|
|
||||||
|
Input: [B, T, 64] (VAE latent dim)
|
||||||
|
Output: quantized [B, T/5, 2048], indices [B, T/5, 1]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int = 2048,
|
||||||
|
intermediate_size: int = 6144,
|
||||||
|
num_attention_heads: int = 16,
|
||||||
|
num_key_value_heads: int = 8,
|
||||||
|
rms_norm_eps: float = 1e-6,
|
||||||
|
attention_bias: bool = False,
|
||||||
|
attention_dropout: float = 0.0,
|
||||||
|
layer_types: Optional[list] = None,
|
||||||
|
head_dim: Optional[int] = None,
|
||||||
|
sliding_window: Optional[int] = 128,
|
||||||
|
use_sliding_window: bool = True,
|
||||||
|
rope_theta: float = 1000000,
|
||||||
|
max_position_embeddings: int = 32768,
|
||||||
|
initializer_range: float = 0.02,
|
||||||
|
audio_acoustic_hidden_dim: int = 64,
|
||||||
|
pool_window_size: int = 5,
|
||||||
|
fsq_dim: int = 2048,
|
||||||
|
fsq_input_levels: list = None,
|
||||||
|
fsq_input_num_quantizers: int = 1,
|
||||||
|
num_attention_pooler_hidden_layers: int = 2,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
# Default matches target library config (24 alternating entries).
|
||||||
|
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * 12)
|
||||||
|
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
self.use_sliding_window = use_sliding_window
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.audio_acoustic_hidden_dim = audio_acoustic_hidden_dim
|
||||||
|
self.pool_window_size = pool_window_size
|
||||||
|
self.fsq_dim = fsq_dim
|
||||||
|
self.fsq_input_levels = fsq_input_levels or [8, 8, 8, 5, 5, 5]
|
||||||
|
self.fsq_input_num_quantizers = fsq_input_num_quantizers
|
||||||
|
self.num_attention_pooler_hidden_layers = num_attention_pooler_hidden_layers
|
||||||
|
self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
|
||||||
|
|
||||||
|
self.audio_acoustic_proj = nn.Linear(audio_acoustic_hidden_dim, hidden_size)
|
||||||
|
# Slice layer_types for the attention pooler
|
||||||
|
pooler_layer_types = self.layer_types[:num_attention_pooler_hidden_layers]
|
||||||
|
self.attention_pooler = AttentionPooler(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
rms_norm_eps=rms_norm_eps,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
attention_dropout=attention_dropout,
|
||||||
|
layer_types=pooler_layer_types,
|
||||||
|
head_dim=head_dim,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
use_sliding_window=use_sliding_window,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
initializer_range=initializer_range,
|
||||||
|
num_attention_pooler_hidden_layers=num_attention_pooler_hidden_layers,
|
||||||
|
)
|
||||||
|
self.quantizer = ResidualFSQ(
|
||||||
|
dim=self.fsq_dim,
|
||||||
|
levels=self.fsq_input_levels,
|
||||||
|
num_quantizers=self.fsq_input_num_quantizers,
|
||||||
|
force_quantization_f32=False, # avoid autocast bug in vector_quantize_pytorch
|
||||||
|
)
|
||||||
|
|
||||||
|
@can_return_tuple
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
hidden_states = self.audio_acoustic_proj(hidden_states)
|
||||||
|
hidden_states = self.attention_pooler(hidden_states)
|
||||||
|
quantized, indices = self.quantizer(hidden_states)
|
||||||
|
return quantized, indices
|
||||||
|
|
||||||
|
def tokenize(self, x):
|
||||||
|
"""Convenience: takes [B, T, 64], rearranges to patches, runs forward."""
|
||||||
|
x = rearrange(x, 'n (t_patch p) d -> n t_patch p d', p=self.pool_window_size)
|
||||||
|
return self.forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioTokenDetokenizer(nn.Module):
|
||||||
|
"""Converts quantized audio tokens back to continuous acoustic representations.
|
||||||
|
|
||||||
|
Input: [B, T/5, hidden_size] (quantized vectors)
|
||||||
|
Output: [B, T, 64] (VAE-latent-shaped continuous features)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int = 2048,
|
||||||
|
intermediate_size: int = 6144,
|
||||||
|
num_attention_heads: int = 16,
|
||||||
|
num_key_value_heads: int = 8,
|
||||||
|
rms_norm_eps: float = 1e-6,
|
||||||
|
attention_bias: bool = False,
|
||||||
|
attention_dropout: float = 0.0,
|
||||||
|
layer_types: Optional[list] = None,
|
||||||
|
head_dim: Optional[int] = None,
|
||||||
|
sliding_window: Optional[int] = 128,
|
||||||
|
use_sliding_window: bool = True,
|
||||||
|
rope_theta: float = 1000000,
|
||||||
|
max_position_embeddings: int = 32768,
|
||||||
|
initializer_range: float = 0.02,
|
||||||
|
pool_window_size: int = 5,
|
||||||
|
audio_acoustic_hidden_dim: int = 64,
|
||||||
|
num_attention_pooler_hidden_layers: int = 2,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
# Default matches target library config (24 alternating entries).
|
||||||
|
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * 12)
|
||||||
|
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
self.use_sliding_window = use_sliding_window
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.pool_window_size = pool_window_size
|
||||||
|
self.audio_acoustic_hidden_dim = audio_acoustic_hidden_dim
|
||||||
|
self.num_attention_pooler_hidden_layers = num_attention_pooler_hidden_layers
|
||||||
|
self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
|
||||||
|
|
||||||
|
self.embed_tokens = nn.Linear(hidden_size, hidden_size)
|
||||||
|
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||||
|
# Slice layer_types to our own layer count (use num_audio_decoder_hidden_layers)
|
||||||
|
detok_layer_types = self.layer_types[:num_attention_pooler_hidden_layers]
|
||||||
|
rope_config = type('RopeConfig', (), {
|
||||||
|
'hidden_size': hidden_size,
|
||||||
|
'num_attention_heads': num_attention_heads,
|
||||||
|
'num_key_value_heads': num_key_value_heads,
|
||||||
|
'head_dim': head_dim,
|
||||||
|
'max_position_embeddings': max_position_embeddings,
|
||||||
|
'rope_theta': rope_theta,
|
||||||
|
'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
|
||||||
|
'rms_norm_eps': rms_norm_eps,
|
||||||
|
'attention_bias': attention_bias,
|
||||||
|
'attention_dropout': attention_dropout,
|
||||||
|
'hidden_act': 'silu',
|
||||||
|
'intermediate_size': intermediate_size,
|
||||||
|
'layer_types': detok_layer_types,
|
||||||
|
'sliding_window': sliding_window,
|
||||||
|
'_attn_implementation': self._attn_implementation,
|
||||||
|
})()
|
||||||
|
self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
self.special_tokens = nn.Parameter(torch.randn(1, pool_window_size, hidden_size) * 0.02)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
AceStepEncoderLayer(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
rms_norm_eps=rms_norm_eps,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
attention_dropout=attention_dropout,
|
||||||
|
layer_types=detok_layer_types,
|
||||||
|
head_dim=head_dim,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
)
|
||||||
|
for layer_idx in range(num_attention_pooler_hidden_layers)
|
||||||
|
])
|
||||||
|
self.proj_out = nn.Linear(hidden_size, audio_acoustic_hidden_dim)
|
||||||
|
|
||||||
|
@can_return_tuple
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
B, T, D = x.shape
|
||||||
|
x = self.embed_tokens(x)
|
||||||
|
x = x.unsqueeze(2).repeat(1, 1, self.pool_window_size, 1)
|
||||||
|
special_tokens = self.special_tokens.expand(B, T, -1, -1)
|
||||||
|
x = x + special_tokens.to(x.device)
|
||||||
|
x = rearrange(x, "b t p c -> (b t) p c")
|
||||||
|
|
||||||
|
cache_position = torch.arange(0, x.shape[1], device=x.device)
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
hidden_states = x
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
|
seq_len = x.shape[1]
|
||||||
|
dtype = x.dtype
|
||||||
|
device = x.device
|
||||||
|
|
||||||
|
full_attn_mask = create_4d_mask(
|
||||||
|
seq_len=seq_len, dtype=dtype, device=device,
|
||||||
|
attention_mask=attention_mask, sliding_window=None,
|
||||||
|
is_sliding_window=False, is_causal=False
|
||||||
|
)
|
||||||
|
sliding_attn_mask = None
|
||||||
|
if self.use_sliding_window:
|
||||||
|
sliding_attn_mask = create_4d_mask(
|
||||||
|
seq_len=seq_len, dtype=dtype, device=device,
|
||||||
|
attention_mask=attention_mask, sliding_window=self.sliding_window,
|
||||||
|
is_sliding_window=True, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self_attn_mask_mapping = {
|
||||||
|
"full_attention": full_attn_mask,
|
||||||
|
"sliding_attention": sliding_attn_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
for layer_module in self.layers:
|
||||||
|
layer_outputs = layer_module(
|
||||||
|
hidden_states, position_embeddings,
|
||||||
|
attention_mask=self_attn_mask_mapping[layer_module.attention_type],
|
||||||
|
**flash_attn_kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
return rearrange(hidden_states, "(b t) p c -> b (t p) c", b=B, p=self.pool_window_size)
|
||||||
|
|
||||||
|
|
||||||
|
class AceStepTokenizer(nn.Module):
|
||||||
|
"""Container for AceStepAudioTokenizer + AudioTokenDetokenizer.
|
||||||
|
|
||||||
|
Provides encode/decode convenience methods for VAE latent discretization.
|
||||||
|
Used in cover song mode to convert source audio latents to discrete tokens
|
||||||
|
and back to continuous conditioning hints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int = 2048,
|
||||||
|
intermediate_size: int = 6144,
|
||||||
|
num_attention_heads: int = 16,
|
||||||
|
num_key_value_heads: int = 8,
|
||||||
|
rms_norm_eps: float = 1e-6,
|
||||||
|
attention_bias: bool = False,
|
||||||
|
attention_dropout: float = 0.0,
|
||||||
|
layer_types: Optional[list] = None,
|
||||||
|
head_dim: Optional[int] = None,
|
||||||
|
sliding_window: Optional[int] = 128,
|
||||||
|
use_sliding_window: bool = True,
|
||||||
|
rope_theta: float = 1000000,
|
||||||
|
max_position_embeddings: int = 32768,
|
||||||
|
initializer_range: float = 0.02,
|
||||||
|
audio_acoustic_hidden_dim: int = 64,
|
||||||
|
pool_window_size: int = 5,
|
||||||
|
fsq_dim: int = 2048,
|
||||||
|
fsq_input_levels: list = None,
|
||||||
|
fsq_input_num_quantizers: int = 1,
|
||||||
|
num_attention_pooler_hidden_layers: int = 2,
|
||||||
|
num_audio_decoder_hidden_layers: int = 24,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# Default layer_types matches target library config (24 alternating entries).
|
||||||
|
# Sub-modules (pooler/detokenizer) slice first N entries for their own layer count.
|
||||||
|
if layer_types is None:
|
||||||
|
layer_types = ["sliding_attention", "full_attention"] * 12
|
||||||
|
self.tokenizer = AceStepAudioTokenizer(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
rms_norm_eps=rms_norm_eps,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
attention_dropout=attention_dropout,
|
||||||
|
layer_types=layer_types,
|
||||||
|
head_dim=head_dim,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
use_sliding_window=use_sliding_window,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
initializer_range=initializer_range,
|
||||||
|
audio_acoustic_hidden_dim=audio_acoustic_hidden_dim,
|
||||||
|
pool_window_size=pool_window_size,
|
||||||
|
fsq_dim=fsq_dim,
|
||||||
|
fsq_input_levels=fsq_input_levels,
|
||||||
|
fsq_input_num_quantizers=fsq_input_num_quantizers,
|
||||||
|
num_attention_pooler_hidden_layers=num_attention_pooler_hidden_layers,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self.detokenizer = AudioTokenDetokenizer(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
rms_norm_eps=rms_norm_eps,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
attention_dropout=attention_dropout,
|
||||||
|
layer_types=layer_types,
|
||||||
|
head_dim=head_dim,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
use_sliding_window=use_sliding_window,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
initializer_range=initializer_range,
|
||||||
|
pool_window_size=pool_window_size,
|
||||||
|
audio_acoustic_hidden_dim=audio_acoustic_hidden_dim,
|
||||||
|
num_attention_pooler_hidden_layers=num_attention_pooler_hidden_layers,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""VAE latent [B, T, 64] → discrete tokens."""
|
||||||
|
return self.tokenizer(hidden_states)
|
||||||
|
|
||||||
|
def decode(self, quantized: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Discrete tokens [B, T/5, hidden_size] → continuous [B, T, 64]."""
|
||||||
|
return self.detokenizer(quantized)
|
||||||
|
|
||||||
|
def tokenize(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Convenience: [B, T, 64] → quantized + indices via patch rearrangement."""
|
||||||
|
return self.tokenizer.tokenize(x)
|
||||||
287
diffsynth/models/ace_step_vae.py
Normal file
287
diffsynth/models/ace_step_vae.py
Normal file
@@ -0,0 +1,287 @@
|
|||||||
|
# Copyright 2025 The ACESTEO Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""ACE-Step Audio VAE (AutoencoderOobleck CNN architecture).
|
||||||
|
|
||||||
|
This is a CNN-based VAE for audio waveform encoding/decoding.
|
||||||
|
It uses weight-normalized convolutions and Snake1d activations.
|
||||||
|
Does NOT depend on diffusers — pure nn.Module implementation.
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.utils import weight_norm, remove_weight_norm
|
||||||
|
|
||||||
|
|
||||||
|
class Snake1d(nn.Module):
|
||||||
|
"""Snake activation: x + 1/(beta+eps) * sin(alpha*x)^2."""
|
||||||
|
|
||||||
|
def __init__(self, hidden_dim: int, logscale: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1))
|
||||||
|
self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1))
|
||||||
|
self.logscale = logscale
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
shape = hidden_states.shape
|
||||||
|
alpha = torch.exp(self.alpha) if self.logscale else self.alpha
|
||||||
|
beta = torch.exp(self.beta) if self.logscale else self.beta
|
||||||
|
hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
|
||||||
|
hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2)
|
||||||
|
return hidden_states.reshape(shape)
|
||||||
|
|
||||||
|
|
||||||
|
class OobleckResidualUnit(nn.Module):
|
||||||
|
"""Residual unit: Snake1d → Conv1d(dilated) → Snake1d → Conv1d(1×1) + skip."""
|
||||||
|
|
||||||
|
def __init__(self, dimension: int = 16, dilation: int = 1):
|
||||||
|
super().__init__()
|
||||||
|
pad = ((7 - 1) * dilation) // 2
|
||||||
|
self.snake1 = Snake1d(dimension)
|
||||||
|
self.conv1 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad))
|
||||||
|
self.snake2 = Snake1d(dimension)
|
||||||
|
self.conv2 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=1))
|
||||||
|
|
||||||
|
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||||
|
output = self.conv1(self.snake1(hidden_state))
|
||||||
|
output = self.conv2(self.snake2(output))
|
||||||
|
padding = (hidden_state.shape[-1] - output.shape[-1]) // 2
|
||||||
|
if padding > 0:
|
||||||
|
hidden_state = hidden_state[..., padding:-padding]
|
||||||
|
return hidden_state + output
|
||||||
|
|
||||||
|
|
||||||
|
class OobleckEncoderBlock(nn.Module):
|
||||||
|
"""Encoder block: 3 residual units + downsampling conv."""
|
||||||
|
|
||||||
|
def __init__(self, input_dim: int, output_dim: int, stride: int = 1):
|
||||||
|
super().__init__()
|
||||||
|
self.res_unit1 = OobleckResidualUnit(input_dim, dilation=1)
|
||||||
|
self.res_unit2 = OobleckResidualUnit(input_dim, dilation=3)
|
||||||
|
self.res_unit3 = OobleckResidualUnit(input_dim, dilation=9)
|
||||||
|
self.snake1 = Snake1d(input_dim)
|
||||||
|
self.conv1 = weight_norm(
|
||||||
|
nn.Conv1d(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2))
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_state = self.res_unit1(hidden_state)
|
||||||
|
hidden_state = self.res_unit2(hidden_state)
|
||||||
|
hidden_state = self.snake1(self.res_unit3(hidden_state))
|
||||||
|
return self.conv1(hidden_state)
|
||||||
|
|
||||||
|
|
||||||
|
class OobleckDecoderBlock(nn.Module):
|
||||||
|
"""Decoder block: upsampling conv + 3 residual units."""
|
||||||
|
|
||||||
|
def __init__(self, input_dim: int, output_dim: int, stride: int = 1):
|
||||||
|
super().__init__()
|
||||||
|
self.snake1 = Snake1d(input_dim)
|
||||||
|
self.conv_t1 = weight_norm(
|
||||||
|
nn.ConvTranspose1d(
|
||||||
|
input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.res_unit1 = OobleckResidualUnit(output_dim, dilation=1)
|
||||||
|
self.res_unit2 = OobleckResidualUnit(output_dim, dilation=3)
|
||||||
|
self.res_unit3 = OobleckResidualUnit(output_dim, dilation=9)
|
||||||
|
|
||||||
|
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_state = self.snake1(hidden_state)
|
||||||
|
hidden_state = self.conv_t1(hidden_state)
|
||||||
|
hidden_state = self.res_unit1(hidden_state)
|
||||||
|
hidden_state = self.res_unit2(hidden_state)
|
||||||
|
return self.res_unit3(hidden_state)
|
||||||
|
|
||||||
|
|
||||||
|
class OobleckEncoder(nn.Module):
|
||||||
|
"""Full encoder: audio → latent representation [B, encoder_hidden_size, T'].
|
||||||
|
|
||||||
|
conv1 → [blocks] → snake1 → conv2
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder_hidden_size: int = 128,
|
||||||
|
audio_channels: int = 2,
|
||||||
|
downsampling_ratios: list = None,
|
||||||
|
channel_multiples: list = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
downsampling_ratios = downsampling_ratios or [2, 4, 4, 6, 10]
|
||||||
|
channel_multiples = channel_multiples or [1, 2, 4, 8, 16]
|
||||||
|
channel_multiples = [1] + channel_multiples
|
||||||
|
|
||||||
|
self.conv1 = weight_norm(nn.Conv1d(audio_channels, encoder_hidden_size, kernel_size=7, padding=3))
|
||||||
|
|
||||||
|
self.block = nn.ModuleList()
|
||||||
|
for stride_index, stride in enumerate(downsampling_ratios):
|
||||||
|
self.block.append(
|
||||||
|
OobleckEncoderBlock(
|
||||||
|
input_dim=encoder_hidden_size * channel_multiples[stride_index],
|
||||||
|
output_dim=encoder_hidden_size * channel_multiples[stride_index + 1],
|
||||||
|
stride=stride,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
d_model = encoder_hidden_size * channel_multiples[-1]
|
||||||
|
self.snake1 = Snake1d(d_model)
|
||||||
|
self.conv2 = weight_norm(nn.Conv1d(d_model, encoder_hidden_size, kernel_size=3, padding=1))
|
||||||
|
|
||||||
|
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_state = self.conv1(hidden_state)
|
||||||
|
for block in self.block:
|
||||||
|
hidden_state = block(hidden_state)
|
||||||
|
hidden_state = self.snake1(hidden_state)
|
||||||
|
return self.conv2(hidden_state)
|
||||||
|
|
||||||
|
|
||||||
|
class OobleckDecoder(nn.Module):
|
||||||
|
"""Full decoder: latent → audio waveform [B, audio_channels, T].
|
||||||
|
|
||||||
|
conv1 → [blocks] → snake1 → conv2(no bias)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int = 128,
|
||||||
|
input_channels: int = 64,
|
||||||
|
audio_channels: int = 2,
|
||||||
|
upsampling_ratios: list = None,
|
||||||
|
channel_multiples: list = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
upsampling_ratios = upsampling_ratios or [10, 6, 4, 4, 2]
|
||||||
|
channel_multiples = channel_multiples or [1, 2, 4, 8, 16]
|
||||||
|
channel_multiples = [1] + channel_multiples
|
||||||
|
|
||||||
|
self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * channel_multiples[-1], kernel_size=7, padding=3))
|
||||||
|
|
||||||
|
self.block = nn.ModuleList()
|
||||||
|
for stride_index, stride in enumerate(upsampling_ratios):
|
||||||
|
self.block.append(
|
||||||
|
OobleckDecoderBlock(
|
||||||
|
input_dim=channels * channel_multiples[len(upsampling_ratios) - stride_index],
|
||||||
|
output_dim=channels * channel_multiples[len(upsampling_ratios) - stride_index - 1],
|
||||||
|
stride=stride,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.snake1 = Snake1d(channels)
|
||||||
|
# conv2 has no bias (matches checkpoint: only weight_g/weight_v, no bias key)
|
||||||
|
self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, kernel_size=7, padding=3, bias=False))
|
||||||
|
|
||||||
|
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_state = self.conv1(hidden_state)
|
||||||
|
for block in self.block:
|
||||||
|
hidden_state = block(hidden_state)
|
||||||
|
hidden_state = self.snake1(hidden_state)
|
||||||
|
return self.conv2(hidden_state)
|
||||||
|
|
||||||
|
|
||||||
|
class OobleckDiagonalGaussianDistribution(object):
|
||||||
|
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
||||||
|
self.parameters = parameters
|
||||||
|
self.mean, self.scale = parameters.chunk(2, dim=1)
|
||||||
|
self.std = nn.functional.softplus(self.scale) + 1e-4
|
||||||
|
self.var = self.std * self.std
|
||||||
|
self.logvar = torch.log(self.var)
|
||||||
|
self.deterministic = deterministic
|
||||||
|
|
||||||
|
def sample(self, generator: torch.Generator | None = None) -> torch.Tensor:
|
||||||
|
# make sure sample is on the same device as the parameters and has same dtype
|
||||||
|
sample = torch.randn(
|
||||||
|
self.mean.shape,
|
||||||
|
generator=generator,
|
||||||
|
device=self.parameters.device,
|
||||||
|
dtype=self.parameters.dtype,
|
||||||
|
)
|
||||||
|
x = self.mean + self.std * sample
|
||||||
|
return x
|
||||||
|
|
||||||
|
def kl(self, other: "OobleckDiagonalGaussianDistribution" = None) -> torch.Tensor:
|
||||||
|
if self.deterministic:
|
||||||
|
return torch.Tensor([0.0])
|
||||||
|
else:
|
||||||
|
if other is None:
|
||||||
|
return (self.mean * self.mean + self.var - self.logvar - 1.0).sum(1).mean()
|
||||||
|
else:
|
||||||
|
normalized_diff = torch.pow(self.mean - other.mean, 2) / other.var
|
||||||
|
var_ratio = self.var / other.var
|
||||||
|
logvar_diff = self.logvar - other.logvar
|
||||||
|
|
||||||
|
kl = normalized_diff + var_ratio + logvar_diff - 1
|
||||||
|
|
||||||
|
kl = kl.sum(1).mean()
|
||||||
|
return kl
|
||||||
|
|
||||||
|
|
||||||
|
class AceStepVAE(nn.Module):
|
||||||
|
"""Audio VAE for ACE-Step (AutoencoderOobleck architecture).
|
||||||
|
|
||||||
|
Encodes audio waveform → latent, decodes latent → audio waveform.
|
||||||
|
Uses Snake1d activations and weight-normalized convolutions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder_hidden_size: int = 128,
|
||||||
|
downsampling_ratios: list = None,
|
||||||
|
channel_multiples: list = None,
|
||||||
|
decoder_channels: int = 128,
|
||||||
|
decoder_input_channels: int = 64,
|
||||||
|
audio_channels: int = 2,
|
||||||
|
sampling_rate: int = 48000,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
downsampling_ratios = downsampling_ratios or [2, 4, 4, 6, 10]
|
||||||
|
channel_multiples = channel_multiples or [1, 2, 4, 8, 16]
|
||||||
|
upsampling_ratios = downsampling_ratios[::-1]
|
||||||
|
|
||||||
|
self.encoder = OobleckEncoder(
|
||||||
|
encoder_hidden_size=encoder_hidden_size,
|
||||||
|
audio_channels=audio_channels,
|
||||||
|
downsampling_ratios=downsampling_ratios,
|
||||||
|
channel_multiples=channel_multiples,
|
||||||
|
)
|
||||||
|
self.decoder = OobleckDecoder(
|
||||||
|
channels=decoder_channels,
|
||||||
|
input_channels=decoder_input_channels,
|
||||||
|
audio_channels=audio_channels,
|
||||||
|
upsampling_ratios=upsampling_ratios,
|
||||||
|
channel_multiples=channel_multiples,
|
||||||
|
)
|
||||||
|
self.sampling_rate = sampling_rate
|
||||||
|
|
||||||
|
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Audio waveform [B, audio_channels, T] → latent [B, decoder_input_channels, T']."""
|
||||||
|
h = self.encoder(x)
|
||||||
|
output = OobleckDiagonalGaussianDistribution(h).sample()
|
||||||
|
return output
|
||||||
|
|
||||||
|
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Latent [B, decoder_input_channels, T] → audio waveform [B, audio_channels, T']."""
|
||||||
|
return self.decoder(z)
|
||||||
|
|
||||||
|
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Full round-trip: encode → decode."""
|
||||||
|
z = self.encode(sample)
|
||||||
|
return self.decode(z)
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
"""Remove weight normalization from all conv layers (for export/inference)."""
|
||||||
|
for module in self.modules():
|
||||||
|
if isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d):
|
||||||
|
remove_weight_norm(module)
|
||||||
1307
diffsynth/models/anima_dit.py
Normal file
1307
diffsynth/models/anima_dit.py
Normal file
File diff suppressed because it is too large
Load Diff
362
diffsynth/models/ernie_image_dit.py
Normal file
362
diffsynth/models/ernie_image_dit.py
Normal file
@@ -0,0 +1,362 @@
|
|||||||
|
"""
|
||||||
|
Ernie-Image DiT for DiffSynth-Studio.
|
||||||
|
|
||||||
|
Refactored from diffusers ErnieImageTransformer2DModel to use DiffSynth core modules.
|
||||||
|
Default parameters from actual checkpoint config.json (PaddlePaddle/ERNIE-Image transformer).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from ..core.attention import attention_forward
|
||||||
|
from ..core.gradient import gradient_checkpoint_forward
|
||||||
|
from .flux2_dit import Timesteps, TimestepEmbedding
|
||||||
|
|
||||||
|
|
||||||
|
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||||
|
assert dim % 2 == 0
|
||||||
|
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||||
|
omega = 1.0 / (theta ** scale)
|
||||||
|
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||||
|
return out.float()
|
||||||
|
|
||||||
|
|
||||||
|
class ErnieImageEmbedND3(nn.Module):
|
||||||
|
def __init__(self, dim: int, theta: int, axes_dim: Tuple[int, int, int]):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.theta = theta
|
||||||
|
self.axes_dim = list(axes_dim)
|
||||||
|
|
||||||
|
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1)
|
||||||
|
emb = emb.unsqueeze(2)
|
||||||
|
return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1)
|
||||||
|
|
||||||
|
|
||||||
|
class ErnieImagePatchEmbedDynamic(nn.Module):
|
||||||
|
def __init__(self, in_channels: int, embed_dim: int, patch_size: int):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.proj(x)
|
||||||
|
batch_size, dim, height, width = x.shape
|
||||||
|
return x.reshape(batch_size, dim, height * width).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
class ErnieImageSingleStreamAttnProcessor:
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn: "ErnieImageAttention",
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
freqs_cis: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
key = attn.to_k(hidden_states)
|
||||||
|
value = attn.to_v(hidden_states)
|
||||||
|
|
||||||
|
query = query.unflatten(-1, (attn.heads, -1))
|
||||||
|
key = key.unflatten(-1, (attn.heads, -1))
|
||||||
|
value = value.unflatten(-1, (attn.heads, -1))
|
||||||
|
|
||||||
|
if attn.norm_q is not None:
|
||||||
|
query = attn.norm_q(query)
|
||||||
|
if attn.norm_k is not None:
|
||||||
|
key = attn.norm_k(key)
|
||||||
|
|
||||||
|
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||||
|
rot_dim = freqs_cis.shape[-1]
|
||||||
|
x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:]
|
||||||
|
cos_ = torch.cos(freqs_cis).to(x.dtype)
|
||||||
|
sin_ = torch.sin(freqs_cis).to(x.dtype)
|
||||||
|
x1, x2 = x.chunk(2, dim=-1)
|
||||||
|
x_rotated = torch.cat((-x2, x1), dim=-1)
|
||||||
|
return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1)
|
||||||
|
|
||||||
|
if freqs_cis is not None:
|
||||||
|
query = apply_rotary_emb(query, freqs_cis)
|
||||||
|
key = apply_rotary_emb(key, freqs_cis)
|
||||||
|
|
||||||
|
if attention_mask is not None and attention_mask.ndim == 2:
|
||||||
|
attention_mask = attention_mask[:, None, None, :]
|
||||||
|
|
||||||
|
hidden_states = attention_forward(
|
||||||
|
query, key, value,
|
||||||
|
q_pattern="b s n d",
|
||||||
|
k_pattern="b s n d",
|
||||||
|
v_pattern="b s n d",
|
||||||
|
out_pattern="b s n d",
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.flatten(2, 3)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
output = attn.to_out[0](hidden_states)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class ErnieImageAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
query_dim: int,
|
||||||
|
heads: int = 8,
|
||||||
|
dim_head: int = 64,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
bias: bool = False,
|
||||||
|
qk_norm: str = "rms_norm",
|
||||||
|
out_bias: bool = True,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
out_dim: int = None,
|
||||||
|
elementwise_affine: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.head_dim = dim_head
|
||||||
|
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||||
|
self.query_dim = query_dim
|
||||||
|
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||||
|
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||||
|
|
||||||
|
self.use_bias = bias
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||||
|
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||||
|
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||||
|
|
||||||
|
if qk_norm == "layer_norm":
|
||||||
|
self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||||
|
self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||||
|
elif qk_norm == "rms_norm":
|
||||||
|
self.norm_q = nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||||
|
self.norm_k = nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'rms_norm'."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.to_out = nn.ModuleList([])
|
||||||
|
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||||
|
|
||||||
|
self.processor = ErnieImageSingleStreamAttnProcessor()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.processor(self, hidden_states, attention_mask, image_rotary_emb)
|
||||||
|
|
||||||
|
|
||||||
|
class ErnieImageFeedForward(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, ffn_hidden_size: int):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
|
||||||
|
self.up_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
|
||||||
|
self.linear_fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class ErnieImageRMSNorm(nn.Module):
|
||||||
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||||
|
hidden_states = hidden_states * self.weight
|
||||||
|
return hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class ErnieImageSharedAdaLNBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
ffn_hidden_size: int,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
qk_layernorm: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.adaLN_sa_ln = ErnieImageRMSNorm(hidden_size, eps=eps)
|
||||||
|
self.self_attention = ErnieImageAttention(
|
||||||
|
query_dim=hidden_size,
|
||||||
|
dim_head=hidden_size // num_heads,
|
||||||
|
heads=num_heads,
|
||||||
|
qk_norm="rms_norm" if qk_layernorm else None,
|
||||||
|
eps=eps,
|
||||||
|
bias=False,
|
||||||
|
out_bias=False,
|
||||||
|
)
|
||||||
|
self.adaLN_mlp_ln = ErnieImageRMSNorm(hidden_size, eps=eps)
|
||||||
|
self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
rotary_pos_emb: torch.Tensor,
|
||||||
|
temb: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb
|
||||||
|
residual = x
|
||||||
|
x = self.adaLN_sa_ln(x)
|
||||||
|
x = (x.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype)
|
||||||
|
x_bsh = x.permute(1, 0, 2)
|
||||||
|
attn_out = self.self_attention(x_bsh, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
|
||||||
|
attn_out = attn_out.permute(1, 0, 2)
|
||||||
|
x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype)
|
||||||
|
residual = x
|
||||||
|
x = self.adaLN_mlp_ln(x)
|
||||||
|
x = (x.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype)
|
||||||
|
return residual + (gate_mlp.float() * self.mlp(x).float()).to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class ErnieImageAdaLNContinuous(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=eps)
|
||||||
|
self.linear = nn.Linear(hidden_size, hidden_size * 2)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
|
||||||
|
scale, shift = self.linear(conditioning).chunk(2, dim=-1)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ErnieImageDiT(nn.Module):
|
||||||
|
"""
|
||||||
|
Ernie-Image DiT model for DiffSynth-Studio.
|
||||||
|
|
||||||
|
Architecture: SharedAdaLN + RoPE 3D + Joint Image-Text Attention.
|
||||||
|
Internal format: [S, B, H] for transformer blocks, [B, S, H] for attention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int = 4096,
|
||||||
|
num_attention_heads: int = 32,
|
||||||
|
num_layers: int = 36,
|
||||||
|
ffn_hidden_size: int = 12288,
|
||||||
|
in_channels: int = 128,
|
||||||
|
out_channels: int = 128,
|
||||||
|
patch_size: int = 1,
|
||||||
|
text_in_dim: int = 3072,
|
||||||
|
rope_theta: int = 256,
|
||||||
|
rope_axes_dim: Tuple[int, int, int] = (32, 48, 48),
|
||||||
|
eps: float = 1e-6,
|
||||||
|
qk_layernorm: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_heads = num_attention_heads
|
||||||
|
self.head_dim = hidden_size // num_attention_heads
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.text_in_dim = text_in_dim
|
||||||
|
|
||||||
|
self.x_embedder = ErnieImagePatchEmbedDynamic(in_channels, hidden_size, patch_size)
|
||||||
|
self.text_proj = nn.Linear(text_in_dim, hidden_size, bias=False) if text_in_dim != hidden_size else None
|
||||||
|
self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False, downscale_freq_shift=0)
|
||||||
|
self.time_embedding = TimestepEmbedding(hidden_size, hidden_size)
|
||||||
|
self.pos_embed = ErnieImageEmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim)
|
||||||
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size))
|
||||||
|
nn.init.zeros_(self.adaLN_modulation[-1].weight)
|
||||||
|
nn.init.zeros_(self.adaLN_modulation[-1].bias)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
ErnieImageSharedAdaLNBlock(hidden_size, num_attention_heads, ffn_hidden_size, eps, qk_layernorm=qk_layernorm)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
self.final_norm = ErnieImageAdaLNContinuous(hidden_size, eps)
|
||||||
|
self.final_linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels)
|
||||||
|
nn.init.zeros_(self.final_linear.weight)
|
||||||
|
nn.init.zeros_(self.final_linear.bias)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
timestep: torch.Tensor,
|
||||||
|
text_bth: torch.Tensor,
|
||||||
|
text_lens: torch.Tensor,
|
||||||
|
use_gradient_checkpointing: bool = False,
|
||||||
|
use_gradient_checkpointing_offload: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
device, dtype = hidden_states.device, hidden_states.dtype
|
||||||
|
B, C, H, W = hidden_states.shape
|
||||||
|
p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size
|
||||||
|
N_img = Hp * Wp
|
||||||
|
|
||||||
|
img_sbh = self.x_embedder(hidden_states).transpose(0, 1).contiguous()
|
||||||
|
|
||||||
|
if self.text_proj is not None and text_bth.numel() > 0:
|
||||||
|
text_bth = self.text_proj(text_bth)
|
||||||
|
Tmax = text_bth.shape[1]
|
||||||
|
text_sbh = text_bth.transpose(0, 1).contiguous()
|
||||||
|
|
||||||
|
x = torch.cat([img_sbh, text_sbh], dim=0)
|
||||||
|
S = x.shape[0]
|
||||||
|
|
||||||
|
text_ids = torch.cat([
|
||||||
|
torch.arange(Tmax, device=device, dtype=torch.float32).view(1, Tmax, 1).expand(B, -1, -1),
|
||||||
|
torch.zeros((B, Tmax, 2), device=device)
|
||||||
|
], dim=-1) if Tmax > 0 else torch.zeros((B, 0, 3), device=device)
|
||||||
|
grid_yx = torch.stack(
|
||||||
|
torch.meshgrid(torch.arange(Hp, device=device, dtype=torch.float32),
|
||||||
|
torch.arange(Wp, device=device, dtype=torch.float32), indexing="ij"),
|
||||||
|
dim=-1
|
||||||
|
).reshape(-1, 2)
|
||||||
|
image_ids = torch.cat([
|
||||||
|
text_lens.float().view(B, 1, 1).expand(-1, N_img, -1),
|
||||||
|
grid_yx.view(1, N_img, 2).expand(B, -1, -1)
|
||||||
|
], dim=-1)
|
||||||
|
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1))
|
||||||
|
|
||||||
|
valid_text = torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) if Tmax > 0 else torch.zeros((B, 0), device=device, dtype=torch.bool)
|
||||||
|
attention_mask = torch.cat([
|
||||||
|
torch.ones((B, N_img), device=device, dtype=torch.bool),
|
||||||
|
valid_text
|
||||||
|
], dim=1)[:, None, None, :]
|
||||||
|
|
||||||
|
sample = self.time_proj(timestep.to(dtype))
|
||||||
|
sample = sample.to(self.time_embedding.linear_1.weight.dtype)
|
||||||
|
c = self.time_embedding(sample)
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
|
||||||
|
t.unsqueeze(0).expand(S, -1, -1).contiguous()
|
||||||
|
for t in self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||||
|
]
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp]
|
||||||
|
if torch.is_grad_enabled() and use_gradient_checkpointing:
|
||||||
|
x = gradient_checkpoint_forward(
|
||||||
|
layer,
|
||||||
|
use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload,
|
||||||
|
x,
|
||||||
|
rotary_pos_emb,
|
||||||
|
temb,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = layer(x, rotary_pos_emb, temb, attention_mask)
|
||||||
|
|
||||||
|
x = self.final_norm(x, c).type_as(x)
|
||||||
|
patches = self.final_linear(x)[:N_img].transpose(0, 1).contiguous()
|
||||||
|
output = patches.view(B, Hp, Wp, p, p, self.out_channels).permute(0, 5, 1, 3, 2, 4).contiguous().view(B, self.out_channels, H, W)
|
||||||
|
|
||||||
|
return output
|
||||||
76
diffsynth/models/ernie_image_text_encoder.py
Normal file
76
diffsynth/models/ernie_image_text_encoder.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
"""
|
||||||
|
Ernie-Image TextEncoder for DiffSynth-Studio.
|
||||||
|
|
||||||
|
Wraps transformers Ministral3Model to output text embeddings.
|
||||||
|
Pattern: lazy import + manual config dict + torch.nn.Module wrapper.
|
||||||
|
Only loads the text (language) model, ignoring vision components.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class ErnieImageTextEncoder(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Text encoder using Ministral3Model (transformers).
|
||||||
|
Only the text_config portion of the full Mistral3Model checkpoint.
|
||||||
|
Uses the base model (no lm_head) since the checkpoint only has embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
from transformers import Ministral3Config, Ministral3Model
|
||||||
|
|
||||||
|
text_config = {
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"bos_token_id": 1,
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"eos_token_id": 2,
|
||||||
|
"head_dim": 128,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 3072,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 9216,
|
||||||
|
"max_position_embeddings": 262144,
|
||||||
|
"model_type": "ministral3",
|
||||||
|
"num_attention_heads": 32,
|
||||||
|
"num_hidden_layers": 26,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"pad_token_id": 11,
|
||||||
|
"rms_norm_eps": 1e-05,
|
||||||
|
"rope_parameters": {
|
||||||
|
"beta_fast": 32.0,
|
||||||
|
"beta_slow": 1.0,
|
||||||
|
"factor": 16.0,
|
||||||
|
"llama_4_scaling_beta": 0.1,
|
||||||
|
"mscale": 1.0,
|
||||||
|
"mscale_all_dim": 1.0,
|
||||||
|
"original_max_position_embeddings": 16384,
|
||||||
|
"rope_theta": 1000000.0,
|
||||||
|
"rope_type": "yarn",
|
||||||
|
"type": "yarn",
|
||||||
|
},
|
||||||
|
"sliding_window": None,
|
||||||
|
"tie_word_embeddings": True,
|
||||||
|
"use_cache": True,
|
||||||
|
"vocab_size": 131072,
|
||||||
|
}
|
||||||
|
config = Ministral3Config(**text_config)
|
||||||
|
self.model = Ministral3Model(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
position_ids=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
output_hidden_states=True,
|
||||||
|
return_dict=True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return (outputs.hidden_states,)
|
||||||
@@ -879,6 +879,9 @@ class Flux2Modulation(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Flux2DiT(torch.nn.Module):
|
class Flux2DiT(torch.nn.Module):
|
||||||
|
|
||||||
|
_repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
patch_size: int = 1,
|
patch_size: int = 1,
|
||||||
|
|||||||
@@ -275,6 +275,9 @@ class AdaLayerNormContinuous(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FluxDiT(torch.nn.Module):
|
class FluxDiT(torch.nn.Module):
|
||||||
|
|
||||||
|
_repeated_blocks = ["FluxJointTransformerBlock", "FluxSingleTransformerBlock"]
|
||||||
|
|
||||||
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
|
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||||
|
|||||||
636
diffsynth/models/joyai_image_dit.py
Normal file
636
diffsynth/models/joyai_image_dit.py
Normal file
@@ -0,0 +1,636 @@
|
|||||||
|
import math
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from ..core.attention import attention_forward
|
||||||
|
from ..core.gradient import gradient_checkpoint_forward
|
||||||
|
|
||||||
|
|
||||||
|
def get_timestep_embedding(
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
embedding_dim: int,
|
||||||
|
flip_sin_to_cos: bool = False,
|
||||||
|
downscale_freq_shift: float = 1,
|
||||||
|
scale: float = 1,
|
||||||
|
max_period: int = 10000,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||||
|
half_dim = embedding_dim // 2
|
||||||
|
exponent = -math.log(max_period) * torch.arange(
|
||||||
|
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
||||||
|
)
|
||||||
|
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||||
|
emb = torch.exp(exponent)
|
||||||
|
emb = timesteps[:, None].float() * emb[None, :]
|
||||||
|
emb = scale * emb
|
||||||
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||||
|
if flip_sin_to_cos:
|
||||||
|
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||||
|
if embedding_dim % 2 == 1:
|
||||||
|
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class Timesteps(nn.Module):
|
||||||
|
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
|
||||||
|
super().__init__()
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.flip_sin_to_cos = flip_sin_to_cos
|
||||||
|
self.downscale_freq_shift = downscale_freq_shift
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||||
|
return get_timestep_embedding(
|
||||||
|
timesteps,
|
||||||
|
self.num_channels,
|
||||||
|
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||||
|
downscale_freq_shift=self.downscale_freq_shift,
|
||||||
|
scale=self.scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedding(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
time_embed_dim: int,
|
||||||
|
act_fn: str = "silu",
|
||||||
|
out_dim: int = None,
|
||||||
|
post_act_fn: Optional[str] = None,
|
||||||
|
cond_proj_dim=None,
|
||||||
|
sample_proj_bias=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
|
||||||
|
if cond_proj_dim is not None:
|
||||||
|
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
||||||
|
else:
|
||||||
|
self.cond_proj = None
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
time_embed_dim_out = out_dim if out_dim is not None else time_embed_dim
|
||||||
|
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
||||||
|
self.post_act = nn.SiLU() if post_act_fn == "silu" else None
|
||||||
|
|
||||||
|
def forward(self, sample, condition=None):
|
||||||
|
if condition is not None:
|
||||||
|
sample = sample + self.cond_proj(condition)
|
||||||
|
sample = self.linear_1(sample)
|
||||||
|
if self.act is not None:
|
||||||
|
sample = self.act(sample)
|
||||||
|
sample = self.linear_2(sample)
|
||||||
|
if self.post_act is not None:
|
||||||
|
sample = self.post_act(sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class PixArtAlphaTextProjection(nn.Module):
|
||||||
|
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
|
||||||
|
super().__init__()
|
||||||
|
if out_features is None:
|
||||||
|
out_features = hidden_size
|
||||||
|
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
|
||||||
|
if act_fn == "gelu_tanh":
|
||||||
|
self.act_1 = nn.GELU(approximate="tanh")
|
||||||
|
elif act_fn == "silu":
|
||||||
|
self.act_1 = nn.SiLU()
|
||||||
|
else:
|
||||||
|
self.act_1 = nn.GELU(approximate="tanh")
|
||||||
|
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
|
||||||
|
|
||||||
|
def forward(self, caption):
|
||||||
|
hidden_states = self.linear_1(caption)
|
||||||
|
hidden_states = self.act_1(hidden_states)
|
||||||
|
hidden_states = self.linear_2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class GELU(nn.Module):
|
||||||
|
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
|
||||||
|
self.approximate = approximate
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = self.proj(hidden_states)
|
||||||
|
hidden_states = F.gelu(hidden_states, approximate=self.approximate)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
dim_out: Optional[int] = None,
|
||||||
|
mult: int = 4,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
activation_fn: str = "geglu",
|
||||||
|
final_dropout: bool = False,
|
||||||
|
inner_dim=None,
|
||||||
|
bias: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if inner_dim is None:
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
dim_out = dim_out if dim_out is not None else dim
|
||||||
|
|
||||||
|
# Build activation + projection matching diffusers pattern
|
||||||
|
if activation_fn == "gelu":
|
||||||
|
act_fn = GELU(dim, inner_dim, bias=bias)
|
||||||
|
elif activation_fn == "gelu-approximate":
|
||||||
|
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
|
||||||
|
else:
|
||||||
|
act_fn = GELU(dim, inner_dim, bias=bias)
|
||||||
|
|
||||||
|
self.net = nn.ModuleList([])
|
||||||
|
self.net.append(act_fn)
|
||||||
|
self.net.append(nn.Dropout(dropout))
|
||||||
|
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
|
||||||
|
if final_dropout:
|
||||||
|
self.net.append(nn.Dropout(dropout))
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||||
|
for module in self.net:
|
||||||
|
hidden_states = module(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def _to_tuple(x, dim=2):
|
||||||
|
if isinstance(x, int):
|
||||||
|
return (x,) * dim
|
||||||
|
elif len(x) == dim:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Expected length {dim} or int, but got {x}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_meshgrid_nd(start, *args, dim=2):
|
||||||
|
if len(args) == 0:
|
||||||
|
num = _to_tuple(start, dim=dim)
|
||||||
|
start = (0,) * dim
|
||||||
|
stop = num
|
||||||
|
elif len(args) == 1:
|
||||||
|
start = _to_tuple(start, dim=dim)
|
||||||
|
stop = _to_tuple(args[0], dim=dim)
|
||||||
|
num = [stop[i] - start[i] for i in range(dim)]
|
||||||
|
elif len(args) == 2:
|
||||||
|
start = _to_tuple(start, dim=dim)
|
||||||
|
stop = _to_tuple(args[0], dim=dim)
|
||||||
|
num = _to_tuple(args[1], dim=dim)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
||||||
|
axis_grid = []
|
||||||
|
for i in range(dim):
|
||||||
|
a, b, n = start[i], stop[i], num[i]
|
||||||
|
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
||||||
|
axis_grid.append(g)
|
||||||
|
grid = torch.meshgrid(*axis_grid, indexing="ij")
|
||||||
|
grid = torch.stack(grid, dim=0)
|
||||||
|
return grid
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_for_broadcast(freqs_cis, x, head_first=False):
|
||||||
|
ndim = x.ndim
|
||||||
|
assert 0 <= 1 < ndim
|
||||||
|
if isinstance(freqs_cis, tuple):
|
||||||
|
if head_first:
|
||||||
|
assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1])
|
||||||
|
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||||
|
else:
|
||||||
|
assert freqs_cis[0].shape == (x.shape[1], x.shape[-1])
|
||||||
|
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||||
|
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
||||||
|
else:
|
||||||
|
if head_first:
|
||||||
|
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
|
||||||
|
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||||
|
else:
|
||||||
|
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||||||
|
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||||
|
return freqs_cis.view(*shape)
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
||||||
|
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb(xq, xk, freqs_cis, head_first=False):
|
||||||
|
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)
|
||||||
|
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
||||||
|
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
||||||
|
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
||||||
|
return xq_out, xk_out
|
||||||
|
|
||||||
|
|
||||||
|
def get_1d_rotary_pos_embed(dim, pos, theta=10000.0, use_real=False, theta_rescale_factor=1.0, interpolation_factor=1.0):
|
||||||
|
if isinstance(pos, int):
|
||||||
|
pos = torch.arange(pos).float()
|
||||||
|
if theta_rescale_factor != 1.0:
|
||||||
|
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
||||||
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||||
|
freqs = torch.outer(pos * interpolation_factor, freqs)
|
||||||
|
if use_real:
|
||||||
|
freqs_cos = freqs.cos().repeat_interleave(2, dim=1)
|
||||||
|
freqs_sin = freqs.sin().repeat_interleave(2, dim=1)
|
||||||
|
return freqs_cos, freqs_sin
|
||||||
|
else:
|
||||||
|
return torch.polar(torch.ones_like(freqs), freqs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_nd_rotary_pos_embed(rope_dim_list, start, *args, theta=10000.0, use_real=False,
|
||||||
|
txt_rope_size=None, theta_rescale_factor=1.0, interpolation_factor=1.0):
|
||||||
|
grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list))
|
||||||
|
if isinstance(theta_rescale_factor, (int, float)):
|
||||||
|
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
|
||||||
|
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
|
||||||
|
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
|
||||||
|
if isinstance(interpolation_factor, (int, float)):
|
||||||
|
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
|
||||||
|
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
|
||||||
|
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
|
||||||
|
embs = []
|
||||||
|
for i in range(len(rope_dim_list)):
|
||||||
|
emb = get_1d_rotary_pos_embed(
|
||||||
|
rope_dim_list[i], grid[i].reshape(-1), theta,
|
||||||
|
use_real=use_real, theta_rescale_factor=theta_rescale_factor[i],
|
||||||
|
interpolation_factor=interpolation_factor[i],
|
||||||
|
)
|
||||||
|
embs.append(emb)
|
||||||
|
if use_real:
|
||||||
|
vis_emb = (torch.cat([emb[0] for emb in embs], dim=1), torch.cat([emb[1] for emb in embs], dim=1))
|
||||||
|
else:
|
||||||
|
vis_emb = torch.cat(embs, dim=1)
|
||||||
|
if txt_rope_size is not None:
|
||||||
|
embs_txt = []
|
||||||
|
vis_max_ids = grid.view(-1).max().item()
|
||||||
|
grid_txt = torch.arange(txt_rope_size) + vis_max_ids + 1
|
||||||
|
for i in range(len(rope_dim_list)):
|
||||||
|
emb = get_1d_rotary_pos_embed(
|
||||||
|
rope_dim_list[i], grid_txt, theta,
|
||||||
|
use_real=use_real, theta_rescale_factor=theta_rescale_factor[i],
|
||||||
|
interpolation_factor=interpolation_factor[i],
|
||||||
|
)
|
||||||
|
embs_txt.append(emb)
|
||||||
|
if use_real:
|
||||||
|
txt_emb = (torch.cat([emb[0] for emb in embs_txt], dim=1), torch.cat([emb[1] for emb in embs_txt], dim=1))
|
||||||
|
else:
|
||||||
|
txt_emb = torch.cat(embs_txt, dim=1)
|
||||||
|
else:
|
||||||
|
txt_emb = None
|
||||||
|
return vis_emb, txt_emb
|
||||||
|
|
||||||
|
|
||||||
|
class ModulateWan(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None):
|
||||||
|
super().__init__()
|
||||||
|
self.factor = factor
|
||||||
|
factory_kwargs = {"dtype": dtype, "device": device}
|
||||||
|
self.modulate_table = nn.Parameter(
|
||||||
|
torch.zeros(1, factor, hidden_size, **factory_kwargs) / hidden_size**0.5,
|
||||||
|
requires_grad=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if len(x.shape) != 3:
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
return [o.squeeze(1) for o in (self.modulate_table + x).chunk(self.factor, dim=1)]
|
||||||
|
|
||||||
|
|
||||||
|
def modulate(x, shift=None, scale=None):
|
||||||
|
if scale is None and shift is None:
|
||||||
|
return x
|
||||||
|
elif shift is None:
|
||||||
|
return x * (1 + scale.unsqueeze(1))
|
||||||
|
elif scale is None:
|
||||||
|
return x + shift.unsqueeze(1)
|
||||||
|
else:
|
||||||
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_gate(x, gate=None, tanh=False):
|
||||||
|
if gate is None:
|
||||||
|
return x
|
||||||
|
if tanh:
|
||||||
|
return x * gate.unsqueeze(1).tanh()
|
||||||
|
else:
|
||||||
|
return x * gate.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
def load_modulation(modulate_type: str, hidden_size: int, factor: int, dtype=None, device=None):
|
||||||
|
factory_kwargs = {"dtype": dtype, "device": device}
|
||||||
|
if modulate_type == 'wanx':
|
||||||
|
return ModulateWan(hidden_size, factor, **factory_kwargs)
|
||||||
|
raise ValueError(f"Unknown modulation type: {modulate_type}. Only 'wanx' is supported.")
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6, device=None, dtype=None):
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
if elementwise_affine:
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
||||||
|
|
||||||
|
def _norm(self, x):
|
||||||
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = self._norm(x.float()).type_as(x)
|
||||||
|
if hasattr(self, "weight"):
|
||||||
|
output = output * self.weight
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class MMDoubleStreamBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
A multimodal dit block with separate modulation for
|
||||||
|
text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
|
||||||
|
(Flux.1): https://github.com/black-forest-labs/flux
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
heads_num: int,
|
||||||
|
mlp_width_ratio: float,
|
||||||
|
mlp_act_type: str = "gelu_tanh",
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dit_modulation_type: Optional[str] = "wanx",
|
||||||
|
):
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
self.dit_modulation_type = dit_modulation_type
|
||||||
|
self.heads_num = heads_num
|
||||||
|
head_dim = hidden_size // heads_num
|
||||||
|
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
||||||
|
|
||||||
|
self.img_mod = load_modulation(
|
||||||
|
modulate_type=self.dit_modulation_type,
|
||||||
|
hidden_size=hidden_size, factor=6, **factory_kwargs,
|
||||||
|
)
|
||||||
|
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs)
|
||||||
|
self.img_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||||
|
self.img_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||||
|
self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs)
|
||||||
|
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
self.img_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate")
|
||||||
|
|
||||||
|
self.txt_mod = load_modulation(
|
||||||
|
modulate_type=self.dit_modulation_type,
|
||||||
|
hidden_size=hidden_size, factor=6, **factory_kwargs,
|
||||||
|
)
|
||||||
|
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs)
|
||||||
|
self.txt_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||||
|
self.txt_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||||
|
self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs)
|
||||||
|
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
self.txt_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
img: torch.Tensor,
|
||||||
|
txt: torch.Tensor,
|
||||||
|
vec: torch.Tensor,
|
||||||
|
vis_freqs_cis: tuple = None,
|
||||||
|
txt_freqs_cis: tuple = None,
|
||||||
|
attn_kwargs: Optional[dict] = {},
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
(
|
||||||
|
img_mod1_shift, img_mod1_scale, img_mod1_gate,
|
||||||
|
img_mod2_shift, img_mod2_scale, img_mod2_gate,
|
||||||
|
) = self.img_mod(vec)
|
||||||
|
(
|
||||||
|
txt_mod1_shift, txt_mod1_scale, txt_mod1_gate,
|
||||||
|
txt_mod2_shift, txt_mod2_scale, txt_mod2_gate,
|
||||||
|
) = self.txt_mod(vec)
|
||||||
|
|
||||||
|
img_modulated = self.img_norm1(img)
|
||||||
|
img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
|
||||||
|
img_qkv = self.img_attn_qkv(img_modulated)
|
||||||
|
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||||
|
img_q = self.img_attn_q_norm(img_q).to(img_v)
|
||||||
|
img_k = self.img_attn_k_norm(img_k).to(img_v)
|
||||||
|
|
||||||
|
if vis_freqs_cis is not None:
|
||||||
|
img_qq, img_kk = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False)
|
||||||
|
img_q, img_k = img_qq, img_kk
|
||||||
|
|
||||||
|
txt_modulated = self.txt_norm1(txt)
|
||||||
|
txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
|
||||||
|
txt_qkv = self.txt_attn_qkv(txt_modulated)
|
||||||
|
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||||
|
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
|
||||||
|
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
|
||||||
|
|
||||||
|
if txt_freqs_cis is not None:
|
||||||
|
raise NotImplementedError("RoPE text is not supported for inference")
|
||||||
|
|
||||||
|
q = torch.cat((img_q, txt_q), dim=1)
|
||||||
|
k = torch.cat((img_k, txt_k), dim=1)
|
||||||
|
v = torch.cat((img_v, txt_v), dim=1)
|
||||||
|
|
||||||
|
# Use DiffSynth unified attention
|
||||||
|
attn_out = attention_forward(
|
||||||
|
q, k, v,
|
||||||
|
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_out = attn_out.flatten(2, 3)
|
||||||
|
img_attn, txt_attn = attn_out[:, : img.shape[1]], attn_out[:, img.shape[1]:]
|
||||||
|
|
||||||
|
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
|
||||||
|
img = img + apply_gate(
|
||||||
|
self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
|
||||||
|
gate=img_mod2_gate,
|
||||||
|
)
|
||||||
|
|
||||||
|
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
|
||||||
|
txt = txt + apply_gate(
|
||||||
|
self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
|
||||||
|
gate=txt_mod2_gate,
|
||||||
|
)
|
||||||
|
|
||||||
|
return img, txt
|
||||||
|
|
||||||
|
|
||||||
|
class WanTimeTextImageEmbedding(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
time_freq_dim: int,
|
||||||
|
time_proj_dim: int,
|
||||||
|
text_embed_dim: int,
|
||||||
|
image_embed_dim: Optional[int] = None,
|
||||||
|
pos_embed_seq_len: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||||
|
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
|
||||||
|
self.act_fn = nn.SiLU()
|
||||||
|
self.time_proj = nn.Linear(dim, time_proj_dim)
|
||||||
|
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
|
||||||
|
|
||||||
|
def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor):
|
||||||
|
timestep = self.timesteps_proj(timestep)
|
||||||
|
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
|
||||||
|
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
||||||
|
timestep = timestep.to(time_embedder_dtype)
|
||||||
|
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
|
||||||
|
timestep_proj = self.time_proj(self.act_fn(temb))
|
||||||
|
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
|
||||||
|
return temb, timestep_proj, encoder_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class JoyAIImageDiT(nn.Module):
|
||||||
|
_supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size: list = [1, 2, 2],
|
||||||
|
in_channels: int = 16,
|
||||||
|
out_channels: int = 16,
|
||||||
|
hidden_size: int = 4096,
|
||||||
|
heads_num: int = 32,
|
||||||
|
text_states_dim: int = 4096,
|
||||||
|
mlp_width_ratio: float = 4.0,
|
||||||
|
mm_double_blocks_depth: int = 40,
|
||||||
|
rope_dim_list: List[int] = [16, 56, 56],
|
||||||
|
rope_type: str = 'rope',
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dit_modulation_type: str = "wanx",
|
||||||
|
theta: int = 10000,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.out_channels = out_channels or in_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.heads_num = heads_num
|
||||||
|
self.rope_dim_list = rope_dim_list
|
||||||
|
self.dit_modulation_type = dit_modulation_type
|
||||||
|
self.mm_double_blocks_depth = mm_double_blocks_depth
|
||||||
|
self.rope_type = rope_type
|
||||||
|
self.theta = theta
|
||||||
|
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
|
||||||
|
if hidden_size % heads_num != 0:
|
||||||
|
raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}")
|
||||||
|
|
||||||
|
self.img_in = nn.Conv3d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
|
self.condition_embedder = WanTimeTextImageEmbedding(
|
||||||
|
dim=hidden_size,
|
||||||
|
time_freq_dim=256,
|
||||||
|
time_proj_dim=hidden_size * 6,
|
||||||
|
text_embed_dim=text_states_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.double_blocks = nn.ModuleList([
|
||||||
|
MMDoubleStreamBlock(
|
||||||
|
self.hidden_size, self.heads_num,
|
||||||
|
mlp_width_ratio=mlp_width_ratio,
|
||||||
|
dit_modulation_type=self.dit_modulation_type,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
for _ in range(mm_double_blocks_depth)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.norm_out = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.proj_out = nn.Linear(hidden_size, self.out_channels * math.prod(patch_size), **factory_kwargs)
|
||||||
|
|
||||||
|
def get_rotary_pos_embed(self, vis_rope_size, txt_rope_size=None):
|
||||||
|
target_ndim = 3
|
||||||
|
if len(vis_rope_size) != target_ndim:
|
||||||
|
vis_rope_size = [1] * (target_ndim - len(vis_rope_size)) + vis_rope_size
|
||||||
|
head_dim = self.hidden_size // self.heads_num
|
||||||
|
rope_dim_list = self.rope_dim_list
|
||||||
|
if rope_dim_list is None:
|
||||||
|
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
||||||
|
assert sum(rope_dim_list) == head_dim
|
||||||
|
vis_freqs, txt_freqs = get_nd_rotary_pos_embed(
|
||||||
|
rope_dim_list, vis_rope_size,
|
||||||
|
txt_rope_size=txt_rope_size if self.rope_type == 'mrope' else None,
|
||||||
|
theta=self.theta, use_real=True, theta_rescale_factor=1,
|
||||||
|
)
|
||||||
|
return vis_freqs, txt_freqs
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
timestep: torch.Tensor,
|
||||||
|
encoder_hidden_states: torch.Tensor = None,
|
||||||
|
encoder_hidden_states_mask: torch.Tensor = None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
use_gradient_checkpointing: bool = False,
|
||||||
|
use_gradient_checkpointing_offload: bool = False,
|
||||||
|
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||||
|
is_multi_item = (len(hidden_states.shape) == 6)
|
||||||
|
num_items = 0
|
||||||
|
if is_multi_item:
|
||||||
|
num_items = hidden_states.shape[1]
|
||||||
|
if num_items > 1:
|
||||||
|
assert self.patch_size[0] == 1, "For multi-item input, patch_size[0] must be 1"
|
||||||
|
hidden_states = torch.cat([hidden_states[:, -1:], hidden_states[:, :-1]], dim=1)
|
||||||
|
hidden_states = rearrange(hidden_states, 'b n c t h w -> b c (n t) h w')
|
||||||
|
|
||||||
|
batch_size, _, ot, oh, ow = hidden_states.shape
|
||||||
|
tt, th, tw = ot // self.patch_size[0], oh // self.patch_size[1], ow // self.patch_size[2]
|
||||||
|
|
||||||
|
if encoder_hidden_states_mask is None:
|
||||||
|
encoder_hidden_states_mask = torch.ones(
|
||||||
|
(encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]),
|
||||||
|
dtype=torch.bool,
|
||||||
|
).to(encoder_hidden_states.device)
|
||||||
|
|
||||||
|
img = self.img_in(hidden_states).flatten(2).transpose(1, 2)
|
||||||
|
temb, vec, txt = self.condition_embedder(timestep, encoder_hidden_states)
|
||||||
|
if vec.shape[-1] > self.hidden_size:
|
||||||
|
vec = vec.unflatten(1, (6, -1))
|
||||||
|
|
||||||
|
txt_seq_len = txt.shape[1]
|
||||||
|
img_seq_len = img.shape[1]
|
||||||
|
|
||||||
|
vis_freqs_cis, txt_freqs_cis = self.get_rotary_pos_embed(
|
||||||
|
vis_rope_size=(tt, th, tw),
|
||||||
|
txt_rope_size=txt_seq_len if self.rope_type == 'mrope' else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
for block in self.double_blocks:
|
||||||
|
img, txt = gradient_checkpoint_forward(
|
||||||
|
block,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
img=img, txt=txt, vec=vec,
|
||||||
|
vis_freqs_cis=vis_freqs_cis, txt_freqs_cis=txt_freqs_cis,
|
||||||
|
attn_kwargs={},
|
||||||
|
)
|
||||||
|
|
||||||
|
img_len = img.shape[1]
|
||||||
|
x = torch.cat((img, txt), 1)
|
||||||
|
img = x[:, :img_len, ...]
|
||||||
|
|
||||||
|
img = self.proj_out(self.norm_out(img))
|
||||||
|
img = self.unpatchify(img, tt, th, tw)
|
||||||
|
|
||||||
|
if is_multi_item:
|
||||||
|
img = rearrange(img, 'b c (n t) h w -> b n c t h w', n=num_items)
|
||||||
|
if num_items > 1:
|
||||||
|
img = torch.cat([img[:, 1:], img[:, :1]], dim=1)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
def unpatchify(self, x, t, h, w):
|
||||||
|
c = self.out_channels
|
||||||
|
pt, ph, pw = self.patch_size
|
||||||
|
assert t * h * w == x.shape[1]
|
||||||
|
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
|
||||||
|
x = torch.einsum("nthwopqc->nctohpwq", x)
|
||||||
|
return x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
|
||||||
82
diffsynth/models/joyai_image_text_encoder.py
Normal file
82
diffsynth/models/joyai_image_text_encoder.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class JoyAIImageTextEncoder(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration
|
||||||
|
|
||||||
|
config = Qwen3VLConfig(
|
||||||
|
text_config={
|
||||||
|
"attention_bias": False,
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"bos_token_id": 151643,
|
||||||
|
"eos_token_id": 151645,
|
||||||
|
"head_dim": 128,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 4096,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 12288,
|
||||||
|
"max_position_embeddings": 262144,
|
||||||
|
"model_type": "qwen3_vl_text",
|
||||||
|
"num_attention_heads": 32,
|
||||||
|
"num_hidden_layers": 36,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"rms_norm_eps": 1e-6,
|
||||||
|
"rope_scaling": {
|
||||||
|
"mrope_interleaved": True,
|
||||||
|
"mrope_section": [24, 20, 20],
|
||||||
|
"rope_type": "default",
|
||||||
|
},
|
||||||
|
"rope_theta": 5000000,
|
||||||
|
"use_cache": True,
|
||||||
|
"vocab_size": 151936,
|
||||||
|
},
|
||||||
|
vision_config={
|
||||||
|
"deepstack_visual_indexes": [8, 16, 24],
|
||||||
|
"depth": 27,
|
||||||
|
"hidden_act": "gelu_pytorch_tanh",
|
||||||
|
"hidden_size": 1152,
|
||||||
|
"in_channels": 3,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 4304,
|
||||||
|
"model_type": "qwen3_vl",
|
||||||
|
"num_heads": 16,
|
||||||
|
"num_position_embeddings": 2304,
|
||||||
|
"out_hidden_size": 4096,
|
||||||
|
"patch_size": 16,
|
||||||
|
"spatial_merge_size": 2,
|
||||||
|
"temporal_patch_size": 2,
|
||||||
|
},
|
||||||
|
image_token_id=151655,
|
||||||
|
video_token_id=151656,
|
||||||
|
vision_start_token_id=151652,
|
||||||
|
vision_end_token_id=151653,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = Qwen3VLForConditionalGeneration(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
pixel_values: Optional[torch.Tensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
pre_norm_output = [None]
|
||||||
|
def hook_fn(module, args, kwargs_output=None):
|
||||||
|
pre_norm_output[0] = args[0]
|
||||||
|
self.model.model.language_model.norm.register_forward_hook(hook_fn)
|
||||||
|
_ = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_hidden_states=True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return pre_norm_output[0]
|
||||||
@@ -5,8 +5,65 @@ import einops
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torchaudio
|
||||||
from .ltx2_common import VideoLatentShape, AudioLatentShape, Patchifier, NormType, build_normalization_layer
|
from .ltx2_common import VideoLatentShape, AudioLatentShape, Patchifier, NormType, build_normalization_layer
|
||||||
|
|
||||||
|
|
||||||
|
class AudioProcessor(nn.Module):
|
||||||
|
"""Converts audio waveforms to log-mel spectrograms with optional resampling."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
mel_bins: int = 64,
|
||||||
|
mel_hop_length: int = 160,
|
||||||
|
n_fft: int = 1024,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.mel_transform = torchaudio.transforms.MelSpectrogram(
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
n_fft=n_fft,
|
||||||
|
win_length=n_fft,
|
||||||
|
hop_length=mel_hop_length,
|
||||||
|
f_min=0.0,
|
||||||
|
f_max=sample_rate / 2.0,
|
||||||
|
n_mels=mel_bins,
|
||||||
|
window_fn=torch.hann_window,
|
||||||
|
center=True,
|
||||||
|
pad_mode="reflect",
|
||||||
|
power=1.0,
|
||||||
|
mel_scale="slaney",
|
||||||
|
norm="slaney",
|
||||||
|
)
|
||||||
|
|
||||||
|
def resample_waveform(
|
||||||
|
self,
|
||||||
|
waveform: torch.Tensor,
|
||||||
|
source_rate: int,
|
||||||
|
target_rate: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Resample waveform to target sample rate if needed."""
|
||||||
|
if source_rate == target_rate:
|
||||||
|
return waveform
|
||||||
|
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
|
||||||
|
return resampled.to(device=waveform.device, dtype=waveform.dtype)
|
||||||
|
|
||||||
|
def waveform_to_mel(
|
||||||
|
self,
|
||||||
|
waveform: torch.Tensor,
|
||||||
|
waveform_sample_rate: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Convert waveform to log-mel spectrogram [batch, channels, time, n_mels]."""
|
||||||
|
waveform = self.resample_waveform(waveform, waveform_sample_rate, self.sample_rate)
|
||||||
|
|
||||||
|
mel = self.mel_transform(waveform)
|
||||||
|
mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||||
|
|
||||||
|
mel = mel.to(device=waveform.device, dtype=waveform.dtype)
|
||||||
|
return mel.permute(0, 1, 3, 2).contiguous()
|
||||||
|
|
||||||
|
|
||||||
class AudioPatchifier(Patchifier):
|
class AudioPatchifier(Patchifier):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -1222,9 +1279,268 @@ class LTX2AudioDecoder(torch.nn.Module):
|
|||||||
return torch.tanh(h) if self.tanh_out else h
|
return torch.tanh(h) if self.tanh_out else h
|
||||||
|
|
||||||
|
|
||||||
|
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
||||||
|
return int((kernel_size * dilation - dilation) / 2)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2
|
||||||
|
# Adopted from https://github.com/NVIDIA/BigVGAN
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _sinc(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.where(
|
||||||
|
x == 0,
|
||||||
|
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
||||||
|
torch.sin(math.pi * x) / math.pi / x,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor:
|
||||||
|
even = kernel_size % 2 == 0
|
||||||
|
half_size = kernel_size // 2
|
||||||
|
delta_f = 4 * half_width
|
||||||
|
amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||||
|
if amplitude > 50.0:
|
||||||
|
beta = 0.1102 * (amplitude - 8.7)
|
||||||
|
elif amplitude >= 21.0:
|
||||||
|
beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)
|
||||||
|
else:
|
||||||
|
beta = 0.0
|
||||||
|
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||||
|
time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size
|
||||||
|
if cutoff == 0:
|
||||||
|
filter_ = torch.zeros_like(time)
|
||||||
|
else:
|
||||||
|
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
|
||||||
|
filter_ /= filter_.sum()
|
||||||
|
return filter_.view(1, 1, kernel_size)
|
||||||
|
|
||||||
|
|
||||||
|
class LowPassFilter1d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cutoff: float = 0.5,
|
||||||
|
half_width: float = 0.6,
|
||||||
|
stride: int = 1,
|
||||||
|
padding: bool = True,
|
||||||
|
padding_mode: str = "replicate",
|
||||||
|
kernel_size: int = 12,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
if cutoff < -0.0:
|
||||||
|
raise ValueError("Minimum cutoff must be larger than zero.")
|
||||||
|
if cutoff > 0.5:
|
||||||
|
raise ValueError("A cutoff above 0.5 does not make sense.")
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.even = kernel_size % 2 == 0
|
||||||
|
self.pad_left = kernel_size // 2 - int(self.even)
|
||||||
|
self.pad_right = kernel_size // 2
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
self.padding_mode = padding_mode
|
||||||
|
self.register_buffer("filter", kaiser_sinc_filter1d(cutoff, half_width, kernel_size))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
_, n_channels, _ = x.shape
|
||||||
|
if self.padding:
|
||||||
|
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
||||||
|
return F.conv1d(x, self.filter.expand(n_channels, -1, -1), stride=self.stride, groups=n_channels)
|
||||||
|
|
||||||
|
|
||||||
|
class UpSample1d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ratio: int = 2,
|
||||||
|
kernel_size: int | None = None,
|
||||||
|
persistent: bool = True,
|
||||||
|
window_type: str = "kaiser",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.ratio = ratio
|
||||||
|
self.stride = ratio
|
||||||
|
|
||||||
|
if window_type == "hann":
|
||||||
|
# Hann-windowed sinc filter equivalent to torchaudio.functional.resample
|
||||||
|
rolloff = 0.99
|
||||||
|
lowpass_filter_width = 6
|
||||||
|
width = math.ceil(lowpass_filter_width / rolloff)
|
||||||
|
self.kernel_size = 2 * width * ratio + 1
|
||||||
|
self.pad = width
|
||||||
|
self.pad_left = 2 * width * ratio
|
||||||
|
self.pad_right = self.kernel_size - ratio
|
||||||
|
time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff
|
||||||
|
time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width)
|
||||||
|
window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2
|
||||||
|
sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1)
|
||||||
|
else:
|
||||||
|
# Kaiser-windowed sinc filter (BigVGAN default).
|
||||||
|
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||||
|
self.pad = self.kernel_size // ratio - 1
|
||||||
|
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||||
|
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||||
|
sinc_filter = kaiser_sinc_filter1d(
|
||||||
|
cutoff=0.5 / ratio,
|
||||||
|
half_width=0.6 / ratio,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_buffer("filter", sinc_filter, persistent=persistent)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
_, n_channels, _ = x.shape
|
||||||
|
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
||||||
|
filt = self.filter.to(dtype=x.dtype, device=x.device).expand(n_channels, -1, -1)
|
||||||
|
x = self.ratio * F.conv_transpose1d(x, filt, stride=self.stride, groups=n_channels)
|
||||||
|
return x[..., self.pad_left : -self.pad_right]
|
||||||
|
|
||||||
|
|
||||||
|
class DownSample1d(nn.Module):
|
||||||
|
def __init__(self, ratio: int = 2, kernel_size: int | None = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.ratio = ratio
|
||||||
|
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||||
|
self.lowpass = LowPassFilter1d(
|
||||||
|
cutoff=0.5 / ratio,
|
||||||
|
half_width=0.6 / ratio,
|
||||||
|
stride=ratio,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.lowpass(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Activation1d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
activation: nn.Module,
|
||||||
|
up_ratio: int = 2,
|
||||||
|
down_ratio: int = 2,
|
||||||
|
up_kernel_size: int = 12,
|
||||||
|
down_kernel_size: int = 12,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.act = activation
|
||||||
|
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||||
|
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.upsample(x)
|
||||||
|
x = self.act(x)
|
||||||
|
return self.downsample(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Snake(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
alpha: float = 1.0,
|
||||||
|
alpha_trainable: bool = True,
|
||||||
|
alpha_logscale: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.alpha_logscale = alpha_logscale
|
||||||
|
self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
|
||||||
|
self.alpha.requires_grad = alpha_trainable
|
||||||
|
self.eps = 1e-9
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||||
|
if self.alpha_logscale:
|
||||||
|
alpha = torch.exp(alpha)
|
||||||
|
return x + (1.0 / (alpha + self.eps)) * torch.sin(x * alpha).pow(2)
|
||||||
|
|
||||||
|
|
||||||
|
class SnakeBeta(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
alpha: float = 1.0,
|
||||||
|
alpha_trainable: bool = True,
|
||||||
|
alpha_logscale: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.alpha_logscale = alpha_logscale
|
||||||
|
self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
|
||||||
|
self.alpha.requires_grad = alpha_trainable
|
||||||
|
self.beta = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
|
||||||
|
self.beta.requires_grad = alpha_trainable
|
||||||
|
self.eps = 1e-9
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||||
|
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||||
|
if self.alpha_logscale:
|
||||||
|
alpha = torch.exp(alpha)
|
||||||
|
beta = torch.exp(beta)
|
||||||
|
return x + (1.0 / (beta + self.eps)) * torch.sin(x * alpha).pow(2)
|
||||||
|
|
||||||
|
|
||||||
|
class AMPBlock1(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
dilation: tuple[int, int, int] = (1, 3, 5),
|
||||||
|
activation: str = "snake",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
||||||
|
self.convs1 = nn.ModuleList(
|
||||||
|
[
|
||||||
|
nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[0],
|
||||||
|
padding=get_padding(kernel_size, dilation[0]),
|
||||||
|
),
|
||||||
|
nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[1],
|
||||||
|
padding=get_padding(kernel_size, dilation[1]),
|
||||||
|
),
|
||||||
|
nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[2],
|
||||||
|
padding=get_padding(kernel_size, dilation[2]),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.convs2 = nn.ModuleList(
|
||||||
|
[
|
||||||
|
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
|
||||||
|
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
|
||||||
|
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.acts1 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs1))])
|
||||||
|
self.acts2 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs2))])
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2, strict=True):
|
||||||
|
xt = a1(x)
|
||||||
|
xt = c1(xt)
|
||||||
|
xt = a2(xt)
|
||||||
|
xt = c2(xt)
|
||||||
|
x = x + xt
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class LTX2Vocoder(torch.nn.Module):
|
class LTX2Vocoder(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Vocoder model for synthesizing audio from Mel spectrograms.
|
LTX2Vocoder model for synthesizing audio from Mel spectrograms.
|
||||||
Args:
|
Args:
|
||||||
resblock_kernel_sizes: List of kernel sizes for the residual blocks.
|
resblock_kernel_sizes: List of kernel sizes for the residual blocks.
|
||||||
This value is read from the checkpoint at `config.vocoder.resblock_kernel_sizes`.
|
This value is read from the checkpoint at `config.vocoder.resblock_kernel_sizes`.
|
||||||
@@ -1236,28 +1552,33 @@ class LTX2Vocoder(torch.nn.Module):
|
|||||||
This value is read from the checkpoint at `config.vocoder.resblock_dilation_sizes`.
|
This value is read from the checkpoint at `config.vocoder.resblock_dilation_sizes`.
|
||||||
upsample_initial_channel: Initial number of channels for the upsampling layers.
|
upsample_initial_channel: Initial number of channels for the upsampling layers.
|
||||||
This value is read from the checkpoint at `config.vocoder.upsample_initial_channel`.
|
This value is read from the checkpoint at `config.vocoder.upsample_initial_channel`.
|
||||||
stereo: Whether to use stereo output.
|
resblock: Type of residual block to use ("1", "2", or "AMP1").
|
||||||
This value is read from the checkpoint at `config.vocoder.stereo`.
|
|
||||||
resblock: Type of residual block to use.
|
|
||||||
This value is read from the checkpoint at `config.vocoder.resblock`.
|
This value is read from the checkpoint at `config.vocoder.resblock`.
|
||||||
output_sample_rate: Waveform sample rate.
|
output_sampling_rate: Waveform sample rate.
|
||||||
This value is read from the checkpoint at `config.vocoder.output_sample_rate`.
|
This value is read from the checkpoint at `config.vocoder.output_sampling_rate`.
|
||||||
|
activation: Activation type for BigVGAN v2 ("snake" or "snakebeta"). Only used when resblock="AMP1".
|
||||||
|
use_tanh_at_final: Apply tanh at the output (when apply_final_activation=True).
|
||||||
|
apply_final_activation: Whether to apply the final tanh/clamp activation.
|
||||||
|
use_bias_at_final: Whether to use bias in the final conv layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__( # noqa: PLR0913
|
||||||
self,
|
self,
|
||||||
resblock_kernel_sizes: List[int] | None = [3, 7, 11],
|
resblock_kernel_sizes: List[int] | None = [3, 7, 11],
|
||||||
upsample_rates: List[int] | None = [6, 5, 2, 2, 2],
|
upsample_rates: List[int] | None = [6, 5, 2, 2, 2],
|
||||||
upsample_kernel_sizes: List[int] | None = [16, 15, 8, 4, 4],
|
upsample_kernel_sizes: List[int] | None = [16, 15, 8, 4, 4],
|
||||||
resblock_dilation_sizes: List[List[int]] | None = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
resblock_dilation_sizes: List[List[int]] | None = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
upsample_initial_channel: int = 1024,
|
upsample_initial_channel: int = 1024,
|
||||||
stereo: bool = True,
|
|
||||||
resblock: str = "1",
|
resblock: str = "1",
|
||||||
output_sample_rate: int = 24000,
|
output_sampling_rate: int = 24000,
|
||||||
):
|
activation: str = "snake",
|
||||||
|
use_tanh_at_final: bool = True,
|
||||||
|
apply_final_activation: bool = True,
|
||||||
|
use_bias_at_final: bool = True,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Initialize default values if not provided. Note that mutable default values are not supported.
|
# Mutable default values are not supported as default arguments.
|
||||||
if resblock_kernel_sizes is None:
|
if resblock_kernel_sizes is None:
|
||||||
resblock_kernel_sizes = [3, 7, 11]
|
resblock_kernel_sizes = [3, 7, 11]
|
||||||
if upsample_rates is None:
|
if upsample_rates is None:
|
||||||
@@ -1267,16 +1588,25 @@ class LTX2Vocoder(torch.nn.Module):
|
|||||||
if resblock_dilation_sizes is None:
|
if resblock_dilation_sizes is None:
|
||||||
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
||||||
|
|
||||||
self.output_sample_rate = output_sample_rate
|
self.output_sampling_rate = output_sampling_rate
|
||||||
self.num_kernels = len(resblock_kernel_sizes)
|
self.num_kernels = len(resblock_kernel_sizes)
|
||||||
self.num_upsamples = len(upsample_rates)
|
self.num_upsamples = len(upsample_rates)
|
||||||
in_channels = 128 if stereo else 64
|
self.use_tanh_at_final = use_tanh_at_final
|
||||||
self.conv_pre = nn.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
self.apply_final_activation = apply_final_activation
|
||||||
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
|
self.is_amp = resblock == "AMP1"
|
||||||
|
|
||||||
self.ups = nn.ModuleList()
|
# All production checkpoints are stereo: 128 input channels (2 stereo channels x 64 mel
|
||||||
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True)):
|
# bins each), 2 output channels.
|
||||||
self.ups.append(
|
self.conv_pre = nn.Conv1d(
|
||||||
|
in_channels=128,
|
||||||
|
out_channels=upsample_initial_channel,
|
||||||
|
kernel_size=7,
|
||||||
|
stride=1,
|
||||||
|
padding=3,
|
||||||
|
)
|
||||||
|
resblock_cls = ResBlock1 if resblock == "1" else AMPBlock1
|
||||||
|
|
||||||
|
self.ups = nn.ModuleList(
|
||||||
nn.ConvTranspose1d(
|
nn.ConvTranspose1d(
|
||||||
upsample_initial_channel // (2**i),
|
upsample_initial_channel // (2**i),
|
||||||
upsample_initial_channel // (2 ** (i + 1)),
|
upsample_initial_channel // (2 ** (i + 1)),
|
||||||
@@ -1284,19 +1614,34 @@ class LTX2Vocoder(torch.nn.Module):
|
|||||||
stride,
|
stride,
|
||||||
padding=(kernel_size - stride) // 2,
|
padding=(kernel_size - stride) // 2,
|
||||||
)
|
)
|
||||||
|
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
final_channels = upsample_initial_channel // (2 ** len(upsample_rates))
|
||||||
self.resblocks = nn.ModuleList()
|
self.resblocks = nn.ModuleList()
|
||||||
for i, _ in enumerate(self.ups):
|
|
||||||
|
for i in range(len(upsample_rates)):
|
||||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||||
for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes, strict=True):
|
for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes, strict=True):
|
||||||
self.resblocks.append(resblock_class(ch, kernel_size, dilations))
|
if self.is_amp:
|
||||||
|
self.resblocks.append(resblock_cls(ch, kernel_size, dilations, activation=activation))
|
||||||
|
else:
|
||||||
|
self.resblocks.append(resblock_cls(ch, kernel_size, dilations))
|
||||||
|
|
||||||
out_channels = 2 if stereo else 1
|
if self.is_amp:
|
||||||
final_channels = upsample_initial_channel // (2**self.num_upsamples)
|
self.act_post: nn.Module = Activation1d(SnakeBeta(final_channels))
|
||||||
self.conv_post = nn.Conv1d(final_channels, out_channels, 7, 1, padding=3)
|
else:
|
||||||
|
self.act_post = nn.LeakyReLU()
|
||||||
|
|
||||||
self.upsample_factor = math.prod(layer.stride[0] for layer in self.ups)
|
# All production checkpoints are stereo: this final conv maps `final_channels` to 2 output channels (stereo).
|
||||||
|
self.conv_post = nn.Conv1d(
|
||||||
|
in_channels=final_channels,
|
||||||
|
out_channels=2,
|
||||||
|
kernel_size=7,
|
||||||
|
stride=1,
|
||||||
|
padding=3,
|
||||||
|
bias=use_bias_at_final,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -1317,6 +1662,7 @@ class LTX2Vocoder(torch.nn.Module):
|
|||||||
x = self.conv_pre(x)
|
x = self.conv_pre(x)
|
||||||
|
|
||||||
for i in range(self.num_upsamples):
|
for i in range(self.num_upsamples):
|
||||||
|
if not self.is_amp:
|
||||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
x = self.ups[i](x)
|
x = self.ups[i](x)
|
||||||
start = i * self.num_kernels
|
start = i * self.num_kernels
|
||||||
@@ -1329,23 +1675,198 @@ class LTX2Vocoder(torch.nn.Module):
|
|||||||
[self.resblocks[idx](x) for idx in range(start, end)],
|
[self.resblocks[idx](x) for idx in range(start, end)],
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
x = block_outputs.mean(dim=0)
|
x = block_outputs.mean(dim=0)
|
||||||
|
|
||||||
x = self.conv_post(F.leaky_relu(x))
|
x = self.act_post(x)
|
||||||
return torch.tanh(x)
|
x = self.conv_post(x)
|
||||||
|
|
||||||
|
if self.apply_final_activation:
|
||||||
|
x = torch.tanh(x) if self.use_tanh_at_final else torch.clamp(x, -1, 1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def decode_audio(latent: torch.Tensor, audio_decoder: "LTX2AudioDecoder", vocoder: "LTX2Vocoder") -> torch.Tensor:
|
class _STFTFn(nn.Module):
|
||||||
|
"""Implements STFT as a convolution with precomputed DFT x Hann-window bases.
|
||||||
|
The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal
|
||||||
|
Hann window are stored as buffers and loaded from the checkpoint. Using the exact
|
||||||
|
bfloat16 bases from training ensures the mel values fed to the BWE generator are
|
||||||
|
bit-identical to what it was trained on.
|
||||||
"""
|
"""
|
||||||
Decode an audio latent representation using the provided audio decoder and vocoder.
|
|
||||||
|
def __init__(self, filter_length: int, hop_length: int, win_length: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.win_length = win_length
|
||||||
|
n_freqs = filter_length // 2 + 1
|
||||||
|
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
||||||
|
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
||||||
|
|
||||||
|
def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Compute magnitude and phase spectrogram from a batch of waveforms.
|
||||||
|
Applies causal (left-only) padding of win_length - hop_length samples so that
|
||||||
|
each output frame depends only on past and present input — no lookahead.
|
||||||
Args:
|
Args:
|
||||||
latent: Input audio latent tensor.
|
y: Waveform tensor of shape (B, T).
|
||||||
audio_decoder: Model to decode the latent to waveform features.
|
|
||||||
vocoder: Model to convert decoded features to audio waveform.
|
|
||||||
Returns:
|
Returns:
|
||||||
Decoded audio as a float tensor.
|
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
||||||
|
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
||||||
"""
|
"""
|
||||||
decoded_audio = audio_decoder(latent)
|
if y.dim() == 2:
|
||||||
decoded_audio = vocoder(decoded_audio).squeeze(0).float()
|
y = y.unsqueeze(1) # (B, 1, T)
|
||||||
return decoded_audio
|
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
|
||||||
|
y = F.pad(y, (left_pad, 0))
|
||||||
|
spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
|
||||||
|
n_freqs = spec.shape[1] // 2
|
||||||
|
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
|
||||||
|
magnitude = torch.sqrt(real**2 + imag**2)
|
||||||
|
phase = torch.atan2(imag.float(), real.float()).to(real.dtype)
|
||||||
|
return magnitude, phase
|
||||||
|
|
||||||
|
|
||||||
|
class MelSTFT(nn.Module):
|
||||||
|
"""Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.
|
||||||
|
Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input
|
||||||
|
waveform and projecting the linear magnitude spectrum onto the mel filterbank.
|
||||||
|
The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint
|
||||||
|
(mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
filter_length: int,
|
||||||
|
hop_length: int,
|
||||||
|
win_length: int,
|
||||||
|
n_mel_channels: int,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.stft_fn = _STFTFn(filter_length, hop_length, win_length)
|
||||||
|
|
||||||
|
# Initialized to zeros; load_state_dict overwrites with the checkpoint's
|
||||||
|
# exact bfloat16 filterbank (vocoder.mel_stft.mel_basis, shape [n_mels, n_freqs]).
|
||||||
|
n_freqs = filter_length // 2 + 1
|
||||||
|
self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs))
|
||||||
|
|
||||||
|
def mel_spectrogram(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Compute log-mel spectrogram and auxiliary spectral quantities.
|
||||||
|
Args:
|
||||||
|
y: Waveform tensor of shape (B, T).
|
||||||
|
Returns:
|
||||||
|
log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).
|
||||||
|
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
||||||
|
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
||||||
|
energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames).
|
||||||
|
"""
|
||||||
|
magnitude, phase = self.stft_fn(y)
|
||||||
|
energy = torch.norm(magnitude, dim=1)
|
||||||
|
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
|
||||||
|
log_mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||||
|
return log_mel, magnitude, phase, energy
|
||||||
|
|
||||||
|
|
||||||
|
class LTX2VocoderWithBWE(nn.Module):
|
||||||
|
"""LTX2Vocoder with bandwidth extension (BWE) upsampling.
|
||||||
|
Chains a mel-to-wav vocoder with a BWE module that upsamples the output
|
||||||
|
to a higher sample rate. The BWE computes a mel spectrogram from the
|
||||||
|
vocoder output, runs it through a second generator to predict a residual,
|
||||||
|
and adds it to a sinc-resampled skip connection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_sampling_rate: int = 16000,
|
||||||
|
output_sampling_rate: int = 48000,
|
||||||
|
hop_length: int = 80,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.vocoder = LTX2Vocoder(
|
||||||
|
resblock_kernel_sizes=[3, 7, 11],
|
||||||
|
upsample_rates=[5, 2, 2, 2, 2, 2],
|
||||||
|
upsample_kernel_sizes=[11, 4, 4, 4, 4, 4],
|
||||||
|
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
|
upsample_initial_channel=1536,
|
||||||
|
resblock="AMP1",
|
||||||
|
activation="snakebeta",
|
||||||
|
use_tanh_at_final=False,
|
||||||
|
apply_final_activation=True,
|
||||||
|
use_bias_at_final=False,
|
||||||
|
output_sampling_rate=input_sampling_rate,
|
||||||
|
)
|
||||||
|
self.bwe_generator = LTX2Vocoder(
|
||||||
|
resblock_kernel_sizes=[3, 7, 11],
|
||||||
|
upsample_rates=[6, 5, 2, 2, 2],
|
||||||
|
upsample_kernel_sizes=[12, 11, 4, 4, 4],
|
||||||
|
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
|
upsample_initial_channel=512,
|
||||||
|
resblock="AMP1",
|
||||||
|
activation="snakebeta",
|
||||||
|
use_tanh_at_final=False,
|
||||||
|
apply_final_activation=False,
|
||||||
|
use_bias_at_final=False,
|
||||||
|
output_sampling_rate=output_sampling_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mel_stft = MelSTFT(
|
||||||
|
filter_length=512,
|
||||||
|
hop_length=hop_length,
|
||||||
|
win_length=512,
|
||||||
|
n_mel_channels=64,
|
||||||
|
)
|
||||||
|
self.input_sampling_rate = input_sampling_rate
|
||||||
|
self.output_sampling_rate = output_sampling_rate
|
||||||
|
self.hop_length = hop_length
|
||||||
|
# Compute the resampler on CPU so the sinc filter is materialized even when
|
||||||
|
# the model is constructed on meta device (SingleGPUModelBuilder pattern).
|
||||||
|
# The filter is not stored in the checkpoint (persistent=False).
|
||||||
|
with torch.device("cpu"):
|
||||||
|
self.resampler = UpSample1d(
|
||||||
|
ratio=output_sampling_rate // input_sampling_rate, persistent=False, window_type="hann"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def conv_pre(self) -> nn.Conv1d:
|
||||||
|
return self.vocoder.conv_pre
|
||||||
|
|
||||||
|
@property
|
||||||
|
def conv_post(self) -> nn.Conv1d:
|
||||||
|
return self.vocoder.conv_post
|
||||||
|
|
||||||
|
def _compute_mel(self, audio: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Compute log-mel spectrogram from waveform using causal STFT bases.
|
||||||
|
Args:
|
||||||
|
audio: Waveform tensor of shape (B, C, T).
|
||||||
|
Returns:
|
||||||
|
mel: Log-mel spectrogram of shape (B, C, n_mels, T_frames).
|
||||||
|
"""
|
||||||
|
batch, n_channels, _ = audio.shape
|
||||||
|
flat = audio.reshape(batch * n_channels, -1) # (B*C, T)
|
||||||
|
mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
|
||||||
|
return mel.reshape(batch, n_channels, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
|
||||||
|
|
||||||
|
def forward(self, mel_spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Run the full vocoder + BWE forward pass.
|
||||||
|
Args:
|
||||||
|
mel_spec: Mel spectrogram of shape (B, 2, T, mel_bins) for stereo
|
||||||
|
or (B, T, mel_bins) for mono. Same format as LTX2Vocoder.forward.
|
||||||
|
Returns:
|
||||||
|
Waveform tensor of shape (B, out_channels, T_out) clipped to [-1, 1].
|
||||||
|
"""
|
||||||
|
x = self.vocoder(mel_spec)
|
||||||
|
_, _, length_low_rate = x.shape
|
||||||
|
output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate
|
||||||
|
|
||||||
|
# Pad to multiple of hop_length for exact mel frame count
|
||||||
|
remainder = length_low_rate % self.hop_length
|
||||||
|
if remainder != 0:
|
||||||
|
x = F.pad(x, (0, self.hop_length - remainder))
|
||||||
|
|
||||||
|
# Compute mel spectrogram from vocoder output: (B, C, n_mels, T_frames)
|
||||||
|
mel = self._compute_mel(x)
|
||||||
|
|
||||||
|
# LTX2Vocoder.forward expects (B, C, T, mel_bins) — transpose before calling bwe_generator
|
||||||
|
mel_for_bwe = mel.transpose(2, 3) # (B, C, T_frames, mel_bins)
|
||||||
|
residual = self.bwe_generator(mel_for_bwe)
|
||||||
|
skip = self.resampler(x)
|
||||||
|
assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
|
||||||
|
|
||||||
|
return torch.clamp(residual + skip, -1, 1)[..., :output_length]
|
||||||
|
|||||||
@@ -251,11 +251,27 @@ class Modality:
|
|||||||
Input data for a single modality (video or audio) in the transformer.
|
Input data for a single modality (video or audio) in the transformer.
|
||||||
Bundles the latent tokens, timestep embeddings, positional information,
|
Bundles the latent tokens, timestep embeddings, positional information,
|
||||||
and text conditioning context for processing by the diffusion transformer.
|
and text conditioning context for processing by the diffusion transformer.
|
||||||
|
Attributes:
|
||||||
|
latent: Patchified latent tokens, shape ``(B, T, D)`` where *B* is
|
||||||
|
the batch size, *T* is the total number of tokens (noisy +
|
||||||
|
conditioning), and *D* is the input dimension.
|
||||||
|
timesteps: Per-token timestep embeddings, shape ``(B, T)``.
|
||||||
|
positions: Positional coordinates, shape ``(B, 3, T)`` for video
|
||||||
|
(time, height, width) or ``(B, 1, T)`` for audio.
|
||||||
|
context: Text conditioning embeddings from the prompt encoder.
|
||||||
|
enabled: Whether this modality is active in the current forward pass.
|
||||||
|
context_mask: Optional mask for the text context tokens.
|
||||||
|
attention_mask: Optional 2-D self-attention mask, shape ``(B, T, T)``.
|
||||||
|
Values in ``[0, 1]`` where ``1`` = full attention and ``0`` = no
|
||||||
|
attention. ``None`` means unrestricted (full) attention between
|
||||||
|
all tokens. Built incrementally by conditioning items; see
|
||||||
|
:class:`~ltx_core.conditioning.types.attention_strength_wrapper.ConditioningItemAttentionStrengthWrapper`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
latent: (
|
latent: (
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
|
) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
|
||||||
|
sigma: torch.Tensor # Shape: (B,). Current sigma value, used for cross-attention timestep calculation.
|
||||||
timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps
|
timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps
|
||||||
positions: (
|
positions: (
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
@@ -263,6 +279,7 @@ class Modality:
|
|||||||
context: torch.Tensor
|
context: torch.Tensor
|
||||||
enabled: bool = True
|
enabled: bool = True
|
||||||
context_mask: torch.Tensor | None = None
|
context_mask: torch.Tensor | None = None
|
||||||
|
attention_mask: torch.Tensor | None = None
|
||||||
|
|
||||||
|
|
||||||
def to_denoised(
|
def to_denoised(
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import torch
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from .ltx2_common import rms_norm, Modality
|
from .ltx2_common import rms_norm, Modality
|
||||||
from ..core.attention.attention import attention_forward
|
from ..core.attention.attention import attention_forward
|
||||||
|
from ..core import gradient_checkpoint_forward
|
||||||
|
|
||||||
|
|
||||||
def get_timestep_embedding(
|
def get_timestep_embedding(
|
||||||
@@ -224,6 +225,17 @@ class BatchedPerturbationConfig:
|
|||||||
return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)])
|
return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ADALN_NUM_BASE_PARAMS = 6
|
||||||
|
# Cross-attention AdaLN adds 3 more (scale, shift, gate) for the CA norm.
|
||||||
|
ADALN_NUM_CROSS_ATTN_PARAMS = 3
|
||||||
|
|
||||||
|
|
||||||
|
def adaln_embedding_coefficient(cross_attention_adaln: bool) -> int:
|
||||||
|
"""Total number of AdaLN parameters per block."""
|
||||||
|
return ADALN_NUM_BASE_PARAMS + (ADALN_NUM_CROSS_ATTN_PARAMS if cross_attention_adaln else 0)
|
||||||
|
|
||||||
|
|
||||||
class AdaLayerNormSingle(torch.nn.Module):
|
class AdaLayerNormSingle(torch.nn.Module):
|
||||||
r"""
|
r"""
|
||||||
Norm layer adaptive layer norm single (adaLN-single).
|
Norm layer adaptive layer norm single (adaLN-single).
|
||||||
@@ -459,6 +471,7 @@ class Attention(torch.nn.Module):
|
|||||||
dim_head: int = 64,
|
dim_head: int = 64,
|
||||||
norm_eps: float = 1e-6,
|
norm_eps: float = 1e-6,
|
||||||
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||||
|
apply_gated_attention: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_type = rope_type
|
self.rope_type = rope_type
|
||||||
@@ -476,6 +489,12 @@ class Attention(torch.nn.Module):
|
|||||||
self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True)
|
self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True)
|
||||||
self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True)
|
self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True)
|
||||||
|
|
||||||
|
# Optional per-head gating
|
||||||
|
if apply_gated_attention:
|
||||||
|
self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True)
|
||||||
|
else:
|
||||||
|
self.to_gate_logits = None
|
||||||
|
|
||||||
self.to_out = torch.nn.Sequential(torch.nn.Linear(inner_dim, query_dim, bias=True), torch.nn.Identity())
|
self.to_out = torch.nn.Sequential(torch.nn.Linear(inner_dim, query_dim, bias=True), torch.nn.Identity())
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -485,6 +504,8 @@ class Attention(torch.nn.Module):
|
|||||||
mask: torch.Tensor | None = None,
|
mask: torch.Tensor | None = None,
|
||||||
pe: torch.Tensor | None = None,
|
pe: torch.Tensor | None = None,
|
||||||
k_pe: torch.Tensor | None = None,
|
k_pe: torch.Tensor | None = None,
|
||||||
|
perturbation_mask: torch.Tensor | None = None,
|
||||||
|
all_perturbed: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
context = x if context is None else context
|
context = x if context is None else context
|
||||||
@@ -516,6 +537,19 @@ class Attention(torch.nn.Module):
|
|||||||
|
|
||||||
# Reshape back to original format
|
# Reshape back to original format
|
||||||
out = out.flatten(2, 3)
|
out = out.flatten(2, 3)
|
||||||
|
|
||||||
|
# Apply per-head gating if enabled
|
||||||
|
if self.to_gate_logits is not None:
|
||||||
|
gate_logits = self.to_gate_logits(x) # (B, T, H)
|
||||||
|
b, t, _ = out.shape
|
||||||
|
# Reshape to (B, T, H, D) for per-head gating
|
||||||
|
out = out.view(b, t, self.heads, self.dim_head)
|
||||||
|
# Apply gating: 2 * sigmoid(x) so that zero-init gives identity (2 * 0.5 = 1.0)
|
||||||
|
gates = 2.0 * torch.sigmoid(gate_logits) # (B, T, H)
|
||||||
|
out = out * gates.unsqueeze(-1) # (B, T, H, D) * (B, T, H, 1)
|
||||||
|
# Reshape back to (B, T, H*D)
|
||||||
|
out = out.view(b, t, self.heads * self.dim_head)
|
||||||
|
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
@@ -544,7 +578,6 @@ class PixArtAlphaTextProjection(torch.nn.Module):
|
|||||||
hidden_states = self.linear_2(hidden_states)
|
hidden_states = self.linear_2(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class TransformerArgs:
|
class TransformerArgs:
|
||||||
x: torch.Tensor
|
x: torch.Tensor
|
||||||
@@ -557,7 +590,10 @@ class TransformerArgs:
|
|||||||
cross_scale_shift_timestep: torch.Tensor | None
|
cross_scale_shift_timestep: torch.Tensor | None
|
||||||
cross_gate_timestep: torch.Tensor | None
|
cross_gate_timestep: torch.Tensor | None
|
||||||
enabled: bool
|
enabled: bool
|
||||||
|
prompt_timestep: torch.Tensor | None = None
|
||||||
|
self_attention_mask: torch.Tensor | None = (
|
||||||
|
None # Additive log-space self-attention bias (B, 1, T, T), None = full attention
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TransformerArgsPreprocessor:
|
class TransformerArgsPreprocessor:
|
||||||
@@ -565,7 +601,6 @@ class TransformerArgsPreprocessor:
|
|||||||
self,
|
self,
|
||||||
patchify_proj: torch.nn.Linear,
|
patchify_proj: torch.nn.Linear,
|
||||||
adaln: AdaLayerNormSingle,
|
adaln: AdaLayerNormSingle,
|
||||||
caption_projection: PixArtAlphaTextProjection,
|
|
||||||
inner_dim: int,
|
inner_dim: int,
|
||||||
max_pos: list[int],
|
max_pos: list[int],
|
||||||
num_attention_heads: int,
|
num_attention_heads: int,
|
||||||
@@ -574,10 +609,11 @@ class TransformerArgsPreprocessor:
|
|||||||
double_precision_rope: bool,
|
double_precision_rope: bool,
|
||||||
positional_embedding_theta: float,
|
positional_embedding_theta: float,
|
||||||
rope_type: LTXRopeType,
|
rope_type: LTXRopeType,
|
||||||
|
caption_projection: torch.nn.Module | None = None,
|
||||||
|
prompt_adaln: AdaLayerNormSingle | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.patchify_proj = patchify_proj
|
self.patchify_proj = patchify_proj
|
||||||
self.adaln = adaln
|
self.adaln = adaln
|
||||||
self.caption_projection = caption_projection
|
|
||||||
self.inner_dim = inner_dim
|
self.inner_dim = inner_dim
|
||||||
self.max_pos = max_pos
|
self.max_pos = max_pos
|
||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
@@ -586,18 +622,18 @@ class TransformerArgsPreprocessor:
|
|||||||
self.double_precision_rope = double_precision_rope
|
self.double_precision_rope = double_precision_rope
|
||||||
self.positional_embedding_theta = positional_embedding_theta
|
self.positional_embedding_theta = positional_embedding_theta
|
||||||
self.rope_type = rope_type
|
self.rope_type = rope_type
|
||||||
|
self.caption_projection = caption_projection
|
||||||
|
self.prompt_adaln = prompt_adaln
|
||||||
|
|
||||||
def _prepare_timestep(
|
def _prepare_timestep(
|
||||||
self, timestep: torch.Tensor, batch_size: int, hidden_dtype: torch.dtype
|
self, timestep: torch.Tensor, adaln: AdaLayerNormSingle, batch_size: int, hidden_dtype: torch.dtype
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Prepare timestep embeddings."""
|
"""Prepare timestep embeddings."""
|
||||||
|
timestep_scaled = timestep * self.timestep_scale_multiplier
|
||||||
timestep = timestep * self.timestep_scale_multiplier
|
timestep, embedded_timestep = adaln(
|
||||||
timestep, embedded_timestep = self.adaln(
|
timestep_scaled.flatten(),
|
||||||
timestep.flatten(),
|
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||||
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
||||||
@@ -607,14 +643,12 @@ class TransformerArgsPreprocessor:
|
|||||||
self,
|
self,
|
||||||
context: torch.Tensor,
|
context: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
attention_mask: torch.Tensor | None = None,
|
) -> torch.Tensor:
|
||||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
||||||
"""Prepare context for transformer blocks."""
|
"""Prepare context for transformer blocks."""
|
||||||
batch_size = x.shape[0]
|
if self.caption_projection is not None:
|
||||||
context = self.caption_projection(context)
|
context = self.caption_projection(context)
|
||||||
context = context.view(batch_size, -1, x.shape[-1])
|
batch_size = x.shape[0]
|
||||||
|
return context.view(batch_size, -1, x.shape[-1])
|
||||||
return context, attention_mask
|
|
||||||
|
|
||||||
def _prepare_attention_mask(self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype) -> torch.Tensor | None:
|
def _prepare_attention_mask(self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype) -> torch.Tensor | None:
|
||||||
"""Prepare attention mask."""
|
"""Prepare attention mask."""
|
||||||
@@ -625,6 +659,34 @@ class TransformerArgsPreprocessor:
|
|||||||
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||||
) * torch.finfo(x_dtype).max
|
) * torch.finfo(x_dtype).max
|
||||||
|
|
||||||
|
def _prepare_self_attention_mask(
|
||||||
|
self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype
|
||||||
|
) -> torch.Tensor | None:
|
||||||
|
"""Prepare self-attention mask by converting [0,1] values to additive log-space bias.
|
||||||
|
Input shape: (B, T, T) with values in [0, 1].
|
||||||
|
Output shape: (B, 1, T, T) with 0.0 for full attention and a large negative value
|
||||||
|
for masked positions.
|
||||||
|
Positions with attention_mask <= 0 are fully masked (mapped to the dtype's minimum
|
||||||
|
representable value). Strictly positive entries are converted via log-space for
|
||||||
|
smooth attenuation, with small values clamped for numerical stability.
|
||||||
|
Returns None if input is None (no masking).
|
||||||
|
"""
|
||||||
|
if attention_mask is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Convert [0, 1] attention mask to additive log-space bias:
|
||||||
|
# 1.0 -> log(1.0) = 0.0 (no bias, full attention)
|
||||||
|
# 0.0 -> finfo.min (fully masked)
|
||||||
|
finfo = torch.finfo(x_dtype)
|
||||||
|
eps = finfo.tiny
|
||||||
|
|
||||||
|
bias = torch.full_like(attention_mask, finfo.min, dtype=x_dtype)
|
||||||
|
positive = attention_mask > 0
|
||||||
|
if positive.any():
|
||||||
|
bias[positive] = torch.log(attention_mask[positive].clamp(min=eps)).to(x_dtype)
|
||||||
|
|
||||||
|
return bias.unsqueeze(1) # (B, 1, T, T) for head broadcast
|
||||||
|
|
||||||
def _prepare_positional_embeddings(
|
def _prepare_positional_embeddings(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
@@ -652,11 +714,20 @@ class TransformerArgsPreprocessor:
|
|||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self,
|
||||||
modality: Modality,
|
modality: Modality,
|
||||||
|
cross_modality: Modality | None = None, # noqa: ARG002
|
||||||
) -> TransformerArgs:
|
) -> TransformerArgs:
|
||||||
x = self.patchify_proj(modality.latent)
|
x = self.patchify_proj(modality.latent)
|
||||||
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], modality.latent.dtype)
|
batch_size = x.shape[0]
|
||||||
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
|
timestep, embedded_timestep = self._prepare_timestep(
|
||||||
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
|
modality.timesteps, self.adaln, batch_size, modality.latent.dtype
|
||||||
|
)
|
||||||
|
prompt_timestep = None
|
||||||
|
if self.prompt_adaln is not None:
|
||||||
|
prompt_timestep, _ = self._prepare_timestep(
|
||||||
|
modality.sigma, self.prompt_adaln, batch_size, modality.latent.dtype
|
||||||
|
)
|
||||||
|
context = self._prepare_context(modality.context, x)
|
||||||
|
attention_mask = self._prepare_attention_mask(modality.context_mask, modality.latent.dtype)
|
||||||
pe = self._prepare_positional_embeddings(
|
pe = self._prepare_positional_embeddings(
|
||||||
positions=modality.positions,
|
positions=modality.positions,
|
||||||
inner_dim=self.inner_dim,
|
inner_dim=self.inner_dim,
|
||||||
@@ -665,6 +736,7 @@ class TransformerArgsPreprocessor:
|
|||||||
num_attention_heads=self.num_attention_heads,
|
num_attention_heads=self.num_attention_heads,
|
||||||
x_dtype=modality.latent.dtype,
|
x_dtype=modality.latent.dtype,
|
||||||
)
|
)
|
||||||
|
self_attention_mask = self._prepare_self_attention_mask(modality.attention_mask, modality.latent.dtype)
|
||||||
return TransformerArgs(
|
return TransformerArgs(
|
||||||
x=x,
|
x=x,
|
||||||
context=context,
|
context=context,
|
||||||
@@ -676,6 +748,8 @@ class TransformerArgsPreprocessor:
|
|||||||
cross_scale_shift_timestep=None,
|
cross_scale_shift_timestep=None,
|
||||||
cross_gate_timestep=None,
|
cross_gate_timestep=None,
|
||||||
enabled=modality.enabled,
|
enabled=modality.enabled,
|
||||||
|
prompt_timestep=prompt_timestep,
|
||||||
|
self_attention_mask=self_attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -684,7 +758,6 @@ class MultiModalTransformerArgsPreprocessor:
|
|||||||
self,
|
self,
|
||||||
patchify_proj: torch.nn.Linear,
|
patchify_proj: torch.nn.Linear,
|
||||||
adaln: AdaLayerNormSingle,
|
adaln: AdaLayerNormSingle,
|
||||||
caption_projection: PixArtAlphaTextProjection,
|
|
||||||
cross_scale_shift_adaln: AdaLayerNormSingle,
|
cross_scale_shift_adaln: AdaLayerNormSingle,
|
||||||
cross_gate_adaln: AdaLayerNormSingle,
|
cross_gate_adaln: AdaLayerNormSingle,
|
||||||
inner_dim: int,
|
inner_dim: int,
|
||||||
@@ -698,11 +771,12 @@ class MultiModalTransformerArgsPreprocessor:
|
|||||||
positional_embedding_theta: float,
|
positional_embedding_theta: float,
|
||||||
rope_type: LTXRopeType,
|
rope_type: LTXRopeType,
|
||||||
av_ca_timestep_scale_multiplier: int,
|
av_ca_timestep_scale_multiplier: int,
|
||||||
|
caption_projection: torch.nn.Module | None = None,
|
||||||
|
prompt_adaln: AdaLayerNormSingle | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.simple_preprocessor = TransformerArgsPreprocessor(
|
self.simple_preprocessor = TransformerArgsPreprocessor(
|
||||||
patchify_proj=patchify_proj,
|
patchify_proj=patchify_proj,
|
||||||
adaln=adaln,
|
adaln=adaln,
|
||||||
caption_projection=caption_projection,
|
|
||||||
inner_dim=inner_dim,
|
inner_dim=inner_dim,
|
||||||
max_pos=max_pos,
|
max_pos=max_pos,
|
||||||
num_attention_heads=num_attention_heads,
|
num_attention_heads=num_attention_heads,
|
||||||
@@ -711,6 +785,8 @@ class MultiModalTransformerArgsPreprocessor:
|
|||||||
double_precision_rope=double_precision_rope,
|
double_precision_rope=double_precision_rope,
|
||||||
positional_embedding_theta=positional_embedding_theta,
|
positional_embedding_theta=positional_embedding_theta,
|
||||||
rope_type=rope_type,
|
rope_type=rope_type,
|
||||||
|
caption_projection=caption_projection,
|
||||||
|
prompt_adaln=prompt_adaln,
|
||||||
)
|
)
|
||||||
self.cross_scale_shift_adaln = cross_scale_shift_adaln
|
self.cross_scale_shift_adaln = cross_scale_shift_adaln
|
||||||
self.cross_gate_adaln = cross_gate_adaln
|
self.cross_gate_adaln = cross_gate_adaln
|
||||||
@@ -721,8 +797,22 @@ class MultiModalTransformerArgsPreprocessor:
|
|||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self,
|
||||||
modality: Modality,
|
modality: Modality,
|
||||||
|
cross_modality: Modality | None = None,
|
||||||
) -> TransformerArgs:
|
) -> TransformerArgs:
|
||||||
transformer_args = self.simple_preprocessor.prepare(modality)
|
transformer_args = self.simple_preprocessor.prepare(modality)
|
||||||
|
if cross_modality is None:
|
||||||
|
return transformer_args
|
||||||
|
|
||||||
|
if cross_modality.sigma.numel() > 1:
|
||||||
|
if cross_modality.sigma.shape[0] != modality.timesteps.shape[0]:
|
||||||
|
raise ValueError("Cross modality sigma must have the same batch size as the modality")
|
||||||
|
if cross_modality.sigma.ndim != 1:
|
||||||
|
raise ValueError("Cross modality sigma must be a 1D tensor")
|
||||||
|
|
||||||
|
cross_timestep = cross_modality.sigma.view(
|
||||||
|
modality.timesteps.shape[0], 1, *[1] * len(modality.timesteps.shape[2:])
|
||||||
|
)
|
||||||
|
|
||||||
cross_pe = self.simple_preprocessor._prepare_positional_embeddings(
|
cross_pe = self.simple_preprocessor._prepare_positional_embeddings(
|
||||||
positions=modality.positions[:, 0:1, :],
|
positions=modality.positions[:, 0:1, :],
|
||||||
inner_dim=self.audio_cross_attention_dim,
|
inner_dim=self.audio_cross_attention_dim,
|
||||||
@@ -733,7 +823,7 @@ class MultiModalTransformerArgsPreprocessor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep(
|
cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep(
|
||||||
timestep=modality.timesteps,
|
timestep=cross_timestep,
|
||||||
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
|
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
|
||||||
batch_size=transformer_args.x.shape[0],
|
batch_size=transformer_args.x.shape[0],
|
||||||
hidden_dtype=modality.latent.dtype,
|
hidden_dtype=modality.latent.dtype,
|
||||||
@@ -748,7 +838,7 @@ class MultiModalTransformerArgsPreprocessor:
|
|||||||
|
|
||||||
def _prepare_cross_attention_timestep(
|
def _prepare_cross_attention_timestep(
|
||||||
self,
|
self,
|
||||||
timestep: torch.Tensor,
|
timestep: torch.Tensor | None,
|
||||||
timestep_scale_multiplier: int,
|
timestep_scale_multiplier: int,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
hidden_dtype: torch.dtype,
|
hidden_dtype: torch.dtype,
|
||||||
@@ -778,6 +868,8 @@ class TransformerConfig:
|
|||||||
heads: int
|
heads: int
|
||||||
d_head: int
|
d_head: int
|
||||||
context_dim: int
|
context_dim: int
|
||||||
|
apply_gated_attention: bool = False
|
||||||
|
cross_attention_adaln: bool = False
|
||||||
|
|
||||||
|
|
||||||
class BasicAVTransformerBlock(torch.nn.Module):
|
class BasicAVTransformerBlock(torch.nn.Module):
|
||||||
@@ -800,6 +892,7 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
|||||||
context_dim=None,
|
context_dim=None,
|
||||||
rope_type=rope_type,
|
rope_type=rope_type,
|
||||||
norm_eps=norm_eps,
|
norm_eps=norm_eps,
|
||||||
|
apply_gated_attention=video.apply_gated_attention,
|
||||||
)
|
)
|
||||||
self.attn2 = Attention(
|
self.attn2 = Attention(
|
||||||
query_dim=video.dim,
|
query_dim=video.dim,
|
||||||
@@ -808,9 +901,11 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
|||||||
dim_head=video.d_head,
|
dim_head=video.d_head,
|
||||||
rope_type=rope_type,
|
rope_type=rope_type,
|
||||||
norm_eps=norm_eps,
|
norm_eps=norm_eps,
|
||||||
|
apply_gated_attention=video.apply_gated_attention,
|
||||||
)
|
)
|
||||||
self.ff = FeedForward(video.dim, dim_out=video.dim)
|
self.ff = FeedForward(video.dim, dim_out=video.dim)
|
||||||
self.scale_shift_table = torch.nn.Parameter(torch.empty(6, video.dim))
|
video_sst_size = adaln_embedding_coefficient(video.cross_attention_adaln)
|
||||||
|
self.scale_shift_table = torch.nn.Parameter(torch.empty(video_sst_size, video.dim))
|
||||||
|
|
||||||
if audio is not None:
|
if audio is not None:
|
||||||
self.audio_attn1 = Attention(
|
self.audio_attn1 = Attention(
|
||||||
@@ -820,6 +915,7 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
|||||||
context_dim=None,
|
context_dim=None,
|
||||||
rope_type=rope_type,
|
rope_type=rope_type,
|
||||||
norm_eps=norm_eps,
|
norm_eps=norm_eps,
|
||||||
|
apply_gated_attention=audio.apply_gated_attention,
|
||||||
)
|
)
|
||||||
self.audio_attn2 = Attention(
|
self.audio_attn2 = Attention(
|
||||||
query_dim=audio.dim,
|
query_dim=audio.dim,
|
||||||
@@ -828,9 +924,11 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
|||||||
dim_head=audio.d_head,
|
dim_head=audio.d_head,
|
||||||
rope_type=rope_type,
|
rope_type=rope_type,
|
||||||
norm_eps=norm_eps,
|
norm_eps=norm_eps,
|
||||||
|
apply_gated_attention=audio.apply_gated_attention,
|
||||||
)
|
)
|
||||||
self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim)
|
self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim)
|
||||||
self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(6, audio.dim))
|
audio_sst_size = adaln_embedding_coefficient(audio.cross_attention_adaln)
|
||||||
|
self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(audio_sst_size, audio.dim))
|
||||||
|
|
||||||
if audio is not None and video is not None:
|
if audio is not None and video is not None:
|
||||||
# Q: Video, K,V: Audio
|
# Q: Video, K,V: Audio
|
||||||
@@ -841,6 +939,7 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
|||||||
dim_head=audio.d_head,
|
dim_head=audio.d_head,
|
||||||
rope_type=rope_type,
|
rope_type=rope_type,
|
||||||
norm_eps=norm_eps,
|
norm_eps=norm_eps,
|
||||||
|
apply_gated_attention=video.apply_gated_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Q: Audio, K,V: Video
|
# Q: Audio, K,V: Video
|
||||||
@@ -851,11 +950,21 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
|||||||
dim_head=audio.d_head,
|
dim_head=audio.d_head,
|
||||||
rope_type=rope_type,
|
rope_type=rope_type,
|
||||||
norm_eps=norm_eps,
|
norm_eps=norm_eps,
|
||||||
|
apply_gated_attention=audio.apply_gated_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.scale_shift_table_a2v_ca_audio = torch.nn.Parameter(torch.empty(5, audio.dim))
|
self.scale_shift_table_a2v_ca_audio = torch.nn.Parameter(torch.empty(5, audio.dim))
|
||||||
self.scale_shift_table_a2v_ca_video = torch.nn.Parameter(torch.empty(5, video.dim))
|
self.scale_shift_table_a2v_ca_video = torch.nn.Parameter(torch.empty(5, video.dim))
|
||||||
|
|
||||||
|
self.cross_attention_adaln = (video is not None and video.cross_attention_adaln) or (
|
||||||
|
audio is not None and audio.cross_attention_adaln
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.cross_attention_adaln and video is not None:
|
||||||
|
self.prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, video.dim))
|
||||||
|
if self.cross_attention_adaln and audio is not None:
|
||||||
|
self.audio_prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, audio.dim))
|
||||||
|
|
||||||
self.norm_eps = norm_eps
|
self.norm_eps = norm_eps
|
||||||
|
|
||||||
def get_ada_values(
|
def get_ada_values(
|
||||||
@@ -875,19 +984,49 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
|||||||
batch_size: int,
|
batch_size: int,
|
||||||
scale_shift_timestep: torch.Tensor,
|
scale_shift_timestep: torch.Tensor,
|
||||||
gate_timestep: torch.Tensor,
|
gate_timestep: torch.Tensor,
|
||||||
|
scale_shift_indices: slice,
|
||||||
num_scale_shift_values: int = 4,
|
num_scale_shift_values: int = 4,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
scale_shift_ada_values = self.get_ada_values(
|
scale_shift_ada_values = self.get_ada_values(
|
||||||
scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, slice(None, None)
|
scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, scale_shift_indices
|
||||||
)
|
)
|
||||||
gate_ada_values = self.get_ada_values(
|
gate_ada_values = self.get_ada_values(
|
||||||
scale_shift_table[num_scale_shift_values:, :], batch_size, gate_timestep, slice(None, None)
|
scale_shift_table[num_scale_shift_values:, :], batch_size, gate_timestep, slice(None, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values]
|
scale, shift = (t.squeeze(2) for t in scale_shift_ada_values)
|
||||||
gate_ada_values = [t.squeeze(2) for t in gate_ada_values]
|
(gate,) = (t.squeeze(2) for t in gate_ada_values)
|
||||||
|
|
||||||
return (*scale_shift_chunks, *gate_ada_values)
|
return scale, shift, gate
|
||||||
|
|
||||||
|
def _apply_text_cross_attention(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
context: torch.Tensor,
|
||||||
|
attn: Attention,
|
||||||
|
scale_shift_table: torch.Tensor,
|
||||||
|
prompt_scale_shift_table: torch.Tensor | None,
|
||||||
|
timestep: torch.Tensor,
|
||||||
|
prompt_timestep: torch.Tensor | None,
|
||||||
|
context_mask: torch.Tensor | None,
|
||||||
|
cross_attention_adaln: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Apply text cross-attention, with optional AdaLN modulation."""
|
||||||
|
if cross_attention_adaln:
|
||||||
|
shift_q, scale_q, gate = self.get_ada_values(scale_shift_table, x.shape[0], timestep, slice(6, 9))
|
||||||
|
return apply_cross_attention_adaln(
|
||||||
|
x,
|
||||||
|
context,
|
||||||
|
attn,
|
||||||
|
shift_q,
|
||||||
|
scale_q,
|
||||||
|
gate,
|
||||||
|
prompt_scale_shift_table,
|
||||||
|
prompt_timestep,
|
||||||
|
context_mask,
|
||||||
|
self.norm_eps,
|
||||||
|
)
|
||||||
|
return attn(rms_norm(x, eps=self.norm_eps), context=context, mask=context_mask)
|
||||||
|
|
||||||
def forward( # noqa: PLR0915
|
def forward( # noqa: PLR0915
|
||||||
self,
|
self,
|
||||||
@@ -895,7 +1034,11 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
|||||||
audio: TransformerArgs | None,
|
audio: TransformerArgs | None,
|
||||||
perturbations: BatchedPerturbationConfig | None = None,
|
perturbations: BatchedPerturbationConfig | None = None,
|
||||||
) -> tuple[TransformerArgs | None, TransformerArgs | None]:
|
) -> tuple[TransformerArgs | None, TransformerArgs | None]:
|
||||||
batch_size = video.x.shape[0]
|
if video is None and audio is None:
|
||||||
|
raise ValueError("At least one of video or audio must be provided")
|
||||||
|
|
||||||
|
batch_size = (video or audio).x.shape[0]
|
||||||
|
|
||||||
if perturbations is None:
|
if perturbations is None:
|
||||||
perturbations = BatchedPerturbationConfig.empty(batch_size)
|
perturbations = BatchedPerturbationConfig.empty(batch_size)
|
||||||
|
|
||||||
@@ -912,63 +1055,103 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
|||||||
vshift_msa, vscale_msa, vgate_msa = self.get_ada_values(
|
vshift_msa, vscale_msa, vgate_msa = self.get_ada_values(
|
||||||
self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)
|
self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)
|
||||||
)
|
)
|
||||||
if not perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx):
|
|
||||||
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
|
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
|
||||||
v_mask = perturbations.mask_like(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx, vx)
|
del vshift_msa, vscale_msa
|
||||||
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa * v_mask
|
|
||||||
|
|
||||||
vx = vx + self.attn2(rms_norm(vx, eps=self.norm_eps), context=video.context, mask=video.context_mask)
|
all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx)
|
||||||
|
none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx)
|
||||||
del vshift_msa, vscale_msa, vgate_msa
|
v_mask = (
|
||||||
|
perturbations.mask_like(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx, vx)
|
||||||
|
if not all_perturbed and not none_perturbed
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
vx = (
|
||||||
|
vx
|
||||||
|
+ self.attn1(
|
||||||
|
norm_vx,
|
||||||
|
pe=video.positional_embeddings,
|
||||||
|
mask=video.self_attention_mask,
|
||||||
|
perturbation_mask=v_mask,
|
||||||
|
all_perturbed=all_perturbed,
|
||||||
|
)
|
||||||
|
* vgate_msa
|
||||||
|
)
|
||||||
|
del vgate_msa, norm_vx, v_mask
|
||||||
|
vx = vx + self._apply_text_cross_attention(
|
||||||
|
vx,
|
||||||
|
video.context,
|
||||||
|
self.attn2,
|
||||||
|
self.scale_shift_table,
|
||||||
|
getattr(self, "prompt_scale_shift_table", None),
|
||||||
|
video.timesteps,
|
||||||
|
video.prompt_timestep,
|
||||||
|
video.context_mask,
|
||||||
|
cross_attention_adaln=self.cross_attention_adaln,
|
||||||
|
)
|
||||||
|
|
||||||
if run_ax:
|
if run_ax:
|
||||||
ashift_msa, ascale_msa, agate_msa = self.get_ada_values(
|
ashift_msa, ascale_msa, agate_msa = self.get_ada_values(
|
||||||
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
|
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx):
|
|
||||||
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
|
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
|
||||||
a_mask = perturbations.mask_like(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx, ax)
|
del ashift_msa, ascale_msa
|
||||||
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa * a_mask
|
all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx)
|
||||||
|
none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx)
|
||||||
ax = ax + self.audio_attn2(rms_norm(ax, eps=self.norm_eps), context=audio.context, mask=audio.context_mask)
|
a_mask = (
|
||||||
|
perturbations.mask_like(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx, ax)
|
||||||
del ashift_msa, ascale_msa, agate_msa
|
if not all_perturbed and not none_perturbed
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
ax = (
|
||||||
|
ax
|
||||||
|
+ self.audio_attn1(
|
||||||
|
norm_ax,
|
||||||
|
pe=audio.positional_embeddings,
|
||||||
|
mask=audio.self_attention_mask,
|
||||||
|
perturbation_mask=a_mask,
|
||||||
|
all_perturbed=all_perturbed,
|
||||||
|
)
|
||||||
|
* agate_msa
|
||||||
|
)
|
||||||
|
del agate_msa, norm_ax, a_mask
|
||||||
|
ax = ax + self._apply_text_cross_attention(
|
||||||
|
ax,
|
||||||
|
audio.context,
|
||||||
|
self.audio_attn2,
|
||||||
|
self.audio_scale_shift_table,
|
||||||
|
getattr(self, "audio_prompt_scale_shift_table", None),
|
||||||
|
audio.timesteps,
|
||||||
|
audio.prompt_timestep,
|
||||||
|
audio.context_mask,
|
||||||
|
cross_attention_adaln=self.cross_attention_adaln,
|
||||||
|
)
|
||||||
|
|
||||||
# Audio - Video cross attention.
|
# Audio - Video cross attention.
|
||||||
if run_a2v or run_v2a:
|
if run_a2v or run_v2a:
|
||||||
vx_norm3 = rms_norm(vx, eps=self.norm_eps)
|
vx_norm3 = rms_norm(vx, eps=self.norm_eps)
|
||||||
ax_norm3 = rms_norm(ax, eps=self.norm_eps)
|
ax_norm3 = rms_norm(ax, eps=self.norm_eps)
|
||||||
|
|
||||||
(
|
if run_a2v and not perturbations.all_in_batch(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx):
|
||||||
scale_ca_audio_hidden_states_a2v,
|
scale_ca_video_a2v, shift_ca_video_a2v, gate_out_a2v = self.get_av_ca_ada_values(
|
||||||
shift_ca_audio_hidden_states_a2v,
|
|
||||||
scale_ca_audio_hidden_states_v2a,
|
|
||||||
shift_ca_audio_hidden_states_v2a,
|
|
||||||
gate_out_v2a,
|
|
||||||
) = self.get_av_ca_ada_values(
|
|
||||||
self.scale_shift_table_a2v_ca_audio,
|
|
||||||
ax.shape[0],
|
|
||||||
audio.cross_scale_shift_timestep,
|
|
||||||
audio.cross_gate_timestep,
|
|
||||||
)
|
|
||||||
|
|
||||||
(
|
|
||||||
scale_ca_video_hidden_states_a2v,
|
|
||||||
shift_ca_video_hidden_states_a2v,
|
|
||||||
scale_ca_video_hidden_states_v2a,
|
|
||||||
shift_ca_video_hidden_states_v2a,
|
|
||||||
gate_out_a2v,
|
|
||||||
) = self.get_av_ca_ada_values(
|
|
||||||
self.scale_shift_table_a2v_ca_video,
|
self.scale_shift_table_a2v_ca_video,
|
||||||
vx.shape[0],
|
vx.shape[0],
|
||||||
video.cross_scale_shift_timestep,
|
video.cross_scale_shift_timestep,
|
||||||
video.cross_gate_timestep,
|
video.cross_gate_timestep,
|
||||||
|
slice(0, 2),
|
||||||
)
|
)
|
||||||
|
vx_scaled = vx_norm3 * (1 + scale_ca_video_a2v) + shift_ca_video_a2v
|
||||||
|
del scale_ca_video_a2v, shift_ca_video_a2v
|
||||||
|
|
||||||
if run_a2v:
|
scale_ca_audio_a2v, shift_ca_audio_a2v, _ = self.get_av_ca_ada_values(
|
||||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v) + shift_ca_video_hidden_states_a2v
|
self.scale_shift_table_a2v_ca_audio,
|
||||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
|
ax.shape[0],
|
||||||
|
audio.cross_scale_shift_timestep,
|
||||||
|
audio.cross_gate_timestep,
|
||||||
|
slice(0, 2),
|
||||||
|
)
|
||||||
|
ax_scaled = ax_norm3 * (1 + scale_ca_audio_a2v) + shift_ca_audio_a2v
|
||||||
|
del scale_ca_audio_a2v, shift_ca_audio_a2v
|
||||||
a2v_mask = perturbations.mask_like(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx, vx)
|
a2v_mask = perturbations.mask_like(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx, vx)
|
||||||
vx = vx + (
|
vx = vx + (
|
||||||
self.audio_to_video_attn(
|
self.audio_to_video_attn(
|
||||||
@@ -980,10 +1163,27 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
|||||||
* gate_out_a2v
|
* gate_out_a2v
|
||||||
* a2v_mask
|
* a2v_mask
|
||||||
)
|
)
|
||||||
|
del gate_out_a2v, a2v_mask, vx_scaled, ax_scaled
|
||||||
|
|
||||||
if run_v2a:
|
if run_v2a and not perturbations.all_in_batch(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx):
|
||||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
|
scale_ca_audio_v2a, shift_ca_audio_v2a, gate_out_v2a = self.get_av_ca_ada_values(
|
||||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
|
self.scale_shift_table_a2v_ca_audio,
|
||||||
|
ax.shape[0],
|
||||||
|
audio.cross_scale_shift_timestep,
|
||||||
|
audio.cross_gate_timestep,
|
||||||
|
slice(2, 4),
|
||||||
|
)
|
||||||
|
ax_scaled = ax_norm3 * (1 + scale_ca_audio_v2a) + shift_ca_audio_v2a
|
||||||
|
del scale_ca_audio_v2a, shift_ca_audio_v2a
|
||||||
|
scale_ca_video_v2a, shift_ca_video_v2a, _ = self.get_av_ca_ada_values(
|
||||||
|
self.scale_shift_table_a2v_ca_video,
|
||||||
|
vx.shape[0],
|
||||||
|
video.cross_scale_shift_timestep,
|
||||||
|
video.cross_gate_timestep,
|
||||||
|
slice(2, 4),
|
||||||
|
)
|
||||||
|
vx_scaled = vx_norm3 * (1 + scale_ca_video_v2a) + shift_ca_video_v2a
|
||||||
|
del scale_ca_video_v2a, shift_ca_video_v2a
|
||||||
v2a_mask = perturbations.mask_like(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx, ax)
|
v2a_mask = perturbations.mask_like(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx, ax)
|
||||||
ax = ax + (
|
ax = ax + (
|
||||||
self.video_to_audio_attn(
|
self.video_to_audio_attn(
|
||||||
@@ -995,40 +1195,53 @@ class BasicAVTransformerBlock(torch.nn.Module):
|
|||||||
* gate_out_v2a
|
* gate_out_v2a
|
||||||
* v2a_mask
|
* v2a_mask
|
||||||
)
|
)
|
||||||
|
del gate_out_v2a, v2a_mask, ax_scaled, vx_scaled
|
||||||
|
|
||||||
del gate_out_a2v, gate_out_v2a
|
del vx_norm3, ax_norm3
|
||||||
del (
|
|
||||||
scale_ca_video_hidden_states_a2v,
|
|
||||||
shift_ca_video_hidden_states_a2v,
|
|
||||||
scale_ca_audio_hidden_states_a2v,
|
|
||||||
shift_ca_audio_hidden_states_a2v,
|
|
||||||
scale_ca_video_hidden_states_v2a,
|
|
||||||
shift_ca_video_hidden_states_v2a,
|
|
||||||
scale_ca_audio_hidden_states_v2a,
|
|
||||||
shift_ca_audio_hidden_states_v2a,
|
|
||||||
)
|
|
||||||
|
|
||||||
if run_vx:
|
if run_vx:
|
||||||
vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(
|
vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(
|
||||||
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, None)
|
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, 6)
|
||||||
)
|
)
|
||||||
vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
|
vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
|
||||||
vx = vx + self.ff(vx_scaled) * vgate_mlp
|
vx = vx + self.ff(vx_scaled) * vgate_mlp
|
||||||
|
|
||||||
del vshift_mlp, vscale_mlp, vgate_mlp
|
del vshift_mlp, vscale_mlp, vgate_mlp, vx_scaled
|
||||||
|
|
||||||
if run_ax:
|
if run_ax:
|
||||||
ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(
|
ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(
|
||||||
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, None)
|
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, 6)
|
||||||
)
|
)
|
||||||
ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
|
ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
|
||||||
ax = ax + self.audio_ff(ax_scaled) * agate_mlp
|
ax = ax + self.audio_ff(ax_scaled) * agate_mlp
|
||||||
|
|
||||||
del ashift_mlp, ascale_mlp, agate_mlp
|
del ashift_mlp, ascale_mlp, agate_mlp, ax_scaled
|
||||||
|
|
||||||
return replace(video, x=vx) if video is not None else None, replace(audio, x=ax) if audio is not None else None
|
return replace(video, x=vx) if video is not None else None, replace(audio, x=ax) if audio is not None else None
|
||||||
|
|
||||||
|
|
||||||
|
def apply_cross_attention_adaln(
|
||||||
|
x: torch.Tensor,
|
||||||
|
context: torch.Tensor,
|
||||||
|
attn: Attention,
|
||||||
|
q_shift: torch.Tensor,
|
||||||
|
q_scale: torch.Tensor,
|
||||||
|
q_gate: torch.Tensor,
|
||||||
|
prompt_scale_shift_table: torch.Tensor,
|
||||||
|
prompt_timestep: torch.Tensor,
|
||||||
|
context_mask: torch.Tensor | None = None,
|
||||||
|
norm_eps: float = 1e-6,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
shift_kv, scale_kv = (
|
||||||
|
prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)
|
||||||
|
+ prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1)
|
||||||
|
).unbind(dim=2)
|
||||||
|
attn_input = rms_norm(x, eps=norm_eps) * (1 + q_scale) + q_shift
|
||||||
|
encoder_hidden_states = context * (1 + scale_kv) + shift_kv
|
||||||
|
return attn(attn_input, context=encoder_hidden_states, mask=context_mask) * q_gate
|
||||||
|
|
||||||
|
|
||||||
class GELUApprox(torch.nn.Module):
|
class GELUApprox(torch.nn.Module):
|
||||||
def __init__(self, dim_in: int, dim_out: int) -> None:
|
def __init__(self, dim_in: int, dim_out: int) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -1067,6 +1280,7 @@ class LTXModel(torch.nn.Module):
|
|||||||
LTX model transformer implementation.
|
LTX model transformer implementation.
|
||||||
This class implements the transformer blocks for the LTX model.
|
This class implements the transformer blocks for the LTX model.
|
||||||
"""
|
"""
|
||||||
|
_repeated_blocks = ["BasicAVTransformerBlock"]
|
||||||
|
|
||||||
def __init__( # noqa: PLR0913
|
def __init__( # noqa: PLR0913
|
||||||
self,
|
self,
|
||||||
@@ -1093,6 +1307,8 @@ class LTXModel(torch.nn.Module):
|
|||||||
av_ca_timestep_scale_multiplier: int = 1000,
|
av_ca_timestep_scale_multiplier: int = 1000,
|
||||||
rope_type: LTXRopeType = LTXRopeType.SPLIT,
|
rope_type: LTXRopeType = LTXRopeType.SPLIT,
|
||||||
double_precision_rope: bool = True,
|
double_precision_rope: bool = True,
|
||||||
|
apply_gated_attention: bool = False,
|
||||||
|
cross_attention_adaln: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._enable_gradient_checkpointing = False
|
self._enable_gradient_checkpointing = False
|
||||||
@@ -1102,6 +1318,7 @@ class LTXModel(torch.nn.Module):
|
|||||||
self.timestep_scale_multiplier = timestep_scale_multiplier
|
self.timestep_scale_multiplier = timestep_scale_multiplier
|
||||||
self.positional_embedding_theta = positional_embedding_theta
|
self.positional_embedding_theta = positional_embedding_theta
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
|
self.cross_attention_adaln = cross_attention_adaln
|
||||||
cross_pe_max_pos = None
|
cross_pe_max_pos = None
|
||||||
if model_type.is_video_enabled():
|
if model_type.is_video_enabled():
|
||||||
if positional_embedding_max_pos is None:
|
if positional_embedding_max_pos is None:
|
||||||
@@ -1144,8 +1361,13 @@ class LTXModel(torch.nn.Module):
|
|||||||
audio_attention_head_dim=audio_attention_head_dim if model_type.is_audio_enabled() else 0,
|
audio_attention_head_dim=audio_attention_head_dim if model_type.is_audio_enabled() else 0,
|
||||||
audio_cross_attention_dim=audio_cross_attention_dim,
|
audio_cross_attention_dim=audio_cross_attention_dim,
|
||||||
norm_eps=norm_eps,
|
norm_eps=norm_eps,
|
||||||
|
apply_gated_attention=apply_gated_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _adaln_embedding_coefficient(self) -> int:
|
||||||
|
return adaln_embedding_coefficient(self.cross_attention_adaln)
|
||||||
|
|
||||||
def _init_video(
|
def _init_video(
|
||||||
self,
|
self,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
@@ -1156,10 +1378,11 @@ class LTXModel(torch.nn.Module):
|
|||||||
"""Initialize video-specific components."""
|
"""Initialize video-specific components."""
|
||||||
# Video input components
|
# Video input components
|
||||||
self.patchify_proj = torch.nn.Linear(in_channels, self.inner_dim, bias=True)
|
self.patchify_proj = torch.nn.Linear(in_channels, self.inner_dim, bias=True)
|
||||||
|
self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=self._adaln_embedding_coefficient)
|
||||||
self.adaln_single = AdaLayerNormSingle(self.inner_dim)
|
self.prompt_adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None
|
||||||
|
|
||||||
# Video caption projection
|
# Video caption projection
|
||||||
|
if caption_channels is not None:
|
||||||
self.caption_projection = PixArtAlphaTextProjection(
|
self.caption_projection = PixArtAlphaTextProjection(
|
||||||
in_features=caption_channels,
|
in_features=caption_channels,
|
||||||
hidden_size=self.inner_dim,
|
hidden_size=self.inner_dim,
|
||||||
@@ -1182,11 +1405,11 @@ class LTXModel(torch.nn.Module):
|
|||||||
# Audio input components
|
# Audio input components
|
||||||
self.audio_patchify_proj = torch.nn.Linear(in_channels, self.audio_inner_dim, bias=True)
|
self.audio_patchify_proj = torch.nn.Linear(in_channels, self.audio_inner_dim, bias=True)
|
||||||
|
|
||||||
self.audio_adaln_single = AdaLayerNormSingle(
|
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=self._adaln_embedding_coefficient)
|
||||||
self.audio_inner_dim,
|
self.audio_prompt_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None
|
||||||
)
|
|
||||||
|
|
||||||
# Audio caption projection
|
# Audio caption projection
|
||||||
|
if caption_channels is not None:
|
||||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||||
in_features=caption_channels,
|
in_features=caption_channels,
|
||||||
hidden_size=self.audio_inner_dim,
|
hidden_size=self.audio_inner_dim,
|
||||||
@@ -1232,7 +1455,6 @@ class LTXModel(torch.nn.Module):
|
|||||||
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
||||||
patchify_proj=self.patchify_proj,
|
patchify_proj=self.patchify_proj,
|
||||||
adaln=self.adaln_single,
|
adaln=self.adaln_single,
|
||||||
caption_projection=self.caption_projection,
|
|
||||||
cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single,
|
cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single,
|
||||||
cross_gate_adaln=self.av_ca_a2v_gate_adaln_single,
|
cross_gate_adaln=self.av_ca_a2v_gate_adaln_single,
|
||||||
inner_dim=self.inner_dim,
|
inner_dim=self.inner_dim,
|
||||||
@@ -1246,11 +1468,12 @@ class LTXModel(torch.nn.Module):
|
|||||||
positional_embedding_theta=self.positional_embedding_theta,
|
positional_embedding_theta=self.positional_embedding_theta,
|
||||||
rope_type=self.rope_type,
|
rope_type=self.rope_type,
|
||||||
av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier,
|
av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier,
|
||||||
|
caption_projection=getattr(self, "caption_projection", None),
|
||||||
|
prompt_adaln=getattr(self, "prompt_adaln_single", None),
|
||||||
)
|
)
|
||||||
self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
||||||
patchify_proj=self.audio_patchify_proj,
|
patchify_proj=self.audio_patchify_proj,
|
||||||
adaln=self.audio_adaln_single,
|
adaln=self.audio_adaln_single,
|
||||||
caption_projection=self.audio_caption_projection,
|
|
||||||
cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single,
|
cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single,
|
||||||
cross_gate_adaln=self.av_ca_v2a_gate_adaln_single,
|
cross_gate_adaln=self.av_ca_v2a_gate_adaln_single,
|
||||||
inner_dim=self.audio_inner_dim,
|
inner_dim=self.audio_inner_dim,
|
||||||
@@ -1264,12 +1487,13 @@ class LTXModel(torch.nn.Module):
|
|||||||
positional_embedding_theta=self.positional_embedding_theta,
|
positional_embedding_theta=self.positional_embedding_theta,
|
||||||
rope_type=self.rope_type,
|
rope_type=self.rope_type,
|
||||||
av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier,
|
av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier,
|
||||||
|
caption_projection=getattr(self, "audio_caption_projection", None),
|
||||||
|
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
|
||||||
)
|
)
|
||||||
elif self.model_type.is_video_enabled():
|
elif self.model_type.is_video_enabled():
|
||||||
self.video_args_preprocessor = TransformerArgsPreprocessor(
|
self.video_args_preprocessor = TransformerArgsPreprocessor(
|
||||||
patchify_proj=self.patchify_proj,
|
patchify_proj=self.patchify_proj,
|
||||||
adaln=self.adaln_single,
|
adaln=self.adaln_single,
|
||||||
caption_projection=self.caption_projection,
|
|
||||||
inner_dim=self.inner_dim,
|
inner_dim=self.inner_dim,
|
||||||
max_pos=self.positional_embedding_max_pos,
|
max_pos=self.positional_embedding_max_pos,
|
||||||
num_attention_heads=self.num_attention_heads,
|
num_attention_heads=self.num_attention_heads,
|
||||||
@@ -1278,12 +1502,13 @@ class LTXModel(torch.nn.Module):
|
|||||||
double_precision_rope=self.double_precision_rope,
|
double_precision_rope=self.double_precision_rope,
|
||||||
positional_embedding_theta=self.positional_embedding_theta,
|
positional_embedding_theta=self.positional_embedding_theta,
|
||||||
rope_type=self.rope_type,
|
rope_type=self.rope_type,
|
||||||
|
caption_projection=getattr(self, "caption_projection", None),
|
||||||
|
prompt_adaln=getattr(self, "prompt_adaln_single", None),
|
||||||
)
|
)
|
||||||
elif self.model_type.is_audio_enabled():
|
elif self.model_type.is_audio_enabled():
|
||||||
self.audio_args_preprocessor = TransformerArgsPreprocessor(
|
self.audio_args_preprocessor = TransformerArgsPreprocessor(
|
||||||
patchify_proj=self.audio_patchify_proj,
|
patchify_proj=self.audio_patchify_proj,
|
||||||
adaln=self.audio_adaln_single,
|
adaln=self.audio_adaln_single,
|
||||||
caption_projection=self.audio_caption_projection,
|
|
||||||
inner_dim=self.audio_inner_dim,
|
inner_dim=self.audio_inner_dim,
|
||||||
max_pos=self.audio_positional_embedding_max_pos,
|
max_pos=self.audio_positional_embedding_max_pos,
|
||||||
num_attention_heads=self.audio_num_attention_heads,
|
num_attention_heads=self.audio_num_attention_heads,
|
||||||
@@ -1292,6 +1517,8 @@ class LTXModel(torch.nn.Module):
|
|||||||
double_precision_rope=self.double_precision_rope,
|
double_precision_rope=self.double_precision_rope,
|
||||||
positional_embedding_theta=self.positional_embedding_theta,
|
positional_embedding_theta=self.positional_embedding_theta,
|
||||||
rope_type=self.rope_type,
|
rope_type=self.rope_type,
|
||||||
|
caption_projection=getattr(self, "audio_caption_projection", None),
|
||||||
|
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_transformer_blocks(
|
def _init_transformer_blocks(
|
||||||
@@ -1302,6 +1529,7 @@ class LTXModel(torch.nn.Module):
|
|||||||
audio_attention_head_dim: int,
|
audio_attention_head_dim: int,
|
||||||
audio_cross_attention_dim: int,
|
audio_cross_attention_dim: int,
|
||||||
norm_eps: float,
|
norm_eps: float,
|
||||||
|
apply_gated_attention: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize transformer blocks for LTX."""
|
"""Initialize transformer blocks for LTX."""
|
||||||
video_config = (
|
video_config = (
|
||||||
@@ -1310,6 +1538,8 @@ class LTXModel(torch.nn.Module):
|
|||||||
heads=self.num_attention_heads,
|
heads=self.num_attention_heads,
|
||||||
d_head=attention_head_dim,
|
d_head=attention_head_dim,
|
||||||
context_dim=cross_attention_dim,
|
context_dim=cross_attention_dim,
|
||||||
|
apply_gated_attention=apply_gated_attention,
|
||||||
|
cross_attention_adaln=self.cross_attention_adaln,
|
||||||
)
|
)
|
||||||
if self.model_type.is_video_enabled()
|
if self.model_type.is_video_enabled()
|
||||||
else None
|
else None
|
||||||
@@ -1320,6 +1550,8 @@ class LTXModel(torch.nn.Module):
|
|||||||
heads=self.audio_num_attention_heads,
|
heads=self.audio_num_attention_heads,
|
||||||
d_head=audio_attention_head_dim,
|
d_head=audio_attention_head_dim,
|
||||||
context_dim=audio_cross_attention_dim,
|
context_dim=audio_cross_attention_dim,
|
||||||
|
apply_gated_attention=apply_gated_attention,
|
||||||
|
cross_attention_adaln=self.cross_attention_adaln,
|
||||||
)
|
)
|
||||||
if self.model_type.is_audio_enabled()
|
if self.model_type.is_audio_enabled()
|
||||||
else None
|
else None
|
||||||
@@ -1352,24 +1584,17 @@ class LTXModel(torch.nn.Module):
|
|||||||
video: TransformerArgs | None,
|
video: TransformerArgs | None,
|
||||||
audio: TransformerArgs | None,
|
audio: TransformerArgs | None,
|
||||||
perturbations: BatchedPerturbationConfig,
|
perturbations: BatchedPerturbationConfig,
|
||||||
|
use_gradient_checkpointing: bool = False,
|
||||||
|
use_gradient_checkpointing_offload: bool = False,
|
||||||
) -> tuple[TransformerArgs, TransformerArgs]:
|
) -> tuple[TransformerArgs, TransformerArgs]:
|
||||||
"""Process transformer blocks for LTXAV."""
|
"""Process transformer blocks for LTXAV."""
|
||||||
|
|
||||||
# Process transformer blocks
|
# Process transformer blocks
|
||||||
for block in self.transformer_blocks:
|
for block in self.transformer_blocks:
|
||||||
if self._enable_gradient_checkpointing and self.training:
|
video, audio = gradient_checkpoint_forward(
|
||||||
# Use gradient checkpointing to save memory during training.
|
|
||||||
# With use_reentrant=False, we can pass dataclasses directly -
|
|
||||||
# PyTorch will track all tensor leaves in the computation graph.
|
|
||||||
video, audio = torch.utils.checkpoint.checkpoint(
|
|
||||||
block,
|
block,
|
||||||
video,
|
use_gradient_checkpointing,
|
||||||
audio,
|
use_gradient_checkpointing_offload,
|
||||||
perturbations,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
video, audio = block(
|
|
||||||
video=video,
|
video=video,
|
||||||
audio=audio,
|
audio=audio,
|
||||||
perturbations=perturbations,
|
perturbations=perturbations,
|
||||||
@@ -1398,7 +1623,12 @@ class LTXModel(torch.nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def _forward(
|
def _forward(
|
||||||
self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig
|
self,
|
||||||
|
video: Modality | None,
|
||||||
|
audio: Modality | None,
|
||||||
|
perturbations: BatchedPerturbationConfig,
|
||||||
|
use_gradient_checkpointing: bool = False,
|
||||||
|
use_gradient_checkpointing_offload: bool = False,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Forward pass for LTX models.
|
Forward pass for LTX models.
|
||||||
@@ -1410,13 +1640,15 @@ class LTXModel(torch.nn.Module):
|
|||||||
if not self.model_type.is_audio_enabled() and audio is not None:
|
if not self.model_type.is_audio_enabled() and audio is not None:
|
||||||
raise ValueError("Audio is not enabled for this model")
|
raise ValueError("Audio is not enabled for this model")
|
||||||
|
|
||||||
video_args = self.video_args_preprocessor.prepare(video) if video is not None else None
|
video_args = self.video_args_preprocessor.prepare(video, audio) if video is not None else None
|
||||||
audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None
|
audio_args = self.audio_args_preprocessor.prepare(audio, video) if audio is not None else None
|
||||||
# Process transformer blocks
|
# Process transformer blocks
|
||||||
video_out, audio_out = self._process_transformer_blocks(
|
video_out, audio_out = self._process_transformer_blocks(
|
||||||
video=video_args,
|
video=video_args,
|
||||||
audio=audio_args,
|
audio=audio_args,
|
||||||
perturbations=perturbations,
|
perturbations=perturbations,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process output
|
# Process output
|
||||||
@@ -1440,12 +1672,12 @@ class LTXModel(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
return vx, ax
|
return vx, ax
|
||||||
|
|
||||||
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps):
|
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps, sigma, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False):
|
||||||
cross_pe_max_pos = None
|
cross_pe_max_pos = None
|
||||||
if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled():
|
if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled():
|
||||||
cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0])
|
cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0])
|
||||||
self._init_preprocessors(cross_pe_max_pos)
|
self._init_preprocessors(cross_pe_max_pos)
|
||||||
video = Modality(video_latents, video_timesteps, video_positions, video_context)
|
video = Modality(video_latents, sigma, video_timesteps, video_positions, video_context)
|
||||||
audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context)
|
audio = Modality(audio_latents, sigma, audio_timesteps, audio_positions, audio_context) if audio_latents is not None else None
|
||||||
vx, ax = self._forward(video=video, audio=audio, perturbations=None)
|
vx, ax = self._forward(video=video, audio=audio, perturbations=None, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload)
|
||||||
return vx, ax
|
return vx, ax
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from einops import rearrange
|
||||||
from transformers import Gemma3ForConditionalGeneration, Gemma3Config, AutoTokenizer
|
from transformers import Gemma3ForConditionalGeneration, Gemma3Config, AutoTokenizer
|
||||||
from .ltx2_dit import (LTXRopeType, generate_freq_grid_np, generate_freq_grid_pytorch, precompute_freqs_cis, Attention,
|
from .ltx2_dit import (LTXRopeType, generate_freq_grid_np, generate_freq_grid_pytorch, precompute_freqs_cis, Attention,
|
||||||
FeedForward)
|
FeedForward)
|
||||||
@@ -147,14 +150,14 @@ class LTXVGemmaTokenizer:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class GemmaFeaturesExtractorProjLinear(torch.nn.Module):
|
class GemmaFeaturesExtractorProjLinear(nn.Module):
|
||||||
"""
|
"""
|
||||||
Feature extractor module for Gemma models.
|
Feature extractor module for Gemma models.
|
||||||
This module applies a single linear projection to the input tensor.
|
This module applies a single linear projection to the input tensor.
|
||||||
It expects a flattened feature tensor of shape (batch_size, 3840*49).
|
It expects a flattened feature tensor of shape (batch_size, 3840*49).
|
||||||
The linear layer maps this to a (batch_size, 3840) embedding.
|
The linear layer maps this to a (batch_size, 3840) embedding.
|
||||||
Attributes:
|
Attributes:
|
||||||
aggregate_embed (torch.nn.Linear): Linear projection layer.
|
aggregate_embed (nn.Linear): Linear projection layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
@@ -163,26 +166,65 @@ class GemmaFeaturesExtractorProjLinear(torch.nn.Module):
|
|||||||
The input dimension is expected to be 3840 * 49, and the output is 3840.
|
The input dimension is expected to be 3840 * 49, and the output is 3840.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.aggregate_embed = torch.nn.Linear(3840 * 49, 3840, bias=False)
|
self.aggregate_embed = nn.Linear(3840 * 49, 3840, bias=False)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(
|
||||||
"""
|
self,
|
||||||
Forward pass for the feature extractor.
|
hidden_states: torch.Tensor,
|
||||||
Args:
|
attention_mask: torch.Tensor,
|
||||||
x (torch.Tensor): Input tensor of shape (batch_size, 3840 * 49).
|
padding_side: str = "left",
|
||||||
Returns:
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||||
torch.Tensor: Output tensor of shape (batch_size, 3840).
|
encoded = torch.stack(hidden_states, dim=-1) if isinstance(hidden_states, (list, tuple)) else hidden_states
|
||||||
"""
|
dtype = encoded.dtype
|
||||||
return self.aggregate_embed(x)
|
sequence_lengths = attention_mask.sum(dim=-1)
|
||||||
|
normed = _norm_and_concat_padded_batch(encoded, sequence_lengths, padding_side)
|
||||||
|
features = self.aggregate_embed(normed.to(dtype))
|
||||||
|
return features, features
|
||||||
|
|
||||||
|
|
||||||
class _BasicTransformerBlock1D(torch.nn.Module):
|
class GemmaSeperatedFeaturesExtractorProjLinear(nn.Module):
|
||||||
|
"""22B: per-token RMS norm → rescale → dual aggregate embeds"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_layers: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
video_inner_dim: int,
|
||||||
|
audio_inner_dim: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
in_dim = embedding_dim * num_layers
|
||||||
|
self.video_aggregate_embed = torch.nn.Linear(in_dim, video_inner_dim, bias=True)
|
||||||
|
self.audio_aggregate_embed = torch.nn.Linear(in_dim, audio_inner_dim, bias=True)
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
padding_side: str = "left", # noqa: ARG002
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||||
|
encoded = torch.stack(hidden_states, dim=-1) if isinstance(hidden_states, (list, tuple)) else hidden_states
|
||||||
|
normed = norm_and_concat_per_token_rms(encoded, attention_mask)
|
||||||
|
normed = normed.to(encoded.dtype)
|
||||||
|
v_dim = self.video_aggregate_embed.out_features
|
||||||
|
video = self.video_aggregate_embed(_rescale_norm(normed, v_dim, self.embedding_dim))
|
||||||
|
audio = None
|
||||||
|
if self.audio_aggregate_embed is not None:
|
||||||
|
a_dim = self.audio_aggregate_embed.out_features
|
||||||
|
audio = self.audio_aggregate_embed(_rescale_norm(normed, a_dim, self.embedding_dim))
|
||||||
|
return video, audio
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class _BasicTransformerBlock1D(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim: int,
|
dim: int,
|
||||||
heads: int,
|
heads: int,
|
||||||
dim_head: int,
|
dim_head: int,
|
||||||
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||||
|
apply_gated_attention: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -191,6 +233,7 @@ class _BasicTransformerBlock1D(torch.nn.Module):
|
|||||||
heads=heads,
|
heads=heads,
|
||||||
dim_head=dim_head,
|
dim_head=dim_head,
|
||||||
rope_type=rope_type,
|
rope_type=rope_type,
|
||||||
|
apply_gated_attention=apply_gated_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ff = FeedForward(
|
self.ff = FeedForward(
|
||||||
@@ -231,7 +274,7 @@ class _BasicTransformerBlock1D(torch.nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class Embeddings1DConnector(torch.nn.Module):
|
class Embeddings1DConnector(nn.Module):
|
||||||
"""
|
"""
|
||||||
Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or
|
Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or
|
||||||
other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can
|
other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can
|
||||||
@@ -263,6 +306,7 @@ class Embeddings1DConnector(torch.nn.Module):
|
|||||||
num_learnable_registers: int | None = 128,
|
num_learnable_registers: int | None = 128,
|
||||||
rope_type: LTXRopeType = LTXRopeType.SPLIT,
|
rope_type: LTXRopeType = LTXRopeType.SPLIT,
|
||||||
double_precision_rope: bool = True,
|
double_precision_rope: bool = True,
|
||||||
|
apply_gated_attention: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
@@ -274,13 +318,14 @@ class Embeddings1DConnector(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
self.rope_type = rope_type
|
self.rope_type = rope_type
|
||||||
self.double_precision_rope = double_precision_rope
|
self.double_precision_rope = double_precision_rope
|
||||||
self.transformer_1d_blocks = torch.nn.ModuleList(
|
self.transformer_1d_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
_BasicTransformerBlock1D(
|
_BasicTransformerBlock1D(
|
||||||
dim=self.inner_dim,
|
dim=self.inner_dim,
|
||||||
heads=num_attention_heads,
|
heads=num_attention_heads,
|
||||||
dim_head=attention_head_dim,
|
dim_head=attention_head_dim,
|
||||||
rope_type=rope_type,
|
rope_type=rope_type,
|
||||||
|
apply_gated_attention=apply_gated_attention,
|
||||||
)
|
)
|
||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
]
|
]
|
||||||
@@ -288,7 +333,7 @@ class Embeddings1DConnector(torch.nn.Module):
|
|||||||
|
|
||||||
self.num_learnable_registers = num_learnable_registers
|
self.num_learnable_registers = num_learnable_registers
|
||||||
if self.num_learnable_registers:
|
if self.num_learnable_registers:
|
||||||
self.learnable_registers = torch.nn.Parameter(
|
self.learnable_registers = nn.Parameter(
|
||||||
torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0
|
torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -307,7 +352,7 @@ class Embeddings1DConnector(torch.nn.Module):
|
|||||||
non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :]
|
non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :]
|
||||||
non_zero_nums = non_zero_hidden_states.shape[1]
|
non_zero_nums = non_zero_hidden_states.shape[1]
|
||||||
pad_length = hidden_states.shape[1] - non_zero_nums
|
pad_length = hidden_states.shape[1] - non_zero_nums
|
||||||
adjusted_hidden_states = torch.nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0)
|
adjusted_hidden_states = nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0)
|
||||||
flipped_mask = torch.flip(attention_mask_binary, dims=[1])
|
flipped_mask = torch.flip(attention_mask_binary, dims=[1])
|
||||||
hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers
|
hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers
|
||||||
|
|
||||||
@@ -358,9 +403,147 @@ class Embeddings1DConnector(torch.nn.Module):
|
|||||||
return hidden_states, attention_mask
|
return hidden_states, attention_mask
|
||||||
|
|
||||||
|
|
||||||
class LTX2TextEncoderPostModules(torch.nn.Module):
|
class LTX2TextEncoderPostModules(nn.Module):
|
||||||
def __init__(self,):
|
def __init__(
|
||||||
|
self,
|
||||||
|
separated_audio_video: bool = False,
|
||||||
|
embedding_dim_gemma: int = 3840,
|
||||||
|
num_layers_gemma: int = 49,
|
||||||
|
video_attention_heads: int = 32,
|
||||||
|
video_attention_head_dim: int = 128,
|
||||||
|
audio_attention_heads: int = 32,
|
||||||
|
audio_attention_head_dim: int = 64,
|
||||||
|
num_connector_layers: int = 2,
|
||||||
|
apply_gated_attention: bool = False,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if not separated_audio_video:
|
||||||
self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear()
|
self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear()
|
||||||
self.embeddings_connector = Embeddings1DConnector()
|
self.embeddings_connector = Embeddings1DConnector()
|
||||||
self.audio_embeddings_connector = Embeddings1DConnector()
|
self.audio_embeddings_connector = Embeddings1DConnector()
|
||||||
|
else:
|
||||||
|
# LTX-2.3
|
||||||
|
self.feature_extractor_linear = GemmaSeperatedFeaturesExtractorProjLinear(
|
||||||
|
num_layers_gemma, embedding_dim_gemma, video_attention_heads * video_attention_head_dim,
|
||||||
|
audio_attention_heads * audio_attention_head_dim)
|
||||||
|
self.embeddings_connector = Embeddings1DConnector(
|
||||||
|
attention_head_dim=video_attention_head_dim,
|
||||||
|
num_attention_heads=video_attention_heads,
|
||||||
|
num_layers=num_connector_layers,
|
||||||
|
apply_gated_attention=apply_gated_attention,
|
||||||
|
)
|
||||||
|
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||||
|
attention_head_dim=audio_attention_head_dim,
|
||||||
|
num_attention_heads=audio_attention_heads,
|
||||||
|
num_layers=num_connector_layers,
|
||||||
|
apply_gated_attention=apply_gated_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_embeddings(
|
||||||
|
self,
|
||||||
|
video_features: torch.Tensor,
|
||||||
|
audio_features: torch.Tensor | None,
|
||||||
|
additive_attention_mask: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]:
|
||||||
|
video_encoded, video_mask = self.embeddings_connector(video_features, additive_attention_mask)
|
||||||
|
video_encoded, binary_mask = _to_binary_mask(video_encoded, video_mask)
|
||||||
|
audio_encoded, _ = self.audio_embeddings_connector(audio_features, additive_attention_mask)
|
||||||
|
|
||||||
|
return video_encoded, audio_encoded, binary_mask
|
||||||
|
|
||||||
|
def process_hidden_states(
|
||||||
|
self,
|
||||||
|
hidden_states: tuple[torch.Tensor, ...],
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
padding_side: str = "left",
|
||||||
|
):
|
||||||
|
video_feats, audio_feats = self.feature_extractor_linear(hidden_states, attention_mask, padding_side)
|
||||||
|
additive_mask = _convert_to_additive_mask(attention_mask, video_feats.dtype)
|
||||||
|
video_enc, audio_enc, binary_mask = self.create_embeddings(video_feats, audio_feats, additive_mask)
|
||||||
|
return video_enc, audio_enc, binary_mask
|
||||||
|
|
||||||
|
|
||||||
|
def _norm_and_concat_padded_batch(
|
||||||
|
encoded_text: torch.Tensor,
|
||||||
|
sequence_lengths: torch.Tensor,
|
||||||
|
padding_side: str = "right",
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Normalize and flatten multi-layer hidden states, respecting padding.
|
||||||
|
Performs per-batch, per-layer normalization using masked mean and range,
|
||||||
|
then concatenates across the layer dimension.
|
||||||
|
Args:
|
||||||
|
encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].
|
||||||
|
sequence_lengths: Number of valid (non-padded) tokens per batch item.
|
||||||
|
padding_side: Whether padding is on "left" or "right".
|
||||||
|
Returns:
|
||||||
|
Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],
|
||||||
|
with padded positions zeroed out.
|
||||||
|
"""
|
||||||
|
b, t, d, l = encoded_text.shape # noqa: E741
|
||||||
|
device = encoded_text.device
|
||||||
|
# Build mask: [B, T, 1, 1]
|
||||||
|
token_indices = torch.arange(t, device=device)[None, :] # [1, T]
|
||||||
|
if padding_side == "right":
|
||||||
|
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||||
|
mask = token_indices < sequence_lengths[:, None] # [B, T]
|
||||||
|
elif padding_side == "left":
|
||||||
|
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||||
|
start_indices = t - sequence_lengths[:, None] # [B, 1]
|
||||||
|
mask = token_indices >= start_indices # [B, T]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||||
|
mask = rearrange(mask, "b t -> b t 1 1")
|
||||||
|
eps = 1e-6
|
||||||
|
# Compute masked mean: [B, 1, 1, L]
|
||||||
|
masked = encoded_text.masked_fill(~mask, 0.0)
|
||||||
|
denom = (sequence_lengths * d).view(b, 1, 1, 1)
|
||||||
|
mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps)
|
||||||
|
# Compute masked min/max: [B, 1, 1, L]
|
||||||
|
x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||||
|
x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||||
|
range_ = x_max - x_min
|
||||||
|
# Normalize only the valid tokens
|
||||||
|
normed = 8 * (encoded_text - mean) / (range_ + eps)
|
||||||
|
# concat to be [Batch, T, D * L] - this preserves the original structure
|
||||||
|
normed = normed.reshape(b, t, -1) # [B, T, D * L]
|
||||||
|
# Apply mask to preserve original padding (set padded positions to 0)
|
||||||
|
mask_flattened = rearrange(mask, "b t 1 1 -> b t 1").expand(-1, -1, d * l)
|
||||||
|
normed = normed.masked_fill(~mask_flattened, 0.0)
|
||||||
|
|
||||||
|
return normed
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_additive_mask(attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||||
|
return (attention_mask - 1).to(dtype).reshape(
|
||||||
|
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(dtype).max
|
||||||
|
|
||||||
|
def _to_binary_mask(encoded: torch.Tensor, encoded_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Convert connector output mask to binary mask and apply to encoded tensor."""
|
||||||
|
binary_mask = (encoded_mask < 0.000001).to(torch.int64)
|
||||||
|
binary_mask = binary_mask.reshape([encoded.shape[0], encoded.shape[1], 1])
|
||||||
|
encoded = encoded * binary_mask
|
||||||
|
return encoded, binary_mask
|
||||||
|
|
||||||
|
|
||||||
|
def norm_and_concat_per_token_rms(
|
||||||
|
encoded_text: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Per-token RMSNorm normalization for V2 models.
|
||||||
|
Args:
|
||||||
|
encoded_text: [B, T, D, L]
|
||||||
|
attention_mask: [B, T] binary mask
|
||||||
|
Returns:
|
||||||
|
[B, T, D*L] normalized tensor with padding zeroed out.
|
||||||
|
"""
|
||||||
|
B, T, D, L = encoded_text.shape # noqa: N806
|
||||||
|
variance = torch.mean(encoded_text**2, dim=2, keepdim=True) # [B,T,1,L]
|
||||||
|
normed = encoded_text * torch.rsqrt(variance + 1e-6)
|
||||||
|
normed = normed.reshape(B, T, D * L)
|
||||||
|
mask_3d = attention_mask.bool().unsqueeze(-1) # [B, T, 1]
|
||||||
|
return torch.where(mask_3d, normed, torch.zeros_like(normed))
|
||||||
|
|
||||||
|
|
||||||
|
def _rescale_norm(x: torch.Tensor, target_dim: int, source_dim: int) -> torch.Tensor:
|
||||||
|
"""Rescale normalization: x * sqrt(target_dim / source_dim)."""
|
||||||
|
return x * math.sqrt(target_dim / source_dim)
|
||||||
|
|||||||
@@ -555,9 +555,6 @@ class PerChannelStatistics(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.register_buffer("std-of-means", torch.empty(latent_channels))
|
self.register_buffer("std-of-means", torch.empty(latent_channels))
|
||||||
self.register_buffer("mean-of-means", torch.empty(latent_channels))
|
self.register_buffer("mean-of-means", torch.empty(latent_channels))
|
||||||
self.register_buffer("mean-of-stds", torch.empty(latent_channels))
|
|
||||||
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(latent_channels))
|
|
||||||
self.register_buffer("channel", torch.empty(latent_channels))
|
|
||||||
|
|
||||||
def un_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
def un_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(
|
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(
|
||||||
@@ -1335,27 +1332,34 @@ class LTX2VideoEncoder(nn.Module):
|
|||||||
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
||||||
latent_log_var: LogVarianceType = LogVarianceType.UNIFORM,
|
latent_log_var: LogVarianceType = LogVarianceType.UNIFORM,
|
||||||
encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||||
|
encoder_version: str = "ltx-2",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
encoder_blocks = [['res_x', {
|
if encoder_version == "ltx-2":
|
||||||
'num_layers': 4
|
encoder_blocks = [
|
||||||
}], ['compress_space_res', {
|
['res_x', {'num_layers': 4}],
|
||||||
'multiplier': 2
|
['compress_space_res', {'multiplier': 2}],
|
||||||
}], ['res_x', {
|
['res_x', {'num_layers': 6}],
|
||||||
'num_layers': 6
|
['compress_time_res', {'multiplier': 2}],
|
||||||
}], ['compress_time_res', {
|
['res_x', {'num_layers': 6}],
|
||||||
'multiplier': 2
|
['compress_all_res', {'multiplier': 2}],
|
||||||
}], ['res_x', {
|
['res_x', {'num_layers': 2}],
|
||||||
'num_layers': 6
|
['compress_all_res', {'multiplier': 2}],
|
||||||
}], ['compress_all_res', {
|
['res_x', {'num_layers': 2}]
|
||||||
'multiplier': 2
|
]
|
||||||
}], ['res_x', {
|
else:
|
||||||
'num_layers': 2
|
# LTX-2.3
|
||||||
}], ['compress_all_res', {
|
encoder_blocks = [
|
||||||
'multiplier': 2
|
["res_x", {"num_layers": 4}],
|
||||||
}], ['res_x', {
|
["compress_space_res", {"multiplier": 2}],
|
||||||
'num_layers': 2
|
["res_x", {"num_layers": 6}],
|
||||||
}]]
|
["compress_time_res", {"multiplier": 2}],
|
||||||
|
["res_x", {"num_layers": 4}],
|
||||||
|
["compress_all_res", {"multiplier": 2}],
|
||||||
|
["res_x", {"num_layers": 2}],
|
||||||
|
["compress_all_res", {"multiplier": 1}],
|
||||||
|
["res_x", {"num_layers": 2}]
|
||||||
|
]
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
self.norm_layer = norm_layer
|
self.norm_layer = norm_layer
|
||||||
self.latent_channels = out_channels
|
self.latent_channels = out_channels
|
||||||
@@ -1435,8 +1439,8 @@ class LTX2VideoEncoder(nn.Module):
|
|||||||
# Validate frame count
|
# Validate frame count
|
||||||
frames_count = sample.shape[2]
|
frames_count = sample.shape[2]
|
||||||
if ((frames_count - 1) % 8) != 0:
|
if ((frames_count - 1) % 8) != 0:
|
||||||
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames "
|
frames_to_crop = (frames_count - 1) % 8
|
||||||
"(e.g., 1, 9, 17, ...). Please check your input.")
|
sample = sample[:, :, :-frames_to_crop, ...]
|
||||||
|
|
||||||
# Initial spatial compression: trade spatial resolution for channel depth
|
# Initial spatial compression: trade spatial resolution for channel depth
|
||||||
# This reduces H,W by patch_size and increases channels, making convolutions more efficient
|
# This reduces H,W by patch_size and increases channels, making convolutions more efficient
|
||||||
@@ -1712,17 +1716,21 @@ def _make_decoder_block(
|
|||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_time":
|
elif block_name == "compress_time":
|
||||||
|
out_channels = in_channels // block_config.get("multiplier", 1)
|
||||||
block = DepthToSpaceUpsample(
|
block = DepthToSpaceUpsample(
|
||||||
dims=convolution_dimensions,
|
dims=convolution_dimensions,
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
stride=(2, 1, 1),
|
stride=(2, 1, 1),
|
||||||
|
out_channels_reduction_factor=block_config.get("multiplier", 1),
|
||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_space":
|
elif block_name == "compress_space":
|
||||||
|
out_channels = in_channels // block_config.get("multiplier", 1)
|
||||||
block = DepthToSpaceUpsample(
|
block = DepthToSpaceUpsample(
|
||||||
dims=convolution_dimensions,
|
dims=convolution_dimensions,
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
stride=(1, 2, 2),
|
stride=(1, 2, 2),
|
||||||
|
out_channels_reduction_factor=block_config.get("multiplier", 1),
|
||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_all":
|
elif block_name == "compress_all":
|
||||||
@@ -1782,6 +1790,8 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
causal: bool = False,
|
causal: bool = False,
|
||||||
timestep_conditioning: bool = False,
|
timestep_conditioning: bool = False,
|
||||||
decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||||
|
decoder_version: str = "ltx-2",
|
||||||
|
base_channels: int = 128,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -1790,28 +1800,29 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
# video inputs by a factor of 8 in the temporal dimension and 32 in
|
# video inputs by a factor of 8 in the temporal dimension and 32 in
|
||||||
# each spatial dimension (height and width). This parameter determines how
|
# each spatial dimension (height and width). This parameter determines how
|
||||||
# many video frames and pixels correspond to a single latent cell.
|
# many video frames and pixels correspond to a single latent cell.
|
||||||
decoder_blocks = [['res_x', {
|
if decoder_version == "ltx-2":
|
||||||
'num_layers': 5,
|
decoder_blocks = [
|
||||||
'inject_noise': False
|
['res_x', {'num_layers': 5, 'inject_noise': False}],
|
||||||
}], ['compress_all', {
|
['compress_all', {'residual': True, 'multiplier': 2}],
|
||||||
'residual': True,
|
['res_x', {'num_layers': 5, 'inject_noise': False}],
|
||||||
'multiplier': 2
|
['compress_all', {'residual': True, 'multiplier': 2}],
|
||||||
}], ['res_x', {
|
['res_x', {'num_layers': 5, 'inject_noise': False}],
|
||||||
'num_layers': 5,
|
['compress_all', {'residual': True, 'multiplier': 2}],
|
||||||
'inject_noise': False
|
['res_x', {'num_layers': 5, 'inject_noise': False}]
|
||||||
}], ['compress_all', {
|
]
|
||||||
'residual': True,
|
else:
|
||||||
'multiplier': 2
|
# LTX-2.3
|
||||||
}], ['res_x', {
|
decoder_blocks = [
|
||||||
'num_layers': 5,
|
["res_x", {"num_layers": 4}],
|
||||||
'inject_noise': False
|
["compress_space", {"multiplier": 2}],
|
||||||
}], ['compress_all', {
|
["res_x", {"num_layers": 6}],
|
||||||
'residual': True,
|
["compress_time", {"multiplier": 2}],
|
||||||
'multiplier': 2
|
["res_x", {"num_layers": 4}],
|
||||||
}], ['res_x', {
|
["compress_all", {"multiplier": 1}],
|
||||||
'num_layers': 5,
|
["res_x", {"num_layers": 2}],
|
||||||
'inject_noise': False
|
["compress_all", {"multiplier": 2}],
|
||||||
}]]
|
["res_x", {"num_layers": 2}]
|
||||||
|
]
|
||||||
self.video_downscale_factors = SpatioTemporalScaleFactors(
|
self.video_downscale_factors = SpatioTemporalScaleFactors(
|
||||||
time=8,
|
time=8,
|
||||||
width=32,
|
width=32,
|
||||||
@@ -1831,15 +1842,9 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
self.decode_noise_scale = 0.025
|
self.decode_noise_scale = 0.025
|
||||||
self.decode_timestep = 0.05
|
self.decode_timestep = 0.05
|
||||||
|
|
||||||
# Compute initial feature_channels by going through blocks in reverse
|
# LTX VAE decoder architecture uses 3 upsampler blocks with multiplier equals to 2.
|
||||||
# This determines the channel width at the start of the decoder
|
# Hence the total feature_channels is multiplied by 8 (2^3).
|
||||||
feature_channels = in_channels
|
feature_channels = base_channels * 8
|
||||||
for block_name, block_params in list(reversed(decoder_blocks)):
|
|
||||||
block_config = block_params if isinstance(block_params, dict) else {}
|
|
||||||
if block_name == "res_x_y":
|
|
||||||
feature_channels = feature_channels * block_config.get("multiplier", 2)
|
|
||||||
if block_name == "compress_all":
|
|
||||||
feature_channels = feature_channels * block_config.get("multiplier", 1)
|
|
||||||
|
|
||||||
self.conv_in = make_conv_nd(
|
self.conv_in = make_conv_nd(
|
||||||
dims=convolution_dimensions,
|
dims=convolution_dimensions,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from ..core.loader import load_model, hash_model_file
|
from ..core.loader import load_model, hash_model_file
|
||||||
from ..core.vram import AutoWrappedModule
|
from ..core.vram import AutoWrappedModule
|
||||||
from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS
|
from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS, VERSION_CHECKER_MAPS
|
||||||
import importlib, json, torch
|
import importlib, json, torch
|
||||||
|
|
||||||
|
|
||||||
@@ -22,7 +22,8 @@ class ModelPool:
|
|||||||
def fetch_module_map(self, model_class, vram_config):
|
def fetch_module_map(self, model_class, vram_config):
|
||||||
if self.need_to_enable_vram_management(vram_config):
|
if self.need_to_enable_vram_management(vram_config):
|
||||||
if model_class in VRAM_MANAGEMENT_MODULE_MAPS:
|
if model_class in VRAM_MANAGEMENT_MODULE_MAPS:
|
||||||
module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in VRAM_MANAGEMENT_MODULE_MAPS[model_class].items()}
|
vram_module_map = VRAM_MANAGEMENT_MODULE_MAPS[model_class] if model_class not in VERSION_CHECKER_MAPS else VERSION_CHECKER_MAPS[model_class]()
|
||||||
|
module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in vram_module_map.items()}
|
||||||
else:
|
else:
|
||||||
module_map = {self.import_model_class(model_class): AutoWrappedModule}
|
module_map = {self.import_model_class(model_class): AutoWrappedModule}
|
||||||
else:
|
else:
|
||||||
|
|||||||
57
diffsynth/models/mova_audio_dit.py
Normal file
57
diffsynth/models/mova_audio_dit.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from .wan_video_dit import WanModel, precompute_freqs_cis, sinusoidal_embedding_1d
|
||||||
|
from einops import rearrange
|
||||||
|
from ..core import gradient_checkpoint_forward
|
||||||
|
|
||||||
|
def precompute_freqs_cis_1d(dim: int, end: int = 16384, theta: float = 10000.0):
|
||||||
|
f_freqs_cis = precompute_freqs_cis(dim, end, theta)
|
||||||
|
return f_freqs_cis.chunk(3, dim=-1)
|
||||||
|
|
||||||
|
class MovaAudioDit(WanModel):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
head_dim = kwargs.get("dim", 1536) // kwargs.get("num_heads", 12)
|
||||||
|
self.freqs = precompute_freqs_cis_1d(head_dim)
|
||||||
|
self.patch_embedding = nn.Conv1d(
|
||||||
|
kwargs.get("in_dim", 128), kwargs.get("dim", 1536), kernel_size=[1], stride=[1]
|
||||||
|
)
|
||||||
|
|
||||||
|
def precompute_freqs_cis(self, dim: int, end: int = 16384, theta: float = 10000.0):
|
||||||
|
self.f_freqs_cis = precompute_freqs_cis_1d(dim, end, theta)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timestep: torch.Tensor,
|
||||||
|
context: torch.Tensor,
|
||||||
|
use_gradient_checkpointing: bool = False,
|
||||||
|
use_gradient_checkpointing_offload: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||||
|
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
||||||
|
context = self.text_embedding(context)
|
||||||
|
x, (f, ) = self.patchify(x)
|
||||||
|
freqs = torch.cat([
|
||||||
|
self.freqs[0][:f].view(f, -1).expand(f, -1),
|
||||||
|
self.freqs[1][:f].view(f, -1).expand(f, -1),
|
||||||
|
self.freqs[2][:f].view(f, -1).expand(f, -1),
|
||||||
|
], dim=-1).reshape(f, 1, -1).to(x.device)
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
x = gradient_checkpoint_forward(
|
||||||
|
block,
|
||||||
|
use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload,
|
||||||
|
x, context, t_mod, freqs,
|
||||||
|
)
|
||||||
|
x = self.head(x, t)
|
||||||
|
x = self.unpatchify(x, (f, ))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
|
||||||
|
return rearrange(
|
||||||
|
x, 'b f (p c) -> b c (f p)',
|
||||||
|
f=grid_size[0],
|
||||||
|
p=self.patch_size[0]
|
||||||
|
)
|
||||||
796
diffsynth/models/mova_audio_vae.py
Normal file
796
diffsynth/models/mova_audio_vae.py
Normal file
@@ -0,0 +1,796 @@
|
|||||||
|
import math
|
||||||
|
from typing import List, Union
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn.utils import weight_norm
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
def WNConv1d(*args, **kwargs):
|
||||||
|
return weight_norm(nn.Conv1d(*args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
def WNConvTranspose1d(*args, **kwargs):
|
||||||
|
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
# Scripting this brings model speed up 1.4x
|
||||||
|
@torch.jit.script
|
||||||
|
def snake(x, alpha):
|
||||||
|
shape = x.shape
|
||||||
|
x = x.reshape(shape[0], shape[1], -1)
|
||||||
|
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
||||||
|
x = x.reshape(shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Snake1d(nn.Module):
|
||||||
|
def __init__(self, channels):
|
||||||
|
super().__init__()
|
||||||
|
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return snake(x, self.alpha)
|
||||||
|
|
||||||
|
|
||||||
|
class VectorQuantize(nn.Module):
|
||||||
|
"""
|
||||||
|
Implementation of VQ similar to Karpathy's repo:
|
||||||
|
https://github.com/karpathy/deep-vector-quantization
|
||||||
|
Additionally uses following tricks from Improved VQGAN
|
||||||
|
(https://arxiv.org/pdf/2110.04627.pdf):
|
||||||
|
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
||||||
|
for improved codebook usage
|
||||||
|
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
||||||
|
improves training stability
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.codebook_size = codebook_size
|
||||||
|
self.codebook_dim = codebook_dim
|
||||||
|
|
||||||
|
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
|
||||||
|
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
|
||||||
|
self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
||||||
|
|
||||||
|
def forward(self, z):
|
||||||
|
"""Quantized the input tensor using a fixed codebook and returns
|
||||||
|
the corresponding codebook vectors
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
z : Tensor[B x D x T]
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor[B x D x T]
|
||||||
|
Quantized continuous representation of input
|
||||||
|
Tensor[1]
|
||||||
|
Commitment loss to train encoder to predict vectors closer to codebook
|
||||||
|
entries
|
||||||
|
Tensor[1]
|
||||||
|
Codebook loss to update the codebook
|
||||||
|
Tensor[B x T]
|
||||||
|
Codebook indices (quantized discrete representation of input)
|
||||||
|
Tensor[B x D x T]
|
||||||
|
Projected latents (continuous representation of input before quantization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
|
||||||
|
z_e = self.in_proj(z) # z_e : (B x D x T)
|
||||||
|
z_q, indices = self.decode_latents(z_e)
|
||||||
|
|
||||||
|
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
||||||
|
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
||||||
|
|
||||||
|
z_q = (
|
||||||
|
z_e + (z_q - z_e).detach()
|
||||||
|
) # noop in forward pass, straight-through gradient estimator in backward pass
|
||||||
|
|
||||||
|
z_q = self.out_proj(z_q)
|
||||||
|
|
||||||
|
return z_q, commitment_loss, codebook_loss, indices, z_e
|
||||||
|
|
||||||
|
def embed_code(self, embed_id):
|
||||||
|
return F.embedding(embed_id, self.codebook.weight)
|
||||||
|
|
||||||
|
def decode_code(self, embed_id):
|
||||||
|
return self.embed_code(embed_id).transpose(1, 2)
|
||||||
|
|
||||||
|
def decode_latents(self, latents):
|
||||||
|
encodings = rearrange(latents, "b d t -> (b t) d")
|
||||||
|
codebook = self.codebook.weight # codebook: (N x D)
|
||||||
|
|
||||||
|
# L2 normalize encodings and codebook (ViT-VQGAN)
|
||||||
|
encodings = F.normalize(encodings)
|
||||||
|
codebook = F.normalize(codebook)
|
||||||
|
|
||||||
|
# Compute euclidean distance with codebook
|
||||||
|
dist = (
|
||||||
|
encodings.pow(2).sum(1, keepdim=True)
|
||||||
|
- 2 * encodings @ codebook.t()
|
||||||
|
+ codebook.pow(2).sum(1, keepdim=True).t()
|
||||||
|
)
|
||||||
|
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
||||||
|
z_q = self.decode_code(indices)
|
||||||
|
return z_q, indices
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualVectorQuantize(nn.Module):
|
||||||
|
"""
|
||||||
|
Introduced in SoundStream: An end2end neural audio codec
|
||||||
|
https://arxiv.org/abs/2107.03312
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int = 512,
|
||||||
|
n_codebooks: int = 9,
|
||||||
|
codebook_size: int = 1024,
|
||||||
|
codebook_dim: Union[int, list] = 8,
|
||||||
|
quantizer_dropout: float = 0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if isinstance(codebook_dim, int):
|
||||||
|
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
||||||
|
|
||||||
|
self.n_codebooks = n_codebooks
|
||||||
|
self.codebook_dim = codebook_dim
|
||||||
|
self.codebook_size = codebook_size
|
||||||
|
|
||||||
|
self.quantizers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
VectorQuantize(input_dim, codebook_size, codebook_dim[i])
|
||||||
|
for i in range(n_codebooks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.quantizer_dropout = quantizer_dropout
|
||||||
|
|
||||||
|
def forward(self, z, n_quantizers: int = None):
|
||||||
|
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
|
||||||
|
the corresponding codebook vectors
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
z : Tensor[B x D x T]
|
||||||
|
n_quantizers : int, optional
|
||||||
|
No. of quantizers to use
|
||||||
|
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
||||||
|
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
||||||
|
when in training mode, and a random number of quantizers is used.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict
|
||||||
|
A dictionary with the following keys:
|
||||||
|
|
||||||
|
"z" : Tensor[B x D x T]
|
||||||
|
Quantized continuous representation of input
|
||||||
|
"codes" : Tensor[B x N x T]
|
||||||
|
Codebook indices for each codebook
|
||||||
|
(quantized discrete representation of input)
|
||||||
|
"latents" : Tensor[B x N*D x T]
|
||||||
|
Projected latents (continuous representation of input before quantization)
|
||||||
|
"vq/commitment_loss" : Tensor[1]
|
||||||
|
Commitment loss to train encoder to predict vectors closer to codebook
|
||||||
|
entries
|
||||||
|
"vq/codebook_loss" : Tensor[1]
|
||||||
|
Codebook loss to update the codebook
|
||||||
|
"""
|
||||||
|
z_q = 0
|
||||||
|
residual = z
|
||||||
|
commitment_loss = 0
|
||||||
|
codebook_loss = 0
|
||||||
|
|
||||||
|
codebook_indices = []
|
||||||
|
latents = []
|
||||||
|
|
||||||
|
if n_quantizers is None:
|
||||||
|
n_quantizers = self.n_codebooks
|
||||||
|
if self.training:
|
||||||
|
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
|
||||||
|
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
|
||||||
|
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
||||||
|
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
||||||
|
n_quantizers = n_quantizers.to(z.device)
|
||||||
|
|
||||||
|
for i, quantizer in enumerate(self.quantizers):
|
||||||
|
if self.training is False and i >= n_quantizers:
|
||||||
|
break
|
||||||
|
|
||||||
|
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
||||||
|
residual
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create mask to apply quantizer dropout
|
||||||
|
mask = (
|
||||||
|
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
||||||
|
)
|
||||||
|
z_q = z_q + z_q_i * mask[:, None, None]
|
||||||
|
residual = residual - z_q_i
|
||||||
|
|
||||||
|
# Sum losses
|
||||||
|
commitment_loss += (commitment_loss_i * mask).mean()
|
||||||
|
codebook_loss += (codebook_loss_i * mask).mean()
|
||||||
|
|
||||||
|
codebook_indices.append(indices_i)
|
||||||
|
latents.append(z_e_i)
|
||||||
|
|
||||||
|
codes = torch.stack(codebook_indices, dim=1)
|
||||||
|
latents = torch.cat(latents, dim=1)
|
||||||
|
|
||||||
|
return z_q, codes, latents, commitment_loss, codebook_loss
|
||||||
|
|
||||||
|
def from_codes(self, codes: torch.Tensor):
|
||||||
|
"""Given the quantized codes, reconstruct the continuous representation
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
codes : Tensor[B x N x T]
|
||||||
|
Quantized discrete representation of input
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor[B x D x T]
|
||||||
|
Quantized continuous representation of input
|
||||||
|
"""
|
||||||
|
z_q = 0.0
|
||||||
|
z_p = []
|
||||||
|
n_codebooks = codes.shape[1]
|
||||||
|
for i in range(n_codebooks):
|
||||||
|
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
||||||
|
z_p.append(z_p_i)
|
||||||
|
|
||||||
|
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
||||||
|
z_q = z_q + z_q_i
|
||||||
|
return z_q, torch.cat(z_p, dim=1), codes
|
||||||
|
|
||||||
|
def from_latents(self, latents: torch.Tensor):
|
||||||
|
"""Given the unquantized latents, reconstruct the
|
||||||
|
continuous representation after quantization.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
latents : Tensor[B x N x T]
|
||||||
|
Continuous representation of input after projection
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor[B x D x T]
|
||||||
|
Quantized representation of full-projected space
|
||||||
|
Tensor[B x D x T]
|
||||||
|
Quantized representation of latent space
|
||||||
|
"""
|
||||||
|
z_q = 0
|
||||||
|
z_p = []
|
||||||
|
codes = []
|
||||||
|
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
||||||
|
|
||||||
|
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
|
||||||
|
0
|
||||||
|
]
|
||||||
|
for i in range(n_codebooks):
|
||||||
|
j, k = dims[i], dims[i + 1]
|
||||||
|
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
||||||
|
z_p.append(z_p_i)
|
||||||
|
codes.append(codes_i)
|
||||||
|
|
||||||
|
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
||||||
|
z_q = z_q + z_q_i
|
||||||
|
|
||||||
|
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractDistribution:
|
||||||
|
def sample(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def mode(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class DiracDistribution(AbstractDistribution):
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def sample(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
def mode(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
class DiagonalGaussianDistribution(object):
|
||||||
|
def __init__(self, parameters, deterministic=False):
|
||||||
|
self.parameters = parameters
|
||||||
|
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||||
|
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||||
|
self.deterministic = deterministic
|
||||||
|
self.std = torch.exp(0.5 * self.logvar)
|
||||||
|
self.var = torch.exp(self.logvar)
|
||||||
|
if self.deterministic:
|
||||||
|
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
||||||
|
|
||||||
|
def sample(self):
|
||||||
|
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def kl(self, other=None):
|
||||||
|
if self.deterministic:
|
||||||
|
return torch.Tensor([0.0])
|
||||||
|
else:
|
||||||
|
if other is None:
|
||||||
|
return 0.5 * torch.mean(
|
||||||
|
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
||||||
|
dim=[1, 2],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return 0.5 * torch.mean(
|
||||||
|
torch.pow(self.mean - other.mean, 2) / other.var
|
||||||
|
+ self.var / other.var
|
||||||
|
- 1.0
|
||||||
|
- self.logvar
|
||||||
|
+ other.logvar,
|
||||||
|
dim=[1, 2],
|
||||||
|
)
|
||||||
|
|
||||||
|
def nll(self, sample, dims=[1, 2]):
|
||||||
|
if self.deterministic:
|
||||||
|
return torch.Tensor([0.0])
|
||||||
|
logtwopi = np.log(2.0 * np.pi)
|
||||||
|
return 0.5 * torch.sum(
|
||||||
|
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||||
|
dim=dims,
|
||||||
|
)
|
||||||
|
|
||||||
|
def mode(self):
|
||||||
|
return self.mean
|
||||||
|
|
||||||
|
|
||||||
|
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||||
|
"""
|
||||||
|
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
||||||
|
Compute the KL divergence between two gaussians.
|
||||||
|
Shapes are automatically broadcasted, so batches can be compared to
|
||||||
|
scalars, among other use cases.
|
||||||
|
"""
|
||||||
|
tensor = None
|
||||||
|
for obj in (mean1, logvar1, mean2, logvar2):
|
||||||
|
if isinstance(obj, torch.Tensor):
|
||||||
|
tensor = obj
|
||||||
|
break
|
||||||
|
assert tensor is not None, "at least one argument must be a Tensor"
|
||||||
|
|
||||||
|
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||||
|
# Tensors, but it does not work for torch.exp().
|
||||||
|
logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
|
||||||
|
|
||||||
|
return 0.5 * (
|
||||||
|
-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def init_weights(m):
|
||||||
|
if isinstance(m, nn.Conv1d):
|
||||||
|
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualUnit(nn.Module):
|
||||||
|
def __init__(self, dim: int = 16, dilation: int = 1):
|
||||||
|
super().__init__()
|
||||||
|
pad = ((7 - 1) * dilation) // 2
|
||||||
|
self.block = nn.Sequential(
|
||||||
|
Snake1d(dim),
|
||||||
|
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
||||||
|
Snake1d(dim),
|
||||||
|
WNConv1d(dim, dim, kernel_size=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = self.block(x)
|
||||||
|
pad = (x.shape[-1] - y.shape[-1]) // 2
|
||||||
|
if pad > 0:
|
||||||
|
x = x[..., pad:-pad]
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderBlock(nn.Module):
|
||||||
|
def __init__(self, dim: int = 16, stride: int = 1):
|
||||||
|
super().__init__()
|
||||||
|
self.block = nn.Sequential(
|
||||||
|
ResidualUnit(dim // 2, dilation=1),
|
||||||
|
ResidualUnit(dim // 2, dilation=3),
|
||||||
|
ResidualUnit(dim // 2, dilation=9),
|
||||||
|
Snake1d(dim // 2),
|
||||||
|
WNConv1d(
|
||||||
|
dim // 2,
|
||||||
|
dim,
|
||||||
|
kernel_size=2 * stride,
|
||||||
|
stride=stride,
|
||||||
|
padding=math.ceil(stride / 2),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.block(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int = 64,
|
||||||
|
strides: list = [2, 4, 8, 8],
|
||||||
|
d_latent: int = 64,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# Create first convolution
|
||||||
|
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
|
||||||
|
|
||||||
|
# Create EncoderBlocks that double channels as they downsample by `stride`
|
||||||
|
for stride in strides:
|
||||||
|
d_model *= 2
|
||||||
|
self.block += [EncoderBlock(d_model, stride=stride)]
|
||||||
|
|
||||||
|
# Create last convolution
|
||||||
|
self.block += [
|
||||||
|
Snake1d(d_model),
|
||||||
|
WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Wrap black into nn.Sequential
|
||||||
|
self.block = nn.Sequential(*self.block)
|
||||||
|
self.enc_dim = d_model
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.block(x)
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderBlock(nn.Module):
|
||||||
|
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
|
||||||
|
super().__init__()
|
||||||
|
self.block = nn.Sequential(
|
||||||
|
Snake1d(input_dim),
|
||||||
|
WNConvTranspose1d(
|
||||||
|
input_dim,
|
||||||
|
output_dim,
|
||||||
|
kernel_size=2 * stride,
|
||||||
|
stride=stride,
|
||||||
|
padding=math.ceil(stride / 2),
|
||||||
|
output_padding=stride % 2,
|
||||||
|
),
|
||||||
|
ResidualUnit(output_dim, dilation=1),
|
||||||
|
ResidualUnit(output_dim, dilation=3),
|
||||||
|
ResidualUnit(output_dim, dilation=9),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.block(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_channel,
|
||||||
|
channels,
|
||||||
|
rates,
|
||||||
|
d_out: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Add first conv layer
|
||||||
|
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
||||||
|
|
||||||
|
# Add upsampling + MRF blocks
|
||||||
|
for i, stride in enumerate(rates):
|
||||||
|
input_dim = channels // 2**i
|
||||||
|
output_dim = channels // 2 ** (i + 1)
|
||||||
|
layers += [DecoderBlock(input_dim, output_dim, stride)]
|
||||||
|
|
||||||
|
# Add final conv layer
|
||||||
|
layers += [
|
||||||
|
Snake1d(output_dim),
|
||||||
|
WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
||||||
|
nn.Tanh(),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.model = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.model(x)
|
||||||
|
|
||||||
|
|
||||||
|
class DacVAE(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder_dim: int = 128,
|
||||||
|
encoder_rates: List[int] = [2, 3, 4, 5, 8],
|
||||||
|
latent_dim: int = 128,
|
||||||
|
decoder_dim: int = 2048,
|
||||||
|
decoder_rates: List[int] = [8, 5, 4, 3, 2],
|
||||||
|
n_codebooks: int = 9,
|
||||||
|
codebook_size: int = 1024,
|
||||||
|
codebook_dim: Union[int, list] = 8,
|
||||||
|
quantizer_dropout: bool = False,
|
||||||
|
sample_rate: int = 48000,
|
||||||
|
continuous: bool = True,
|
||||||
|
use_weight_norm: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.encoder_dim = encoder_dim
|
||||||
|
self.encoder_rates = encoder_rates
|
||||||
|
self.decoder_dim = decoder_dim
|
||||||
|
self.decoder_rates = decoder_rates
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.continuous = continuous
|
||||||
|
self.use_weight_norm = use_weight_norm
|
||||||
|
|
||||||
|
if latent_dim is None:
|
||||||
|
latent_dim = encoder_dim * (2 ** len(encoder_rates))
|
||||||
|
|
||||||
|
self.latent_dim = latent_dim
|
||||||
|
|
||||||
|
self.hop_length = np.prod(encoder_rates)
|
||||||
|
self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
|
||||||
|
|
||||||
|
if not continuous:
|
||||||
|
self.n_codebooks = n_codebooks
|
||||||
|
self.codebook_size = codebook_size
|
||||||
|
self.codebook_dim = codebook_dim
|
||||||
|
self.quantizer = ResidualVectorQuantize(
|
||||||
|
input_dim=latent_dim,
|
||||||
|
n_codebooks=n_codebooks,
|
||||||
|
codebook_size=codebook_size,
|
||||||
|
codebook_dim=codebook_dim,
|
||||||
|
quantizer_dropout=quantizer_dropout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1)
|
||||||
|
self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1)
|
||||||
|
|
||||||
|
self.decoder = Decoder(
|
||||||
|
latent_dim,
|
||||||
|
decoder_dim,
|
||||||
|
decoder_rates,
|
||||||
|
)
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.apply(init_weights)
|
||||||
|
|
||||||
|
self.delay = self.get_delay()
|
||||||
|
|
||||||
|
if not self.use_weight_norm:
|
||||||
|
self.remove_weight_norm()
|
||||||
|
|
||||||
|
def get_delay(self):
|
||||||
|
# Any number works here, delay is invariant to input length
|
||||||
|
l_out = self.get_output_length(0)
|
||||||
|
L = l_out
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
for layer in self.modules():
|
||||||
|
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
|
||||||
|
layers.append(layer)
|
||||||
|
|
||||||
|
for layer in reversed(layers):
|
||||||
|
d = layer.dilation[0]
|
||||||
|
k = layer.kernel_size[0]
|
||||||
|
s = layer.stride[0]
|
||||||
|
|
||||||
|
if isinstance(layer, nn.ConvTranspose1d):
|
||||||
|
L = ((L - d * (k - 1) - 1) / s) + 1
|
||||||
|
elif isinstance(layer, nn.Conv1d):
|
||||||
|
L = (L - 1) * s + d * (k - 1) + 1
|
||||||
|
|
||||||
|
L = math.ceil(L)
|
||||||
|
|
||||||
|
l_in = L
|
||||||
|
|
||||||
|
return (l_in - l_out) // 2
|
||||||
|
|
||||||
|
def get_output_length(self, input_length):
|
||||||
|
L = input_length
|
||||||
|
# Calculate output length
|
||||||
|
for layer in self.modules():
|
||||||
|
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
|
||||||
|
d = layer.dilation[0]
|
||||||
|
k = layer.kernel_size[0]
|
||||||
|
s = layer.stride[0]
|
||||||
|
|
||||||
|
if isinstance(layer, nn.Conv1d):
|
||||||
|
L = ((L - d * (k - 1) - 1) / s) + 1
|
||||||
|
elif isinstance(layer, nn.ConvTranspose1d):
|
||||||
|
L = (L - 1) * s + d * (k - 1) + 1
|
||||||
|
|
||||||
|
L = math.floor(L)
|
||||||
|
return L
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
"""Get the dtype of the model parameters."""
|
||||||
|
# Return the dtype of the first parameter found
|
||||||
|
for param in self.parameters():
|
||||||
|
return param.dtype
|
||||||
|
return torch.float32 # fallback
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
"""Get the device of the model parameters."""
|
||||||
|
# Return the device of the first parameter found
|
||||||
|
for param in self.parameters():
|
||||||
|
return param.device
|
||||||
|
return torch.device('cpu') # fallback
|
||||||
|
|
||||||
|
def preprocess(self, audio_data, sample_rate):
|
||||||
|
if sample_rate is None:
|
||||||
|
sample_rate = self.sample_rate
|
||||||
|
assert sample_rate == self.sample_rate
|
||||||
|
|
||||||
|
length = audio_data.shape[-1]
|
||||||
|
right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
|
||||||
|
audio_data = nn.functional.pad(audio_data, (0, right_pad))
|
||||||
|
|
||||||
|
return audio_data
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
audio_data: torch.Tensor,
|
||||||
|
n_quantizers: int = None,
|
||||||
|
):
|
||||||
|
"""Encode given audio data and return quantized latent codes
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
audio_data : Tensor[B x 1 x T]
|
||||||
|
Audio data to encode
|
||||||
|
n_quantizers : int, optional
|
||||||
|
Number of quantizers to use, by default None
|
||||||
|
If None, all quantizers are used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict
|
||||||
|
A dictionary with the following keys:
|
||||||
|
"z" : Tensor[B x D x T]
|
||||||
|
Quantized continuous representation of input
|
||||||
|
"codes" : Tensor[B x N x T]
|
||||||
|
Codebook indices for each codebook
|
||||||
|
(quantized discrete representation of input)
|
||||||
|
"latents" : Tensor[B x N*D x T]
|
||||||
|
Projected latents (continuous representation of input before quantization)
|
||||||
|
"vq/commitment_loss" : Tensor[1]
|
||||||
|
Commitment loss to train encoder to predict vectors closer to codebook
|
||||||
|
entries
|
||||||
|
"vq/codebook_loss" : Tensor[1]
|
||||||
|
Codebook loss to update the codebook
|
||||||
|
"length" : int
|
||||||
|
Number of samples in input audio
|
||||||
|
"""
|
||||||
|
z = self.encoder(audio_data) # [B x D x T]
|
||||||
|
if not self.continuous:
|
||||||
|
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
|
||||||
|
else:
|
||||||
|
z = self.quant_conv(z) # [B x 2D x T]
|
||||||
|
z = DiagonalGaussianDistribution(z)
|
||||||
|
codes, latents, commitment_loss, codebook_loss = None, None, 0, 0
|
||||||
|
|
||||||
|
return z, codes, latents, commitment_loss, codebook_loss
|
||||||
|
|
||||||
|
def decode(self, z: torch.Tensor):
|
||||||
|
"""Decode given latent codes and return audio data
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
z : Tensor[B x D x T]
|
||||||
|
Quantized continuous representation of input
|
||||||
|
length : int, optional
|
||||||
|
Number of samples in output audio, by default None
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict
|
||||||
|
A dictionary with the following keys:
|
||||||
|
"audio" : Tensor[B x 1 x length]
|
||||||
|
Decoded audio data.
|
||||||
|
"""
|
||||||
|
if not self.continuous:
|
||||||
|
audio = self.decoder(z)
|
||||||
|
else:
|
||||||
|
z = self.post_quant_conv(z)
|
||||||
|
audio = self.decoder(z)
|
||||||
|
|
||||||
|
return audio
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
audio_data: torch.Tensor,
|
||||||
|
sample_rate: int = None,
|
||||||
|
n_quantizers: int = None,
|
||||||
|
):
|
||||||
|
"""Model forward pass
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
audio_data : Tensor[B x 1 x T]
|
||||||
|
Audio data to encode
|
||||||
|
sample_rate : int, optional
|
||||||
|
Sample rate of audio data in Hz, by default None
|
||||||
|
If None, defaults to `self.sample_rate`
|
||||||
|
n_quantizers : int, optional
|
||||||
|
Number of quantizers to use, by default None.
|
||||||
|
If None, all quantizers are used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict
|
||||||
|
A dictionary with the following keys:
|
||||||
|
"z" : Tensor[B x D x T]
|
||||||
|
Quantized continuous representation of input
|
||||||
|
"codes" : Tensor[B x N x T]
|
||||||
|
Codebook indices for each codebook
|
||||||
|
(quantized discrete representation of input)
|
||||||
|
"latents" : Tensor[B x N*D x T]
|
||||||
|
Projected latents (continuous representation of input before quantization)
|
||||||
|
"vq/commitment_loss" : Tensor[1]
|
||||||
|
Commitment loss to train encoder to predict vectors closer to codebook
|
||||||
|
entries
|
||||||
|
"vq/codebook_loss" : Tensor[1]
|
||||||
|
Codebook loss to update the codebook
|
||||||
|
"length" : int
|
||||||
|
Number of samples in input audio
|
||||||
|
"audio" : Tensor[B x 1 x length]
|
||||||
|
Decoded audio data.
|
||||||
|
"""
|
||||||
|
length = audio_data.shape[-1]
|
||||||
|
audio_data = self.preprocess(audio_data, sample_rate)
|
||||||
|
if not self.continuous:
|
||||||
|
z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
|
||||||
|
|
||||||
|
x = self.decode(z)
|
||||||
|
return {
|
||||||
|
"audio": x[..., :length],
|
||||||
|
"z": z,
|
||||||
|
"codes": codes,
|
||||||
|
"latents": latents,
|
||||||
|
"vq/commitment_loss": commitment_loss,
|
||||||
|
"vq/codebook_loss": codebook_loss,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
posterior, _, _, _, _ = self.encode(audio_data, n_quantizers)
|
||||||
|
z = posterior.sample()
|
||||||
|
x = self.decode(z)
|
||||||
|
|
||||||
|
kl_loss = posterior.kl()
|
||||||
|
kl_loss = kl_loss.mean()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"audio": x[..., :length],
|
||||||
|
"z": z,
|
||||||
|
"kl_loss": kl_loss,
|
||||||
|
}
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
"""
|
||||||
|
Remove weight_norm from all modules in the model.
|
||||||
|
This fuses the weight_g and weight_v parameters into a single weight parameter.
|
||||||
|
Should be called before inference for better performance.
|
||||||
|
Returns:
|
||||||
|
self: The model with weight_norm removed
|
||||||
|
"""
|
||||||
|
from torch.nn.utils import remove_weight_norm
|
||||||
|
num_removed = 0
|
||||||
|
for name, module in list(self.named_modules()):
|
||||||
|
if hasattr(module, "_forward_pre_hooks"):
|
||||||
|
for hook_id, hook in list(module._forward_pre_hooks.items()):
|
||||||
|
if "WeightNorm" in str(type(hook)):
|
||||||
|
try:
|
||||||
|
remove_weight_norm(module)
|
||||||
|
num_removed += 1
|
||||||
|
# print(f"Removed weight_norm from: {name}")
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"Failed to remove weight_norm from {name}: {e}")
|
||||||
|
if num_removed > 0:
|
||||||
|
# print(f"Successfully removed weight_norm from {num_removed} modules")
|
||||||
|
self.use_weight_norm = False
|
||||||
|
else:
|
||||||
|
print("No weight_norm found in the model")
|
||||||
|
return self
|
||||||
595
diffsynth/models/mova_dual_tower_bridge.py
Normal file
595
diffsynth/models/mova_dual_tower_bridge.py
Normal file
@@ -0,0 +1,595 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import Dict, List, Tuple, Optional
|
||||||
|
from einops import rearrange
|
||||||
|
from .wan_video_dit import AttentionModule, RMSNorm
|
||||||
|
from ..core import gradient_checkpoint_forward
|
||||||
|
|
||||||
|
class RotaryEmbedding(nn.Module):
|
||||||
|
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
||||||
|
|
||||||
|
def __init__(self, base: float, dim: int, device=None):
|
||||||
|
super().__init__()
|
||||||
|
self.base = base
|
||||||
|
self.dim = dim
|
||||||
|
self.attention_scaling = 1.0
|
||||||
|
|
||||||
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, x, position_ids):
|
||||||
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||||
|
position_ids_expanded = position_ids[:, None, :].float()
|
||||||
|
|
||||||
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||||
|
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||||
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
cos = emb.cos() * self.attention_scaling
|
||||||
|
sin = emb.sin() * self.attention_scaling
|
||||||
|
|
||||||
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
"""Rotates half the hidden dims of the input."""
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile(fullgraph=True)
|
||||||
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (`torch.Tensor`): The query tensor.
|
||||||
|
k (`torch.Tensor`): The key tensor.
|
||||||
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||||
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||||
|
position_ids (`torch.Tensor`, *optional*):
|
||||||
|
Deprecated and unused.
|
||||||
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||||
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||||
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||||
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||||
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||||
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||||
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||||
|
Returns:
|
||||||
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||||
|
"""
|
||||||
|
cos = cos.unsqueeze(unsqueeze_dim)
|
||||||
|
sin = sin.unsqueeze(unsqueeze_dim)
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
class PerFrameAttentionPooling(nn.Module):
|
||||||
|
"""
|
||||||
|
Per-frame multi-head attention pooling.
|
||||||
|
|
||||||
|
Given a flattened token sequence [B, L, D] and grid size (T, H, W), perform a
|
||||||
|
single-query attention pooling over the H*W tokens for each time frame, producing
|
||||||
|
[B, T, D].
|
||||||
|
|
||||||
|
Inspired by SigLIP's Multihead Attention Pooling head (without MLP/residual stack).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
|
||||||
|
super().__init__()
|
||||||
|
assert dim % num_heads == 0, "dim must be divisible by num_heads"
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
self.probe = nn.Parameter(torch.randn(1, 1, dim))
|
||||||
|
nn.init.normal_(self.probe, std=0.02)
|
||||||
|
|
||||||
|
self.attention = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)
|
||||||
|
self.layernorm = nn.LayerNorm(dim, eps=eps)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, grid_size: Tuple[int, int, int]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: [B, L, D], where L = T*H*W
|
||||||
|
grid_size: (T, H, W)
|
||||||
|
Returns:
|
||||||
|
pooled: [B, T, D]
|
||||||
|
"""
|
||||||
|
B, L, D = x.shape
|
||||||
|
T, H, W = grid_size
|
||||||
|
assert D == self.dim, f"Channel dimension mismatch: D={D} vs dim={self.dim}"
|
||||||
|
assert L == T * H * W, f"Flattened length mismatch: L={L} vs T*H*W={T*H*W}"
|
||||||
|
|
||||||
|
S = H * W
|
||||||
|
# Re-arrange tokens grouped by frame.
|
||||||
|
x_bt_s_d = x.view(B, T, S, D).contiguous().view(B * T, S, D) # [B*T, S, D]
|
||||||
|
|
||||||
|
# A learnable probe as the query (one query per frame).
|
||||||
|
probe = self.probe.expand(B * T, -1, -1) # [B*T, 1, D]
|
||||||
|
|
||||||
|
# Attention pooling: query=probe, key/value=H*W tokens within the frame.
|
||||||
|
pooled_bt_1_d = self.attention(probe, x_bt_s_d, x_bt_s_d, need_weights=False)[0] # [B*T, 1, D]
|
||||||
|
pooled_bt_d = pooled_bt_1_d.squeeze(1) # [B*T, D]
|
||||||
|
|
||||||
|
# Restore to [B, T, D].
|
||||||
|
pooled = pooled_bt_d.view(B, T, D)
|
||||||
|
pooled = self.layernorm(pooled)
|
||||||
|
return pooled
|
||||||
|
|
||||||
|
|
||||||
|
class CrossModalInteractionController:
|
||||||
|
"""
|
||||||
|
Strategy class that controls interactions between two towers.
|
||||||
|
Manages the interaction mapping between visual DiT (e.g. 30 layers) and audio DiT (e.g. 30 layers).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, visual_layers: int = 30, audio_layers: int = 30):
|
||||||
|
self.visual_layers = visual_layers
|
||||||
|
self.audio_layers = audio_layers
|
||||||
|
self.min_layers = min(visual_layers, audio_layers)
|
||||||
|
|
||||||
|
def get_interaction_layers(self, strategy: str = "shallow_focus") -> Dict[str, List[Tuple[int, int]]]:
|
||||||
|
"""
|
||||||
|
Get interaction layer mappings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strategy: interaction strategy
|
||||||
|
- "shallow_focus": emphasize shallow layers to avoid deep-layer asymmetry
|
||||||
|
- "distributed": distributed interactions across the network
|
||||||
|
- "progressive": dense shallow interactions, sparse deeper interactions
|
||||||
|
- "custom": custom interaction layers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict containing mappings for 'v2a' (visual -> audio) and 'a2v' (audio -> visual).
|
||||||
|
"""
|
||||||
|
|
||||||
|
if strategy == "shallow_focus":
|
||||||
|
# Emphasize the first ~1/3 layers to avoid deep-layer asymmetry.
|
||||||
|
num_interact = min(10, self.min_layers // 3)
|
||||||
|
interact_layers = list(range(0, num_interact))
|
||||||
|
|
||||||
|
elif strategy == "distributed":
|
||||||
|
# Distribute interactions across the network (every few layers).
|
||||||
|
step = 3
|
||||||
|
interact_layers = list(range(0, self.min_layers, step))
|
||||||
|
|
||||||
|
elif strategy == "progressive":
|
||||||
|
# Progressive: dense shallow interactions, sparse deeper interactions.
|
||||||
|
shallow = list(range(0, min(8, self.min_layers))) # Dense for the first 8 layers.
|
||||||
|
if self.min_layers > 8:
|
||||||
|
deep = list(range(8, self.min_layers, 3)) # Every 3 layers afterwards.
|
||||||
|
interact_layers = shallow + deep
|
||||||
|
else:
|
||||||
|
interact_layers = shallow
|
||||||
|
|
||||||
|
elif strategy == "custom":
|
||||||
|
# Custom strategy: adjust as needed.
|
||||||
|
interact_layers = [0, 2, 4, 6, 8, 12, 16, 20] # Explicit layer indices.
|
||||||
|
interact_layers = [i for i in interact_layers if i < self.min_layers]
|
||||||
|
|
||||||
|
elif strategy == "full":
|
||||||
|
interact_layers = list(range(0, self.min_layers))
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown interaction strategy: {strategy}")
|
||||||
|
|
||||||
|
# Build bidirectional mapping.
|
||||||
|
mapping = {
|
||||||
|
'v2a': [(i, i) for i in interact_layers], # visual layer i -> audio layer i
|
||||||
|
'a2v': [(i, i) for i in interact_layers] # audio layer i -> visual layer i
|
||||||
|
}
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
def should_interact(self, layer_idx: int, direction: str, interaction_mapping: Dict) -> bool:
|
||||||
|
"""
|
||||||
|
Check whether a given layer should interact.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_idx: current layer index
|
||||||
|
direction: interaction direction ('v2a' or 'a2v')
|
||||||
|
interaction_mapping: interaction mapping table
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: whether to interact
|
||||||
|
"""
|
||||||
|
if direction not in interaction_mapping:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return any(src == layer_idx for src, _ in interaction_mapping[direction])
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionalCrossAttention(nn.Module):
|
||||||
|
def __init__(self, dim: int, kv_dim: int, num_heads: int, eps: float = 1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.q_dim = dim
|
||||||
|
self.kv_dim = kv_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = self.q_dim // num_heads
|
||||||
|
|
||||||
|
self.q = nn.Linear(dim, dim)
|
||||||
|
self.k = nn.Linear(kv_dim, dim)
|
||||||
|
self.v = nn.Linear(kv_dim, dim)
|
||||||
|
self.o = nn.Linear(dim, dim)
|
||||||
|
self.norm_q = RMSNorm(dim, eps=eps)
|
||||||
|
self.norm_k = RMSNorm(dim, eps=eps)
|
||||||
|
|
||||||
|
self.attn = AttentionModule(self.num_heads)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, y: torch.Tensor, x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
|
||||||
|
ctx = y
|
||||||
|
q = self.norm_q(self.q(x))
|
||||||
|
k = self.norm_k(self.k(ctx))
|
||||||
|
v = self.v(ctx)
|
||||||
|
if x_freqs is not None:
|
||||||
|
x_cos, x_sin = x_freqs
|
||||||
|
B, L, _ = q.shape
|
||||||
|
q_view = rearrange(q, 'b l (h d) -> b l h d', d=self.head_dim)
|
||||||
|
x_cos = x_cos.to(q_view.dtype).to(q_view.device)
|
||||||
|
x_sin = x_sin.to(q_view.dtype).to(q_view.device)
|
||||||
|
# Expect x_cos/x_sin shape: [B or 1, L, head_dim]
|
||||||
|
q_view, _ = apply_rotary_pos_emb(q_view, q_view, x_cos, x_sin, unsqueeze_dim=2)
|
||||||
|
q = rearrange(q_view, 'b l h d -> b l (h d)')
|
||||||
|
if y_freqs is not None:
|
||||||
|
y_cos, y_sin = y_freqs
|
||||||
|
Bc, Lc, _ = k.shape
|
||||||
|
k_view = rearrange(k, 'b l (h d) -> b l h d', d=self.head_dim)
|
||||||
|
y_cos = y_cos.to(k_view.dtype).to(k_view.device)
|
||||||
|
y_sin = y_sin.to(k_view.dtype).to(k_view.device)
|
||||||
|
# Expect y_cos/y_sin shape: [B or 1, L, head_dim]
|
||||||
|
_, k_view = apply_rotary_pos_emb(k_view, k_view, y_cos, y_sin, unsqueeze_dim=2)
|
||||||
|
k = rearrange(k_view, 'b l h d -> b l (h d)')
|
||||||
|
x = self.attn(q, k, v)
|
||||||
|
return self.o(x)
|
||||||
|
|
||||||
|
|
||||||
|
# from diffusers.models.attention import AdaLayerNorm
|
||||||
|
class AdaLayerNorm(nn.Module):
|
||||||
|
r"""
|
||||||
|
Norm layer modified to incorporate timestep embeddings.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
embedding_dim (`int`): The size of each embedding vector.
|
||||||
|
num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
|
||||||
|
output_dim (`int`, *optional*):
|
||||||
|
norm_elementwise_affine (`bool`, defaults to `False):
|
||||||
|
norm_eps (`bool`, defaults to `False`):
|
||||||
|
chunk_dim (`int`, defaults to `0`):
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
num_embeddings: Optional[int] = None,
|
||||||
|
output_dim: Optional[int] = None,
|
||||||
|
norm_elementwise_affine: bool = False,
|
||||||
|
norm_eps: float = 1e-5,
|
||||||
|
chunk_dim: int = 0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.chunk_dim = chunk_dim
|
||||||
|
output_dim = output_dim or embedding_dim * 2
|
||||||
|
|
||||||
|
if num_embeddings is not None:
|
||||||
|
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
||||||
|
else:
|
||||||
|
self.emb = None
|
||||||
|
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||||
|
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if self.emb is not None:
|
||||||
|
temb = self.emb(timestep)
|
||||||
|
|
||||||
|
temb = self.linear(self.silu(temb))
|
||||||
|
|
||||||
|
if self.chunk_dim == 2:
|
||||||
|
scale, shift = temb.chunk(2, dim=2)
|
||||||
|
# print(f"{x.shape = }, {scale.shape = }, {shift.shape = }")
|
||||||
|
elif self.chunk_dim == 1:
|
||||||
|
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
|
||||||
|
# other if-branch. This branch is specific to CogVideoX and OmniGen for now.
|
||||||
|
shift, scale = temb.chunk(2, dim=1)
|
||||||
|
shift = shift[:, None, :]
|
||||||
|
scale = scale[:, None, :]
|
||||||
|
else:
|
||||||
|
scale, shift = temb.chunk(2, dim=0)
|
||||||
|
|
||||||
|
x = self.norm(x) * (1 + scale) + shift
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionalCrossAttentionBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
A thin wrapper around ConditionalCrossAttention.
|
||||||
|
Applies LayerNorm to the conditioning input `y` before cross-attention.
|
||||||
|
"""
|
||||||
|
def __init__(self, dim: int, kv_dim: int, num_heads: int, eps: float = 1e-6, pooled_adaln: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.y_norm = nn.LayerNorm(kv_dim, eps=eps)
|
||||||
|
self.inner = ConditionalCrossAttention(dim=dim, kv_dim=kv_dim, num_heads=num_heads, eps=eps)
|
||||||
|
self.pooled_adaln = pooled_adaln
|
||||||
|
if pooled_adaln:
|
||||||
|
self.per_frame_pooling = PerFrameAttentionPooling(kv_dim, num_heads=num_heads, eps=eps)
|
||||||
|
self.adaln = AdaLayerNorm(kv_dim, output_dim=dim*2, chunk_dim=2)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
video_grid_size: Optional[Tuple[int, int, int]] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if self.pooled_adaln:
|
||||||
|
assert video_grid_size is not None, "video_grid_size must not be None"
|
||||||
|
pooled_y = self.per_frame_pooling(y, video_grid_size)
|
||||||
|
# Interpolate pooled_y along its temporal dimension to match x's sequence length.
|
||||||
|
if pooled_y.shape[1] != x.shape[1]:
|
||||||
|
pooled_y = F.interpolate(
|
||||||
|
pooled_y.permute(0, 2, 1), # [B, C, T]
|
||||||
|
size=x.shape[1],
|
||||||
|
mode='linear',
|
||||||
|
align_corners=False,
|
||||||
|
).permute(0, 2, 1) # [B, T, C]
|
||||||
|
x = self.adaln(x, temb=pooled_y)
|
||||||
|
y = self.y_norm(y)
|
||||||
|
return self.inner(x=x, y=y, x_freqs=x_freqs, y_freqs=y_freqs)
|
||||||
|
|
||||||
|
|
||||||
|
class DualTowerConditionalBridge(nn.Module):
|
||||||
|
"""
|
||||||
|
Dual-tower conditional bridge.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
visual_layers: int = 40,
|
||||||
|
audio_layers: int = 30,
|
||||||
|
visual_hidden_dim: int = 5120, # visual DiT hidden state dimension
|
||||||
|
audio_hidden_dim: int = 1536, # audio DiT hidden state dimension
|
||||||
|
audio_fps: float = 50.0,
|
||||||
|
head_dim: int = 128, # attention head dimension
|
||||||
|
interaction_strategy: str = "full",
|
||||||
|
apply_cross_rope: bool = True, # whether to apply RoPE in cross-attention
|
||||||
|
apply_first_frame_bias_in_rope: bool = False, # whether to account for 1/video_fps bias for the first frame in RoPE alignment
|
||||||
|
trainable_condition_scale: bool = False,
|
||||||
|
pooled_adaln: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.visual_hidden_dim = visual_hidden_dim
|
||||||
|
self.audio_hidden_dim = audio_hidden_dim
|
||||||
|
self.audio_fps = audio_fps
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.apply_cross_rope = apply_cross_rope
|
||||||
|
self.apply_first_frame_bias_in_rope = apply_first_frame_bias_in_rope
|
||||||
|
self.trainable_condition_scale = trainable_condition_scale
|
||||||
|
self.pooled_adaln = pooled_adaln
|
||||||
|
if self.trainable_condition_scale:
|
||||||
|
self.condition_scale = nn.Parameter(torch.tensor([1.0], dtype=torch.float32))
|
||||||
|
else:
|
||||||
|
self.condition_scale = 1.0
|
||||||
|
|
||||||
|
self.controller = CrossModalInteractionController(visual_layers, audio_layers)
|
||||||
|
self.interaction_mapping = self.controller.get_interaction_layers(interaction_strategy)
|
||||||
|
|
||||||
|
# Conditional cross-attention modules operating at the DiT hidden-state level.
|
||||||
|
self.audio_to_video_conditioners = nn.ModuleDict() # audio hidden states -> visual DiT conditioning
|
||||||
|
self.video_to_audio_conditioners = nn.ModuleDict() # visual hidden states -> audio DiT conditioning
|
||||||
|
|
||||||
|
# Build conditioners for layers that should interact.
|
||||||
|
# audio hidden states condition the visual DiT
|
||||||
|
self.rotary = RotaryEmbedding(base=10000.0, dim=head_dim)
|
||||||
|
for v_layer, _ in self.interaction_mapping['a2v']:
|
||||||
|
self.audio_to_video_conditioners[str(v_layer)] = ConditionalCrossAttentionBlock(
|
||||||
|
dim=visual_hidden_dim, # 3072 (visual DiT hidden states)
|
||||||
|
kv_dim=audio_hidden_dim, # 1536 (audio DiT hidden states)
|
||||||
|
num_heads=visual_hidden_dim // head_dim, # derive number of heads from hidden dim
|
||||||
|
pooled_adaln=False # a2v typically does not need pooled AdaLN
|
||||||
|
)
|
||||||
|
|
||||||
|
# visual hidden states condition the audio DiT
|
||||||
|
for a_layer, _ in self.interaction_mapping['v2a']:
|
||||||
|
self.video_to_audio_conditioners[str(a_layer)] = ConditionalCrossAttentionBlock(
|
||||||
|
dim=audio_hidden_dim, # 1536 (audio DiT hidden states)
|
||||||
|
kv_dim=visual_hidden_dim, # 3072 (visual DiT hidden states)
|
||||||
|
num_heads=audio_hidden_dim // head_dim, # safe head count derivation
|
||||||
|
pooled_adaln=self.pooled_adaln
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def build_aligned_freqs(self,
|
||||||
|
video_fps: float,
|
||||||
|
grid_size: Tuple[int, int, int],
|
||||||
|
audio_steps: int,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Build aligned RoPE (cos, sin) pairs based on video fps, video grid size (f_v, h, w),
|
||||||
|
and audio sequence length `audio_steps` (with fixed audio fps = 44100/2048).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
visual_freqs: (cos_v, sin_v), shape [1, f_v*h*w, head_dim]
|
||||||
|
audio_freqs: (cos_a, sin_a), shape [1, audio_steps, head_dim]
|
||||||
|
"""
|
||||||
|
f_v, h, w = grid_size
|
||||||
|
L_v = f_v * h * w
|
||||||
|
L_a = int(audio_steps)
|
||||||
|
|
||||||
|
device = device or next(self.parameters()).device
|
||||||
|
dtype = dtype or torch.float32
|
||||||
|
|
||||||
|
# Audio positions: 0,1,2,...,L_a-1 (audio as reference).
|
||||||
|
audio_pos = torch.arange(L_a, device=device, dtype=torch.float32).unsqueeze(0)
|
||||||
|
|
||||||
|
# Video positions: align video frames to audio-step units.
|
||||||
|
# FIXME(dhyu): hard-coded VAE temporal stride = 4
|
||||||
|
if self.apply_first_frame_bias_in_rope:
|
||||||
|
# Account for the "first frame lasts 1/video_fps" bias.
|
||||||
|
video_effective_fps = float(video_fps) / 4.0
|
||||||
|
if f_v > 0:
|
||||||
|
t_starts = torch.zeros((f_v,), device=device, dtype=torch.float32)
|
||||||
|
if f_v > 1:
|
||||||
|
t_starts[1:] = (1.0 / float(video_fps)) + torch.arange(f_v - 1, device=device, dtype=torch.float32) * (1.0 / video_effective_fps)
|
||||||
|
else:
|
||||||
|
t_starts = torch.zeros((0,), device=device, dtype=torch.float32)
|
||||||
|
# Convert to audio-step units.
|
||||||
|
video_pos_per_frame = t_starts * float(self.audio_fps)
|
||||||
|
else:
|
||||||
|
# No first-frame bias: uniform alignment.
|
||||||
|
scale = float(self.audio_fps) / float(video_fps / 4.0)
|
||||||
|
video_pos_per_frame = torch.arange(f_v, device=device, dtype=torch.float32) * scale
|
||||||
|
# Flatten to f*h*w; tokens within the same frame share the same time position.
|
||||||
|
video_pos = video_pos_per_frame.repeat_interleave(h * w).unsqueeze(0)
|
||||||
|
|
||||||
|
# print(f"video fps: {video_fps}, audio fps: {self.audio_fps}, scale: {scale}")
|
||||||
|
# print(f"video pos: {video_pos.shape}, audio pos: {audio_pos.shape}")
|
||||||
|
|
||||||
|
# Build dummy x to produce cos/sin, dim=head_dim.
|
||||||
|
dummy_v = torch.zeros((1, L_v, self.head_dim), device=device, dtype=dtype)
|
||||||
|
dummy_a = torch.zeros((1, L_a, self.head_dim), device=device, dtype=dtype)
|
||||||
|
|
||||||
|
cos_v, sin_v = self.rotary(dummy_v, position_ids=video_pos)
|
||||||
|
cos_a, sin_a = self.rotary(dummy_a, position_ids=audio_pos)
|
||||||
|
|
||||||
|
return (cos_v, sin_v), (cos_a, sin_a)
|
||||||
|
|
||||||
|
def should_interact(self, layer_idx: int, direction: str) -> bool:
|
||||||
|
return self.controller.should_interact(layer_idx, direction, self.interaction_mapping)
|
||||||
|
|
||||||
|
def apply_conditional_control(
|
||||||
|
self,
|
||||||
|
layer_idx: int,
|
||||||
|
direction: str,
|
||||||
|
primary_hidden_states: torch.Tensor,
|
||||||
|
condition_hidden_states: torch.Tensor,
|
||||||
|
x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
condition_scale: Optional[float] = None,
|
||||||
|
video_grid_size: Optional[Tuple[int, int, int]] = None,
|
||||||
|
use_gradient_checkpointing: Optional[bool] = False,
|
||||||
|
use_gradient_checkpointing_offload: Optional[bool] = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply conditional control (at the DiT hidden-state level).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_idx: current layer index
|
||||||
|
direction: conditioning direction
|
||||||
|
- 'a2v': audio hidden states -> visual DiT
|
||||||
|
- 'v2a': visual hidden states -> audio DiT
|
||||||
|
primary_hidden_states: primary DiT hidden states [B, L, hidden_dim]
|
||||||
|
condition_hidden_states: condition DiT hidden states [B, L, hidden_dim]
|
||||||
|
condition_scale: conditioning strength (similar to CFG scale)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Conditioned primary DiT hidden states [B, L, hidden_dim]
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not self.controller.should_interact(layer_idx, direction, self.interaction_mapping):
|
||||||
|
return primary_hidden_states
|
||||||
|
|
||||||
|
if direction == 'a2v':
|
||||||
|
# audio hidden states condition the visual DiT
|
||||||
|
conditioner = self.audio_to_video_conditioners[str(layer_idx)]
|
||||||
|
|
||||||
|
elif direction == 'v2a':
|
||||||
|
# visual hidden states condition the audio DiT
|
||||||
|
conditioner = self.video_to_audio_conditioners[str(layer_idx)]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid direction: {direction}")
|
||||||
|
|
||||||
|
conditioned_features = gradient_checkpoint_forward(
|
||||||
|
conditioner,
|
||||||
|
use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload,
|
||||||
|
x=primary_hidden_states,
|
||||||
|
y=condition_hidden_states,
|
||||||
|
x_freqs=x_freqs,
|
||||||
|
y_freqs=y_freqs,
|
||||||
|
video_grid_size=video_grid_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.trainable_condition_scale and condition_scale is not None:
|
||||||
|
print(
|
||||||
|
"[WARN] This model has a trainable condition_scale, but an external "
|
||||||
|
f"condition_scale={condition_scale} was provided. The trainable condition_scale "
|
||||||
|
"will be ignored in favor of the external value."
|
||||||
|
)
|
||||||
|
|
||||||
|
scale = condition_scale if condition_scale is not None else self.condition_scale
|
||||||
|
|
||||||
|
primary_hidden_states = primary_hidden_states + conditioned_features * scale
|
||||||
|
|
||||||
|
return primary_hidden_states
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
layer_idx: int,
|
||||||
|
visual_hidden_states: torch.Tensor,
|
||||||
|
audio_hidden_states: torch.Tensor,
|
||||||
|
*,
|
||||||
|
x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
a2v_condition_scale: Optional[float] = None,
|
||||||
|
v2a_condition_scale: Optional[float] = None,
|
||||||
|
condition_scale: Optional[float] = None,
|
||||||
|
video_grid_size: Optional[Tuple[int, int, int]] = None,
|
||||||
|
use_gradient_checkpointing: Optional[bool] = False,
|
||||||
|
use_gradient_checkpointing_offload: Optional[bool] = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Apply bidirectional conditional control to both visual/audio towers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_idx: current layer index
|
||||||
|
visual_hidden_states: visual DiT hidden states
|
||||||
|
audio_hidden_states: audio DiT hidden states
|
||||||
|
x_freqs / y_freqs: cross-modal RoPE (cos, sin) pairs.
|
||||||
|
If provided, x_freqs is assumed to correspond to the primary tower and y_freqs
|
||||||
|
to the conditioning tower.
|
||||||
|
a2v_condition_scale: audio->visual conditioning strength (overrides global condition_scale)
|
||||||
|
v2a_condition_scale: visual->audio conditioning strength (overrides global condition_scale)
|
||||||
|
condition_scale: fallback conditioning strength when per-direction scale is None
|
||||||
|
video_grid_size: (F, H, W), used on the audio side when pooled_adaln is enabled
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(visual_hidden_states, audio_hidden_states), both conditioned in their respective directions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
visual_conditioned = self.apply_conditional_control(
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
direction="a2v",
|
||||||
|
primary_hidden_states=visual_hidden_states,
|
||||||
|
condition_hidden_states=audio_hidden_states,
|
||||||
|
x_freqs=x_freqs,
|
||||||
|
y_freqs=y_freqs,
|
||||||
|
condition_scale=a2v_condition_scale if a2v_condition_scale is not None else condition_scale,
|
||||||
|
video_grid_size=video_grid_size,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_conditioned = self.apply_conditional_control(
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
direction="v2a",
|
||||||
|
primary_hidden_states=audio_hidden_states,
|
||||||
|
condition_hidden_states=visual_hidden_states,
|
||||||
|
x_freqs=y_freqs,
|
||||||
|
y_freqs=x_freqs,
|
||||||
|
condition_scale=v2a_condition_scale if v2a_condition_scale is not None else condition_scale,
|
||||||
|
video_grid_size=video_grid_size,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
)
|
||||||
|
|
||||||
|
return visual_conditioned, audio_conditioned
|
||||||
@@ -549,6 +549,9 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class QwenImageDiT(torch.nn.Module):
|
class QwenImageDiT(torch.nn.Module):
|
||||||
|
|
||||||
|
_repeated_blocks = ["QwenImageTransformerBlock"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_layers: int = 60,
|
num_layers: int = 60,
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import Tuple, Optional
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from .wan_video_camera_controller import SimpleAdapter
|
from .wan_video_camera_controller import SimpleAdapter
|
||||||
from ..core.gradient import gradient_checkpoint_forward
|
from ..core.gradient import gradient_checkpoint_forward
|
||||||
|
from .wantodance import WanToDanceRotaryEmbedding, WanToDanceMusicEncoderLayer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import flash_attn_interface
|
import flash_attn_interface
|
||||||
@@ -99,17 +100,29 @@ def rope_apply(x, freqs, num_heads):
|
|||||||
return x_out.to(x.dtype)
|
return x_out.to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def set_to_torch_norm(models):
|
||||||
|
for model in models:
|
||||||
|
for module in model.modules():
|
||||||
|
if isinstance(module, RMSNorm):
|
||||||
|
module.use_torch_norm = True
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
def __init__(self, dim, eps=1e-5):
|
def __init__(self, dim, eps=1e-5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.weight = nn.Parameter(torch.ones(dim))
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
self.use_torch_norm = False
|
||||||
|
self.normalized_shape = (dim,)
|
||||||
|
|
||||||
def norm(self, x):
|
def norm(self, x):
|
||||||
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
dtype = x.dtype
|
dtype = x.dtype
|
||||||
|
if self.use_torch_norm:
|
||||||
|
return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
|
||||||
|
else:
|
||||||
return self.norm(x.float()).to(dtype) * self.weight
|
return self.norm(x.float()).to(dtype) * self.weight
|
||||||
|
|
||||||
|
|
||||||
@@ -271,7 +284,61 @@ class Head(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def wantodance_torch_dfs(model: nn.Module, parent_name='root'):
|
||||||
|
module_names, modules = [], []
|
||||||
|
current_name = parent_name if parent_name else 'root'
|
||||||
|
module_names.append(current_name)
|
||||||
|
modules.append(model)
|
||||||
|
for name, child in model.named_children():
|
||||||
|
if parent_name:
|
||||||
|
child_name = f'{parent_name}.{name}'
|
||||||
|
else:
|
||||||
|
child_name = name
|
||||||
|
child_modules, child_names = wantodance_torch_dfs(child, child_name)
|
||||||
|
module_names += child_names
|
||||||
|
modules += child_modules
|
||||||
|
return modules, module_names
|
||||||
|
|
||||||
|
|
||||||
|
class WanToDanceInjector(nn.Module):
|
||||||
|
def __init__(self, all_modules, all_modules_names, dim=2048, num_heads=32, inject_layer=[0, 27]):
|
||||||
|
super().__init__()
|
||||||
|
self.injected_block_id = {}
|
||||||
|
injector_id = 0
|
||||||
|
for mod_name, mod in zip(all_modules_names, all_modules):
|
||||||
|
if isinstance(mod, DiTBlock):
|
||||||
|
for inject_id in inject_layer:
|
||||||
|
if f'root.transformer_blocks.{inject_id}' == mod_name:
|
||||||
|
self.injected_block_id[inject_id] = injector_id
|
||||||
|
injector_id += 1
|
||||||
|
|
||||||
|
self.injector = nn.ModuleList(
|
||||||
|
[
|
||||||
|
CrossAttention(
|
||||||
|
dim=dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
)
|
||||||
|
for _ in range(injector_id)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.injector_pre_norm_feat = nn.ModuleList(
|
||||||
|
[
|
||||||
|
nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6,)
|
||||||
|
for _ in range(injector_id)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.injector_pre_norm_vec = nn.ModuleList(
|
||||||
|
[
|
||||||
|
nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6,)
|
||||||
|
for _ in range(injector_id)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WanModel(torch.nn.Module):
|
class WanModel(torch.nn.Module):
|
||||||
|
|
||||||
|
_repeated_blocks = ["DiTBlock"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim: int,
|
dim: int,
|
||||||
@@ -293,6 +360,13 @@ class WanModel(torch.nn.Module):
|
|||||||
require_vae_embedding: bool = True,
|
require_vae_embedding: bool = True,
|
||||||
require_clip_embedding: bool = True,
|
require_clip_embedding: bool = True,
|
||||||
fuse_vae_embedding_in_latents: bool = False,
|
fuse_vae_embedding_in_latents: bool = False,
|
||||||
|
wantodance_enable_music_inject: bool = False,
|
||||||
|
wantodance_music_inject_layers = [0, 4, 8, 12, 16, 20, 24, 27],
|
||||||
|
wantodance_enable_refimage: bool = False,
|
||||||
|
wantodance_enable_refface: bool = False,
|
||||||
|
wantodance_enable_global: bool = False,
|
||||||
|
wantodance_enable_dynamicfps: bool = False,
|
||||||
|
wantodance_enable_unimodel: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@@ -325,6 +399,11 @@ class WanModel(torch.nn.Module):
|
|||||||
])
|
])
|
||||||
self.head = Head(dim, out_dim, patch_size, eps)
|
self.head = Head(dim, out_dim, patch_size, eps)
|
||||||
head_dim = dim // num_heads
|
head_dim = dim // num_heads
|
||||||
|
|
||||||
|
if wantodance_enable_dynamicfps or wantodance_enable_unimodel:
|
||||||
|
end = int(22350 / 8 + 0.5) # 149f * 30fps * 5s = 22350
|
||||||
|
self.freqs = precompute_freqs_cis_3d(head_dim, end=end)
|
||||||
|
else:
|
||||||
self.freqs = precompute_freqs_cis_3d(head_dim)
|
self.freqs = precompute_freqs_cis_3d(head_dim)
|
||||||
|
|
||||||
if has_image_input:
|
if has_image_input:
|
||||||
@@ -338,7 +417,82 @@ class WanModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.control_adapter = None
|
self.control_adapter = None
|
||||||
|
|
||||||
def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None):
|
self.prepare_wantodance(in_dim, dim, num_heads, has_image_pos_emb, out_dim, patch_size, eps,
|
||||||
|
wantodance_enable_music_inject, wantodance_music_inject_layers, wantodance_enable_refimage, wantodance_enable_refface,
|
||||||
|
wantodance_enable_global, wantodance_enable_dynamicfps, wantodance_enable_unimodel)
|
||||||
|
|
||||||
|
def prepare_wantodance(
|
||||||
|
self,
|
||||||
|
in_dim, dim, num_heads, has_image_pos_emb, out_dim, patch_size, eps,
|
||||||
|
wantodance_enable_music_inject: bool = False,
|
||||||
|
wantodance_music_inject_layers = [0, 4, 8, 12, 16, 20, 24, 27],
|
||||||
|
wantodance_enable_refimage: bool = False,
|
||||||
|
wantodance_enable_refface: bool = False,
|
||||||
|
wantodance_enable_global: bool = False,
|
||||||
|
wantodance_enable_dynamicfps: bool = False,
|
||||||
|
wantodance_enable_unimodel: bool = False,
|
||||||
|
):
|
||||||
|
if wantodance_enable_music_inject:
|
||||||
|
all_modules, all_modules_names = wantodance_torch_dfs(self.blocks, parent_name="root.transformer_blocks")
|
||||||
|
self.music_injector = WanToDanceInjector(all_modules, all_modules_names, dim=dim, num_heads=num_heads, inject_layer=wantodance_music_inject_layers)
|
||||||
|
if wantodance_enable_refimage:
|
||||||
|
self.img_emb_refimage = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
|
||||||
|
if wantodance_enable_refface:
|
||||||
|
self.img_emb_refface = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
|
||||||
|
if wantodance_enable_global or wantodance_enable_dynamicfps or wantodance_enable_unimodel:
|
||||||
|
music_feature_dim = 35
|
||||||
|
ff_size = 1024
|
||||||
|
dropout = 0.1
|
||||||
|
latent_dim = 256
|
||||||
|
nhead = 4
|
||||||
|
activation = F.gelu
|
||||||
|
rotary = WanToDanceRotaryEmbedding(dim=latent_dim)
|
||||||
|
self.music_projection = nn.Linear(music_feature_dim, latent_dim)
|
||||||
|
self.music_encoder = nn.Sequential()
|
||||||
|
for _ in range(2):
|
||||||
|
self.music_encoder.append(
|
||||||
|
WanToDanceMusicEncoderLayer(
|
||||||
|
d_model=latent_dim,
|
||||||
|
nhead=nhead,
|
||||||
|
dim_feedforward=ff_size,
|
||||||
|
dropout=dropout,
|
||||||
|
activation=activation,
|
||||||
|
batch_first=True,
|
||||||
|
rotary=rotary,
|
||||||
|
device='cuda',
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if wantodance_enable_unimodel:
|
||||||
|
self.patch_embedding_global = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||||
|
if wantodance_enable_unimodel:
|
||||||
|
self.head_global = Head(dim, out_dim, patch_size, eps)
|
||||||
|
self.wantodance_enable_music_inject = wantodance_enable_music_inject
|
||||||
|
self.wantodance_enable_refimage = wantodance_enable_refimage
|
||||||
|
self.wantodance_enable_refface = wantodance_enable_refface
|
||||||
|
self.wantodance_enable_global = wantodance_enable_global
|
||||||
|
self.wantodance_enable_dynamicfps = wantodance_enable_dynamicfps
|
||||||
|
self.wantodance_enable_unimodel = wantodance_enable_unimodel
|
||||||
|
|
||||||
|
def wantodance_after_transformer_block(self, block_idx, hidden_states):
|
||||||
|
if self.wantodance_enable_music_inject:
|
||||||
|
if block_idx in self.music_injector.injected_block_id.keys():
|
||||||
|
audio_attn_id = self.music_injector.injected_block_id[block_idx]
|
||||||
|
audio_emb = self.merged_audio_emb # b f n c
|
||||||
|
num_frames = audio_emb.shape[1]
|
||||||
|
input_hidden_states = hidden_states.clone() # b (f h w) c
|
||||||
|
input_hidden_states = rearrange(input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames)
|
||||||
|
attn_hidden_states = self.music_injector.injector_pre_norm_feat[audio_attn_id](input_hidden_states)
|
||||||
|
audio_emb = rearrange(audio_emb, "b t c -> (b t) 1 c", t=num_frames)
|
||||||
|
attn_audio_emb = audio_emb
|
||||||
|
residual_out = self.music_injector.injector[audio_attn_id](attn_hidden_states, attn_audio_emb)
|
||||||
|
residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames)
|
||||||
|
hidden_states = hidden_states + residual_out
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None, enable_wantodance_global=False):
|
||||||
|
if enable_wantodance_global:
|
||||||
|
x = self.patch_embedding_global(x)
|
||||||
|
else:
|
||||||
x = self.patch_embedding(x)
|
x = self.patch_embedding(x)
|
||||||
if self.control_adapter is not None and control_camera_latents_input is not None:
|
if self.control_adapter is not None and control_camera_latents_input is not None:
|
||||||
y_camera = self.control_adapter(control_camera_latents_input)
|
y_camera = self.control_adapter(control_camera_latents_input)
|
||||||
|
|||||||
@@ -469,7 +469,7 @@ class Down_ResidualBlock(nn.Module):
|
|||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
x_copy = x.clone()
|
x_copy = x.clone()
|
||||||
for module in self.downsamples:
|
for module in self.downsamples:
|
||||||
x = module(x, feat_cache, feat_idx)
|
x, feat_cache, feat_idx = module(x, feat_cache, feat_idx)
|
||||||
|
|
||||||
return x + self.avg_shortcut(x_copy), feat_cache, feat_idx
|
return x + self.avg_shortcut(x_copy), feat_cache, feat_idx
|
||||||
|
|
||||||
@@ -506,10 +506,10 @@ class Up_ResidualBlock(nn.Module):
|
|||||||
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||||
x_main = x.clone()
|
x_main = x.clone()
|
||||||
for module in self.upsamples:
|
for module in self.upsamples:
|
||||||
x_main = module(x_main, feat_cache, feat_idx)
|
x_main, feat_cache, feat_idx = module(x_main, feat_cache, feat_idx)
|
||||||
if self.avg_shortcut is not None:
|
if self.avg_shortcut is not None:
|
||||||
x_shortcut = self.avg_shortcut(x, first_chunk)
|
x_shortcut = self.avg_shortcut(x, first_chunk)
|
||||||
return x_main + x_shortcut
|
return x_main + x_shortcut, feat_cache, feat_idx
|
||||||
else:
|
else:
|
||||||
return x_main, feat_cache, feat_idx
|
return x_main, feat_cache, feat_idx
|
||||||
|
|
||||||
@@ -1247,6 +1247,22 @@ class WanVideoVAE(nn.Module):
|
|||||||
return videos
|
return videos
|
||||||
|
|
||||||
|
|
||||||
|
def encode_framewise(self, videos, device):
|
||||||
|
hidden_states = []
|
||||||
|
for i in range(videos.shape[2]):
|
||||||
|
hidden_states.append(self.single_encode(videos[:, :, i:i+1], device))
|
||||||
|
hidden_states = torch.concat(hidden_states, dim=2)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def decode_framewise(self, hidden_states, device):
|
||||||
|
video = []
|
||||||
|
for i in range(hidden_states.shape[2]):
|
||||||
|
video.append(self.single_decode(hidden_states[:, :, i:i+1], device))
|
||||||
|
video = torch.concat(video, dim=2)
|
||||||
|
return video
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def state_dict_converter():
|
def state_dict_converter():
|
||||||
return WanVideoVAEStateDictConverter()
|
return WanVideoVAEStateDictConverter()
|
||||||
|
|||||||
209
diffsynth/models/wantodance.py
Normal file
209
diffsynth/models/wantodance.py
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
from inspect import isfunction
|
||||||
|
from math import log, pi
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from torch import einsum, nn
|
||||||
|
|
||||||
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
from torch import Tensor
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
# helper functions
|
||||||
|
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
|
||||||
|
def broadcat(tensors, dim=-1):
|
||||||
|
num_tensors = len(tensors)
|
||||||
|
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
||||||
|
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
||||||
|
shape_len = list(shape_lens)[0]
|
||||||
|
|
||||||
|
dim = (dim + shape_len) if dim < 0 else dim
|
||||||
|
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
||||||
|
|
||||||
|
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
||||||
|
assert all(
|
||||||
|
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
|
||||||
|
), "invalid dimensions for broadcastable concatentation"
|
||||||
|
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
||||||
|
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
||||||
|
expanded_dims.insert(dim, (dim, dims[dim]))
|
||||||
|
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
||||||
|
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
||||||
|
return torch.cat(tensors, dim=dim)
|
||||||
|
|
||||||
|
|
||||||
|
# rotary embedding helper functions
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
||||||
|
x1, x2 = x.unbind(dim=-1)
|
||||||
|
x = torch.stack((-x2, x1), dim=-1)
|
||||||
|
return rearrange(x, "... d r -> ... (d r)")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb(freqs, t, start_index=0):
|
||||||
|
freqs = freqs.to(t)
|
||||||
|
rot_dim = freqs.shape[-1]
|
||||||
|
end_index = start_index + rot_dim
|
||||||
|
assert (
|
||||||
|
rot_dim <= t.shape[-1]
|
||||||
|
), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
|
||||||
|
t_left, t, t_right = (
|
||||||
|
t[..., :start_index],
|
||||||
|
t[..., start_index:end_index],
|
||||||
|
t[..., end_index:],
|
||||||
|
)
|
||||||
|
t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
|
||||||
|
return torch.cat((t_left, t, t_right), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
# learned rotation helpers
|
||||||
|
|
||||||
|
|
||||||
|
def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
|
||||||
|
if exists(freq_ranges):
|
||||||
|
rotations = einsum("..., f -> ... f", rotations, freq_ranges)
|
||||||
|
rotations = rearrange(rotations, "... r f -> ... (r f)")
|
||||||
|
|
||||||
|
rotations = repeat(rotations, "... n -> ... (n r)", r=2)
|
||||||
|
return apply_rotary_emb(rotations, t, start_index=start_index)
|
||||||
|
|
||||||
|
|
||||||
|
# classes
|
||||||
|
|
||||||
|
|
||||||
|
class WanToDanceRotaryEmbedding(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
custom_freqs=None,
|
||||||
|
freqs_for="lang",
|
||||||
|
theta=10000,
|
||||||
|
max_freq=10,
|
||||||
|
num_freqs=1,
|
||||||
|
learned_freq=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if exists(custom_freqs):
|
||||||
|
freqs = custom_freqs
|
||||||
|
elif freqs_for == "lang":
|
||||||
|
freqs = 1.0 / (
|
||||||
|
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
||||||
|
)
|
||||||
|
elif freqs_for == "pixel":
|
||||||
|
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
||||||
|
elif freqs_for == "constant":
|
||||||
|
freqs = torch.ones(num_freqs).float()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown modality {freqs_for}")
|
||||||
|
|
||||||
|
self.cache = dict()
|
||||||
|
|
||||||
|
if learned_freq:
|
||||||
|
self.freqs = nn.Parameter(freqs)
|
||||||
|
else:
|
||||||
|
self.register_buffer("freqs", freqs, persistent=False)
|
||||||
|
|
||||||
|
def rotate_queries_or_keys(self, t, seq_dim=-2):
|
||||||
|
device = t.device
|
||||||
|
seq_len = t.shape[seq_dim]
|
||||||
|
freqs = self.forward(
|
||||||
|
lambda: torch.arange(seq_len, device=device), cache_key=seq_len
|
||||||
|
)
|
||||||
|
return apply_rotary_emb(freqs, t)
|
||||||
|
|
||||||
|
def forward(self, t, cache_key=None):
|
||||||
|
if exists(cache_key) and cache_key in self.cache:
|
||||||
|
return self.cache[cache_key]
|
||||||
|
|
||||||
|
if isfunction(t):
|
||||||
|
t = t()
|
||||||
|
|
||||||
|
# freqs = self.freqs
|
||||||
|
freqs = self.freqs.to(t.device)
|
||||||
|
|
||||||
|
freqs = torch.einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
|
||||||
|
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
||||||
|
|
||||||
|
if exists(cache_key):
|
||||||
|
self.cache[cache_key] = freqs
|
||||||
|
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
|
||||||
|
class WanToDanceMusicEncoderLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
nhead: int,
|
||||||
|
dim_feedforward: int = 2048,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
||||||
|
layer_norm_eps: float = 1e-5,
|
||||||
|
batch_first: bool = False,
|
||||||
|
norm_first: bool = True,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
rotary=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = nn.MultiheadAttention(
|
||||||
|
d_model, nhead, dropout=dropout, batch_first=batch_first, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
# Implementation of Feedforward model
|
||||||
|
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||||
|
|
||||||
|
self.norm_first = norm_first
|
||||||
|
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
||||||
|
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
||||||
|
self.dropout1 = nn.Dropout(dropout)
|
||||||
|
self.dropout2 = nn.Dropout(dropout)
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
|
self.rotary = rotary
|
||||||
|
self.use_rotary = rotary is not None
|
||||||
|
|
||||||
|
# self-attention block
|
||||||
|
def _sa_block(
|
||||||
|
self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]
|
||||||
|
) -> Tensor:
|
||||||
|
qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
|
||||||
|
x = self.self_attn(
|
||||||
|
qk,
|
||||||
|
qk,
|
||||||
|
x,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
key_padding_mask=key_padding_mask,
|
||||||
|
need_weights=False,
|
||||||
|
)[0]
|
||||||
|
return self.dropout1(x)
|
||||||
|
|
||||||
|
# feed forward block
|
||||||
|
def _ff_block(self, x: Tensor) -> Tensor:
|
||||||
|
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
||||||
|
return self.dropout2(x)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
src: Tensor,
|
||||||
|
src_mask: Optional[Tensor] = None,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
x = src
|
||||||
|
if self.norm_first:
|
||||||
|
self.norm1.to(device=x.device)
|
||||||
|
self.norm2.to(device=x.device)
|
||||||
|
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
|
||||||
|
x = x + self._ff_block(self.norm2(x))
|
||||||
|
else:
|
||||||
|
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
|
||||||
|
x = self.norm2(x + self._ff_block(x))
|
||||||
|
return x
|
||||||
@@ -326,6 +326,7 @@ class RopeEmbedder:
|
|||||||
class ZImageDiT(nn.Module):
|
class ZImageDiT(nn.Module):
|
||||||
_supports_gradient_checkpointing = True
|
_supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["ZImageTransformerBlock"]
|
_no_split_modules = ["ZImageTransformerBlock"]
|
||||||
|
_repeated_blocks = ["ZImageTransformerBlock"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
584
diffsynth/pipelines/ace_step.py
Normal file
584
diffsynth/pipelines/ace_step.py
Normal file
@@ -0,0 +1,584 @@
|
|||||||
|
"""
|
||||||
|
ACE-Step Pipeline for DiffSynth-Studio.
|
||||||
|
|
||||||
|
Text-to-Music generation pipeline using ACE-Step 1.5 model.
|
||||||
|
"""
|
||||||
|
import re, torch
|
||||||
|
from typing import Optional, Dict, Any, List, Tuple
|
||||||
|
from tqdm import tqdm
|
||||||
|
import random, math
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
|
from ..diffusion import FlowMatchScheduler
|
||||||
|
from ..core import ModelConfig
|
||||||
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||||
|
|
||||||
|
from ..models.ace_step_dit import AceStepDiTModel
|
||||||
|
from ..models.ace_step_conditioner import AceStepConditionEncoder
|
||||||
|
from ..models.ace_step_text_encoder import AceStepTextEncoder
|
||||||
|
from ..models.ace_step_vae import AceStepVAE
|
||||||
|
from ..models.ace_step_tokenizer import AceStepTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class AceStepPipeline(BasePipeline):
|
||||||
|
"""Pipeline for ACE-Step text-to-music generation."""
|
||||||
|
|
||||||
|
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||||
|
super().__init__(
|
||||||
|
device=device,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
height_division_factor=1,
|
||||||
|
width_division_factor=1,
|
||||||
|
)
|
||||||
|
self.scheduler = FlowMatchScheduler("ACE-Step")
|
||||||
|
self.text_encoder: AceStepTextEncoder = None
|
||||||
|
self.conditioner: AceStepConditionEncoder = None
|
||||||
|
self.dit: AceStepDiTModel = None
|
||||||
|
self.vae: AceStepVAE = None
|
||||||
|
self.tokenizer_model: AceStepTokenizer = None
|
||||||
|
|
||||||
|
self.in_iteration_models = ("dit",)
|
||||||
|
self.units = [
|
||||||
|
AceStepUnit_TaskTypeChecker(),
|
||||||
|
AceStepUnit_PromptEmbedder(),
|
||||||
|
AceStepUnit_ReferenceAudioEmbedder(),
|
||||||
|
AceStepUnit_ContextLatentBuilder(),
|
||||||
|
AceStepUnit_ConditionEmbedder(),
|
||||||
|
AceStepUnit_NoiseInitializer(),
|
||||||
|
AceStepUnit_InputAudioEmbedder(),
|
||||||
|
]
|
||||||
|
self.model_fn = model_fn_ace_step
|
||||||
|
self.compilable_models = ["dit"]
|
||||||
|
|
||||||
|
self.sample_rate = 48000
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(
|
||||||
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
device: str = get_device_type(),
|
||||||
|
model_configs: list[ModelConfig] = [],
|
||||||
|
text_tokenizer_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||||
|
silence_latent_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
||||||
|
vram_limit: float = None,
|
||||||
|
):
|
||||||
|
pipe = AceStepPipeline(device=device, torch_dtype=torch_dtype)
|
||||||
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||||
|
|
||||||
|
pipe.text_encoder = model_pool.fetch_model("ace_step_text_encoder")
|
||||||
|
pipe.conditioner = model_pool.fetch_model("ace_step_conditioner")
|
||||||
|
pipe.dit = model_pool.fetch_model("ace_step_dit")
|
||||||
|
pipe.vae = model_pool.fetch_model("ace_step_vae")
|
||||||
|
pipe.vae.remove_weight_norm()
|
||||||
|
pipe.tokenizer_model = model_pool.fetch_model("ace_step_tokenizer")
|
||||||
|
|
||||||
|
if text_tokenizer_config is not None:
|
||||||
|
text_tokenizer_config.download_if_necessary()
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
pipe.tokenizer = AutoTokenizer.from_pretrained(text_tokenizer_config.path)
|
||||||
|
if silence_latent_config is not None:
|
||||||
|
silence_latent_config.download_if_necessary()
|
||||||
|
pipe.silence_latent = torch.load(silence_latent_config.path, weights_only=True).transpose(1, 2).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
|
||||||
|
# VRAM Management
|
||||||
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
# Prompt
|
||||||
|
prompt: str,
|
||||||
|
cfg_scale: float = 1.0,
|
||||||
|
# Lyrics
|
||||||
|
lyrics: str = "",
|
||||||
|
# Task type
|
||||||
|
task_type: Optional[str] = "text2music",
|
||||||
|
# Reference audio
|
||||||
|
reference_audios: List[torch.Tensor] = None,
|
||||||
|
# Source audio
|
||||||
|
src_audio: torch.Tensor = None,
|
||||||
|
denoising_strength: float = 1.0, # denoising_strength = 1 - cover_noise_strength
|
||||||
|
audio_cover_strength: float = 1.0,
|
||||||
|
# Audio codes
|
||||||
|
audio_code_string: Optional[str] = None,
|
||||||
|
# Inpainting
|
||||||
|
repainting_ranges: Optional[List[Tuple[float, float]]] = None,
|
||||||
|
repainting_strength: float = 1.0,
|
||||||
|
# Shape
|
||||||
|
duration: int = 60,
|
||||||
|
# Audio Meta
|
||||||
|
bpm: Optional[int] = 100,
|
||||||
|
keyscale: Optional[str] = "B minor",
|
||||||
|
timesignature: Optional[str] = "4",
|
||||||
|
vocal_language: Optional[str] = "unknown",
|
||||||
|
# Randomness
|
||||||
|
seed: int = None,
|
||||||
|
rand_device: str = "cpu",
|
||||||
|
# Steps
|
||||||
|
num_inference_steps: int = 8,
|
||||||
|
# Scheduler-specific parameters
|
||||||
|
shift: float = 1.0,
|
||||||
|
# Progress
|
||||||
|
progress_bar_cmd=tqdm,
|
||||||
|
):
|
||||||
|
# Scheduler
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=denoising_strength, shift=shift)
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
inputs_posi = {"prompt": prompt, "positive": True}
|
||||||
|
inputs_nega = {"positive": False}
|
||||||
|
inputs_shared = {
|
||||||
|
"cfg_scale": cfg_scale,
|
||||||
|
"lyrics": lyrics,
|
||||||
|
"task_type": task_type,
|
||||||
|
"reference_audios": reference_audios,
|
||||||
|
"src_audio": src_audio, "audio_cover_strength": audio_cover_strength, "audio_code_string": audio_code_string,
|
||||||
|
"repainting_ranges": repainting_ranges, "repainting_strength": repainting_strength,
|
||||||
|
"duration": duration,
|
||||||
|
"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "vocal_language": vocal_language,
|
||||||
|
"seed": seed,
|
||||||
|
"rand_device": rand_device,
|
||||||
|
"num_inference_steps": num_inference_steps,
|
||||||
|
"shift": shift,
|
||||||
|
}
|
||||||
|
|
||||||
|
for unit in self.units:
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
|
||||||
|
unit, self, inputs_shared, inputs_posi, inputs_nega
|
||||||
|
)
|
||||||
|
|
||||||
|
# Denoise
|
||||||
|
self.load_models_to_device(self.in_iteration_models)
|
||||||
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
self.switch_noncover_condition(inputs_shared, inputs_posi, inputs_nega, progress_id)
|
||||||
|
noise_pred = self.cfg_guided_model_fn(
|
||||||
|
self.model_fn, cfg_scale,
|
||||||
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id,
|
||||||
|
)
|
||||||
|
inputs_shared["latents"] = self.step(
|
||||||
|
self.scheduler, inpaint_mask=inputs_shared.get("denoise_mask", None), input_latents=inputs_shared.get("src_latents", None),
|
||||||
|
progress_id=progress_id, noise_pred=noise_pred, **inputs_shared,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
self.load_models_to_device(['vae'])
|
||||||
|
# DiT output is [B, T, 64] (channels-last), VAE expects [B, 64, T] (channels-first)
|
||||||
|
latents = inputs_shared["latents"].transpose(1, 2)
|
||||||
|
vae_output = self.vae.decode(latents)
|
||||||
|
audio_output = self.normalize_audio(vae_output, target_db=-1.0)
|
||||||
|
audio = self.output_audio_format_check(audio_output)
|
||||||
|
self.load_models_to_device([])
|
||||||
|
return audio
|
||||||
|
|
||||||
|
def normalize_audio(self, audio: torch.Tensor, target_db: float = -1.0) -> torch.Tensor:
|
||||||
|
peak = torch.max(torch.abs(audio))
|
||||||
|
if peak < 1e-6:
|
||||||
|
return audio
|
||||||
|
target_amp = 10 ** (target_db / 20.0)
|
||||||
|
gain = target_amp / peak
|
||||||
|
return audio * gain
|
||||||
|
|
||||||
|
def switch_noncover_condition(self, inputs_shared, inputs_posi, inputs_nega, progress_id):
|
||||||
|
if inputs_shared["task_type"] != "cover" or inputs_shared["audio_cover_strength"] >= 1.0:
|
||||||
|
return
|
||||||
|
if inputs_shared.get("shared_noncover", None) is None:
|
||||||
|
return
|
||||||
|
cover_steps = int(len(self.scheduler.timesteps) * inputs_shared["audio_cover_strength"])
|
||||||
|
if progress_id >= cover_steps:
|
||||||
|
inputs_shared.update(inputs_shared.pop("shared_noncover", {}))
|
||||||
|
inputs_posi.update(inputs_shared.pop("posi_noncover", {}))
|
||||||
|
if inputs_shared["cfg_scale"] != 1.0:
|
||||||
|
inputs_nega.update(inputs_shared.pop("nega_noncover", {}))
|
||||||
|
|
||||||
|
|
||||||
|
class AceStepUnit_TaskTypeChecker(PipelineUnit):
|
||||||
|
"""Check and compute sequence length from duration."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("task_type", "src_audio", "repainting_ranges", "audio_code_string"),
|
||||||
|
output_params=("task_type",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe, task_type, src_audio, repainting_ranges, audio_code_string):
|
||||||
|
assert task_type in ["text2music", "cover", "repaint"], f"Unsupported task_type: {task_type}"
|
||||||
|
if task_type == "cover":
|
||||||
|
assert (src_audio is not None) or (audio_code_string is not None), "For cover task, either src_audio or audio_code_string must be provided."
|
||||||
|
elif task_type == "repaint":
|
||||||
|
assert src_audio is not None, "For repaint task, src_audio must be provided."
|
||||||
|
assert repainting_ranges is not None and len(repainting_ranges) > 0, "For repaint task, inpainting_ranges must be provided and non-empty."
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class AceStepUnit_PromptEmbedder(PipelineUnit):
|
||||||
|
SFT_GEN_PROMPT = "# Instruction\n{}\n\n# Caption\n{}\n\n# Metas\n{}<|endoftext|>\n"
|
||||||
|
INSTRUCTION_MAP = {
|
||||||
|
"text2music": "Fill the audio semantic mask based on the given conditions:",
|
||||||
|
"cover": "Generate audio semantic tokens based on the given conditions:",
|
||||||
|
"repaint": "Repaint the mask area based on the given conditions:",
|
||||||
|
"extract": "Extract the {TRACK_NAME} track from the audio:",
|
||||||
|
"extract_default": "Extract the track from the audio:",
|
||||||
|
"lego": "Generate the {TRACK_NAME} track based on the audio context:",
|
||||||
|
"lego_default": "Generate the track based on the audio context:",
|
||||||
|
"complete": "Complete the input track with {TRACK_CLASSES}:",
|
||||||
|
"complete_default": "Complete the input track:",
|
||||||
|
}
|
||||||
|
LYRIC_PROMPT = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|>"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
seperate_cfg=True,
|
||||||
|
input_params_posi={"prompt": "prompt", "positive": "positive"},
|
||||||
|
input_params_nega={"prompt": "prompt", "positive": "positive"},
|
||||||
|
input_params=("lyrics", "duration", "bpm", "keyscale", "timesignature", "vocal_language", "task_type"),
|
||||||
|
output_params=("text_hidden_states", "text_attention_mask", "lyric_hidden_states", "lyric_attention_mask"),
|
||||||
|
onload_model_names=("text_encoder",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _encode_text(self, pipe, text, max_length=256):
|
||||||
|
"""Encode text using Qwen3-Embedding → [B, T, 1024]."""
|
||||||
|
text_inputs = pipe.tokenizer(
|
||||||
|
text,
|
||||||
|
max_length=max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
input_ids = text_inputs.input_ids.to(pipe.device)
|
||||||
|
attention_mask = text_inputs.attention_mask.bool().to(pipe.device)
|
||||||
|
hidden_states = pipe.text_encoder(input_ids, attention_mask)
|
||||||
|
return hidden_states, attention_mask
|
||||||
|
|
||||||
|
def _encode_lyrics(self, pipe, lyric_text, max_length=2048):
|
||||||
|
text_inputs = pipe.tokenizer(
|
||||||
|
lyric_text,
|
||||||
|
max_length=max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
input_ids = text_inputs.input_ids.to(pipe.device)
|
||||||
|
attention_mask = text_inputs.attention_mask.bool().to(pipe.device)
|
||||||
|
hidden_states = pipe.text_encoder.model.embed_tokens(input_ids)
|
||||||
|
return hidden_states, attention_mask
|
||||||
|
|
||||||
|
def _dict_to_meta_string(self, meta_dict: Dict[str, Any]) -> str:
|
||||||
|
bpm = meta_dict.get("bpm", "N/A")
|
||||||
|
timesignature = meta_dict.get("timesignature", "N/A")
|
||||||
|
keyscale = meta_dict.get("keyscale", "N/A")
|
||||||
|
duration = meta_dict.get("duration", 30)
|
||||||
|
duration = f"{int(duration)} seconds"
|
||||||
|
return (
|
||||||
|
f"- bpm: {bpm}\n"
|
||||||
|
f"- timesignature: {timesignature}\n"
|
||||||
|
f"- keyscale: {keyscale}\n"
|
||||||
|
f"- duration: {duration}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe, prompt, positive, lyrics, duration, bpm, keyscale, timesignature, vocal_language, task_type):
|
||||||
|
if not positive:
|
||||||
|
return {}
|
||||||
|
pipe.load_models_to_device(['text_encoder'])
|
||||||
|
meta_dict = {"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "duration": duration}
|
||||||
|
INSTRUCTION = self.INSTRUCTION_MAP.get(task_type, self.INSTRUCTION_MAP["text2music"])
|
||||||
|
prompt = self.SFT_GEN_PROMPT.format(INSTRUCTION, prompt, self._dict_to_meta_string(meta_dict))
|
||||||
|
text_hidden_states, text_attention_mask = self._encode_text(pipe, prompt, max_length=256)
|
||||||
|
|
||||||
|
lyric_text = self.LYRIC_PROMPT.format(vocal_language, lyrics)
|
||||||
|
lyric_hidden_states, lyric_attention_mask = self._encode_lyrics(pipe, lyric_text, max_length=2048)
|
||||||
|
|
||||||
|
# TODO: remove this
|
||||||
|
newtext = prompt + "\n\n" + lyric_text
|
||||||
|
return {
|
||||||
|
"text_hidden_states": text_hidden_states,
|
||||||
|
"text_attention_mask": text_attention_mask,
|
||||||
|
"lyric_hidden_states": lyric_hidden_states,
|
||||||
|
"lyric_attention_mask": lyric_attention_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("reference_audios",),
|
||||||
|
output_params=("reference_latents", "refer_audio_order_mask"),
|
||||||
|
onload_model_names=("vae",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe, reference_audios):
|
||||||
|
if reference_audios is not None:
|
||||||
|
pipe.load_models_to_device(['vae'])
|
||||||
|
reference_audios = [
|
||||||
|
self.process_reference_audio(reference_audio).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
for reference_audio in reference_audios
|
||||||
|
]
|
||||||
|
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, [reference_audios])
|
||||||
|
else:
|
||||||
|
reference_audios = [[torch.zeros(2, 30 * pipe.vae.sampling_rate).to(dtype=pipe.torch_dtype, device=pipe.device)]]
|
||||||
|
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, reference_audios)
|
||||||
|
return {"reference_latents": reference_latents, "refer_audio_order_mask": refer_audio_order_mask}
|
||||||
|
|
||||||
|
def process_reference_audio(self, audio) -> Optional[torch.Tensor]:
|
||||||
|
if audio.ndim == 3 and audio.shape[0] == 1:
|
||||||
|
audio = audio.squeeze(0)
|
||||||
|
target_frames = 30 * 48000
|
||||||
|
segment_frames = 10 * 48000
|
||||||
|
if audio.shape[-1] < target_frames:
|
||||||
|
repeat_times = math.ceil(target_frames / audio.shape[-1])
|
||||||
|
audio = audio.repeat(1, repeat_times)
|
||||||
|
total_frames = audio.shape[-1]
|
||||||
|
segment_size = total_frames // 3
|
||||||
|
front_start = random.randint(0, max(0, segment_size - segment_frames))
|
||||||
|
front_audio = audio[:, front_start:front_start + segment_frames]
|
||||||
|
middle_start = segment_size + random.randint(0, max(0, segment_size - segment_frames))
|
||||||
|
middle_audio = audio[:, middle_start:middle_start + segment_frames]
|
||||||
|
back_start = 2 * segment_size + random.randint(0, max(0, (total_frames - 2 * segment_size) - segment_frames))
|
||||||
|
back_audio = audio[:, back_start:back_start + segment_frames]
|
||||||
|
return torch.cat([front_audio, middle_audio, back_audio], dim=-1).unsqueeze(0)
|
||||||
|
|
||||||
|
def infer_refer_latent(self, pipe, refer_audioss: List[List[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Infer packed reference-audio latents and order mask."""
|
||||||
|
refer_audio_order_mask = []
|
||||||
|
refer_audio_latents = []
|
||||||
|
for batch_idx, refer_audios in enumerate(refer_audioss):
|
||||||
|
if len(refer_audios) == 1 and torch.all(refer_audios[0] == 0.0):
|
||||||
|
refer_audio_latent = pipe.silence_latent[:, :750, :]
|
||||||
|
refer_audio_latents.append(refer_audio_latent)
|
||||||
|
refer_audio_order_mask.append(batch_idx)
|
||||||
|
else:
|
||||||
|
for refer_audio in refer_audios:
|
||||||
|
refer_audio_latent = pipe.vae.encode(refer_audio).transpose(1, 2).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
refer_audio_latents.append(refer_audio_latent)
|
||||||
|
refer_audio_order_mask.append(batch_idx)
|
||||||
|
refer_audio_latents = torch.cat(refer_audio_latents, dim=0)
|
||||||
|
refer_audio_order_mask = torch.tensor(refer_audio_order_mask, device=pipe.device, dtype=torch.long)
|
||||||
|
return refer_audio_latents, refer_audio_order_mask
|
||||||
|
|
||||||
|
|
||||||
|
class AceStepUnit_ConditionEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
take_over=True,
|
||||||
|
output_params=("encoder_hidden_states", "encoder_attention_mask"),
|
||||||
|
onload_model_names=("conditioner",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe, inputs_shared, inputs_posi, inputs_nega):
|
||||||
|
pipe.load_models_to_device(['conditioner'])
|
||||||
|
encoder_hidden_states, encoder_attention_mask = pipe.conditioner(
|
||||||
|
text_hidden_states=inputs_posi.get("text_hidden_states", None),
|
||||||
|
text_attention_mask=inputs_posi.get("text_attention_mask", None),
|
||||||
|
lyric_hidden_states=inputs_posi.get("lyric_hidden_states", None),
|
||||||
|
lyric_attention_mask=inputs_posi.get("lyric_attention_mask", None),
|
||||||
|
reference_latents=inputs_shared.get("reference_latents", None),
|
||||||
|
refer_audio_order_mask=inputs_shared.get("refer_audio_order_mask", None),
|
||||||
|
)
|
||||||
|
inputs_posi["encoder_hidden_states"] = encoder_hidden_states
|
||||||
|
inputs_posi["encoder_attention_mask"] = encoder_attention_mask
|
||||||
|
if inputs_shared["cfg_scale"] != 1.0:
|
||||||
|
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(
|
||||||
|
dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device,
|
||||||
|
)
|
||||||
|
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
|
||||||
|
if inputs_shared["task_type"] == "cover" and inputs_shared["audio_cover_strength"] < 1.0:
|
||||||
|
hidden_states_noncover = AceStepUnit_PromptEmbedder().process(
|
||||||
|
pipe, inputs_posi["prompt"], True, inputs_shared["lyrics"], inputs_shared["duration"],
|
||||||
|
inputs_shared["bpm"], inputs_shared["keyscale"], inputs_shared["timesignature"],
|
||||||
|
inputs_shared["vocal_language"], "text2music")
|
||||||
|
encoder_hidden_states_noncover, encoder_attention_mask_noncover = pipe.conditioner(
|
||||||
|
**hidden_states_noncover,
|
||||||
|
reference_latents=inputs_shared.get("reference_latents", None),
|
||||||
|
refer_audio_order_mask=inputs_shared.get("refer_audio_order_mask", None),
|
||||||
|
)
|
||||||
|
duration = inputs_shared["context_latents"].shape[1] * 1920 / pipe.vae.sampling_rate
|
||||||
|
context_latents_noncover = AceStepUnit_ContextLatentBuilder().process(pipe, duration, None, None)["context_latents"]
|
||||||
|
inputs_shared["shared_noncover"] = {"context_latents": context_latents_noncover}
|
||||||
|
inputs_shared["posi_noncover"] = {"encoder_hidden_states": encoder_hidden_states_noncover, "encoder_attention_mask": encoder_attention_mask_noncover}
|
||||||
|
if inputs_shared["cfg_scale"] != 1.0:
|
||||||
|
inputs_shared["nega_noncover"] = {
|
||||||
|
"encoder_hidden_states": pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states_noncover).to(
|
||||||
|
dtype=encoder_hidden_states_noncover.dtype, device=encoder_hidden_states_noncover.device,
|
||||||
|
),
|
||||||
|
"encoder_attention_mask": encoder_attention_mask_noncover,
|
||||||
|
}
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
|
|
||||||
|
class AceStepUnit_ContextLatentBuilder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("duration", "src_audio", "audio_code_string", "task_type", "repainting_ranges", "repainting_strength"),
|
||||||
|
output_params=("context_latents", "src_latents", "chunk_masks", "attention_mask"),
|
||||||
|
onload_model_names=("vae", "tokenizer_model",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_silence_latent_slice(self, pipe, length: int) -> torch.Tensor:
|
||||||
|
available = pipe.silence_latent.shape[1]
|
||||||
|
if length <= available:
|
||||||
|
return pipe.silence_latent[0, :length, :]
|
||||||
|
repeats = (length + available - 1) // available
|
||||||
|
tiled = pipe.silence_latent[0].repeat(repeats, 1)
|
||||||
|
return tiled[:length, :]
|
||||||
|
|
||||||
|
def tokenize(self, tokenizer, x, silence_latent, pool_window_size):
|
||||||
|
if x.shape[1] % pool_window_size != 0:
|
||||||
|
pad_len = pool_window_size - (x.shape[1] % pool_window_size)
|
||||||
|
x = torch.cat([x, silence_latent[:1,:pad_len].repeat(x.shape[0],1,1)], dim=1)
|
||||||
|
x = rearrange(x, 'n (t_patch p) d -> n t_patch p d', p=pool_window_size)
|
||||||
|
quantized, indices = tokenizer(x)
|
||||||
|
return quantized
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_audio_code_string(code_str: str) -> list:
|
||||||
|
"""Extract integer audio codes from tokens like <|audio_code_123|>."""
|
||||||
|
if not code_str:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
codes = []
|
||||||
|
max_audio_code = 63999
|
||||||
|
for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str):
|
||||||
|
code_value = int(x)
|
||||||
|
codes.append(max(0, min(code_value, max_audio_code)))
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Invalid audio_code_string format: {e}")
|
||||||
|
return codes
|
||||||
|
|
||||||
|
def pad_src_audio(self, pipe, src_audio, task_type, repainting_ranges):
|
||||||
|
if task_type != "repaint" or repainting_ranges is None:
|
||||||
|
return src_audio, repainting_ranges, None, None
|
||||||
|
min_left = min([start for start, end in repainting_ranges])
|
||||||
|
max_right = max([end for start, end in repainting_ranges])
|
||||||
|
total_length = src_audio.shape[-1] // pipe.vae.sampling_rate
|
||||||
|
pad_left = max(0, -min_left)
|
||||||
|
pad_right = max(0, max_right - total_length)
|
||||||
|
if pad_left > 0 or pad_right > 0:
|
||||||
|
padding_frames_left, padding_frames_right = pad_left * pipe.vae.sampling_rate, pad_right * pipe.vae.sampling_rate
|
||||||
|
src_audio = F.pad(src_audio, (padding_frames_left, padding_frames_right), value=0.0)
|
||||||
|
repainting_ranges = [(start + pad_left, end + pad_left) for start, end in repainting_ranges]
|
||||||
|
return src_audio, repainting_ranges, pad_left, pad_right
|
||||||
|
|
||||||
|
def parse_repaint_masks(self, pipe, src_latents, task_type, repainting_ranges, repainting_strength, pad_left, pad_right):
|
||||||
|
if task_type != "repaint" or repainting_ranges is None:
|
||||||
|
return None, src_latents
|
||||||
|
# let repainting area be repainting_strength, non-repainting area be 0.0, and blend at the boundary with cf_frames.
|
||||||
|
max_latent_length = src_latents.shape[1]
|
||||||
|
denoise_mask = torch.zeros((1, max_latent_length, 1), dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
for start, end in repainting_ranges:
|
||||||
|
start_frame = start * pipe.vae.sampling_rate // 1920
|
||||||
|
end_frame = end * pipe.vae.sampling_rate // 1920
|
||||||
|
denoise_mask[:, start_frame:end_frame, :] = repainting_strength
|
||||||
|
# set padding areas to 1.0 (full repaint) to avoid artifacts at the boundaries caused by padding
|
||||||
|
pad_left_frames = pad_left * pipe.vae.sampling_rate // 1920
|
||||||
|
pad_right_frames = pad_right * pipe.vae.sampling_rate // 1920
|
||||||
|
denoise_mask[:, :pad_left_frames, :] = 1
|
||||||
|
denoise_mask[:, max_latent_length - pad_right_frames:, :] = 1
|
||||||
|
|
||||||
|
silent_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
|
||||||
|
src_latents = src_latents * (1 - denoise_mask) + silent_latents * denoise_mask
|
||||||
|
return denoise_mask, src_latents
|
||||||
|
|
||||||
|
def process(self, pipe, duration, src_audio, audio_code_string, task_type=None, repainting_ranges=None, repainting_strength=None):
|
||||||
|
# get src_latents from audio_code_string > src_audio > silence
|
||||||
|
source_latents = None
|
||||||
|
denoise_mask = None
|
||||||
|
if audio_code_string is not None:
|
||||||
|
# use audio_cede_string to get src_latents.
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
code_ids = self._parse_audio_code_string(audio_code_string)
|
||||||
|
quantizer = pipe.tokenizer_model.tokenizer.quantizer.to(device=pipe.device)
|
||||||
|
indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
|
||||||
|
codes = quantizer.get_codes_from_indices(indices)
|
||||||
|
quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)
|
||||||
|
quantized = quantizer.project_out(quantized)
|
||||||
|
src_latents = pipe.tokenizer_model.detokenizer(quantized).to(pipe.device)
|
||||||
|
max_latent_length = src_latents.shape[1]
|
||||||
|
elif src_audio is not None:
|
||||||
|
# use src_audio to get src_latents.
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
src_audio = src_audio.unsqueeze(0) if src_audio.dim() == 2 else src_audio
|
||||||
|
src_audio = torch.clamp(src_audio, -1.0, 1.0)
|
||||||
|
|
||||||
|
src_audio, repainting_ranges, pad_left, pad_right = self.pad_src_audio(pipe, src_audio, task_type, repainting_ranges)
|
||||||
|
|
||||||
|
src_latents = pipe.vae.encode(src_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
|
||||||
|
source_latents = src_latents # cache for potential use in audio inpainting tasks
|
||||||
|
denoise_mask, src_latents = self.parse_repaint_masks(pipe, src_latents, task_type, repainting_ranges, repainting_strength, pad_left, pad_right)
|
||||||
|
if task_type == "cover":
|
||||||
|
lm_hints_5Hz = self.tokenize(pipe.tokenizer_model.tokenizer, src_latents, pipe.silence_latent, pipe.tokenizer_model.tokenizer.pool_window_size)
|
||||||
|
src_latents = pipe.tokenizer_model.detokenizer(lm_hints_5Hz)
|
||||||
|
max_latent_length = src_latents.shape[1]
|
||||||
|
else:
|
||||||
|
# use silence latents.
|
||||||
|
max_latent_length = int(duration * pipe.sample_rate // 1920)
|
||||||
|
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
|
||||||
|
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
|
||||||
|
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
|
||||||
|
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
|
||||||
|
return {"context_latents": context_latents, "attention_mask": attention_mask, "src_latents": source_latents, "denoise_mask": denoise_mask}
|
||||||
|
|
||||||
|
|
||||||
|
class AceStepUnit_NoiseInitializer(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("context_latents", "seed", "rand_device", "src_latents"),
|
||||||
|
output_params=("noise",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe, context_latents, seed, rand_device, src_latents):
|
||||||
|
src_latents_shape = (context_latents.shape[0], context_latents.shape[1], context_latents.shape[-1] // 2)
|
||||||
|
noise = pipe.generate_noise(src_latents_shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||||
|
if src_latents is not None:
|
||||||
|
noise = pipe.scheduler.add_noise(src_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||||
|
return {"noise": noise}
|
||||||
|
|
||||||
|
|
||||||
|
class AceStepUnit_InputAudioEmbedder(PipelineUnit):
|
||||||
|
"""Only for training."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("noise", "input_audio"),
|
||||||
|
output_params=("latents", "input_latents"),
|
||||||
|
onload_model_names=("vae",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe, noise, input_audio):
|
||||||
|
if input_audio is None:
|
||||||
|
return {"latents": noise}
|
||||||
|
if pipe.scheduler.training:
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
input_audio, sample_rate = input_audio
|
||||||
|
input_audio = torch.clamp(input_audio, -1.0, 1.0)
|
||||||
|
if input_audio.dim() == 2:
|
||||||
|
input_audio = input_audio.unsqueeze(0)
|
||||||
|
input_latents = pipe.vae.encode(input_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
|
||||||
|
# prevent potential size mismatch between context_latents and input_latents by cropping input_latents to the same temporal length as noise
|
||||||
|
input_latents = input_latents[:, :noise.shape[1]]
|
||||||
|
return {"input_latents": input_latents}
|
||||||
|
|
||||||
|
|
||||||
|
def model_fn_ace_step(
|
||||||
|
dit: AceStepDiTModel,
|
||||||
|
latents=None,
|
||||||
|
timestep=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
context_latents=None,
|
||||||
|
attention_mask=None,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
decoder_outputs = dit(
|
||||||
|
hidden_states=latents,
|
||||||
|
timestep=timestep,
|
||||||
|
timestep_r=timestep,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
context_latents=context_latents,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
)[0]
|
||||||
|
return decoder_outputs
|
||||||
264
diffsynth/pipelines/anima_image.py
Normal file
264
diffsynth/pipelines/anima_image.py
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
import torch, math
|
||||||
|
from PIL import Image
|
||||||
|
from typing import Union
|
||||||
|
from tqdm import tqdm
|
||||||
|
from einops import rearrange
|
||||||
|
import numpy as np
|
||||||
|
from math import prod
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
|
from ..diffusion import FlowMatchScheduler
|
||||||
|
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||||
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
||||||
|
from ..utils.lora.merge import merge_lora
|
||||||
|
|
||||||
|
from ..models.anima_dit import AnimaDiT
|
||||||
|
from ..models.z_image_text_encoder import ZImageTextEncoder
|
||||||
|
from ..models.wan_video_vae import WanVideoVAE
|
||||||
|
|
||||||
|
|
||||||
|
class AnimaImagePipeline(BasePipeline):
|
||||||
|
|
||||||
|
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||||
|
super().__init__(
|
||||||
|
device=device, torch_dtype=torch_dtype,
|
||||||
|
height_division_factor=16, width_division_factor=16,
|
||||||
|
)
|
||||||
|
self.scheduler = FlowMatchScheduler("Z-Image")
|
||||||
|
self.text_encoder: ZImageTextEncoder = None
|
||||||
|
self.dit: AnimaDiT = None
|
||||||
|
self.vae: WanVideoVAE = None
|
||||||
|
self.tokenizer: AutoTokenizer = None
|
||||||
|
self.tokenizer_t5xxl: AutoTokenizer = None
|
||||||
|
self.in_iteration_models = ("dit",)
|
||||||
|
self.units = [
|
||||||
|
AnimaUnit_ShapeChecker(),
|
||||||
|
AnimaUnit_NoiseInitializer(),
|
||||||
|
AnimaUnit_InputImageEmbedder(),
|
||||||
|
AnimaUnit_PromptEmbedder(),
|
||||||
|
]
|
||||||
|
self.model_fn = model_fn_anima
|
||||||
|
self.compilable_models = ["dit"]
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(
|
||||||
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
device: Union[str, torch.device] = get_device_type(),
|
||||||
|
model_configs: list[ModelConfig] = [],
|
||||||
|
tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
|
||||||
|
tokenizer_t5xxl_config: ModelConfig = ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"),
|
||||||
|
vram_limit: float = None,
|
||||||
|
):
|
||||||
|
# Initialize pipeline
|
||||||
|
pipe = AnimaImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||||
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||||
|
|
||||||
|
# Fetch models
|
||||||
|
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
|
||||||
|
pipe.dit = model_pool.fetch_model("anima_dit")
|
||||||
|
pipe.vae = model_pool.fetch_model("wan_video_vae")
|
||||||
|
if tokenizer_config is not None:
|
||||||
|
tokenizer_config.download_if_necessary()
|
||||||
|
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||||
|
if tokenizer_t5xxl_config is not None:
|
||||||
|
tokenizer_t5xxl_config.download_if_necessary()
|
||||||
|
pipe.tokenizer_t5xxl = AutoTokenizer.from_pretrained(tokenizer_t5xxl_config.path)
|
||||||
|
# VRAM Management
|
||||||
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
# Prompt
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
cfg_scale: float = 4.0,
|
||||||
|
# Image
|
||||||
|
input_image: Image.Image = None,
|
||||||
|
denoising_strength: float = 1.0,
|
||||||
|
# Shape
|
||||||
|
height: int = 1024,
|
||||||
|
width: int = 1024,
|
||||||
|
# Randomness
|
||||||
|
seed: int = None,
|
||||||
|
rand_device: str = "cpu",
|
||||||
|
# Steps
|
||||||
|
num_inference_steps: int = 30,
|
||||||
|
sigma_shift: float = None,
|
||||||
|
# Progress bar
|
||||||
|
progress_bar_cmd = tqdm,
|
||||||
|
):
|
||||||
|
# Scheduler
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
inputs_posi = {
|
||||||
|
"prompt": prompt,
|
||||||
|
}
|
||||||
|
inputs_nega = {
|
||||||
|
"negative_prompt": negative_prompt,
|
||||||
|
}
|
||||||
|
inputs_shared = {
|
||||||
|
"cfg_scale": cfg_scale,
|
||||||
|
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||||
|
"height": height, "width": width,
|
||||||
|
"seed": seed, "rand_device": rand_device,
|
||||||
|
"num_inference_steps": num_inference_steps,
|
||||||
|
}
|
||||||
|
for unit in self.units:
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||||
|
|
||||||
|
# Denoise
|
||||||
|
self.load_models_to_device(self.in_iteration_models)
|
||||||
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
noise_pred = self.cfg_guided_model_fn(
|
||||||
|
self.model_fn, cfg_scale,
|
||||||
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
self.load_models_to_device(['vae'])
|
||||||
|
image = self.vae.decode(inputs_shared["latents"].unsqueeze(2), device=self.device).squeeze(2)
|
||||||
|
image = self.vae_output_to_image(image)
|
||||||
|
self.load_models_to_device([])
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class AnimaUnit_ShapeChecker(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width"),
|
||||||
|
output_params=("height", "width"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: AnimaImagePipeline, height, width):
|
||||||
|
height, width = pipe.check_resize_height_width(height, width)
|
||||||
|
return {"height": height, "width": width}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class AnimaUnit_NoiseInitializer(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width", "seed", "rand_device"),
|
||||||
|
output_params=("noise",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: AnimaImagePipeline, height, width, seed, rand_device):
|
||||||
|
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||||
|
return {"noise": noise}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class AnimaUnit_InputImageEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("input_image", "noise"),
|
||||||
|
output_params=("latents", "input_latents"),
|
||||||
|
onload_model_names=("vae",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: AnimaImagePipeline, input_image, noise):
|
||||||
|
if input_image is None:
|
||||||
|
return {"latents": noise, "input_latents": None}
|
||||||
|
pipe.load_models_to_device(['vae'])
|
||||||
|
if isinstance(input_image, list):
|
||||||
|
input_latents = []
|
||||||
|
for image in input_image:
|
||||||
|
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||||
|
input_latents.append(pipe.vae.encode(image))
|
||||||
|
input_latents = torch.concat(input_latents, dim=0)
|
||||||
|
else:
|
||||||
|
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||||
|
input_latents = pipe.vae.encode(image.unsqueeze(2), device=pipe.device).squeeze(2)
|
||||||
|
if pipe.scheduler.training:
|
||||||
|
return {"latents": noise, "input_latents": input_latents}
|
||||||
|
else:
|
||||||
|
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||||
|
return {"latents": latents, "input_latents": input_latents}
|
||||||
|
|
||||||
|
|
||||||
|
class AnimaUnit_PromptEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
seperate_cfg=True,
|
||||||
|
input_params_posi={"prompt": "prompt"},
|
||||||
|
input_params_nega={"prompt": "negative_prompt"},
|
||||||
|
output_params=("prompt_emb",),
|
||||||
|
onload_model_names=("text_encoder",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode_prompt(
|
||||||
|
self,
|
||||||
|
pipe: AnimaImagePipeline,
|
||||||
|
prompt,
|
||||||
|
device = None,
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
):
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
prompt = [prompt]
|
||||||
|
|
||||||
|
text_inputs = pipe.tokenizer(
|
||||||
|
prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=max_sequence_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
text_input_ids = text_inputs.input_ids.to(device)
|
||||||
|
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
||||||
|
|
||||||
|
prompt_embeds = pipe.text_encoder(
|
||||||
|
input_ids=text_input_ids,
|
||||||
|
attention_mask=prompt_masks,
|
||||||
|
output_hidden_states=True,
|
||||||
|
).hidden_states[-1]
|
||||||
|
|
||||||
|
t5xxl_text_inputs = pipe.tokenizer_t5xxl(
|
||||||
|
prompt,
|
||||||
|
max_length=max_sequence_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
t5xxl_ids = t5xxl_text_inputs.input_ids.to(device)
|
||||||
|
|
||||||
|
return prompt_embeds.to(pipe.torch_dtype), t5xxl_ids
|
||||||
|
|
||||||
|
def process(self, pipe: AnimaImagePipeline, prompt):
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
prompt_embeds, t5xxl_ids = self.encode_prompt(pipe, prompt, pipe.device)
|
||||||
|
return {"prompt_emb": prompt_embeds, "t5xxl_ids": t5xxl_ids}
|
||||||
|
|
||||||
|
|
||||||
|
def model_fn_anima(
|
||||||
|
dit: AnimaDiT = None,
|
||||||
|
latents=None,
|
||||||
|
timestep=None,
|
||||||
|
prompt_emb=None,
|
||||||
|
t5xxl_ids=None,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
latents = latents.unsqueeze(2)
|
||||||
|
timestep = timestep / 1000
|
||||||
|
model_output = dit(
|
||||||
|
x=latents,
|
||||||
|
timesteps=timestep,
|
||||||
|
context=prompt_emb,
|
||||||
|
t5xxl_ids=t5xxl_ids,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
)
|
||||||
|
model_output = model_output.squeeze(2)
|
||||||
|
return model_output
|
||||||
266
diffsynth/pipelines/ernie_image.py
Normal file
266
diffsynth/pipelines/ernie_image.py
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
"""
|
||||||
|
ERNIE-Image Text-to-Image Pipeline for DiffSynth-Studio.
|
||||||
|
|
||||||
|
Architecture: SharedAdaLN DiT + RoPE 3D + Joint Image-Text Attention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from typing import Union, Optional
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
|
from ..diffusion import FlowMatchScheduler
|
||||||
|
from ..core import ModelConfig
|
||||||
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||||
|
from ..models.ernie_image_text_encoder import ErnieImageTextEncoder
|
||||||
|
from ..models.ernie_image_dit import ErnieImageDiT
|
||||||
|
from ..models.flux2_vae import Flux2VAE
|
||||||
|
|
||||||
|
|
||||||
|
class ErnieImagePipeline(BasePipeline):
|
||||||
|
|
||||||
|
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||||
|
super().__init__(
|
||||||
|
device=device, torch_dtype=torch_dtype,
|
||||||
|
height_division_factor=16, width_division_factor=16,
|
||||||
|
)
|
||||||
|
self.scheduler = FlowMatchScheduler("ERNIE-Image")
|
||||||
|
self.text_encoder: ErnieImageTextEncoder = None
|
||||||
|
self.dit: ErnieImageDiT = None
|
||||||
|
self.vae: Flux2VAE = None
|
||||||
|
self.tokenizer: AutoTokenizer = None
|
||||||
|
|
||||||
|
self.in_iteration_models = ("dit",)
|
||||||
|
self.units = [
|
||||||
|
ErnieImageUnit_ShapeChecker(),
|
||||||
|
ErnieImageUnit_PromptEmbedder(),
|
||||||
|
ErnieImageUnit_NoiseInitializer(),
|
||||||
|
ErnieImageUnit_InputImageEmbedder(),
|
||||||
|
]
|
||||||
|
self.model_fn = model_fn_ernie_image
|
||||||
|
self.compilable_models = ["dit"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(
|
||||||
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
device: Union[str, torch.device] = get_device_type(),
|
||||||
|
model_configs: list[ModelConfig] = [],
|
||||||
|
tokenizer_config: ModelConfig = ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit: float = None,
|
||||||
|
):
|
||||||
|
pipe = ErnieImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||||
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||||
|
|
||||||
|
pipe.text_encoder = model_pool.fetch_model("ernie_image_text_encoder")
|
||||||
|
pipe.dit = model_pool.fetch_model("ernie_image_dit")
|
||||||
|
pipe.vae = model_pool.fetch_model("flux2_vae")
|
||||||
|
|
||||||
|
if tokenizer_config is not None:
|
||||||
|
tokenizer_config.download_if_necessary()
|
||||||
|
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||||
|
|
||||||
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
# Prompt
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
cfg_scale: float = 4.0,
|
||||||
|
# Shape
|
||||||
|
height: int = 1024,
|
||||||
|
width: int = 1024,
|
||||||
|
# Randomness
|
||||||
|
seed: int = None,
|
||||||
|
rand_device: str = "cuda",
|
||||||
|
# Steps
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
sigma_shift: float = 3.0,
|
||||||
|
# Progress bar
|
||||||
|
progress_bar_cmd=tqdm,
|
||||||
|
):
|
||||||
|
# Scheduler
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, shift=sigma_shift)
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
inputs_posi = {"prompt": prompt}
|
||||||
|
inputs_nega = {"negative_prompt": negative_prompt}
|
||||||
|
inputs_shared = {
|
||||||
|
"height": height, "width": width, "seed": seed,
|
||||||
|
"cfg_scale": cfg_scale, "num_inference_steps": num_inference_steps,
|
||||||
|
"rand_device": rand_device,
|
||||||
|
}
|
||||||
|
for unit in self.units:
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||||
|
|
||||||
|
# Denoise
|
||||||
|
self.load_models_to_device(self.in_iteration_models)
|
||||||
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
noise_pred = self.cfg_guided_model_fn(
|
||||||
|
self.model_fn, cfg_scale,
|
||||||
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
self.load_models_to_device(['vae'])
|
||||||
|
latents = inputs_shared["latents"]
|
||||||
|
image = self.vae.decode(latents)
|
||||||
|
image = self.vae_output_to_image(image)
|
||||||
|
self.load_models_to_device([])
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class ErnieImageUnit_ShapeChecker(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width"),
|
||||||
|
output_params=("height", "width"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: ErnieImagePipeline, height, width):
|
||||||
|
height, width = pipe.check_resize_height_width(height, width)
|
||||||
|
return {"height": height, "width": width}
|
||||||
|
|
||||||
|
|
||||||
|
class ErnieImageUnit_PromptEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
seperate_cfg=True,
|
||||||
|
input_params_posi={"prompt": "prompt"},
|
||||||
|
input_params_nega={"prompt": "negative_prompt"},
|
||||||
|
output_params=("prompt_embeds", "prompt_embeds_mask"),
|
||||||
|
onload_model_names=("text_encoder",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode_prompt(self, pipe: ErnieImagePipeline, prompt):
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
prompt = [prompt]
|
||||||
|
|
||||||
|
text_hiddens = []
|
||||||
|
text_lens_list = []
|
||||||
|
for p in prompt:
|
||||||
|
ids = pipe.tokenizer(
|
||||||
|
p,
|
||||||
|
add_special_tokens=True,
|
||||||
|
truncation=True,
|
||||||
|
padding=False,
|
||||||
|
)["input_ids"]
|
||||||
|
|
||||||
|
if len(ids) == 0:
|
||||||
|
if pipe.tokenizer.bos_token_id is not None:
|
||||||
|
ids = [pipe.tokenizer.bos_token_id]
|
||||||
|
else:
|
||||||
|
ids = [0]
|
||||||
|
|
||||||
|
input_ids = torch.tensor([ids], device=pipe.device)
|
||||||
|
outputs = pipe.text_encoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
)
|
||||||
|
# Text encoder returns tuple of (hidden_states_tuple,) where each layer's hidden state is included
|
||||||
|
all_hidden_states = outputs[0]
|
||||||
|
hidden = all_hidden_states[-2][0] # [T, H] - second to last layer
|
||||||
|
text_hiddens.append(hidden)
|
||||||
|
text_lens_list.append(hidden.shape[0])
|
||||||
|
|
||||||
|
# Pad to uniform length
|
||||||
|
if len(text_hiddens) == 0:
|
||||||
|
text_in_dim = pipe.text_encoder.config.hidden_size if hasattr(pipe.text_encoder, 'config') else 3072
|
||||||
|
return {
|
||||||
|
"prompt_embeds": torch.zeros((0, 0, text_in_dim), device=pipe.device, dtype=pipe.torch_dtype),
|
||||||
|
"prompt_embeds_mask": torch.zeros((0,), device=pipe.device, dtype=torch.long),
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized = [th.to(pipe.device).to(pipe.torch_dtype) for th in text_hiddens]
|
||||||
|
text_lens = torch.tensor([t.shape[0] for t in normalized], device=pipe.device, dtype=torch.long)
|
||||||
|
Tmax = int(text_lens.max().item())
|
||||||
|
text_in_dim = normalized[0].shape[1]
|
||||||
|
text_bth = torch.zeros((len(normalized), Tmax, text_in_dim), device=pipe.device, dtype=pipe.torch_dtype)
|
||||||
|
for i, t in enumerate(normalized):
|
||||||
|
text_bth[i, :t.shape[0], :] = t
|
||||||
|
|
||||||
|
return {"prompt_embeds": text_bth, "prompt_embeds_mask": text_lens}
|
||||||
|
|
||||||
|
def process(self, pipe: ErnieImagePipeline, prompt):
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
if pipe.text_encoder is not None:
|
||||||
|
return self.encode_prompt(pipe, prompt)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class ErnieImageUnit_NoiseInitializer(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width", "seed", "rand_device"),
|
||||||
|
output_params=("noise",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: ErnieImagePipeline, height, width, seed, rand_device):
|
||||||
|
latent_h = height // pipe.height_division_factor
|
||||||
|
latent_w = width // pipe.width_division_factor
|
||||||
|
latent_channels = pipe.dit.in_channels
|
||||||
|
|
||||||
|
# Use pipeline device if rand_device is not specified
|
||||||
|
if rand_device is None:
|
||||||
|
rand_device = str(pipe.device)
|
||||||
|
|
||||||
|
noise = pipe.generate_noise(
|
||||||
|
(1, latent_channels, latent_h, latent_w),
|
||||||
|
seed=seed,
|
||||||
|
rand_device=rand_device,
|
||||||
|
rand_torch_dtype=pipe.torch_dtype,
|
||||||
|
)
|
||||||
|
return {"noise": noise}
|
||||||
|
|
||||||
|
|
||||||
|
class ErnieImageUnit_InputImageEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("input_image", "noise"),
|
||||||
|
output_params=("latents", "input_latents"),
|
||||||
|
onload_model_names=("vae",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: ErnieImagePipeline, input_image, noise):
|
||||||
|
if input_image is None:
|
||||||
|
# T2I path: use noise directly as initial latents
|
||||||
|
return {"latents": noise, "input_latents": None}
|
||||||
|
|
||||||
|
# I2I path: VAE encode input image
|
||||||
|
pipe.load_models_to_device(['vae'])
|
||||||
|
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||||
|
input_latents = pipe.vae.encode(image)
|
||||||
|
|
||||||
|
if pipe.scheduler.training:
|
||||||
|
return {"latents": noise, "input_latents": input_latents}
|
||||||
|
else:
|
||||||
|
# In inference mode, add noise to encoded latents
|
||||||
|
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||||
|
return {"latents": latents}
|
||||||
|
|
||||||
|
|
||||||
|
def model_fn_ernie_image(
|
||||||
|
dit: ErnieImageDiT,
|
||||||
|
latents=None,
|
||||||
|
timestep=None,
|
||||||
|
prompt_embeds=None,
|
||||||
|
prompt_embeds_mask=None,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
output = dit(
|
||||||
|
hidden_states=latents,
|
||||||
|
timestep=timestep,
|
||||||
|
text_bth=prompt_embeds,
|
||||||
|
text_lens=prompt_embeds_mask,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
)
|
||||||
|
return output
|
||||||
@@ -42,6 +42,7 @@ class Flux2ImagePipeline(BasePipeline):
|
|||||||
Flux2Unit_ImageIDs(),
|
Flux2Unit_ImageIDs(),
|
||||||
]
|
]
|
||||||
self.model_fn = model_fn_flux2
|
self.model_fn = model_fn_flux2
|
||||||
|
self.compilable_models = ["dit"]
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -90,6 +91,7 @@ class Flux2ImagePipeline(BasePipeline):
|
|||||||
# Randomness
|
# Randomness
|
||||||
seed: int = None,
|
seed: int = None,
|
||||||
rand_device: str = "cpu",
|
rand_device: str = "cpu",
|
||||||
|
initial_noise: torch.Tensor = None,
|
||||||
# Steps
|
# Steps
|
||||||
num_inference_steps: int = 30,
|
num_inference_steps: int = 30,
|
||||||
# Progress bar
|
# Progress bar
|
||||||
@@ -109,7 +111,7 @@ class Flux2ImagePipeline(BasePipeline):
|
|||||||
"input_image": input_image, "denoising_strength": denoising_strength,
|
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||||
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
|
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
|
||||||
"height": height, "width": width,
|
"height": height, "width": width,
|
||||||
"seed": seed, "rand_device": rand_device,
|
"seed": seed, "rand_device": rand_device, "initial_noise": initial_noise,
|
||||||
"num_inference_steps": num_inference_steps,
|
"num_inference_steps": num_inference_steps,
|
||||||
}
|
}
|
||||||
for unit in self.units:
|
for unit in self.units:
|
||||||
@@ -429,11 +431,14 @@ class Flux2Unit_Qwen3PromptEmbedder(PipelineUnit):
|
|||||||
class Flux2Unit_NoiseInitializer(PipelineUnit):
|
class Flux2Unit_NoiseInitializer(PipelineUnit):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_params=("height", "width", "seed", "rand_device"),
|
input_params=("height", "width", "seed", "rand_device", "initial_noise"),
|
||||||
output_params=("noise",),
|
output_params=("noise",),
|
||||||
)
|
)
|
||||||
|
|
||||||
def process(self, pipe: Flux2ImagePipeline, height, width, seed, rand_device):
|
def process(self, pipe: Flux2ImagePipeline, height, width, seed, rand_device, initial_noise):
|
||||||
|
if initial_noise is not None:
|
||||||
|
noise = initial_noise.clone()
|
||||||
|
else:
|
||||||
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||||
noise = noise.reshape(1, 128, height//16 * width//16).permute(0, 2, 1)
|
noise = noise.reshape(1, 128, height//16 * width//16).permute(0, 2, 1)
|
||||||
return {"noise": noise}
|
return {"noise": noise}
|
||||||
|
|||||||
@@ -103,6 +103,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
FluxImageUnit_LoRAEncode(),
|
FluxImageUnit_LoRAEncode(),
|
||||||
]
|
]
|
||||||
self.model_fn = model_fn_flux_image
|
self.model_fn = model_fn_flux_image
|
||||||
|
self.compilable_models = ["dit"]
|
||||||
self.lora_loader = FluxLoRALoader
|
self.lora_loader = FluxLoRALoader
|
||||||
|
|
||||||
def enable_lora_merger(self):
|
def enable_lora_merger(self):
|
||||||
|
|||||||
282
diffsynth/pipelines/joyai_image.py
Normal file
282
diffsynth/pipelines/joyai_image.py
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from typing import Union, Optional
|
||||||
|
from tqdm import tqdm
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
|
from ..diffusion import FlowMatchScheduler
|
||||||
|
from ..core import ModelConfig
|
||||||
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||||
|
from ..models.joyai_image_dit import JoyAIImageDiT
|
||||||
|
from ..models.joyai_image_text_encoder import JoyAIImageTextEncoder
|
||||||
|
from ..models.wan_video_vae import WanVideoVAE
|
||||||
|
|
||||||
|
class JoyAIImagePipeline(BasePipeline):
|
||||||
|
|
||||||
|
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||||
|
super().__init__(
|
||||||
|
device=device, torch_dtype=torch_dtype,
|
||||||
|
height_division_factor=16, width_division_factor=16,
|
||||||
|
)
|
||||||
|
self.scheduler = FlowMatchScheduler("Wan")
|
||||||
|
self.text_encoder: JoyAIImageTextEncoder = None
|
||||||
|
self.dit: JoyAIImageDiT = None
|
||||||
|
self.vae: WanVideoVAE = None
|
||||||
|
self.processor = None
|
||||||
|
self.in_iteration_models = ("dit",)
|
||||||
|
|
||||||
|
self.units = [
|
||||||
|
JoyAIImageUnit_ShapeChecker(),
|
||||||
|
JoyAIImageUnit_EditImageEmbedder(),
|
||||||
|
JoyAIImageUnit_PromptEmbedder(),
|
||||||
|
JoyAIImageUnit_NoiseInitializer(),
|
||||||
|
JoyAIImageUnit_InputImageEmbedder(),
|
||||||
|
]
|
||||||
|
self.model_fn = model_fn_joyai_image
|
||||||
|
self.compilable_models = ["dit"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(
|
||||||
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
device: Union[str, torch.device] = get_device_type(),
|
||||||
|
model_configs: list[ModelConfig] = [],
|
||||||
|
# Processor
|
||||||
|
processor_config: ModelConfig = None,
|
||||||
|
# Optional
|
||||||
|
vram_limit: float = None,
|
||||||
|
):
|
||||||
|
pipe = JoyAIImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||||
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||||
|
|
||||||
|
pipe.text_encoder = model_pool.fetch_model("joyai_image_text_encoder")
|
||||||
|
pipe.dit = model_pool.fetch_model("joyai_image_dit")
|
||||||
|
pipe.vae = model_pool.fetch_model("wan_video_vae")
|
||||||
|
|
||||||
|
if processor_config is not None:
|
||||||
|
processor_config.download_if_necessary()
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
pipe.processor = AutoProcessor.from_pretrained(processor_config.path)
|
||||||
|
|
||||||
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
# Prompt
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
cfg_scale: float = 5.0,
|
||||||
|
# Image
|
||||||
|
edit_image: Image.Image = None,
|
||||||
|
denoising_strength: float = 1.0,
|
||||||
|
# Shape
|
||||||
|
height: int = 1024,
|
||||||
|
width: int = 1024,
|
||||||
|
# Randomness
|
||||||
|
seed: int = None,
|
||||||
|
# Steps
|
||||||
|
max_sequence_length: int = 4096,
|
||||||
|
num_inference_steps: int = 30,
|
||||||
|
# Tiling
|
||||||
|
tiled: Optional[bool] = False,
|
||||||
|
tile_size: Optional[tuple[int, int]] = (30, 52),
|
||||||
|
tile_stride: Optional[tuple[int, int]] = (15, 26),
|
||||||
|
# Scheduler
|
||||||
|
shift: Optional[float] = 4.0,
|
||||||
|
# Progress bar
|
||||||
|
progress_bar_cmd=tqdm,
|
||||||
|
):
|
||||||
|
# Scheduler
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=shift)
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
inputs_posi = {"prompt": prompt}
|
||||||
|
inputs_nega = {"negative_prompt": negative_prompt}
|
||||||
|
inputs_shared = {
|
||||||
|
"cfg_scale": cfg_scale,
|
||||||
|
"edit_image": edit_image,
|
||||||
|
"denoising_strength": denoising_strength,
|
||||||
|
"height": height, "width": width,
|
||||||
|
"seed": seed, "max_sequence_length": max_sequence_length,
|
||||||
|
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Unit chain
|
||||||
|
for unit in self.units:
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
|
||||||
|
unit, self, inputs_shared, inputs_posi, inputs_nega
|
||||||
|
)
|
||||||
|
|
||||||
|
# Denoise
|
||||||
|
self.load_models_to_device(self.in_iteration_models)
|
||||||
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
noise_pred = self.cfg_guided_model_fn(
|
||||||
|
self.model_fn, cfg_scale,
|
||||||
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
self.load_models_to_device(['vae'])
|
||||||
|
latents = rearrange(inputs_shared["latents"], "b n c f h w -> (b n) c f h w")
|
||||||
|
image = self.vae.decode(latents, device=self.device)[0]
|
||||||
|
image = self.vae_output_to_image(image, pattern="C 1 H W")
|
||||||
|
self.load_models_to_device([])
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class JoyAIImageUnit_ShapeChecker(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width"),
|
||||||
|
output_params=("height", "width"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: "JoyAIImagePipeline", height, width):
|
||||||
|
height, width = pipe.check_resize_height_width(height, width)
|
||||||
|
return {"height": height, "width": width}
|
||||||
|
|
||||||
|
|
||||||
|
class JoyAIImageUnit_PromptEmbedder(PipelineUnit):
|
||||||
|
prompt_template_encode = {
|
||||||
|
'image':
|
||||||
|
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
|
||||||
|
'multiple_images':
|
||||||
|
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n{}<|im_start|>assistant\n",
|
||||||
|
'video':
|
||||||
|
"<|im_start|>system\n \\nDescribe the video by detailing the following aspects:\n1. The main content and theme of the video.\n2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\n3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\n4. background environment, light, style and atmosphere.\n5. camera angles, movements, and transitions used in the video:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
}
|
||||||
|
prompt_template_encode_start_idx = {'image': 34, 'multiple_images': 34, 'video': 91}
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
seperate_cfg=True,
|
||||||
|
input_params_posi={"prompt": "prompt", "positive": "positive"},
|
||||||
|
input_params_nega={"prompt": "negative_prompt", "positive": "positive"},
|
||||||
|
input_params=("edit_image", "max_sequence_length"),
|
||||||
|
output_params=("prompt_embeds", "prompt_embeds_mask"),
|
||||||
|
onload_model_names=("joyai_image_text_encoder",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: "JoyAIImagePipeline", prompt, positive, edit_image, max_sequence_length):
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
|
||||||
|
has_image = edit_image is not None
|
||||||
|
|
||||||
|
if has_image:
|
||||||
|
prompt_embeds, prompt_embeds_mask = self._encode_with_image(pipe, prompt, edit_image, max_sequence_length)
|
||||||
|
else:
|
||||||
|
prompt_embeds, prompt_embeds_mask = self._encode_text_only(pipe, prompt, max_sequence_length)
|
||||||
|
|
||||||
|
return {"prompt_embeds": prompt_embeds, "prompt_embeds_mask": prompt_embeds_mask}
|
||||||
|
|
||||||
|
def _encode_with_image(self, pipe, prompt, edit_image, max_sequence_length):
|
||||||
|
template = self.prompt_template_encode['multiple_images']
|
||||||
|
drop_idx = self.prompt_template_encode_start_idx['multiple_images']
|
||||||
|
|
||||||
|
image_tokens = '<image>\n'
|
||||||
|
prompt = f"<|im_start|>user\n{image_tokens}{prompt}<|im_end|>\n"
|
||||||
|
prompt = prompt.replace('<image>\n', '<|vision_start|><|image_pad|><|vision_end|>')
|
||||||
|
prompt = template.format(prompt)
|
||||||
|
inputs = pipe.processor(text=[prompt], images=[edit_image], padding=True, return_tensors="pt").to(pipe.device)
|
||||||
|
last_hidden_states = pipe.text_encoder(**inputs)
|
||||||
|
|
||||||
|
prompt_embeds = last_hidden_states[:, drop_idx:]
|
||||||
|
prompt_embeds_mask = inputs['attention_mask'][:, drop_idx:]
|
||||||
|
|
||||||
|
if max_sequence_length is not None and prompt_embeds.shape[1] > max_sequence_length:
|
||||||
|
prompt_embeds = prompt_embeds[:, -max_sequence_length:, :]
|
||||||
|
prompt_embeds_mask = prompt_embeds_mask[:, -max_sequence_length:]
|
||||||
|
|
||||||
|
return prompt_embeds, prompt_embeds_mask
|
||||||
|
|
||||||
|
def _encode_text_only(self, pipe, prompt, max_sequence_length):
|
||||||
|
# TODO: may support for text-only encoding in the future.
|
||||||
|
raise NotImplementedError("Text-only encoding is not implemented yet. Please provide edit_image for now.")
|
||||||
|
return prompt_embeds, encoder_attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
class JoyAIImageUnit_EditImageEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("edit_image", "tiled", "tile_size", "tile_stride", "height", "width"),
|
||||||
|
output_params=("ref_latents", "num_items", "is_multi_item"),
|
||||||
|
onload_model_names=("wan_video_vae",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: "JoyAIImagePipeline", edit_image, tiled, tile_size, tile_stride, height, width):
|
||||||
|
if edit_image is None:
|
||||||
|
return {}
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
# Resize edit image to match target dimensions (from ShapeChecker) to ensure ref_latents matches latents
|
||||||
|
edit_image = edit_image.resize((width, height), Image.LANCZOS)
|
||||||
|
images = [pipe.preprocess_image(edit_image).transpose(0, 1)]
|
||||||
|
latents = pipe.vae.encode(images, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
ref_vae = rearrange(latents, "(b n) c 1 h w -> b n c 1 h w", n=1).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||||
|
|
||||||
|
return {"ref_latents": ref_vae, "edit_image": edit_image}
|
||||||
|
|
||||||
|
|
||||||
|
class JoyAIImageUnit_NoiseInitializer(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("seed", "height", "width", "rand_device"),
|
||||||
|
output_params=("noise"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: "JoyAIImagePipeline", seed, height, width, rand_device):
|
||||||
|
latent_h = height // pipe.vae.upsampling_factor
|
||||||
|
latent_w = width // pipe.vae.upsampling_factor
|
||||||
|
shape = (1, 1, pipe.vae.z_dim, 1, latent_h, latent_w)
|
||||||
|
noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||||
|
return {"noise": noise}
|
||||||
|
|
||||||
|
|
||||||
|
class JoyAIImageUnit_InputImageEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"),
|
||||||
|
output_params=("latents", "input_latents"),
|
||||||
|
onload_model_names=("vae",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: JoyAIImagePipeline, input_image, noise, tiled, tile_size, tile_stride):
|
||||||
|
if input_image is None:
|
||||||
|
return {"latents": noise}
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
if isinstance(input_image, Image.Image):
|
||||||
|
input_image = [input_image]
|
||||||
|
input_image = [pipe.preprocess_image(img).transpose(0, 1) for img in input_image]
|
||||||
|
latents = pipe.vae.encode(input_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
input_latents = rearrange(latents, "(b n) c 1 h w -> b n c 1 h w", n=(len(input_image)))
|
||||||
|
return {"latents": noise, "input_latents": input_latents}
|
||||||
|
|
||||||
|
def model_fn_joyai_image(
|
||||||
|
dit,
|
||||||
|
latents,
|
||||||
|
timestep,
|
||||||
|
prompt_embeds,
|
||||||
|
prompt_embeds_mask,
|
||||||
|
ref_latents=None,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
|
||||||
|
img = torch.cat([ref_latents, latents], dim=1) if ref_latents is not None else latents
|
||||||
|
|
||||||
|
img = dit(
|
||||||
|
hidden_states=img,
|
||||||
|
timestep=timestep,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
)
|
||||||
|
|
||||||
|
img = img[:, -latents.size(1):]
|
||||||
|
return img
|
||||||
@@ -12,16 +12,17 @@ from transformers import AutoImageProcessor, Gemma3Processor
|
|||||||
|
|
||||||
from ..core.device.npu_compatible_device import get_device_type
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
from ..diffusion import FlowMatchScheduler
|
from ..diffusion import FlowMatchScheduler
|
||||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
from ..core import ModelConfig
|
||||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||||
|
|
||||||
from ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer
|
from ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer
|
||||||
from ..models.ltx2_dit import LTXModel
|
from ..models.ltx2_dit import LTXModel
|
||||||
from ..models.ltx2_video_vae import LTX2VideoEncoder, LTX2VideoDecoder, VideoLatentPatchifier
|
from ..models.ltx2_video_vae import LTX2VideoEncoder, LTX2VideoDecoder, VideoLatentPatchifier
|
||||||
from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier
|
from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier, AudioProcessor
|
||||||
from ..models.ltx2_upsampler import LTX2LatentUpsampler
|
from ..models.ltx2_upsampler import LTX2LatentUpsampler
|
||||||
from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS
|
from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS
|
||||||
from ..utils.data.media_io_ltx2 import ltx2_preprocess
|
from ..utils.data.media_io_ltx2 import ltx2_preprocess
|
||||||
|
from ..utils.data.audio import convert_to_stereo
|
||||||
|
|
||||||
|
|
||||||
class LTX2AudioVideoPipeline(BasePipeline):
|
class LTX2AudioVideoPipeline(BasePipeline):
|
||||||
@@ -50,6 +51,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|||||||
|
|
||||||
self.video_patchifier: VideoLatentPatchifier = VideoLatentPatchifier(patch_size=1)
|
self.video_patchifier: VideoLatentPatchifier = VideoLatentPatchifier(patch_size=1)
|
||||||
self.audio_patchifier: AudioPatchifier = AudioPatchifier(patch_size=1)
|
self.audio_patchifier: AudioPatchifier = AudioPatchifier(patch_size=1)
|
||||||
|
self.audio_processor: AudioProcessor = AudioProcessor()
|
||||||
|
|
||||||
self.in_iteration_models = ("dit",)
|
self.in_iteration_models = ("dit",)
|
||||||
self.units = [
|
self.units = [
|
||||||
@@ -57,10 +59,53 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|||||||
LTX2AudioVideoUnit_ShapeChecker(),
|
LTX2AudioVideoUnit_ShapeChecker(),
|
||||||
LTX2AudioVideoUnit_PromptEmbedder(),
|
LTX2AudioVideoUnit_PromptEmbedder(),
|
||||||
LTX2AudioVideoUnit_NoiseInitializer(),
|
LTX2AudioVideoUnit_NoiseInitializer(),
|
||||||
|
LTX2AudioVideoUnit_VideoRetakeEmbedder(),
|
||||||
|
LTX2AudioVideoUnit_AudioRetakeEmbedder(),
|
||||||
|
LTX2AudioVideoUnit_InputAudioEmbedder(),
|
||||||
LTX2AudioVideoUnit_InputVideoEmbedder(),
|
LTX2AudioVideoUnit_InputVideoEmbedder(),
|
||||||
LTX2AudioVideoUnit_InputImagesEmbedder(),
|
LTX2AudioVideoUnit_InputImagesEmbedder(),
|
||||||
|
LTX2AudioVideoUnit_InContextVideoEmbedder(),
|
||||||
|
]
|
||||||
|
self.stage2_units = [
|
||||||
|
LTX2AudioVideoUnit_SwitchStage2(),
|
||||||
|
LTX2AudioVideoUnit_NoiseInitializer(),
|
||||||
|
LTX2AudioVideoUnit_LatentsUpsampler(),
|
||||||
|
LTX2AudioVideoUnit_VideoRetakeEmbedder(),
|
||||||
|
LTX2AudioVideoUnit_AudioRetakeEmbedder(),
|
||||||
|
LTX2AudioVideoUnit_InputImagesEmbedder(),
|
||||||
|
LTX2AudioVideoUnit_SetScheduleStage2(),
|
||||||
]
|
]
|
||||||
self.model_fn = model_fn_ltx2
|
self.model_fn = model_fn_ltx2
|
||||||
|
self.compilable_models = ["dit"]
|
||||||
|
|
||||||
|
self.default_negative_prompt = {
|
||||||
|
"LTX-2": (
|
||||||
|
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||||
|
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||||
|
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||||
|
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||||
|
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||||
|
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||||
|
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||||
|
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||||
|
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||||
|
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||||
|
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
),
|
||||||
|
"LTX-2.3": (
|
||||||
|
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||||
|
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||||
|
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||||
|
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||||
|
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||||
|
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||||
|
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||||
|
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||||
|
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||||
|
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||||
|
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
@@ -69,6 +114,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|||||||
model_configs: list[ModelConfig] = [],
|
model_configs: list[ModelConfig] = [],
|
||||||
tokenizer_config: ModelConfig = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config: ModelConfig = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
stage2_lora_config: Optional[ModelConfig] = None,
|
stage2_lora_config: Optional[ModelConfig] = None,
|
||||||
|
stage2_lora_strength: float = 0.8,
|
||||||
vram_limit: float = None,
|
vram_limit: float = None,
|
||||||
):
|
):
|
||||||
# Initialize pipeline
|
# Initialize pipeline
|
||||||
@@ -89,113 +135,22 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|||||||
pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder")
|
pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder")
|
||||||
pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder")
|
pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder")
|
||||||
pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler")
|
pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler")
|
||||||
|
pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
|
||||||
|
|
||||||
# Stage 2
|
# Stage 2
|
||||||
if stage2_lora_config is not None:
|
if stage2_lora_config is not None:
|
||||||
stage2_lora_config.download_if_necessary()
|
pipe.stage2_lora_config = stage2_lora_config
|
||||||
pipe.stage2_lora_path = stage2_lora_config.path
|
pipe.stage2_lora_strength = stage2_lora_strength
|
||||||
# Optional, currently not used
|
|
||||||
# pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
|
|
||||||
|
|
||||||
# VRAM Management
|
# VRAM Management
|
||||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd=tqdm):
|
def denoise_stage(self, inputs_shared, inputs_posi, inputs_nega, units, cfg_scale=1.0, progress_bar_cmd=tqdm, skip_stage=False):
|
||||||
if inputs_shared["use_two_stage_pipeline"]:
|
if skip_stage:
|
||||||
latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
self.load_models_to_device('upsampler',)
|
for unit in units:
|
||||||
latent = self.upsampler(latent)
|
|
||||||
latent = self.video_vae_encoder.per_channel_statistics.normalize(latent)
|
|
||||||
self.scheduler.set_timesteps(special_case="stage2")
|
|
||||||
inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")})
|
|
||||||
denoise_mask_video = 1.0
|
|
||||||
if inputs_shared.get("input_images", None) is not None:
|
|
||||||
latent, denoise_mask_video, initial_latents = self.apply_input_images_to_latents(
|
|
||||||
latent, inputs_shared.pop("input_latents"), inputs_shared["input_images_indexes"],
|
|
||||||
inputs_shared["input_images_strength"], latent.clone())
|
|
||||||
inputs_shared.update({"input_latents_video": initial_latents, "denoise_mask_video": denoise_mask_video})
|
|
||||||
inputs_shared["video_latents"] = self.scheduler.sigmas[0] * denoise_mask_video * inputs_shared[
|
|
||||||
"video_noise"] + (1 - self.scheduler.sigmas[0] * denoise_mask_video) * latent
|
|
||||||
inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (
|
|
||||||
1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"]
|
|
||||||
|
|
||||||
self.load_models_to_device(self.in_iteration_models)
|
|
||||||
if not inputs_shared["use_distilled_pipeline"]:
|
|
||||||
self.load_lora(self.dit, self.stage2_lora_path, alpha=0.8)
|
|
||||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
||||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
|
||||||
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
|
|
||||||
self.model_fn, 1.0, inputs_shared, inputs_posi, inputs_nega,
|
|
||||||
**models, timestep=timestep, progress_id=progress_id
|
|
||||||
)
|
|
||||||
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id,
|
|
||||||
noise_pred=noise_pred_video, inpaint_mask=inputs_shared.get("denoise_mask_video", None),
|
|
||||||
input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
|
|
||||||
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
|
|
||||||
noise_pred=noise_pred_audio, **inputs_shared)
|
|
||||||
return inputs_shared
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
# Prompt
|
|
||||||
prompt: str,
|
|
||||||
negative_prompt: Optional[str] = "",
|
|
||||||
# Image-to-video
|
|
||||||
denoising_strength: float = 1.0,
|
|
||||||
input_images: Optional[list[Image.Image]] = None,
|
|
||||||
input_images_indexes: Optional[list[int]] = None,
|
|
||||||
input_images_strength: Optional[float] = 1.0,
|
|
||||||
# Randomness
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
rand_device: Optional[str] = "cpu",
|
|
||||||
# Shape
|
|
||||||
height: Optional[int] = 512,
|
|
||||||
width: Optional[int] = 768,
|
|
||||||
num_frames=121,
|
|
||||||
# Classifier-free guidance
|
|
||||||
cfg_scale: Optional[float] = 3.0,
|
|
||||||
cfg_merge: Optional[bool] = False,
|
|
||||||
# Scheduler
|
|
||||||
num_inference_steps: Optional[int] = 40,
|
|
||||||
# VAE tiling
|
|
||||||
tiled: Optional[bool] = True,
|
|
||||||
tile_size_in_pixels: Optional[int] = 512,
|
|
||||||
tile_overlap_in_pixels: Optional[int] = 128,
|
|
||||||
tile_size_in_frames: Optional[int] = 128,
|
|
||||||
tile_overlap_in_frames: Optional[int] = 24,
|
|
||||||
# Special Pipelines
|
|
||||||
use_two_stage_pipeline: Optional[bool] = False,
|
|
||||||
use_distilled_pipeline: Optional[bool] = False,
|
|
||||||
# progress_bar
|
|
||||||
progress_bar_cmd=tqdm,
|
|
||||||
):
|
|
||||||
# Scheduler
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength,
|
|
||||||
special_case="ditilled_stage1" if use_distilled_pipeline else None)
|
|
||||||
# Inputs
|
|
||||||
inputs_posi = {
|
|
||||||
"prompt": prompt,
|
|
||||||
}
|
|
||||||
inputs_nega = {
|
|
||||||
"negative_prompt": negative_prompt,
|
|
||||||
}
|
|
||||||
inputs_shared = {
|
|
||||||
"input_images": input_images, "input_images_indexes": input_images_indexes, "input_images_strength": input_images_strength,
|
|
||||||
"seed": seed, "rand_device": rand_device,
|
|
||||||
"height": height, "width": width, "num_frames": num_frames,
|
|
||||||
"cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
|
|
||||||
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
|
|
||||||
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
|
|
||||||
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline,
|
|
||||||
"video_patchifier": self.video_patchifier, "audio_patchifier": self.audio_patchifier,
|
|
||||||
}
|
|
||||||
for unit in self.units:
|
|
||||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||||
|
|
||||||
# Denoise Stage 1
|
|
||||||
self.load_models_to_device(self.in_iteration_models)
|
self.load_models_to_device(self.in_iteration_models)
|
||||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
@@ -206,34 +161,93 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|||||||
)
|
)
|
||||||
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video,
|
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video,
|
||||||
inpaint_mask=inputs_shared.get("denoise_mask_video", None), input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
|
inpaint_mask=inputs_shared.get("denoise_mask_video", None), input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
|
||||||
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
|
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio,
|
||||||
noise_pred=noise_pred_audio, **inputs_shared)
|
inpaint_mask=inputs_shared.get("denoise_mask_audio", None), input_latents=inputs_shared.get("input_latents_audio", None), **inputs_shared)
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
# Denoise Stage 2
|
|
||||||
inputs_shared = self.stage2_denoise(inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd)
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
# Prompt
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: Optional[str] = "",
|
||||||
|
denoising_strength: float = 1.0,
|
||||||
|
# Image-to-video
|
||||||
|
input_images: Optional[list[Image.Image]] = None,
|
||||||
|
input_images_indexes: Optional[list[int]] = [0],
|
||||||
|
input_images_strength: Optional[float] = 1.0,
|
||||||
|
# In-Context Video Control
|
||||||
|
in_context_videos: Optional[list[list[Image.Image]]] = None,
|
||||||
|
in_context_downsample_factor: Optional[int] = 2,
|
||||||
|
# Video-to-video
|
||||||
|
retake_video: Optional[list[Image.Image]] = None,
|
||||||
|
retake_video_regions: Optional[list[tuple[float, float]]] = None,
|
||||||
|
# Audio-to-video
|
||||||
|
retake_audio: Optional[torch.Tensor] = None,
|
||||||
|
audio_sample_rate: Optional[int] = 48000,
|
||||||
|
retake_audio_regions: Optional[list[tuple[float, float]]] = None,
|
||||||
|
# Randomness
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
rand_device: Optional[str] = "cpu",
|
||||||
|
# Shape
|
||||||
|
height: Optional[int] = 512,
|
||||||
|
width: Optional[int] = 768,
|
||||||
|
num_frames: Optional[int] = 121,
|
||||||
|
frame_rate: Optional[int] = 24,
|
||||||
|
# Classifier-free guidance
|
||||||
|
cfg_scale: Optional[float] = 3.0,
|
||||||
|
# Scheduler
|
||||||
|
num_inference_steps: Optional[int] = 30,
|
||||||
|
# VAE tiling
|
||||||
|
tiled: Optional[bool] = True,
|
||||||
|
tile_size_in_pixels: Optional[int] = 512,
|
||||||
|
tile_overlap_in_pixels: Optional[int] = 128,
|
||||||
|
tile_size_in_frames: Optional[int] = 128,
|
||||||
|
tile_overlap_in_frames: Optional[int] = 24,
|
||||||
|
# Special Pipelines
|
||||||
|
use_two_stage_pipeline: Optional[bool] = False,
|
||||||
|
stage2_spatial_upsample_factor: Optional[int] = 2,
|
||||||
|
clear_lora_before_state_two: Optional[bool] = False,
|
||||||
|
use_distilled_pipeline: Optional[bool] = False,
|
||||||
|
# progress_bar
|
||||||
|
progress_bar_cmd=tqdm,
|
||||||
|
):
|
||||||
|
# Scheduler
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, special_case="ditilled_stage1" if use_distilled_pipeline else None)
|
||||||
|
# Inputs
|
||||||
|
inputs_posi = {
|
||||||
|
"prompt": prompt,
|
||||||
|
}
|
||||||
|
inputs_nega = {
|
||||||
|
"negative_prompt": negative_prompt,
|
||||||
|
}
|
||||||
|
inputs_shared = {
|
||||||
|
"input_images": input_images, "input_images_indexes": input_images_indexes, "input_images_strength": input_images_strength,
|
||||||
|
"retake_video": retake_video, "retake_video_regions": retake_video_regions,
|
||||||
|
"retake_audio": (retake_audio, audio_sample_rate) if retake_audio is not None else None, "retake_audio_regions": retake_audio_regions,
|
||||||
|
"in_context_videos": in_context_videos, "in_context_downsample_factor": in_context_downsample_factor,
|
||||||
|
"seed": seed, "rand_device": rand_device,
|
||||||
|
"height": height, "width": width, "num_frames": num_frames, "frame_rate": frame_rate,
|
||||||
|
"cfg_scale": cfg_scale,
|
||||||
|
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
|
||||||
|
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
|
||||||
|
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline, "clear_lora_before_state_two": clear_lora_before_state_two, "stage2_spatial_upsample_factor": stage2_spatial_upsample_factor,
|
||||||
|
"video_patchifier": self.video_patchifier, "audio_patchifier": self.audio_patchifier,
|
||||||
|
}
|
||||||
|
# Stage 1
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.denoise_stage(inputs_shared, inputs_posi, inputs_nega, self.units, cfg_scale, progress_bar_cmd)
|
||||||
|
# Stage 2
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.denoise_stage(inputs_shared, inputs_posi, inputs_nega, self.stage2_units, 1.0, progress_bar_cmd, not inputs_shared["use_two_stage_pipeline"])
|
||||||
# Decode
|
# Decode
|
||||||
self.load_models_to_device(['video_vae_decoder'])
|
self.load_models_to_device(['video_vae_decoder'])
|
||||||
video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels,
|
video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels, tile_overlap_in_pixels, tile_size_in_frames, tile_overlap_in_frames)
|
||||||
tile_overlap_in_pixels, tile_size_in_frames, tile_overlap_in_frames)
|
|
||||||
video = self.vae_output_to_video(video)
|
video = self.vae_output_to_video(video)
|
||||||
self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder'])
|
self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder'])
|
||||||
decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"])
|
decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"])
|
||||||
decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float()
|
decoded_audio = self.audio_vocoder(decoded_audio)
|
||||||
|
decoded_audio = self.output_audio_format_check(decoded_audio)
|
||||||
return video, decoded_audio
|
return video, decoded_audio
|
||||||
|
|
||||||
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength, initial_latents=None, num_frames=121):
|
|
||||||
b, _, f, h, w = latents.shape
|
|
||||||
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device)
|
|
||||||
initial_latents = torch.zeros_like(latents) if initial_latents is None else initial_latents
|
|
||||||
for idx, input_latent in zip(input_indexes, input_latents):
|
|
||||||
idx = min(max(1 + (idx-1) // 8, 0), f - 1)
|
|
||||||
input_latent = input_latent.to(dtype=latents.dtype, device=latents.device)
|
|
||||||
initial_latents[:, :, idx:idx + input_latent.shape[2], :, :] = input_latent
|
|
||||||
denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength
|
|
||||||
latents = latents * denoise_mask + initial_latents * (1.0 - denoise_mask)
|
|
||||||
return latents, denoise_mask, initial_latents
|
|
||||||
|
|
||||||
|
|
||||||
class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
|
class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -251,8 +265,8 @@ class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
|
|||||||
if inputs_shared.get("use_two_stage_pipeline", False):
|
if inputs_shared.get("use_two_stage_pipeline", False):
|
||||||
# distill pipeline also uses two-stage, but it does not needs lora
|
# distill pipeline also uses two-stage, but it does not needs lora
|
||||||
if not inputs_shared.get("use_distilled_pipeline", False):
|
if not inputs_shared.get("use_distilled_pipeline", False):
|
||||||
if not (hasattr(pipe, "stage2_lora_path") and pipe.stage2_lora_path is not None):
|
if not (hasattr(pipe, "stage2_lora_config") and pipe.stage2_lora_config is not None):
|
||||||
raise ValueError("Two-stage pipeline requested, but stage2_lora_path is not set in the pipeline.")
|
raise ValueError("Two-stage pipeline requested, but stage2_lora_config is not set in the pipeline.")
|
||||||
if not (hasattr(pipe, "upsampler") and pipe.upsampler is not None):
|
if not (hasattr(pipe, "upsampler") and pipe.upsampler is not None):
|
||||||
raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.")
|
raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.")
|
||||||
return inputs_shared, inputs_posi, inputs_nega
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
@@ -262,22 +276,23 @@ class LTX2AudioVideoUnit_ShapeChecker(PipelineUnit):
|
|||||||
"""
|
"""
|
||||||
For two-stage pipelines, the resolution must be divisible by 64.
|
For two-stage pipelines, the resolution must be divisible by 64.
|
||||||
For one-stage pipelines, the resolution must be divisible by 32.
|
For one-stage pipelines, the resolution must be divisible by 32.
|
||||||
|
This unit set height and width to stage 1 resolution, and stage_2_width and stage_2_height.
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_params=("height", "width", "num_frames"),
|
input_params=("height", "width", "num_frames", "use_two_stage_pipeline", "stage2_spatial_upsample_factor"),
|
||||||
output_params=("height", "width", "num_frames"),
|
output_params=("height", "width", "num_frames", "stage_2_height", "stage_2_width"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, use_two_stage_pipeline=False):
|
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, use_two_stage_pipeline=False, stage2_spatial_upsample_factor=2):
|
||||||
if use_two_stage_pipeline:
|
if use_two_stage_pipeline:
|
||||||
self.width_division_factor = 64
|
height, width = height // stage2_spatial_upsample_factor, width // stage2_spatial_upsample_factor
|
||||||
self.height_division_factor = 64
|
|
||||||
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
|
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
|
||||||
if use_two_stage_pipeline:
|
stage_2_height, stage_2_width = int(height * stage2_spatial_upsample_factor), int(width * stage2_spatial_upsample_factor)
|
||||||
self.width_division_factor = 32
|
else:
|
||||||
self.height_division_factor = 32
|
stage_2_height, stage_2_width = None, None
|
||||||
return {"height": height, "width": width, "num_frames": num_frames}
|
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
|
||||||
|
return {"height": height, "width": width, "num_frames": num_frames, "stage_2_height": stage_2_height, "stage_2_width": stage_2_width}
|
||||||
|
|
||||||
|
|
||||||
class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
|
class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
|
||||||
@@ -290,121 +305,20 @@ class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
|
|||||||
output_params=("video_context", "audio_context"),
|
output_params=("video_context", "audio_context"),
|
||||||
onload_model_names=("text_encoder", "text_encoder_post_modules"),
|
onload_model_names=("text_encoder", "text_encoder_post_modules"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _convert_to_additive_mask(self, attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
|
||||||
return (attention_mask - 1).to(dtype).reshape(
|
|
||||||
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(dtype).max
|
|
||||||
|
|
||||||
def _run_connectors(self, pipe, encoded_input: torch.Tensor,
|
|
||||||
attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
connector_attention_mask = self._convert_to_additive_mask(attention_mask, encoded_input.dtype)
|
|
||||||
|
|
||||||
encoded, encoded_connector_attention_mask = pipe.text_encoder_post_modules.embeddings_connector(
|
|
||||||
encoded_input,
|
|
||||||
connector_attention_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
# restore the mask values to int64
|
|
||||||
attention_mask = (encoded_connector_attention_mask < 0.000001).to(torch.int64)
|
|
||||||
attention_mask = attention_mask.reshape([encoded.shape[0], encoded.shape[1], 1])
|
|
||||||
encoded = encoded * attention_mask
|
|
||||||
|
|
||||||
encoded_for_audio, _ = pipe.text_encoder_post_modules.audio_embeddings_connector(
|
|
||||||
encoded_input, connector_attention_mask)
|
|
||||||
|
|
||||||
return encoded, encoded_for_audio, attention_mask.squeeze(-1)
|
|
||||||
|
|
||||||
def _norm_and_concat_padded_batch(
|
|
||||||
self,
|
|
||||||
encoded_text: torch.Tensor,
|
|
||||||
sequence_lengths: torch.Tensor,
|
|
||||||
padding_side: str = "right",
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Normalize and flatten multi-layer hidden states, respecting padding.
|
|
||||||
Performs per-batch, per-layer normalization using masked mean and range,
|
|
||||||
then concatenates across the layer dimension.
|
|
||||||
Args:
|
|
||||||
encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].
|
|
||||||
sequence_lengths: Number of valid (non-padded) tokens per batch item.
|
|
||||||
padding_side: Whether padding is on "left" or "right".
|
|
||||||
Returns:
|
|
||||||
Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],
|
|
||||||
with padded positions zeroed out.
|
|
||||||
"""
|
|
||||||
b, t, d, l = encoded_text.shape # noqa: E741
|
|
||||||
device = encoded_text.device
|
|
||||||
# Build mask: [B, T, 1, 1]
|
|
||||||
token_indices = torch.arange(t, device=device)[None, :] # [1, T]
|
|
||||||
if padding_side == "right":
|
|
||||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
|
||||||
mask = token_indices < sequence_lengths[:, None] # [B, T]
|
|
||||||
elif padding_side == "left":
|
|
||||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
|
||||||
start_indices = t - sequence_lengths[:, None] # [B, 1]
|
|
||||||
mask = token_indices >= start_indices # [B, T]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
|
||||||
mask = rearrange(mask, "b t -> b t 1 1")
|
|
||||||
eps = 1e-6
|
|
||||||
# Compute masked mean: [B, 1, 1, L]
|
|
||||||
masked = encoded_text.masked_fill(~mask, 0.0)
|
|
||||||
denom = (sequence_lengths * d).view(b, 1, 1, 1)
|
|
||||||
mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps)
|
|
||||||
# Compute masked min/max: [B, 1, 1, L]
|
|
||||||
x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
|
||||||
x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
|
||||||
range_ = x_max - x_min
|
|
||||||
# Normalize only the valid tokens
|
|
||||||
normed = 8 * (encoded_text - mean) / (range_ + eps)
|
|
||||||
# concat to be [Batch, T, D * L] - this preserves the original structure
|
|
||||||
normed = normed.reshape(b, t, -1) # [B, T, D * L]
|
|
||||||
# Apply mask to preserve original padding (set padded positions to 0)
|
|
||||||
mask_flattened = rearrange(mask, "b t 1 1 -> b t 1").expand(-1, -1, d * l)
|
|
||||||
normed = normed.masked_fill(~mask_flattened, 0.0)
|
|
||||||
|
|
||||||
return normed
|
|
||||||
|
|
||||||
def _run_feature_extractor(self,
|
|
||||||
pipe,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: torch.Tensor,
|
|
||||||
padding_side: str = "right") -> torch.Tensor:
|
|
||||||
encoded_text_features = torch.stack(hidden_states, dim=-1)
|
|
||||||
encoded_text_features_dtype = encoded_text_features.dtype
|
|
||||||
sequence_lengths = attention_mask.sum(dim=-1)
|
|
||||||
normed_concated_encoded_text_features = self._norm_and_concat_padded_batch(encoded_text_features,
|
|
||||||
sequence_lengths,
|
|
||||||
padding_side=padding_side)
|
|
||||||
|
|
||||||
return pipe.text_encoder_post_modules.feature_extractor_linear(
|
|
||||||
normed_concated_encoded_text_features.to(encoded_text_features_dtype))
|
|
||||||
|
|
||||||
def _preprocess_text(
|
def _preprocess_text(
|
||||||
self,
|
self,
|
||||||
pipe,
|
pipe,
|
||||||
text: str,
|
text: str,
|
||||||
padding_side: str = "left",
|
|
||||||
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
||||||
"""
|
|
||||||
Encode a given string into feature tensors suitable for downstream tasks.
|
|
||||||
Args:
|
|
||||||
text (str): Input string to encode.
|
|
||||||
Returns:
|
|
||||||
tuple[torch.Tensor, dict[str, torch.Tensor]]: Encoded features and a dictionary with attention mask.
|
|
||||||
"""
|
|
||||||
token_pairs = pipe.tokenizer.tokenize_with_weights(text)["gemma"]
|
token_pairs = pipe.tokenizer.tokenize_with_weights(text)["gemma"]
|
||||||
input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.device)
|
input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.device)
|
||||||
attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.device)
|
attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.device)
|
||||||
outputs = pipe.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
outputs = pipe.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
||||||
projected = self._run_feature_extractor(pipe,
|
return outputs.hidden_states, attention_mask
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
padding_side=padding_side)
|
|
||||||
return projected, attention_mask
|
|
||||||
|
|
||||||
def encode_prompt(self, pipe, text, padding_side="left"):
|
def encode_prompt(self, pipe, text, padding_side="left"):
|
||||||
encoded_inputs, attention_mask = self._preprocess_text(pipe, text, padding_side)
|
hidden_states, attention_mask = self._preprocess_text(pipe, text)
|
||||||
video_encoding, audio_encoding, attention_mask = self._run_connectors(pipe, encoded_inputs, attention_mask)
|
video_encoding, audio_encoding, attention_mask = pipe.text_encoder_post_modules.process_hidden_states(
|
||||||
|
hidden_states, attention_mask, padding_side)
|
||||||
return video_encoding, audio_encoding, attention_mask
|
return video_encoding, audio_encoding, attention_mask
|
||||||
|
|
||||||
def process(self, pipe: LTX2AudioVideoPipeline, prompt: str):
|
def process(self, pipe: LTX2AudioVideoPipeline, prompt: str):
|
||||||
@@ -416,13 +330,13 @@ class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
|
|||||||
class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
|
class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_params=("height", "width", "num_frames", "seed", "rand_device", "use_two_stage_pipeline"),
|
input_params=("height", "width", "num_frames", "seed", "rand_device", "frame_rate"),
|
||||||
output_params=("video_noise", "audio_noise",),
|
output_params=("video_noise", "audio_noise", "video_positions", "audio_positions", "video_latent_shape", "audio_latent_shape")
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
|
def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
|
||||||
video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
|
video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
|
||||||
video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=pipe.video_vae_encoder.latent_channels)
|
video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=128)
|
||||||
video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
|
video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
|
||||||
|
|
||||||
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device)
|
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device)
|
||||||
@@ -442,36 +356,125 @@ class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
|
|||||||
"audio_latent_shape": audio_latent_shape
|
"audio_latent_shape": audio_latent_shape
|
||||||
}
|
}
|
||||||
|
|
||||||
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0, use_two_stage_pipeline=False):
|
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
|
||||||
if use_two_stage_pipeline:
|
|
||||||
stage1_dict = self.process_stage(pipe, height // 2, width // 2, num_frames, seed, rand_device, frame_rate)
|
|
||||||
stage2_dict = self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
|
|
||||||
initial_dict = stage1_dict
|
|
||||||
initial_dict.update({"stage2_" + k: v for k, v in stage2_dict.items()})
|
|
||||||
return initial_dict
|
|
||||||
else:
|
|
||||||
return self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
|
return self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
|
||||||
|
|
||||||
|
|
||||||
class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
|
class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_params=("input_video", "video_noise", "audio_noise", "tiled", "tile_size", "tile_stride"),
|
input_params=("input_video", "video_noise", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels"),
|
||||||
output_params=("video_latents", "audio_latents"),
|
output_params=("video_latents", "input_latents"),
|
||||||
onload_model_names=("video_vae_encoder")
|
onload_model_names=("video_vae_encoder")
|
||||||
)
|
)
|
||||||
|
|
||||||
def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, audio_noise, tiled, tile_size, tile_stride):
|
def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, tiled, tile_size_in_pixels, tile_overlap_in_pixels):
|
||||||
if input_video is None:
|
if input_video is None or not pipe.scheduler.training:
|
||||||
return {"video_latents": video_noise, "audio_latents": audio_noise}
|
return {"video_latents": video_noise}
|
||||||
else:
|
else:
|
||||||
# TODO: implement video-to-video
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
raise NotImplementedError("Video-to-video not implemented yet.")
|
input_video = pipe.preprocess_video(input_video)
|
||||||
|
input_latents = pipe.video_vae_encoder.encode(input_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
return {"video_latents": input_latents, "input_latents": input_latents}
|
||||||
|
|
||||||
|
class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("input_audio", "audio_noise"),
|
||||||
|
output_params=("audio_latents", "audio_input_latents", "audio_positions", "audio_latent_shape"),
|
||||||
|
onload_model_names=("audio_vae_encoder",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: LTX2AudioVideoPipeline, input_audio, audio_noise):
|
||||||
|
if input_audio is None or not pipe.scheduler.training:
|
||||||
|
return {"audio_latents": audio_noise}
|
||||||
|
else:
|
||||||
|
input_audio, sample_rate = input_audio
|
||||||
|
input_audio = convert_to_stereo(input_audio)
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype)
|
||||||
|
audio_input_latents = pipe.audio_vae_encoder(input_audio)
|
||||||
|
audio_latent_shape = AudioLatentShape.from_torch_shape(audio_input_latents.shape)
|
||||||
|
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
|
||||||
|
return {"audio_latents": audio_input_latents, "audio_input_latents": audio_input_latents, "audio_positions": audio_positions, "audio_latent_shape": audio_latent_shape}
|
||||||
|
|
||||||
|
|
||||||
|
class LTX2AudioVideoUnit_VideoRetakeEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("retake_video", "height", "width", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "video_positions", "retake_video_regions"),
|
||||||
|
output_params=("input_latents_video", "denoise_mask_video"),
|
||||||
|
onload_model_names=("video_vae_encoder")
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: LTX2AudioVideoPipeline, retake_video, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, video_positions, retake_video_regions=None):
|
||||||
|
if retake_video is None:
|
||||||
|
return {}
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
resized_video = [frame.resize((width, height)) for frame in retake_video]
|
||||||
|
input_video = pipe.preprocess_video(resized_video)
|
||||||
|
input_latents_video = pipe.video_vae_encoder.encode(input_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
|
||||||
|
b, c, f, h, w = input_latents_video.shape
|
||||||
|
denoise_mask_video = torch.zeros((b, 1, f, h, w), device=input_latents_video.device, dtype=input_latents_video.dtype)
|
||||||
|
if retake_video_regions is not None and len(retake_video_regions) > 0:
|
||||||
|
for start_time, end_time in retake_video_regions:
|
||||||
|
t_start, t_end = video_positions[0, 0].unbind(dim=-1)
|
||||||
|
in_region = (t_end >= start_time) & (t_start <= end_time)
|
||||||
|
in_region = pipe.video_patchifier.unpatchify_video(in_region.unsqueeze(0).unsqueeze(-1), f, h, w)
|
||||||
|
denoise_mask_video = torch.where(in_region, torch.ones_like(denoise_mask_video), denoise_mask_video)
|
||||||
|
|
||||||
|
return {"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video}
|
||||||
|
|
||||||
|
|
||||||
|
class LTX2AudioVideoUnit_AudioRetakeEmbedder(PipelineUnit):
|
||||||
|
"""
|
||||||
|
Functionality of audio2video, audio retaking.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("retake_audio", "seed", "rand_device", "retake_audio_regions"),
|
||||||
|
output_params=("input_latents_audio", "audio_noise", "audio_positions", "audio_latent_shape", "denoise_mask_audio"),
|
||||||
|
onload_model_names=("audio_vae_encoder",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: LTX2AudioVideoPipeline, retake_audio, seed, rand_device, retake_audio_regions=None):
|
||||||
|
if retake_audio is None:
|
||||||
|
return {}
|
||||||
|
else:
|
||||||
|
input_audio, sample_rate = retake_audio
|
||||||
|
input_audio = convert_to_stereo(input_audio)
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
input_latents_audio = pipe.audio_vae_encoder(input_audio)
|
||||||
|
audio_latent_shape = AudioLatentShape.from_torch_shape(input_latents_audio.shape)
|
||||||
|
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
|
||||||
|
# Regenerate noise for the new shape if retake_audio is provided, to avoid shape mismatch.
|
||||||
|
audio_noise = pipe.generate_noise(input_latents_audio.shape, seed=seed, rand_device=rand_device)
|
||||||
|
|
||||||
|
b, c, t, f = input_latents_audio.shape
|
||||||
|
denoise_mask_audio = torch.zeros((b, 1, t, 1), device=input_latents_audio.device, dtype=input_latents_audio.dtype)
|
||||||
|
if retake_audio_regions is not None and len(retake_audio_regions) > 0:
|
||||||
|
for start_time, end_time in retake_audio_regions:
|
||||||
|
t_start, t_end = audio_positions[:, 0, :, 0], audio_positions[:, 0, :, 1]
|
||||||
|
in_region = (t_end >= start_time) & (t_start <= end_time)
|
||||||
|
in_region = pipe.audio_patchifier.unpatchify_audio(in_region.unsqueeze(-1), 1, 1)
|
||||||
|
denoise_mask_audio = torch.where(in_region, torch.ones_like(denoise_mask_audio), denoise_mask_audio)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_latents_audio": input_latents_audio,
|
||||||
|
"denoise_mask_audio": denoise_mask_audio,
|
||||||
|
"audio_noise": audio_noise,
|
||||||
|
"audio_positions": audio_positions,
|
||||||
|
"audio_latent_shape": audio_latent_shape,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "num_frames", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"),
|
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "frame_rate", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "input_latents_video", "denoise_mask_video"),
|
||||||
output_params=("video_latents"),
|
output_params=("denoise_mask_video", "input_latents_video", "ref_frames_latents", "ref_frames_positions"),
|
||||||
onload_model_names=("video_vae_encoder")
|
onload_model_names=("video_vae_encoder")
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -480,30 +483,166 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
|||||||
image = torch.Tensor(np.array(image, dtype=np.float32)).to(dtype=pipe.torch_dtype, device=pipe.device)
|
image = torch.Tensor(np.array(image, dtype=np.float32)).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
image = image / 127.5 - 1.0
|
image = image / 127.5 - 1.0
|
||||||
image = repeat(image, f"H W C -> B C F H W", B=1, F=1)
|
image = repeat(image, f"H W C -> B C F H W", B=1, F=1)
|
||||||
latent = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device)
|
latents = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device)
|
||||||
return latent
|
return latents
|
||||||
|
|
||||||
def process(self, pipe: LTX2AudioVideoPipeline, input_images, input_images_indexes, input_images_strength, video_latents, height, width, num_frames, tiled, tile_size_in_pixels, tile_overlap_in_pixels, use_two_stage_pipeline=False):
|
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength=1.0, input_latents_video=None, denoise_mask_video=None):
|
||||||
|
b, _, f, h, w = latents.shape
|
||||||
|
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device) if denoise_mask_video is None else denoise_mask_video
|
||||||
|
input_latents_video = torch.zeros_like(latents) if input_latents_video is None else input_latents_video
|
||||||
|
for idx, input_latent in zip(input_indexes, input_latents):
|
||||||
|
idx = min(max(1 + (idx-1) // 8, 0), f - 1)
|
||||||
|
input_latent = input_latent.to(dtype=latents.dtype, device=latents.device)
|
||||||
|
input_latents_video[:, :, idx:idx + input_latent.shape[2], :, :] = input_latent
|
||||||
|
denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength
|
||||||
|
return input_latents_video, denoise_mask
|
||||||
|
|
||||||
|
def process(
|
||||||
|
self,
|
||||||
|
pipe: LTX2AudioVideoPipeline,
|
||||||
|
video_latents,
|
||||||
|
input_images,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
frame_rate,
|
||||||
|
tiled,
|
||||||
|
tile_size_in_pixels,
|
||||||
|
tile_overlap_in_pixels,
|
||||||
|
input_images_indexes=[0],
|
||||||
|
input_images_strength=1.0,
|
||||||
|
input_latents_video=None,
|
||||||
|
denoise_mask_video=None,
|
||||||
|
):
|
||||||
if input_images is None or len(input_images) == 0:
|
if input_images is None or len(input_images) == 0:
|
||||||
return {"video_latents": video_latents}
|
return {}
|
||||||
|
else:
|
||||||
|
if len(input_images_indexes) != len(set(input_images_indexes)):
|
||||||
|
raise ValueError("Input images must have unique indexes.")
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
frame_conditions = {"input_latents_video": None, "denoise_mask_video": None, "ref_frames_latents": [], "ref_frames_positions": []}
|
||||||
|
for img, index in zip(input_images, input_images_indexes):
|
||||||
|
latents = self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels)
|
||||||
|
# first_frame by replacing latents
|
||||||
|
if index == 0:
|
||||||
|
input_latents_video, denoise_mask_video = self.apply_input_images_to_latents(
|
||||||
|
video_latents, [latents], [0], input_images_strength, input_latents_video, denoise_mask_video)
|
||||||
|
frame_conditions.update({"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video})
|
||||||
|
# other frames by adding reference latents
|
||||||
|
else:
|
||||||
|
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=VideoLatentShape.from_torch_shape(latents.shape), device=pipe.device)
|
||||||
|
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, False).float()
|
||||||
|
video_positions[:, 0, ...] = (video_positions[:, 0, ...] + index) / frame_rate
|
||||||
|
video_positions = video_positions.to(pipe.torch_dtype)
|
||||||
|
frame_conditions["ref_frames_latents"].append(latents)
|
||||||
|
frame_conditions["ref_frames_positions"].append(video_positions)
|
||||||
|
if len(frame_conditions["ref_frames_latents"]) == 0:
|
||||||
|
frame_conditions.update({"ref_frames_latents": None, "ref_frames_positions": None})
|
||||||
|
return frame_conditions
|
||||||
|
|
||||||
|
|
||||||
|
class LTX2AudioVideoUnit_InContextVideoEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("in_context_videos", "height", "width", "num_frames", "frame_rate", "in_context_downsample_factor", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels"),
|
||||||
|
output_params=("in_context_video_latents", "in_context_video_positions"),
|
||||||
|
onload_model_names=("video_vae_encoder")
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_in_context_video(self, pipe, in_context_video, height, width, num_frames, in_context_downsample_factor):
|
||||||
|
if in_context_video is None or len(in_context_video) == 0:
|
||||||
|
raise ValueError("In-context video is None or empty.")
|
||||||
|
in_context_video = in_context_video[:num_frames]
|
||||||
|
expected_height = height // in_context_downsample_factor
|
||||||
|
expected_width = width // in_context_downsample_factor
|
||||||
|
current_h, current_w, current_f = in_context_video[0].size[1], in_context_video[0].size[0], len(in_context_video)
|
||||||
|
h, w, f = pipe.check_resize_height_width(expected_height, expected_width, current_f, verbose=0)
|
||||||
|
if current_h != h or current_w != w:
|
||||||
|
in_context_video = [img.resize((w, h)) for img in in_context_video]
|
||||||
|
if current_f != f:
|
||||||
|
# pad black frames at the end
|
||||||
|
in_context_video = in_context_video + [Image.new("RGB", (w, h), (0, 0, 0))] * (f - current_f)
|
||||||
|
return in_context_video
|
||||||
|
|
||||||
|
def process(self, pipe: LTX2AudioVideoPipeline, in_context_videos, height, width, num_frames, frame_rate, in_context_downsample_factor, tiled, tile_size_in_pixels, tile_overlap_in_pixels):
|
||||||
|
if in_context_videos is None or len(in_context_videos) == 0:
|
||||||
|
return {}
|
||||||
else:
|
else:
|
||||||
pipe.load_models_to_device(self.onload_model_names)
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
output_dicts = {}
|
latents, positions = [], []
|
||||||
stage1_height = height // 2 if use_two_stage_pipeline else height
|
for in_context_video in in_context_videos:
|
||||||
stage1_width = width // 2 if use_two_stage_pipeline else width
|
in_context_video = self.check_in_context_video(pipe, in_context_video, height, width, num_frames, in_context_downsample_factor)
|
||||||
stage1_latents = [
|
in_context_video = pipe.preprocess_video(in_context_video)
|
||||||
self.get_image_latent(pipe, img, stage1_height, stage1_width, tiled, tile_size_in_pixels,
|
in_context_latents = pipe.video_vae_encoder.encode(in_context_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
tile_overlap_in_pixels) for img in input_images
|
|
||||||
]
|
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=VideoLatentShape.from_torch_shape(in_context_latents.shape), device=pipe.device)
|
||||||
video_latents, denoise_mask_video, initial_latents = pipe.apply_input_images_to_latents(video_latents, stage1_latents, input_images_indexes, input_images_strength, num_frames=num_frames)
|
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float()
|
||||||
output_dicts.update({"video_latents": video_latents, "denoise_mask_video": denoise_mask_video, "input_latents_video": initial_latents})
|
video_positions[:, 0, ...] = video_positions[:, 0, ...] / frame_rate
|
||||||
if use_two_stage_pipeline:
|
video_positions[:, 1, ...] *= in_context_downsample_factor # height axis
|
||||||
stage2_latents = [
|
video_positions[:, 2, ...] *= in_context_downsample_factor # width axis
|
||||||
self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels,
|
video_positions = video_positions.to(pipe.torch_dtype)
|
||||||
tile_overlap_in_pixels) for img in input_images
|
|
||||||
]
|
latents.append(in_context_latents)
|
||||||
output_dicts.update({"stage2_input_latents": stage2_latents})
|
positions.append(video_positions)
|
||||||
return output_dicts
|
latents = torch.cat(latents, dim=1)
|
||||||
|
positions = torch.cat(positions, dim=1)
|
||||||
|
return {"in_context_video_latents": latents, "in_context_video_positions": positions}
|
||||||
|
|
||||||
|
|
||||||
|
class LTX2AudioVideoUnit_SwitchStage2(PipelineUnit):
|
||||||
|
"""
|
||||||
|
1. switch height and width to stage 2 resolution
|
||||||
|
2. clear in_context_video_latents and in_context_video_positions
|
||||||
|
3. switch stage 2 lora model
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("stage_2_height", "stage_2_width", "clear_lora_before_state_two", "use_distilled_pipeline"),
|
||||||
|
output_params=("height", "width", "in_context_video_latents", "in_context_video_positions"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: LTX2AudioVideoPipeline, stage_2_height, stage_2_width, clear_lora_before_state_two, use_distilled_pipeline):
|
||||||
|
stage2_params = {}
|
||||||
|
stage2_params.update({"height": stage_2_height, "width": stage_2_width})
|
||||||
|
stage2_params.update({"in_context_video_latents": None, "in_context_video_positions": None})
|
||||||
|
stage2_params.update({"input_latents_video": None, "denoise_mask_video": None})
|
||||||
|
if clear_lora_before_state_two:
|
||||||
|
pipe.clear_lora()
|
||||||
|
if not use_distilled_pipeline:
|
||||||
|
pipe.load_lora(pipe.dit, pipe.stage2_lora_config, alpha=pipe.stage2_lora_strength, state_dict=pipe.stage2_lora_config.state_dict)
|
||||||
|
return stage2_params
|
||||||
|
|
||||||
|
|
||||||
|
class LTX2AudioVideoUnit_SetScheduleStage2(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("video_latents", "video_noise", "audio_latents", "audio_noise"),
|
||||||
|
output_params=("video_latents", "audio_latents"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: LTX2AudioVideoPipeline, video_latents, video_noise, audio_latents, audio_noise):
|
||||||
|
pipe.scheduler.set_timesteps(special_case="stage2")
|
||||||
|
video_latents = pipe.scheduler.add_noise(video_latents, video_noise, pipe.scheduler.timesteps[0])
|
||||||
|
audio_latents = pipe.scheduler.add_noise(audio_latents, audio_noise, pipe.scheduler.timesteps[0])
|
||||||
|
return {"video_latents": video_latents, "audio_latents": audio_latents}
|
||||||
|
|
||||||
|
|
||||||
|
class LTX2AudioVideoUnit_LatentsUpsampler(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("video_latents",),
|
||||||
|
output_params=("video_latents",),
|
||||||
|
onload_model_names=("upsampler",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: LTX2AudioVideoPipeline, video_latents):
|
||||||
|
if video_latents is None or pipe.upsampler is None:
|
||||||
|
raise ValueError("No upsampler or no video latents before stage 2.")
|
||||||
|
else:
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
video_latents = pipe.video_vae_encoder.per_channel_statistics.un_normalize(video_latents)
|
||||||
|
video_latents = pipe.upsampler(video_latents)
|
||||||
|
video_latents = pipe.video_vae_encoder.per_channel_statistics.normalize(video_latents)
|
||||||
|
return {"video_latents": video_latents}
|
||||||
|
|
||||||
|
|
||||||
def model_fn_ltx2(
|
def model_fn_ltx2(
|
||||||
@@ -517,7 +656,19 @@ def model_fn_ltx2(
|
|||||||
audio_positions=None,
|
audio_positions=None,
|
||||||
audio_patchifier=None,
|
audio_patchifier=None,
|
||||||
timestep=None,
|
timestep=None,
|
||||||
|
# First Frame Conditioning
|
||||||
|
input_latents_video=None,
|
||||||
denoise_mask_video=None,
|
denoise_mask_video=None,
|
||||||
|
# Other Frames Conditioning
|
||||||
|
ref_frames_latents=None,
|
||||||
|
ref_frames_positions=None,
|
||||||
|
# In-Context Conditioning
|
||||||
|
in_context_video_latents=None,
|
||||||
|
in_context_video_positions=None,
|
||||||
|
# Audio Inputs
|
||||||
|
input_latents_audio=None,
|
||||||
|
denoise_mask_audio=None,
|
||||||
|
# Gradient Checkpointing
|
||||||
use_gradient_checkpointing=False,
|
use_gradient_checkpointing=False,
|
||||||
use_gradient_checkpointing_offload=False,
|
use_gradient_checkpointing_offload=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -527,13 +678,38 @@ def model_fn_ltx2(
|
|||||||
# patchify
|
# patchify
|
||||||
b, c_v, f, h, w = video_latents.shape
|
b, c_v, f, h, w = video_latents.shape
|
||||||
video_latents = video_patchifier.patchify(video_latents)
|
video_latents = video_patchifier.patchify(video_latents)
|
||||||
|
seq_len_video = video_latents.shape[1]
|
||||||
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
|
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
|
||||||
if denoise_mask_video is not None:
|
# Frist frame conditioning by replacing the video latents
|
||||||
video_timesteps = video_patchifier.patchify(denoise_mask_video) * video_timesteps
|
if input_latents_video is not None:
|
||||||
|
denoise_mask_video = video_patchifier.patchify(denoise_mask_video)
|
||||||
|
video_latents = video_latents * denoise_mask_video + video_patchifier.patchify(input_latents_video) * (1.0 - denoise_mask_video)
|
||||||
|
video_timesteps = denoise_mask_video * video_timesteps
|
||||||
|
|
||||||
|
# Reference conditioning by appending the reference video or frame latents
|
||||||
|
total_ref_latents = ref_frames_latents if ref_frames_latents is not None else []
|
||||||
|
total_ref_positions = ref_frames_positions if ref_frames_positions is not None else []
|
||||||
|
total_ref_latents += [in_context_video_latents] if in_context_video_latents is not None else []
|
||||||
|
total_ref_positions += [in_context_video_positions] if in_context_video_positions is not None else []
|
||||||
|
if len(total_ref_latents) > 0:
|
||||||
|
for ref_frames_latent, ref_frames_position in zip(total_ref_latents, total_ref_positions):
|
||||||
|
ref_frames_latent = video_patchifier.patchify(ref_frames_latent)
|
||||||
|
ref_frames_timestep = timestep.repeat(1, ref_frames_latent.shape[1], 1) * 0.
|
||||||
|
video_latents = torch.cat([video_latents, ref_frames_latent], dim=1)
|
||||||
|
video_positions = torch.cat([video_positions, ref_frames_position], dim=2)
|
||||||
|
video_timesteps = torch.cat([video_timesteps, ref_frames_timestep], dim=1)
|
||||||
|
|
||||||
|
if audio_latents is not None:
|
||||||
_, c_a, _, mel_bins = audio_latents.shape
|
_, c_a, _, mel_bins = audio_latents.shape
|
||||||
audio_latents = audio_patchifier.patchify(audio_latents)
|
audio_latents = audio_patchifier.patchify(audio_latents)
|
||||||
audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1)
|
audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1)
|
||||||
#TODO: support gradient checkpointing in training
|
else:
|
||||||
|
audio_timesteps = None
|
||||||
|
if input_latents_audio is not None:
|
||||||
|
denoise_mask_audio = audio_patchifier.patchify(denoise_mask_audio)
|
||||||
|
audio_latents = audio_latents * denoise_mask_audio + audio_patchifier.patchify(input_latents_audio) * (1.0 - denoise_mask_audio)
|
||||||
|
audio_timesteps = denoise_mask_audio * audio_timesteps
|
||||||
|
|
||||||
vx, ax = dit(
|
vx, ax = dit(
|
||||||
video_latents=video_latents,
|
video_latents=video_latents,
|
||||||
video_positions=video_positions,
|
video_positions=video_positions,
|
||||||
@@ -543,8 +719,13 @@ def model_fn_ltx2(
|
|||||||
audio_positions=audio_positions,
|
audio_positions=audio_positions,
|
||||||
audio_context=audio_context,
|
audio_context=audio_context,
|
||||||
audio_timesteps=audio_timesteps,
|
audio_timesteps=audio_timesteps,
|
||||||
|
sigma=timestep,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
vx = vx[:, :seq_len_video, ...]
|
||||||
# unpatchify
|
# unpatchify
|
||||||
vx = video_patchifier.unpatchify_video(vx, f, h, w)
|
vx = video_patchifier.unpatchify_video(vx, f, h, w)
|
||||||
ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins)
|
ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins) if ax is not None else None
|
||||||
return vx, ax
|
return vx, ax
|
||||||
|
|||||||
461
diffsynth/pipelines/mova_audio_video.py
Normal file
461
diffsynth/pipelines/mova_audio_video.py
Normal file
@@ -0,0 +1,461 @@
|
|||||||
|
import sys
|
||||||
|
import torch, types
|
||||||
|
from PIL import Image
|
||||||
|
from typing import Optional, Union
|
||||||
|
from einops import rearrange
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
|
from ..diffusion import FlowMatchScheduler
|
||||||
|
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||||
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||||
|
|
||||||
|
from ..models.wan_video_dit import WanModel, sinusoidal_embedding_1d, set_to_torch_norm
|
||||||
|
from ..models.wan_video_text_encoder import WanTextEncoder, HuggingfaceTokenizer
|
||||||
|
from ..models.wan_video_vae import WanVideoVAE
|
||||||
|
from ..models.mova_audio_dit import MovaAudioDit
|
||||||
|
from ..models.mova_audio_vae import DacVAE
|
||||||
|
from ..models.mova_dual_tower_bridge import DualTowerConditionalBridge
|
||||||
|
from ..utils.data.audio import convert_to_mono, resample_waveform
|
||||||
|
|
||||||
|
|
||||||
|
class MovaAudioVideoPipeline(BasePipeline):
|
||||||
|
|
||||||
|
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||||
|
super().__init__(
|
||||||
|
device=device, torch_dtype=torch_dtype,
|
||||||
|
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
|
||||||
|
)
|
||||||
|
self.scheduler = FlowMatchScheduler("Wan")
|
||||||
|
self.tokenizer: HuggingfaceTokenizer = None
|
||||||
|
self.text_encoder: WanTextEncoder = None
|
||||||
|
self.video_dit: WanModel = None # high noise model
|
||||||
|
self.video_dit2: WanModel = None # low noise model
|
||||||
|
self.audio_dit: MovaAudioDit = None
|
||||||
|
self.dual_tower_bridge: DualTowerConditionalBridge = None
|
||||||
|
self.video_vae: WanVideoVAE = None
|
||||||
|
self.audio_vae: DacVAE = None
|
||||||
|
|
||||||
|
self.in_iteration_models = ("video_dit", "audio_dit", "dual_tower_bridge")
|
||||||
|
self.in_iteration_models_2 = ("video_dit2", "audio_dit", "dual_tower_bridge")
|
||||||
|
|
||||||
|
self.units = [
|
||||||
|
MovaAudioVideoUnit_ShapeChecker(),
|
||||||
|
MovaAudioVideoUnit_NoiseInitializer(),
|
||||||
|
MovaAudioVideoUnit_InputVideoEmbedder(),
|
||||||
|
MovaAudioVideoUnit_InputAudioEmbedder(),
|
||||||
|
MovaAudioVideoUnit_PromptEmbedder(),
|
||||||
|
MovaAudioVideoUnit_ImageEmbedderVAE(),
|
||||||
|
MovaAudioVideoUnit_UnifiedSequenceParallel(),
|
||||||
|
]
|
||||||
|
self.model_fn = model_fn_mova_audio_video
|
||||||
|
self.compilable_models = ["video_dit", "video_dit2", "audio_dit"]
|
||||||
|
|
||||||
|
def enable_usp(self):
|
||||||
|
from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward
|
||||||
|
for block in self.video_dit.blocks + self.audio_dit.blocks + self.video_dit2.blocks:
|
||||||
|
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||||
|
self.sp_size = get_sequence_parallel_world_size()
|
||||||
|
self.use_unified_sequence_parallel = True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(
|
||||||
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
device: Union[str, torch.device] = get_device_type(),
|
||||||
|
model_configs: list[ModelConfig] = [],
|
||||||
|
tokenizer_config: ModelConfig = ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"),
|
||||||
|
use_usp: bool = False,
|
||||||
|
vram_limit: float = None,
|
||||||
|
):
|
||||||
|
if use_usp:
|
||||||
|
from ..utils.xfuser import initialize_usp
|
||||||
|
initialize_usp(device)
|
||||||
|
import torch.distributed as dist
|
||||||
|
from ..core.device.npu_compatible_device import get_device_name
|
||||||
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
device = get_device_name()
|
||||||
|
# Initialize pipeline
|
||||||
|
pipe = MovaAudioVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||||
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||||
|
|
||||||
|
# Fetch models
|
||||||
|
pipe.text_encoder = model_pool.fetch_model("wan_video_text_encoder")
|
||||||
|
dit = model_pool.fetch_model("wan_video_dit", index=2)
|
||||||
|
if isinstance(dit, list):
|
||||||
|
pipe.video_dit, pipe.video_dit2 = dit
|
||||||
|
else:
|
||||||
|
pipe.video_dit = dit
|
||||||
|
pipe.audio_dit = model_pool.fetch_model("mova_audio_dit")
|
||||||
|
pipe.dual_tower_bridge = model_pool.fetch_model("mova_dual_tower_bridge")
|
||||||
|
pipe.video_vae = model_pool.fetch_model("wan_video_vae")
|
||||||
|
pipe.audio_vae = model_pool.fetch_model("mova_audio_vae")
|
||||||
|
set_to_torch_norm([pipe.video_dit, pipe.audio_dit, pipe.dual_tower_bridge] + ([pipe.video_dit2] if pipe.video_dit2 is not None else []))
|
||||||
|
|
||||||
|
# Size division factor
|
||||||
|
if pipe.video_vae is not None:
|
||||||
|
pipe.height_division_factor = pipe.video_vae.upsampling_factor * 2
|
||||||
|
pipe.width_division_factor = pipe.video_vae.upsampling_factor * 2
|
||||||
|
|
||||||
|
# Initialize tokenizer and processor
|
||||||
|
if tokenizer_config is not None:
|
||||||
|
tokenizer_config.download_if_necessary()
|
||||||
|
pipe.tokenizer = HuggingfaceTokenizer(name=tokenizer_config.path, seq_len=512, clean='whitespace')
|
||||||
|
|
||||||
|
# Unified Sequence Parallel
|
||||||
|
if use_usp: pipe.enable_usp()
|
||||||
|
|
||||||
|
# VRAM Management
|
||||||
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
# Prompt
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: Optional[str] = "",
|
||||||
|
# Image-to-video
|
||||||
|
input_image: Optional[Image.Image] = None,
|
||||||
|
# First-last-frame-to-video
|
||||||
|
end_image: Optional[Image.Image] = None,
|
||||||
|
# Video-to-video
|
||||||
|
denoising_strength: Optional[float] = 1.0,
|
||||||
|
# Randomness
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
rand_device: Optional[str] = "cpu",
|
||||||
|
# Shape
|
||||||
|
height: Optional[int] = 352,
|
||||||
|
width: Optional[int] = 640,
|
||||||
|
num_frames: Optional[int] = 81,
|
||||||
|
frame_rate: Optional[int] = 24,
|
||||||
|
# Classifier-free guidance
|
||||||
|
cfg_scale: Optional[float] = 5.0,
|
||||||
|
# Boundary
|
||||||
|
switch_DiT_boundary: Optional[float] = 0.9,
|
||||||
|
# Scheduler
|
||||||
|
num_inference_steps: Optional[int] = 50,
|
||||||
|
sigma_shift: Optional[float] = 5.0,
|
||||||
|
# VAE tiling
|
||||||
|
tiled: Optional[bool] = True,
|
||||||
|
tile_size: Optional[tuple[int, int]] = (30, 52),
|
||||||
|
tile_stride: Optional[tuple[int, int]] = (15, 26),
|
||||||
|
# progress_bar
|
||||||
|
progress_bar_cmd=tqdm,
|
||||||
|
):
|
||||||
|
# Scheduler
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
inputs_posi = {
|
||||||
|
"prompt": prompt,
|
||||||
|
}
|
||||||
|
inputs_nega = {
|
||||||
|
"negative_prompt": negative_prompt,
|
||||||
|
}
|
||||||
|
inputs_shared = {
|
||||||
|
"input_image": input_image,
|
||||||
|
"end_image": end_image,
|
||||||
|
"denoising_strength": denoising_strength,
|
||||||
|
"seed": seed, "rand_device": rand_device,
|
||||||
|
"height": height, "width": width, "num_frames": num_frames, "frame_rate": frame_rate,
|
||||||
|
"cfg_scale": cfg_scale,
|
||||||
|
"sigma_shift": sigma_shift,
|
||||||
|
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||||
|
}
|
||||||
|
for unit in self.units:
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||||
|
|
||||||
|
# Denoise
|
||||||
|
self.load_models_to_device(self.in_iteration_models)
|
||||||
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
# Switch DiT if necessary
|
||||||
|
if timestep.item() < switch_DiT_boundary * 1000 and self.video_dit2 is not None and not models["video_dit"] is self.video_dit2:
|
||||||
|
self.load_models_to_device(self.in_iteration_models_2)
|
||||||
|
models["video_dit"] = self.video_dit2
|
||||||
|
# Timestep
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
|
||||||
|
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
# Scheduler
|
||||||
|
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video, **inputs_shared)
|
||||||
|
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared)
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
self.load_models_to_device(['video_vae'])
|
||||||
|
video = self.video_vae.decode(inputs_shared["video_latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
video = self.vae_output_to_video(video)
|
||||||
|
self.load_models_to_device(["audio_vae"])
|
||||||
|
audio = self.audio_vae.decode(inputs_shared["audio_latents"])
|
||||||
|
audio = self.output_audio_format_check(audio)
|
||||||
|
self.load_models_to_device([])
|
||||||
|
return video, audio
|
||||||
|
|
||||||
|
|
||||||
|
class MovaAudioVideoUnit_ShapeChecker(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width", "num_frames"),
|
||||||
|
output_params=("height", "width", "num_frames"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: MovaAudioVideoPipeline, height, width, num_frames):
|
||||||
|
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
|
||||||
|
return {"height": height, "width": width, "num_frames": num_frames}
|
||||||
|
|
||||||
|
|
||||||
|
class MovaAudioVideoUnit_NoiseInitializer(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width", "num_frames", "seed", "rand_device", "frame_rate"),
|
||||||
|
output_params=("video_noise", "audio_noise")
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: MovaAudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate):
|
||||||
|
length = (num_frames - 1) // 4 + 1
|
||||||
|
video_shape = (1, pipe.video_vae.model.z_dim, length, height // pipe.video_vae.upsampling_factor, width // pipe.video_vae.upsampling_factor)
|
||||||
|
video_noise = pipe.generate_noise(video_shape, seed=seed, rand_device=rand_device)
|
||||||
|
|
||||||
|
audio_num_samples = (int(pipe.audio_vae.sample_rate * num_frames / frame_rate) - 1) // int(pipe.audio_vae.hop_length) + 1
|
||||||
|
audio_shape = (1, pipe.audio_vae.latent_dim, audio_num_samples)
|
||||||
|
audio_noise = pipe.generate_noise(audio_shape, seed=seed, rand_device=rand_device)
|
||||||
|
return {"video_noise": video_noise, "audio_noise": audio_noise}
|
||||||
|
|
||||||
|
|
||||||
|
class MovaAudioVideoUnit_InputVideoEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("input_video", "video_noise", "tiled", "tile_size", "tile_stride"),
|
||||||
|
output_params=("video_latents", "input_latents"),
|
||||||
|
onload_model_names=("video_vae",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: MovaAudioVideoPipeline, input_video, video_noise, tiled, tile_size, tile_stride):
|
||||||
|
if input_video is None or not pipe.scheduler.training:
|
||||||
|
return {"video_latents": video_noise}
|
||||||
|
else:
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
input_video = pipe.preprocess_video(input_video)
|
||||||
|
input_latents = pipe.video_vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
return {"input_latents": input_latents}
|
||||||
|
|
||||||
|
|
||||||
|
class MovaAudioVideoUnit_InputAudioEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("input_audio", "audio_noise"),
|
||||||
|
output_params=("audio_latents", "audio_input_latents"),
|
||||||
|
onload_model_names=("audio_vae",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: MovaAudioVideoPipeline, input_audio, audio_noise):
|
||||||
|
if input_audio is None or not pipe.scheduler.training:
|
||||||
|
return {"audio_latents": audio_noise}
|
||||||
|
else:
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
input_audio, sample_rate = input_audio
|
||||||
|
input_audio = convert_to_mono(input_audio)
|
||||||
|
input_audio = resample_waveform(input_audio, sample_rate, pipe.audio_vae.sample_rate)
|
||||||
|
input_audio = pipe.audio_vae.preprocess(input_audio.unsqueeze(0), pipe.audio_vae.sample_rate)
|
||||||
|
z, _, _, _, _ = pipe.audio_vae.encode(input_audio)
|
||||||
|
return {"audio_input_latents": z.mode()}
|
||||||
|
|
||||||
|
|
||||||
|
class MovaAudioVideoUnit_PromptEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
seperate_cfg=True,
|
||||||
|
input_params_posi={"prompt": "prompt"},
|
||||||
|
input_params_nega={"prompt": "negative_prompt"},
|
||||||
|
output_params=("context",),
|
||||||
|
onload_model_names=("text_encoder",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode_prompt(self, pipe: MovaAudioVideoPipeline, prompt):
|
||||||
|
ids, mask = pipe.tokenizer(
|
||||||
|
prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=512,
|
||||||
|
truncation=True,
|
||||||
|
add_special_tokens=True,
|
||||||
|
return_mask=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
ids = ids.to(pipe.device)
|
||||||
|
mask = mask.to(pipe.device)
|
||||||
|
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||||
|
prompt_emb = pipe.text_encoder(ids, mask)
|
||||||
|
for i, v in enumerate(seq_lens):
|
||||||
|
prompt_emb[:, v:] = 0
|
||||||
|
return prompt_emb
|
||||||
|
|
||||||
|
def process(self, pipe: MovaAudioVideoPipeline, prompt) -> dict:
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
prompt_emb = self.encode_prompt(pipe, prompt)
|
||||||
|
return {"context": prompt_emb}
|
||||||
|
|
||||||
|
|
||||||
|
class MovaAudioVideoUnit_ImageEmbedderVAE(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||||
|
output_params=("y",),
|
||||||
|
onload_model_names=("video_vae",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: MovaAudioVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
|
||||||
|
if input_image is None or not pipe.video_dit.require_vae_embedding:
|
||||||
|
return {}
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
|
||||||
|
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
||||||
|
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
|
||||||
|
msk[:, 1:] = 0
|
||||||
|
if end_image is not None:
|
||||||
|
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
|
||||||
|
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
|
||||||
|
msk[:, -1:] = 1
|
||||||
|
else:
|
||||||
|
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
||||||
|
|
||||||
|
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||||
|
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||||
|
msk = msk.transpose(1, 2)[0]
|
||||||
|
|
||||||
|
y = pipe.video_vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||||
|
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
y = torch.concat([msk, y])
|
||||||
|
y = y.unsqueeze(0)
|
||||||
|
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
return {"y": y}
|
||||||
|
|
||||||
|
|
||||||
|
class MovaAudioVideoUnit_UnifiedSequenceParallel(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(input_params=(), output_params=("use_unified_sequence_parallel",))
|
||||||
|
|
||||||
|
def process(self, pipe: MovaAudioVideoPipeline):
|
||||||
|
if hasattr(pipe, "use_unified_sequence_parallel") and pipe.use_unified_sequence_parallel:
|
||||||
|
return {"use_unified_sequence_parallel": True}
|
||||||
|
return {"use_unified_sequence_parallel": False}
|
||||||
|
|
||||||
|
|
||||||
|
def model_fn_mova_audio_video(
|
||||||
|
video_dit: WanModel,
|
||||||
|
audio_dit: MovaAudioDit,
|
||||||
|
dual_tower_bridge: DualTowerConditionalBridge,
|
||||||
|
video_latents: torch.Tensor = None,
|
||||||
|
audio_latents: torch.Tensor = None,
|
||||||
|
timestep: torch.Tensor = None,
|
||||||
|
context: torch.Tensor = None,
|
||||||
|
y: Optional[torch.Tensor] = None,
|
||||||
|
frame_rate: Optional[int] = 24,
|
||||||
|
use_unified_sequence_parallel: bool = False,
|
||||||
|
use_gradient_checkpointing: bool = False,
|
||||||
|
use_gradient_checkpointing_offload: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
video_x, audio_x = video_latents, audio_latents
|
||||||
|
# First-Last Frame
|
||||||
|
if y is not None:
|
||||||
|
video_x = torch.cat([video_x, y], dim=1)
|
||||||
|
|
||||||
|
# Timestep
|
||||||
|
video_t = video_dit.time_embedding(sinusoidal_embedding_1d(video_dit.freq_dim, timestep))
|
||||||
|
video_t_mod = video_dit.time_projection(video_t).unflatten(1, (6, video_dit.dim))
|
||||||
|
audio_t = audio_dit.time_embedding(sinusoidal_embedding_1d(audio_dit.freq_dim, timestep))
|
||||||
|
audio_t_mod = audio_dit.time_projection(audio_t).unflatten(1, (6, audio_dit.dim))
|
||||||
|
|
||||||
|
# Context
|
||||||
|
video_context = video_dit.text_embedding(context)
|
||||||
|
audio_context = audio_dit.text_embedding(context)
|
||||||
|
|
||||||
|
# Patchify
|
||||||
|
video_x = video_dit.patch_embedding(video_x)
|
||||||
|
f_v, h, w = video_x.shape[2:]
|
||||||
|
video_x = rearrange(video_x, 'b c f h w -> b (f h w) c').contiguous()
|
||||||
|
seq_len_video = video_x.shape[1]
|
||||||
|
|
||||||
|
audio_x = audio_dit.patch_embedding(audio_x)
|
||||||
|
f_a = audio_x.shape[2]
|
||||||
|
audio_x = rearrange(audio_x, 'b c f -> b f c').contiguous()
|
||||||
|
seq_len_audio = audio_x.shape[1]
|
||||||
|
|
||||||
|
# Freqs
|
||||||
|
video_freqs = torch.cat([
|
||||||
|
video_dit.freqs[0][:f_v].view(f_v, 1, 1, -1).expand(f_v, h, w, -1),
|
||||||
|
video_dit.freqs[1][:h].view(1, h, 1, -1).expand(f_v, h, w, -1),
|
||||||
|
video_dit.freqs[2][:w].view(1, 1, w, -1).expand(f_v, h, w, -1)
|
||||||
|
], dim=-1).reshape(f_v * h * w, 1, -1).to(video_x.device)
|
||||||
|
audio_freqs = torch.cat([
|
||||||
|
audio_dit.freqs[0][:f_a].view(f_a, -1).expand(f_a, -1),
|
||||||
|
audio_dit.freqs[1][:f_a].view(f_a, -1).expand(f_a, -1),
|
||||||
|
audio_dit.freqs[2][:f_a].view(f_a, -1).expand(f_a, -1),
|
||||||
|
], dim=-1).reshape(f_a, 1, -1).to(audio_x.device)
|
||||||
|
|
||||||
|
video_rope, audio_rope = dual_tower_bridge.build_aligned_freqs(
|
||||||
|
video_fps=frame_rate,
|
||||||
|
grid_size=(f_v, h, w),
|
||||||
|
audio_steps=audio_x.shape[1],
|
||||||
|
device=video_x.device,
|
||||||
|
dtype=video_x.dtype,
|
||||||
|
)
|
||||||
|
# usp func
|
||||||
|
if use_unified_sequence_parallel:
|
||||||
|
from ..utils.xfuser import get_current_chunk, gather_all_chunks
|
||||||
|
else:
|
||||||
|
get_current_chunk = lambda x, dim=1: x
|
||||||
|
gather_all_chunks = lambda x, seq_len, dim=1: x
|
||||||
|
# Forward blocks
|
||||||
|
for block_id in range(len(audio_dit.blocks)):
|
||||||
|
if dual_tower_bridge.should_interact(block_id, "a2v"):
|
||||||
|
video_x, audio_x = dual_tower_bridge(
|
||||||
|
block_id,
|
||||||
|
video_x,
|
||||||
|
audio_x,
|
||||||
|
x_freqs=video_rope,
|
||||||
|
y_freqs=audio_rope,
|
||||||
|
condition_scale=1.0,
|
||||||
|
video_grid_size=(f_v, h, w),
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
)
|
||||||
|
video_x = get_current_chunk(video_x, dim=1)
|
||||||
|
video_x = gradient_checkpoint_forward(
|
||||||
|
video_dit.blocks[block_id],
|
||||||
|
use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload,
|
||||||
|
video_x, video_context, video_t_mod, video_freqs
|
||||||
|
)
|
||||||
|
video_x = gather_all_chunks(video_x, seq_len=seq_len_video, dim=1)
|
||||||
|
audio_x = get_current_chunk(audio_x, dim=1)
|
||||||
|
audio_x = gradient_checkpoint_forward(
|
||||||
|
audio_dit.blocks[block_id],
|
||||||
|
use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload,
|
||||||
|
audio_x, audio_context, audio_t_mod, audio_freqs
|
||||||
|
)
|
||||||
|
audio_x = gather_all_chunks(audio_x, seq_len=seq_len_audio, dim=1)
|
||||||
|
|
||||||
|
video_x = get_current_chunk(video_x, dim=1)
|
||||||
|
for block_id in range(len(audio_dit.blocks), len(video_dit.blocks)):
|
||||||
|
video_x = gradient_checkpoint_forward(
|
||||||
|
video_dit.blocks[block_id],
|
||||||
|
use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload,
|
||||||
|
video_x, video_context, video_t_mod, video_freqs
|
||||||
|
)
|
||||||
|
video_x = gather_all_chunks(video_x, seq_len=seq_len_video, dim=1)
|
||||||
|
|
||||||
|
# Head
|
||||||
|
video_x = video_dit.head(video_x, video_t)
|
||||||
|
video_x = video_dit.unpatchify(video_x, (f_v, h, w))
|
||||||
|
|
||||||
|
audio_x = audio_dit.head(audio_x, audio_t)
|
||||||
|
audio_x = audio_dit.unpatchify(audio_x, (f_a,))
|
||||||
|
return video_x, audio_x
|
||||||
@@ -56,6 +56,7 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
QwenImageUnit_BlockwiseControlNet(),
|
QwenImageUnit_BlockwiseControlNet(),
|
||||||
]
|
]
|
||||||
self.model_fn = model_fn_qwen_image
|
self.model_fn = model_fn_qwen_image
|
||||||
|
self.compilable_models = ["dit"]
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -682,14 +683,16 @@ class QwenImageUnit_Image2LoRADecode(PipelineUnit):
|
|||||||
class QwenImageUnit_ContextImageEmbedder(PipelineUnit):
|
class QwenImageUnit_ContextImageEmbedder(PipelineUnit):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_params=("context_image", "height", "width", "tiled", "tile_size", "tile_stride"),
|
input_params=("context_image", "height", "width", "tiled", "tile_size", "tile_stride", "layer_input_image"),
|
||||||
output_params=("context_latents",),
|
output_params=("context_latents",),
|
||||||
onload_model_names=("vae",)
|
onload_model_names=("vae",)
|
||||||
)
|
)
|
||||||
|
|
||||||
def process(self, pipe: QwenImagePipeline, context_image, height, width, tiled, tile_size, tile_stride):
|
def process(self, pipe: QwenImagePipeline, context_image, height, width, tiled, tile_size, tile_stride, layer_input_image=None):
|
||||||
if context_image is None:
|
if context_image is None:
|
||||||
return {}
|
return {}
|
||||||
|
if layer_input_image is not None:
|
||||||
|
context_image = context_image.convert("RGBA")
|
||||||
pipe.load_models_to_device(self.onload_model_names)
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
context_image = pipe.preprocess_image(context_image.resize((width, height))).to(device=pipe.device, dtype=pipe.torch_dtype)
|
context_image = pipe.preprocess_image(context_image.resize((width, height))).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||||
context_latents = pipe.vae.encode(context_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
context_latents = pipe.vae.encode(context_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
|||||||
@@ -75,15 +75,19 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
WanVideoUnit_TeaCache(),
|
WanVideoUnit_TeaCache(),
|
||||||
WanVideoUnit_CfgMerger(),
|
WanVideoUnit_CfgMerger(),
|
||||||
WanVideoUnit_LongCatVideo(),
|
WanVideoUnit_LongCatVideo(),
|
||||||
|
WanVideoUnit_WanToDance_ProcessInputs(),
|
||||||
|
WanVideoUnit_WanToDance_RefImageEmbedder(),
|
||||||
|
WanVideoUnit_WanToDance_ImageKeyframesEmbedder(),
|
||||||
]
|
]
|
||||||
self.post_units = [
|
self.post_units = [
|
||||||
WanVideoPostUnit_S2V(),
|
WanVideoPostUnit_S2V(),
|
||||||
]
|
]
|
||||||
self.model_fn = model_fn_wan_video
|
self.model_fn = model_fn_wan_video
|
||||||
|
self.compilable_models = ["dit", "dit2"]
|
||||||
|
|
||||||
|
|
||||||
def enable_usp(self):
|
def enable_usp(self):
|
||||||
from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward, usp_dit_forward
|
from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward, usp_dit_forward, usp_vace_forward
|
||||||
|
|
||||||
for block in self.dit.blocks:
|
for block in self.dit.blocks:
|
||||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||||
@@ -92,6 +96,14 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
for block in self.dit2.blocks:
|
for block in self.dit2.blocks:
|
||||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||||
self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
|
self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
|
||||||
|
if self.vace is not None:
|
||||||
|
for block in self.vace.vace_blocks:
|
||||||
|
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||||
|
self.vace.forward = types.MethodType(usp_vace_forward, self.vace)
|
||||||
|
if self.vace2 is not None:
|
||||||
|
for block in self.vace2.vace_blocks:
|
||||||
|
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||||
|
self.vace2.forward = types.MethodType(usp_vace_forward, self.vace2)
|
||||||
self.sp_size = get_sequence_parallel_world_size()
|
self.sp_size = get_sequence_parallel_world_size()
|
||||||
self.use_unified_sequence_parallel = True
|
self.use_unified_sequence_parallel = True
|
||||||
|
|
||||||
@@ -244,6 +256,13 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
# Teacache
|
# Teacache
|
||||||
tea_cache_l1_thresh: Optional[float] = None,
|
tea_cache_l1_thresh: Optional[float] = None,
|
||||||
tea_cache_model_id: Optional[str] = "",
|
tea_cache_model_id: Optional[str] = "",
|
||||||
|
# WanToDance
|
||||||
|
wantodance_music_path: Optional[str] = None,
|
||||||
|
wantodance_reference_image: Optional[Image.Image] = None,
|
||||||
|
wantodance_fps: Optional[float] = 30,
|
||||||
|
wantodance_keyframes: Optional[list[Image.Image]] = None,
|
||||||
|
wantodance_keyframes_mask: Optional[list[int]] = None,
|
||||||
|
framewise_decoding: bool = False,
|
||||||
# progress_bar
|
# progress_bar
|
||||||
progress_bar_cmd=tqdm,
|
progress_bar_cmd=tqdm,
|
||||||
output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized",
|
output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized",
|
||||||
@@ -280,6 +299,9 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video,
|
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video,
|
||||||
"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video,
|
"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video,
|
||||||
"vap_video": vap_video,
|
"vap_video": vap_video,
|
||||||
|
"wantodance_music_path": wantodance_music_path, "wantodance_reference_image": wantodance_reference_image, "wantodance_fps": wantodance_fps,
|
||||||
|
"wantodance_keyframes": wantodance_keyframes, "wantodance_keyframes_mask": wantodance_keyframes_mask,
|
||||||
|
"framewise_decoding": framewise_decoding,
|
||||||
}
|
}
|
||||||
for unit in self.units:
|
for unit in self.units:
|
||||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||||
@@ -325,6 +347,9 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||||
# Decode
|
# Decode
|
||||||
self.load_models_to_device(['vae'])
|
self.load_models_to_device(['vae'])
|
||||||
|
if framewise_decoding:
|
||||||
|
video = self.vae.decode_framewise(inputs_shared["latents"], device=self.device)
|
||||||
|
else:
|
||||||
video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
if output_type == "quantized":
|
if output_type == "quantized":
|
||||||
video = self.vae_output_to_video(video)
|
video = self.vae_output_to_video(video)
|
||||||
@@ -371,16 +396,19 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit):
|
|||||||
class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
|
class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"),
|
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image", "framewise_decoding"),
|
||||||
output_params=("latents", "input_latents"),
|
output_params=("latents", "input_latents"),
|
||||||
onload_model_names=("vae",)
|
onload_model_names=("vae",)
|
||||||
)
|
)
|
||||||
|
|
||||||
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image):
|
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image, framewise_decoding):
|
||||||
if input_video is None:
|
if input_video is None:
|
||||||
return {"latents": noise}
|
return {"latents": noise}
|
||||||
pipe.load_models_to_device(self.onload_model_names)
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
input_video = pipe.preprocess_video(input_video)
|
input_video = pipe.preprocess_video(input_video)
|
||||||
|
if framewise_decoding:
|
||||||
|
input_latents = pipe.vae.encode_framewise(input_video, device=pipe.device)
|
||||||
|
else:
|
||||||
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
if vace_reference_image is not None:
|
if vace_reference_image is not None:
|
||||||
if not isinstance(vace_reference_image, list):
|
if not isinstance(vace_reference_image, list):
|
||||||
@@ -1018,6 +1046,111 @@ class WanVideoUnit_LongCatVideo(PipelineUnit):
|
|||||||
return {"longcat_latents": longcat_latents}
|
return {"longcat_latents": longcat_latents}
|
||||||
|
|
||||||
|
|
||||||
|
class WanVideoUnit_WanToDance_ProcessInputs(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
take_over=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_music_base_feature(self, music_path, fps=30):
|
||||||
|
import librosa
|
||||||
|
hop_length = 512
|
||||||
|
sr = fps * hop_length
|
||||||
|
data, sr = librosa.load(music_path, sr=sr)
|
||||||
|
sr = 22050
|
||||||
|
envelope = librosa.onset.onset_strength(y=data, sr=sr)
|
||||||
|
mfcc = librosa.feature.mfcc(y=data, sr=sr, n_mfcc=20).T
|
||||||
|
chroma = librosa.feature.chroma_cens(
|
||||||
|
y=data, sr=sr, hop_length=hop_length, n_chroma=12
|
||||||
|
).T
|
||||||
|
peak_idxs = librosa.onset.onset_detect(
|
||||||
|
onset_envelope=envelope.flatten(), sr=sr, hop_length=hop_length
|
||||||
|
)
|
||||||
|
peak_onehot = np.zeros_like(envelope, dtype=np.float32)
|
||||||
|
peak_onehot[peak_idxs] = 1.0
|
||||||
|
start_bpm = librosa.beat.tempo(y=librosa.load(music_path)[0])[0]
|
||||||
|
_, beat_idxs = librosa.beat.beat_track(
|
||||||
|
onset_envelope=envelope,
|
||||||
|
sr=sr,
|
||||||
|
hop_length=hop_length,
|
||||||
|
start_bpm=start_bpm,
|
||||||
|
tightness=100,
|
||||||
|
)
|
||||||
|
beat_onehot = np.zeros_like(envelope, dtype=np.float32)
|
||||||
|
beat_onehot[beat_idxs] = 1.0
|
||||||
|
audio_feature = np.concatenate(
|
||||||
|
[envelope[:, None], mfcc, chroma, peak_onehot[:, None], beat_onehot[:, None]],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
return torch.from_numpy(audio_feature)
|
||||||
|
|
||||||
|
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||||
|
if pipe.dit.wantodance_enable_global:
|
||||||
|
inputs_nega["skip_9th_layer"] = True
|
||||||
|
if inputs_shared.get("wantodance_music_path", None) is not None:
|
||||||
|
inputs_shared["music_feature"] = self.get_music_base_feature(inputs_shared["wantodance_music_path"]).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
|
|
||||||
|
class WanVideoUnit_WanToDance_RefImageEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("wantodance_reference_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||||
|
output_params=("wantodance_refimage_feature",),
|
||||||
|
onload_model_names=("image_encoder", "vae")
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: WanVideoPipeline, wantodance_reference_image, num_frames, height, width, tiled, tile_size, tile_stride):
|
||||||
|
if wantodance_reference_image is None:
|
||||||
|
return {}
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
if isinstance(wantodance_reference_image, list):
|
||||||
|
wantodance_reference_image = wantodance_reference_image[0]
|
||||||
|
image = pipe.preprocess_image(wantodance_reference_image.resize((width, height))).to(pipe.device) # B,C,H,W;B=1
|
||||||
|
refimage_feature = pipe.image_encoder.encode_image([image])
|
||||||
|
refimage_feature = refimage_feature.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
return {"wantodance_refimage_feature": refimage_feature}
|
||||||
|
|
||||||
|
|
||||||
|
class WanVideoUnit_WanToDance_ImageKeyframesEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("wantodance_keyframes", "wantodance_keyframes_mask", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||||
|
output_params=("clip_feature", "y"),
|
||||||
|
onload_model_names=("image_encoder", "vae")
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: WanVideoPipeline, wantodance_keyframes, wantodance_keyframes_mask, num_frames, height, width, tiled, tile_size, tile_stride):
|
||||||
|
if wantodance_keyframes is None:
|
||||||
|
return {}
|
||||||
|
wantodance_keyframes_mask = torch.tensor(wantodance_keyframes_mask)
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
images = []
|
||||||
|
for input_image in wantodance_keyframes:
|
||||||
|
input_image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
||||||
|
images.append(input_image)
|
||||||
|
|
||||||
|
clip_context = pipe.image_encoder.encode_image(images[:1]) # 取第一帧作为clip输入
|
||||||
|
msk = torch.zeros(1, num_frames, height//8, width//8, device=pipe.device)
|
||||||
|
msk[:, wantodance_keyframes_mask==1, :, :] = torch.ones(1, height//8, width//8, device=pipe.device) # set keyframes mask to 1
|
||||||
|
|
||||||
|
images = [image.transpose(0, 1) for image in images] # 3, num_frames, h, w
|
||||||
|
images = torch.concat(images, dim=1)
|
||||||
|
vae_input = images
|
||||||
|
|
||||||
|
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) # expand first frame mask, N to N + 3
|
||||||
|
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||||
|
msk = msk.transpose(1, 2)[0]
|
||||||
|
|
||||||
|
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||||
|
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
y = torch.concat([msk, y])
|
||||||
|
y = y.unsqueeze(0)
|
||||||
|
clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
return {"clip_feature": clip_context, "y": y}
|
||||||
|
|
||||||
|
|
||||||
class TeaCache:
|
class TeaCache:
|
||||||
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
||||||
self.num_inference_steps = num_inference_steps
|
self.num_inference_steps = num_inference_steps
|
||||||
@@ -1123,6 +1256,22 @@ class TemporalTiler_BCTHW:
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def wantodance_get_single_freqs(freqs, frame_num, fps):
|
||||||
|
total_frame = int(30.0 / (fps + 1e-6) * frame_num + 0.5)
|
||||||
|
interval_frame = 30.0 / (fps + 1e-6)
|
||||||
|
freqs_0 = freqs[:total_frame]
|
||||||
|
freqs_new = torch.zeros((frame_num, freqs_0.shape[1]), device=freqs_0.device, dtype=freqs_0.dtype)
|
||||||
|
freqs_new[0] = freqs_0[0]
|
||||||
|
freqs_new[-1] = freqs_0[total_frame - 1]
|
||||||
|
for i in range(1, frame_num-1):
|
||||||
|
pos = i * interval_frame
|
||||||
|
low_idx = int(pos)
|
||||||
|
high_idx = min(low_idx + 1, total_frame - 1)
|
||||||
|
weight_high = pos - low_idx
|
||||||
|
weight_low = 1.0 - weight_high
|
||||||
|
freqs_new[i] = freqs_0[low_idx] * weight_low + freqs_0[high_idx] * weight_high
|
||||||
|
return freqs_new
|
||||||
|
|
||||||
|
|
||||||
def model_fn_wan_video(
|
def model_fn_wan_video(
|
||||||
dit: WanModel,
|
dit: WanModel,
|
||||||
@@ -1158,6 +1307,10 @@ def model_fn_wan_video(
|
|||||||
use_gradient_checkpointing_offload: bool = False,
|
use_gradient_checkpointing_offload: bool = False,
|
||||||
control_camera_latents_input = None,
|
control_camera_latents_input = None,
|
||||||
fuse_vae_embedding_in_latents: bool = False,
|
fuse_vae_embedding_in_latents: bool = False,
|
||||||
|
wantodance_refimage_feature = None,
|
||||||
|
wantodance_fps: float = 30.0,
|
||||||
|
music_feature = None,
|
||||||
|
skip_9th_layer: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if sliding_window_size is not None and sliding_window_stride is not None:
|
if sliding_window_size is not None and sliding_window_stride is not None:
|
||||||
@@ -1255,6 +1408,9 @@ def model_fn_wan_video(
|
|||||||
context = torch.cat([clip_embdding, context], dim=1)
|
context = torch.cat([clip_embdding, context], dim=1)
|
||||||
|
|
||||||
# Camera control
|
# Camera control
|
||||||
|
if hasattr(dit, "wantodance_enable_global") and dit.wantodance_enable_global and int(wantodance_fps + 0.5) != 30:
|
||||||
|
x = dit.patchify(x, control_camera_latents_input, enable_wantodance_global=True)
|
||||||
|
else:
|
||||||
x = dit.patchify(x, control_camera_latents_input)
|
x = dit.patchify(x, control_camera_latents_input)
|
||||||
|
|
||||||
# Animate
|
# Animate
|
||||||
@@ -1304,12 +1460,59 @@ def model_fn_wan_video(
|
|||||||
else:
|
else:
|
||||||
tea_cache_update = False
|
tea_cache_update = False
|
||||||
|
|
||||||
if vace_context is not None:
|
# WanToDance
|
||||||
vace_hints = vace(
|
if hasattr(dit, "wantodance_enable_global") and dit.wantodance_enable_global:
|
||||||
x, vace_context, context, t_mod, freqs,
|
if wantodance_refimage_feature is not None:
|
||||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
refimage_feature_embedding = dit.img_emb_refimage(wantodance_refimage_feature)
|
||||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload
|
context = torch.cat([refimage_feature_embedding, context], dim=1)
|
||||||
)
|
if (dit.wantodance_enable_dynamicfps or dit.wantodance_enable_unimodel) and int(wantodance_fps + 0.5) != 30:
|
||||||
|
freqs_0 = wantodance_get_single_freqs(dit.freqs[0], f, wantodance_fps)
|
||||||
|
freqs = torch.cat([
|
||||||
|
freqs_0.view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||||
|
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||||
|
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||||
|
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||||
|
if dit.wantodance_enable_global or dit.wantodance_enable_dynamicfps or dit.wantodance_enable_unimodel:
|
||||||
|
if use_unified_sequence_parallel:
|
||||||
|
length = int(float(music_feature.shape[0]) / get_sequence_parallel_world_size()) * get_sequence_parallel_world_size()
|
||||||
|
music_feature = music_feature[:length]
|
||||||
|
music_feature = torch.chunk(music_feature, get_sequence_parallel_world_size(), dim=0)[get_sequence_parallel_rank()]
|
||||||
|
if not dit.training:
|
||||||
|
dit.music_encoder.to(x.device, dtype=x.dtype) # only evaluation
|
||||||
|
music_feature = music_feature.to(x.device, dtype=x.dtype)
|
||||||
|
music_feature = dit.music_projection(music_feature)
|
||||||
|
music_feature = dit.music_encoder(music_feature)
|
||||||
|
if music_feature.dim() == 2:
|
||||||
|
music_feature = music_feature.unsqueeze(0)
|
||||||
|
if use_unified_sequence_parallel:
|
||||||
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
|
music_feature = get_sp_group().all_gather(music_feature, dim=1)
|
||||||
|
music_feature = music_feature.unsqueeze(1) # [1, 1, 149, 4800]
|
||||||
|
N = 149
|
||||||
|
M = 4800
|
||||||
|
music_feature = torch.nn.functional.interpolate(music_feature, size=(N, M), mode='bilinear')
|
||||||
|
music_feature = music_feature.squeeze(1) # shape: [1, 149, 4800]
|
||||||
|
if music_feature is not None:
|
||||||
|
if music_feature.dim() == 2:
|
||||||
|
music_feature = music_feature.unsqueeze(0)
|
||||||
|
music_feature = music_feature.to(x.device, dtype=x.dtype)
|
||||||
|
interp_mode = 'bilinear'
|
||||||
|
if interp_mode == 'bilinear':
|
||||||
|
frame_num = latents.shape[2] if len(latents.shape) == 5 else latents.shape[1] # 21
|
||||||
|
context_shape_end = context.shape[2] ## 14B 5120
|
||||||
|
music_feature = music_feature.unsqueeze(1) # shape: [1, 1, 149, 4800]
|
||||||
|
if use_unified_sequence_parallel:
|
||||||
|
N = int(float(frame_num * 8) / get_sequence_parallel_world_size()) * get_sequence_parallel_world_size()
|
||||||
|
else:
|
||||||
|
N = frame_num * 8
|
||||||
|
music_feature = torch.nn.functional.interpolate(music_feature, size=(N, context_shape_end), mode='bilinear')
|
||||||
|
music_feature = music_feature.squeeze(1) # shape: [1, N, context_shape_end]
|
||||||
|
if use_unified_sequence_parallel:
|
||||||
|
dit.merged_audio_emb = torch.chunk(music_feature, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||||
|
else:
|
||||||
|
dit.merged_audio_emb = music_feature
|
||||||
|
else:
|
||||||
|
dit.merged_audio_emb = music_feature
|
||||||
|
|
||||||
# blocks
|
# blocks
|
||||||
if use_unified_sequence_parallel:
|
if use_unified_sequence_parallel:
|
||||||
@@ -1318,6 +1521,13 @@ def model_fn_wan_video(
|
|||||||
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
|
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
|
||||||
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
|
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
|
||||||
x = chunks[get_sequence_parallel_rank()]
|
x = chunks[get_sequence_parallel_rank()]
|
||||||
|
|
||||||
|
if vace_context is not None:
|
||||||
|
vace_hints = vace(
|
||||||
|
x, vace_context, context, t_mod, freqs,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload
|
||||||
|
)
|
||||||
if tea_cache_update:
|
if tea_cache_update:
|
||||||
x = tea_cache.update(x)
|
x = tea_cache.update(x)
|
||||||
else:
|
else:
|
||||||
@@ -1326,8 +1536,12 @@ def model_fn_wan_video(
|
|||||||
return vap(block, *inputs)
|
return vap(block, *inputs)
|
||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
for block_id, block in enumerate(dit.blocks):
|
|
||||||
# Block
|
# Block
|
||||||
|
for block_id, block in enumerate(dit.blocks):
|
||||||
|
if skip_9th_layer:
|
||||||
|
# This is only used in WanToDance
|
||||||
|
if block_id == 9:
|
||||||
|
continue
|
||||||
if vap is not None and block_id in vap.mot_layers_mapping:
|
if vap is not None and block_id in vap.mot_layers_mapping:
|
||||||
if use_gradient_checkpointing_offload:
|
if use_gradient_checkpointing_offload:
|
||||||
with torch.autograd.graph.save_on_cpu():
|
with torch.autograd.graph.save_on_cpu():
|
||||||
@@ -1356,18 +1570,23 @@ def model_fn_wan_video(
|
|||||||
# VACE
|
# VACE
|
||||||
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
||||||
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
|
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
|
||||||
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
|
||||||
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
|
||||||
current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0)
|
|
||||||
x = x + current_vace_hint * vace_scale
|
x = x + current_vace_hint * vace_scale
|
||||||
|
|
||||||
# Animate
|
# Animate
|
||||||
if pose_latents is not None and face_pixel_values is not None:
|
if pose_latents is not None and face_pixel_values is not None:
|
||||||
x = animate_adapter.after_transformer_block(block_id, x, motion_vec)
|
x = animate_adapter.after_transformer_block(block_id, x, motion_vec)
|
||||||
|
|
||||||
|
# WanToDance
|
||||||
|
if hasattr(dit, "wantodance_enable_music_inject") and dit.wantodance_enable_music_inject:
|
||||||
|
x = dit.wantodance_after_transformer_block(block_id, x)
|
||||||
if tea_cache is not None:
|
if tea_cache is not None:
|
||||||
tea_cache.store(x)
|
tea_cache.store(x)
|
||||||
|
|
||||||
|
if hasattr(dit, "wantodance_enable_unimodel") and dit.wantodance_enable_unimodel and int(wantodance_fps + 0.5) != 30:
|
||||||
|
x = dit.head_global(x, t)
|
||||||
|
else:
|
||||||
x = dit.head(x, t)
|
x = dit.head(x, t)
|
||||||
|
|
||||||
if use_unified_sequence_parallel:
|
if use_unified_sequence_parallel:
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
x = get_sp_group().all_gather(x, dim=1)
|
x = get_sp_group().all_gather(x, dim=1)
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ class ZImagePipeline(BasePipeline):
|
|||||||
ZImageUnit_PAIControlNet(),
|
ZImageUnit_PAIControlNet(),
|
||||||
]
|
]
|
||||||
self.model_fn = model_fn_z_image
|
self.model_fn = model_fn_z_image
|
||||||
|
self.compilable_models = ["dit"]
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -299,7 +300,7 @@ class ZImageUnit_PromptEmbedder(PipelineUnit):
|
|||||||
|
|
||||||
def process(self, pipe: ZImagePipeline, prompt, edit_image):
|
def process(self, pipe: ZImagePipeline, prompt, edit_image):
|
||||||
pipe.load_models_to_device(self.onload_model_names)
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
if hasattr(pipe, "dit") and pipe.dit.siglip_embedder is not None:
|
if hasattr(pipe, "dit") and pipe.dit is not None and pipe.dit.siglip_embedder is not None:
|
||||||
# Z-Image-Turbo and Z-Image-Omni-Base use different prompt encoding methods.
|
# Z-Image-Turbo and Z-Image-Omni-Base use different prompt encoding methods.
|
||||||
# We determine which encoding method to use based on the model architecture.
|
# We determine which encoding method to use based on the model architecture.
|
||||||
# If you are using two-stage split training,
|
# If you are using two-stage split training,
|
||||||
|
|||||||
@@ -116,7 +116,7 @@ class VideoData:
|
|||||||
if self.height is not None and self.width is not None:
|
if self.height is not None and self.width is not None:
|
||||||
return self.height, self.width
|
return self.height, self.width
|
||||||
else:
|
else:
|
||||||
height, width, _ = self.__getitem__(0).shape
|
width, height = self.__getitem__(0).size
|
||||||
return height, width
|
return height, width
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
|
|||||||
109
diffsynth/utils/data/audio.py
Normal file
109
diffsynth/utils/data/audio.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_mono(audio_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Convert audio to mono by averaging channels.
|
||||||
|
Supports [C, T] or [B, C, T]. Output shape: [1, T] or [B, 1, T].
|
||||||
|
"""
|
||||||
|
return audio_tensor.mean(dim=-2, keepdim=True)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_stereo(audio_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Convert audio to stereo.
|
||||||
|
Supports [C, T] or [B, C, T]. Duplicate mono, keep stereo.
|
||||||
|
"""
|
||||||
|
if audio_tensor.size(-2) == 1:
|
||||||
|
return audio_tensor.repeat(1, 2, 1) if audio_tensor.dim() == 3 else audio_tensor.repeat(2, 1)
|
||||||
|
return audio_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def resample_waveform(waveform: torch.Tensor, source_rate: int, target_rate: int) -> torch.Tensor:
|
||||||
|
"""Resample waveform to target sample rate if needed."""
|
||||||
|
if source_rate == target_rate:
|
||||||
|
return waveform
|
||||||
|
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
|
||||||
|
return resampled.to(dtype=waveform.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def read_audio_with_torchcodec(
|
||||||
|
path: str,
|
||||||
|
start_time: float = 0,
|
||||||
|
duration: float | None = None,
|
||||||
|
) -> tuple[torch.Tensor, int]:
|
||||||
|
"""
|
||||||
|
Read audio from file natively using torchcodec, with optional start time and duration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): The file path to the audio file.
|
||||||
|
start_time (float, optional): The start time in seconds to read from. Defaults to 0.
|
||||||
|
duration (float | None, optional): The duration in seconds to read. If None, reads until the end. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[torch.Tensor, int]: A tuple containing the audio tensor and the sample rate.
|
||||||
|
The audio tensor shape is [C, T] where C is the number of channels and T is the number of audio frames.
|
||||||
|
"""
|
||||||
|
from torchcodec.decoders import AudioDecoder
|
||||||
|
decoder = AudioDecoder(path)
|
||||||
|
stop_seconds = None if duration is None else start_time + duration
|
||||||
|
waveform = decoder.get_samples_played_in_range(start_seconds=start_time, stop_seconds=stop_seconds).data
|
||||||
|
return waveform, decoder.metadata.sample_rate
|
||||||
|
|
||||||
|
|
||||||
|
def read_audio(
|
||||||
|
path: str,
|
||||||
|
start_time: float = 0,
|
||||||
|
duration: float | None = None,
|
||||||
|
resample: bool = False,
|
||||||
|
resample_rate: int = 48000,
|
||||||
|
backend: str = "torchcodec",
|
||||||
|
) -> tuple[torch.Tensor, int]:
|
||||||
|
"""
|
||||||
|
Read audio from file, with optional start time, duration, and resampling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): The file path to the audio file.
|
||||||
|
start_time (float, optional): The start time in seconds to read from. Defaults to 0.
|
||||||
|
duration (float | None, optional): The duration in seconds to read. If None, reads until the end. Defaults to None.
|
||||||
|
resample (bool, optional): Whether to resample the audio to a different sample rate. Defaults to False.
|
||||||
|
resample_rate (int, optional): The target sample rate for resampling if resample is True. Defaults to 48000.
|
||||||
|
backend (str, optional): The audio backend to use for reading. Defaults to "torchcodec".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[torch.Tensor, int]: A tuple containing the audio tensor and the sample rate.
|
||||||
|
The audio tensor shape is [C, T] where C is the number of channels and T is the number of audio frames.
|
||||||
|
"""
|
||||||
|
if backend == "torchcodec":
|
||||||
|
waveform, sample_rate = read_audio_with_torchcodec(path, start_time, duration)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported audio backend: {backend}")
|
||||||
|
|
||||||
|
if resample:
|
||||||
|
waveform = resample_waveform(waveform, sample_rate, resample_rate)
|
||||||
|
sample_rate = resample_rate
|
||||||
|
|
||||||
|
return waveform, sample_rate
|
||||||
|
|
||||||
|
|
||||||
|
def save_audio(waveform: torch.Tensor, sample_rate: int, save_path: str, backend: str = "torchcodec"):
|
||||||
|
"""
|
||||||
|
Save audio tensor to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
waveform (torch.Tensor): The audio tensor to save. Shape can be [C, T] or [B, C, T].
|
||||||
|
sample_rate (int): The sample rate of the audio.
|
||||||
|
save_path (str): The file path to save the audio to.
|
||||||
|
backend (str, optional): The audio backend to use for saving. Defaults to "torchcodec".
|
||||||
|
"""
|
||||||
|
if waveform.dim() == 3:
|
||||||
|
waveform = waveform[0]
|
||||||
|
waveform.cpu()
|
||||||
|
|
||||||
|
if backend == "torchcodec":
|
||||||
|
from torchcodec.encoders import AudioEncoder
|
||||||
|
encoder = AudioEncoder(waveform, sample_rate=sample_rate)
|
||||||
|
encoder.to_file(dest=save_path)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported audio backend: {backend}")
|
||||||
134
diffsynth/utils/data/audio_video.py
Normal file
134
diffsynth/utils/data/audio_video.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
import av
|
||||||
|
from fractions import Fraction
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
from .audio import convert_to_stereo
|
||||||
|
|
||||||
|
|
||||||
|
def _resample_audio(
|
||||||
|
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
|
||||||
|
) -> None:
|
||||||
|
cc = audio_stream.codec_context
|
||||||
|
|
||||||
|
# Use the encoder's format/layout/rate as the *target*
|
||||||
|
target_format = cc.format or "fltp" # AAC → usually fltp
|
||||||
|
target_layout = cc.layout or "stereo"
|
||||||
|
target_rate = cc.sample_rate or frame_in.sample_rate
|
||||||
|
|
||||||
|
audio_resampler = av.audio.resampler.AudioResampler(
|
||||||
|
format=target_format,
|
||||||
|
layout=target_layout,
|
||||||
|
rate=target_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_next_pts = 0
|
||||||
|
for rframe in audio_resampler.resample(frame_in):
|
||||||
|
if rframe.pts is None:
|
||||||
|
rframe.pts = audio_next_pts
|
||||||
|
audio_next_pts += rframe.samples
|
||||||
|
rframe.sample_rate = frame_in.sample_rate
|
||||||
|
container.mux(audio_stream.encode(rframe))
|
||||||
|
|
||||||
|
# flush audio encoder
|
||||||
|
for packet in audio_stream.encode():
|
||||||
|
container.mux(packet)
|
||||||
|
|
||||||
|
|
||||||
|
def _write_audio(
|
||||||
|
container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int
|
||||||
|
) -> None:
|
||||||
|
if samples.ndim == 1:
|
||||||
|
samples = samples.unsqueeze(0)
|
||||||
|
samples = convert_to_stereo(samples)
|
||||||
|
assert samples.ndim == 2 and samples.shape[0] == 2, "audio samples must be [C, S] or [S], C must be 1 or 2"
|
||||||
|
samples = samples.T
|
||||||
|
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
|
||||||
|
if samples.dtype != torch.int16:
|
||||||
|
samples = torch.clip(samples, -1.0, 1.0)
|
||||||
|
samples = (samples * 32767.0).to(torch.int16)
|
||||||
|
|
||||||
|
frame_in = av.AudioFrame.from_ndarray(
|
||||||
|
samples.contiguous().reshape(1, -1).cpu().numpy(),
|
||||||
|
format="s16",
|
||||||
|
layout="stereo",
|
||||||
|
)
|
||||||
|
frame_in.sample_rate = audio_sample_rate
|
||||||
|
|
||||||
|
_resample_audio(container, audio_stream, frame_in)
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
|
||||||
|
"""
|
||||||
|
Prepare the audio stream for writing.
|
||||||
|
"""
|
||||||
|
audio_stream = container.add_stream("aac")
|
||||||
|
supported_sample_rates = audio_stream.codec_context.codec.audio_rates
|
||||||
|
if supported_sample_rates:
|
||||||
|
best_rate = min(supported_sample_rates, key=lambda x: abs(x - audio_sample_rate))
|
||||||
|
if best_rate != audio_sample_rate:
|
||||||
|
print(f"Using closest supported audio sample rate: {best_rate}")
|
||||||
|
else:
|
||||||
|
best_rate = audio_sample_rate
|
||||||
|
audio_stream.codec_context.sample_rate = best_rate
|
||||||
|
audio_stream.codec_context.layout = "stereo"
|
||||||
|
audio_stream.codec_context.time_base = Fraction(1, best_rate)
|
||||||
|
return audio_stream
|
||||||
|
|
||||||
|
|
||||||
|
def write_video_audio(
|
||||||
|
video: list[Image.Image],
|
||||||
|
audio: torch.Tensor | None,
|
||||||
|
output_path: str,
|
||||||
|
fps: int = 24,
|
||||||
|
audio_sample_rate: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Writes a sequence of images and an audio tensor to a video file.
|
||||||
|
|
||||||
|
This function utilizes PyAV (or a similar multimedia library) to encode a list of PIL images into a video stream
|
||||||
|
and multiplex a PyTorch tensor as the audio stream into the output container.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video (list[Image.Image]): A list of PIL Image objects representing the video frames.
|
||||||
|
The length of this list determines the total duration of the video based on the FPS.
|
||||||
|
audio (torch.Tensor | None): The audio data as a PyTorch tensor.
|
||||||
|
The shape is typically (channels, samples). If no audio is required, pass None.
|
||||||
|
channels can be 1 or 2. 1 for mono, 2 for stereo.
|
||||||
|
output_path (str): The file path (including extension) where the output video will be saved.
|
||||||
|
fps (int, optional): The frame rate (frames per second) for the video. Defaults to 24.
|
||||||
|
audio_sample_rate (int | None, optional): The sample rate (e.g., 44100, 48000) for the audio.
|
||||||
|
If the audio tensor is provided and this is None, the function attempts to infer the rate
|
||||||
|
based on the audio tensor's length and the video duration.
|
||||||
|
Raises:
|
||||||
|
ValueError: If an audio tensor is provided but the sample rate cannot be determined.
|
||||||
|
"""
|
||||||
|
duration = len(video) / fps
|
||||||
|
if audio_sample_rate is None:
|
||||||
|
audio_sample_rate = int(audio.shape[-1] / duration)
|
||||||
|
|
||||||
|
width, height = video[0].size
|
||||||
|
container = av.open(output_path, mode="w")
|
||||||
|
stream = container.add_stream("libx264", rate=int(fps))
|
||||||
|
stream.width = width
|
||||||
|
stream.height = height
|
||||||
|
stream.pix_fmt = "yuv420p"
|
||||||
|
|
||||||
|
if audio is not None:
|
||||||
|
if audio_sample_rate is None:
|
||||||
|
raise ValueError("audio_sample_rate is required when audio is provided")
|
||||||
|
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
|
||||||
|
|
||||||
|
for frame in tqdm(video, total=len(video)):
|
||||||
|
frame = av.VideoFrame.from_image(frame)
|
||||||
|
for packet in stream.encode(frame):
|
||||||
|
container.mux(packet)
|
||||||
|
|
||||||
|
# Flush encoder
|
||||||
|
for packet in stream.encode():
|
||||||
|
container.mux(packet)
|
||||||
|
|
||||||
|
if audio is not None:
|
||||||
|
_write_audio(container, audio_stream, audio, audio_sample_rate)
|
||||||
|
|
||||||
|
container.close()
|
||||||
@@ -1,113 +1,7 @@
|
|||||||
|
|
||||||
from fractions import Fraction
|
|
||||||
import torch
|
|
||||||
import av
|
import av
|
||||||
from tqdm import tqdm
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from collections.abc import Generator, Iterator
|
from .audio_video import write_video_audio as write_video_audio_ltx2
|
||||||
|
|
||||||
|
|
||||||
def _resample_audio(
|
|
||||||
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
|
|
||||||
) -> None:
|
|
||||||
cc = audio_stream.codec_context
|
|
||||||
|
|
||||||
# Use the encoder's format/layout/rate as the *target*
|
|
||||||
target_format = cc.format or "fltp" # AAC → usually fltp
|
|
||||||
target_layout = cc.layout or "stereo"
|
|
||||||
target_rate = cc.sample_rate or frame_in.sample_rate
|
|
||||||
|
|
||||||
audio_resampler = av.audio.resampler.AudioResampler(
|
|
||||||
format=target_format,
|
|
||||||
layout=target_layout,
|
|
||||||
rate=target_rate,
|
|
||||||
)
|
|
||||||
|
|
||||||
audio_next_pts = 0
|
|
||||||
for rframe in audio_resampler.resample(frame_in):
|
|
||||||
if rframe.pts is None:
|
|
||||||
rframe.pts = audio_next_pts
|
|
||||||
audio_next_pts += rframe.samples
|
|
||||||
rframe.sample_rate = frame_in.sample_rate
|
|
||||||
container.mux(audio_stream.encode(rframe))
|
|
||||||
|
|
||||||
# flush audio encoder
|
|
||||||
for packet in audio_stream.encode():
|
|
||||||
container.mux(packet)
|
|
||||||
|
|
||||||
|
|
||||||
def _write_audio(
|
|
||||||
container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int
|
|
||||||
) -> None:
|
|
||||||
if samples.ndim == 1:
|
|
||||||
samples = samples[:, None]
|
|
||||||
|
|
||||||
if samples.shape[1] != 2 and samples.shape[0] == 2:
|
|
||||||
samples = samples.T
|
|
||||||
|
|
||||||
if samples.shape[1] != 2:
|
|
||||||
raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
|
|
||||||
|
|
||||||
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
|
|
||||||
if samples.dtype != torch.int16:
|
|
||||||
samples = torch.clip(samples, -1.0, 1.0)
|
|
||||||
samples = (samples * 32767.0).to(torch.int16)
|
|
||||||
|
|
||||||
frame_in = av.AudioFrame.from_ndarray(
|
|
||||||
samples.contiguous().reshape(1, -1).cpu().numpy(),
|
|
||||||
format="s16",
|
|
||||||
layout="stereo",
|
|
||||||
)
|
|
||||||
frame_in.sample_rate = audio_sample_rate
|
|
||||||
|
|
||||||
_resample_audio(container, audio_stream, frame_in)
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
|
|
||||||
"""
|
|
||||||
Prepare the audio stream for writing.
|
|
||||||
"""
|
|
||||||
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
|
|
||||||
audio_stream.codec_context.sample_rate = audio_sample_rate
|
|
||||||
audio_stream.codec_context.layout = "stereo"
|
|
||||||
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
|
|
||||||
return audio_stream
|
|
||||||
|
|
||||||
def write_video_audio_ltx2(
|
|
||||||
video: list[Image.Image],
|
|
||||||
audio: torch.Tensor | None,
|
|
||||||
output_path: str,
|
|
||||||
fps: int = 24,
|
|
||||||
audio_sample_rate: int | None = 24000,
|
|
||||||
) -> None:
|
|
||||||
|
|
||||||
width, height = video[0].size
|
|
||||||
container = av.open(output_path, mode="w")
|
|
||||||
stream = container.add_stream("libx264", rate=int(fps))
|
|
||||||
stream.width = width
|
|
||||||
stream.height = height
|
|
||||||
stream.pix_fmt = "yuv420p"
|
|
||||||
|
|
||||||
if audio is not None:
|
|
||||||
if audio_sample_rate is None:
|
|
||||||
raise ValueError("audio_sample_rate is required when audio is provided")
|
|
||||||
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
|
|
||||||
|
|
||||||
for frame in tqdm(video, total=len(video)):
|
|
||||||
frame = av.VideoFrame.from_image(frame)
|
|
||||||
for packet in stream.encode(frame):
|
|
||||||
container.mux(packet)
|
|
||||||
|
|
||||||
# Flush encoder
|
|
||||||
for packet in stream.encode():
|
|
||||||
container.mux(packet)
|
|
||||||
|
|
||||||
if audio is not None:
|
|
||||||
_write_audio(container, audio_stream, audio, audio_sample_rate)
|
|
||||||
|
|
||||||
container.close()
|
|
||||||
|
|
||||||
|
|
||||||
def encode_single_frame(output_file: str, image_array: np.ndarray, crf: float) -> None:
|
def encode_single_frame(output_file: str, image_array: np.ndarray, crf: float) -> None:
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import torch
|
import torch, warnings
|
||||||
|
|
||||||
|
|
||||||
class GeneralLoRALoader:
|
class GeneralLoRALoader:
|
||||||
@@ -26,7 +26,11 @@ class GeneralLoRALoader:
|
|||||||
keys.pop(0)
|
keys.pop(0)
|
||||||
keys.pop(-1)
|
keys.pop(-1)
|
||||||
target_name = ".".join(keys)
|
target_name = ".".join(keys)
|
||||||
lora_name_dict[target_name] = (key, key.replace(lora_B_key, lora_A_key))
|
# Alpha: Deprecated but retained for compatibility.
|
||||||
|
key_alpha = key.replace(lora_B_key + ".weight", "alpha").replace(lora_B_key + ".default.weight", "alpha")
|
||||||
|
if key_alpha == key or key_alpha not in lora_state_dict:
|
||||||
|
key_alpha = None
|
||||||
|
lora_name_dict[target_name] = (key, key.replace(lora_B_key, lora_A_key), key_alpha)
|
||||||
return lora_name_dict
|
return lora_name_dict
|
||||||
|
|
||||||
|
|
||||||
@@ -36,6 +40,10 @@ class GeneralLoRALoader:
|
|||||||
for name in name_dict:
|
for name in name_dict:
|
||||||
weight_up = state_dict[name_dict[name][0]]
|
weight_up = state_dict[name_dict[name][0]]
|
||||||
weight_down = state_dict[name_dict[name][1]]
|
weight_down = state_dict[name_dict[name][1]]
|
||||||
|
if name_dict[name][2] is not None:
|
||||||
|
warnings.warn("Alpha detected in the LoRA file. This may be a LoRA model not trained by DiffSynth-Studio. To ensure compatibility, the LoRA weights will be converted to weight * alpha / rank.")
|
||||||
|
alpha = state_dict[name_dict[name][2]] / weight_down.shape[0]
|
||||||
|
weight_down = weight_down * alpha
|
||||||
state_dict_[name + f".lora_B{suffix}"] = weight_up
|
state_dict_[name + f".lora_B{suffix}"] = weight_up
|
||||||
state_dict_[name + f".lora_A{suffix}"] = weight_down
|
state_dict_[name + f".lora_A{suffix}"] = weight_down
|
||||||
return state_dict_
|
return state_dict_
|
||||||
|
|||||||
1
diffsynth/utils/ses/README.md
Normal file
1
diffsynth/utils/ses/README.md
Normal file
@@ -0,0 +1 @@
|
|||||||
|
Please see `docs/en/Research_Tutorial/inference_time_scaling.md` or `docs/zh/Research_Tutorial/inference_time_scaling.md` for more details.
|
||||||
1
diffsynth/utils/ses/__init__.py
Normal file
1
diffsynth/utils/ses/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .ses import ses_search
|
||||||
117
diffsynth/utils/ses/ses.py
Normal file
117
diffsynth/utils/ses/ses.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
import torch
|
||||||
|
import pywt
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def split_dwt(z_tensor_cpu, wavelet_name, dwt_level):
|
||||||
|
all_clow_np = []
|
||||||
|
all_chigh_list = []
|
||||||
|
z_tensor_cpu = z_tensor_cpu.cpu().float()
|
||||||
|
|
||||||
|
for i in range(z_tensor_cpu.shape[0]):
|
||||||
|
z_numpy_ch = z_tensor_cpu[i].numpy()
|
||||||
|
|
||||||
|
coeffs_ch = pywt.wavedec2(z_numpy_ch, wavelet_name, level=dwt_level, mode='symmetric', axes=(-2, -1))
|
||||||
|
|
||||||
|
clow_np = coeffs_ch[0]
|
||||||
|
chigh_list = coeffs_ch[1:]
|
||||||
|
|
||||||
|
all_clow_np.append(clow_np)
|
||||||
|
all_chigh_list.append(chigh_list)
|
||||||
|
|
||||||
|
all_clow_tensor = torch.from_numpy(np.stack(all_clow_np, axis=0))
|
||||||
|
return all_clow_tensor, all_chigh_list
|
||||||
|
|
||||||
|
|
||||||
|
def reconstruct_dwt(c_low_tensor_cpu, c_high_coeffs, wavelet_name, original_shape):
|
||||||
|
H_high, W_high = original_shape
|
||||||
|
c_low_tensor_cpu = c_low_tensor_cpu.cpu().float()
|
||||||
|
|
||||||
|
clow_np = c_low_tensor_cpu.numpy()
|
||||||
|
|
||||||
|
if clow_np.ndim == 4 and clow_np.shape[0] == 1:
|
||||||
|
clow_np = clow_np[0]
|
||||||
|
|
||||||
|
coeffs_combined = [clow_np] + c_high_coeffs
|
||||||
|
z_recon_np = pywt.waverec2(coeffs_combined, wavelet_name, mode='symmetric', axes=(-2, -1))
|
||||||
|
if z_recon_np.shape[-2] != H_high or z_recon_np.shape[-1] != W_high:
|
||||||
|
z_recon_np = z_recon_np[..., :H_high, :W_high]
|
||||||
|
z_recon_tensor = torch.from_numpy(z_recon_np)
|
||||||
|
if z_recon_tensor.ndim == 3:
|
||||||
|
z_recon_tensor = z_recon_tensor.unsqueeze(0)
|
||||||
|
return z_recon_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def ses_search(
|
||||||
|
base_latents,
|
||||||
|
objective_reward_fn,
|
||||||
|
total_eval_budget=30,
|
||||||
|
popsize=10,
|
||||||
|
k_elites=5,
|
||||||
|
wavelet_name="db1",
|
||||||
|
dwt_level=4,
|
||||||
|
):
|
||||||
|
latent_h, latent_w = base_latents.shape[-2], base_latents.shape[-1]
|
||||||
|
c_low_init, c_high_fixed_batch = split_dwt(base_latents, wavelet_name, dwt_level)
|
||||||
|
c_high_fixed = c_high_fixed_batch[0]
|
||||||
|
c_low_shape = c_low_init.shape[1:]
|
||||||
|
mu = torch.zeros_like(c_low_init.view(-1).cpu())
|
||||||
|
sigma_sq = torch.ones_like(mu) * 1.0
|
||||||
|
|
||||||
|
best_overall = {"fitness": -float('inf'), "score": -float('inf'), "c_low": c_low_init[0]}
|
||||||
|
eval_count = 0
|
||||||
|
|
||||||
|
elite_db = []
|
||||||
|
n_generations = (total_eval_budget // popsize) + 5
|
||||||
|
pbar = tqdm(total=total_eval_budget, desc="[SES] Searching", unit="img")
|
||||||
|
|
||||||
|
for gen in range(n_generations):
|
||||||
|
if eval_count >= total_eval_budget: break
|
||||||
|
|
||||||
|
std = torch.sqrt(torch.clamp(sigma_sq, min=1e-9))
|
||||||
|
z_noise = torch.randn(popsize, mu.shape[0])
|
||||||
|
samples_flat = mu + z_noise * std
|
||||||
|
samples_reshaped = samples_flat.view(popsize, *c_low_shape)
|
||||||
|
|
||||||
|
batch_results = []
|
||||||
|
|
||||||
|
for i in range(popsize):
|
||||||
|
if eval_count >= total_eval_budget: break
|
||||||
|
|
||||||
|
c_low_sample = samples_reshaped[i].unsqueeze(0)
|
||||||
|
z_recon = reconstruct_dwt(c_low_sample, c_high_fixed, wavelet_name, (latent_h, latent_w))
|
||||||
|
z_recon = z_recon.to(base_latents.device, dtype=base_latents.dtype)
|
||||||
|
# img = pipeline_callback(z_recon)
|
||||||
|
|
||||||
|
# score = scorer.get_score(img, prompt)
|
||||||
|
score = objective_reward_fn(z_recon)
|
||||||
|
res = {
|
||||||
|
"score": score,
|
||||||
|
"c_low": c_low_sample.cpu()
|
||||||
|
}
|
||||||
|
batch_results.append(res)
|
||||||
|
if score > best_overall['score']:
|
||||||
|
best_overall = res
|
||||||
|
|
||||||
|
eval_count += 1
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
if not batch_results: break
|
||||||
|
elite_db.extend(batch_results)
|
||||||
|
elite_db.sort(key=lambda x: x['score'], reverse=True)
|
||||||
|
elite_db = elite_db[:k_elites]
|
||||||
|
elites_flat = torch.stack([x['c_low'].view(-1) for x in elite_db])
|
||||||
|
mu_new = torch.mean(elites_flat, dim=0)
|
||||||
|
|
||||||
|
if len(elite_db) > 1:
|
||||||
|
sigma_sq_new = torch.var(elites_flat, dim=0, unbiased=True) + 1e-7
|
||||||
|
else:
|
||||||
|
sigma_sq_new = sigma_sq
|
||||||
|
mu = mu_new
|
||||||
|
sigma_sq = sigma_sq_new
|
||||||
|
pbar.close()
|
||||||
|
best_c_low = best_overall['c_low']
|
||||||
|
final_latents = reconstruct_dwt(best_c_low, c_high_fixed, wavelet_name, (latent_h, latent_w))
|
||||||
|
|
||||||
|
return final_latents.to(base_latents.device, dtype=base_latents.dtype)
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
def AceStepConditionEncoderStateDictConverter(state_dict):
|
||||||
|
new_state_dict = {}
|
||||||
|
prefix = "encoder."
|
||||||
|
|
||||||
|
for key in state_dict:
|
||||||
|
if key.startswith(prefix):
|
||||||
|
new_key = key[len(prefix):]
|
||||||
|
new_state_dict[new_key] = state_dict[key]
|
||||||
|
|
||||||
|
if "null_condition_emb" in state_dict:
|
||||||
|
new_state_dict["null_condition_emb"] = state_dict["null_condition_emb"]
|
||||||
|
|
||||||
|
return new_state_dict
|
||||||
10
diffsynth/utils/state_dict_converters/ace_step_dit.py
Normal file
10
diffsynth/utils/state_dict_converters/ace_step_dit.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
def AceStepDiTModelStateDictConverter(state_dict):
|
||||||
|
new_state_dict = {}
|
||||||
|
prefix = "decoder."
|
||||||
|
|
||||||
|
for key in state_dict:
|
||||||
|
if key.startswith(prefix):
|
||||||
|
new_key = key[len(prefix):]
|
||||||
|
new_state_dict[new_key] = state_dict[key]
|
||||||
|
|
||||||
|
return new_state_dict
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
def AceStepTextEncoderStateDictConverter(state_dict):
|
||||||
|
new_state_dict = {}
|
||||||
|
prefix = "model."
|
||||||
|
nested_prefix = "model.model."
|
||||||
|
|
||||||
|
for key in state_dict:
|
||||||
|
if key.startswith(nested_prefix):
|
||||||
|
new_key = key
|
||||||
|
elif key.startswith(prefix):
|
||||||
|
new_key = "model." + key
|
||||||
|
else:
|
||||||
|
new_key = "model." + key
|
||||||
|
new_state_dict[new_key] = state_dict[key]
|
||||||
|
|
||||||
|
return new_state_dict
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
def AceStepTokenizerStateDictConverter(state_dict):
|
||||||
|
new_state_dict = {}
|
||||||
|
|
||||||
|
for key in state_dict:
|
||||||
|
if key.startswith("tokenizer.") or key.startswith("detokenizer."):
|
||||||
|
new_state_dict[key] = state_dict[key]
|
||||||
|
|
||||||
|
return new_state_dict
|
||||||
6
diffsynth/utils/state_dict_converters/anima_dit.py
Normal file
6
diffsynth/utils/state_dict_converters/anima_dit.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
def AnimaDiTStateDictConverter(state_dict):
|
||||||
|
new_state_dict = {}
|
||||||
|
for key in state_dict:
|
||||||
|
value = state_dict[key]
|
||||||
|
new_state_dict[key.replace("net.", "")] = value
|
||||||
|
return new_state_dict
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
def ErnieImageTextEncoderStateDictConverter(state_dict):
|
||||||
|
"""
|
||||||
|
Maps checkpoint keys from multimodal Mistral3Model format
|
||||||
|
to text-only Ministral3Model format.
|
||||||
|
|
||||||
|
Checkpoint keys (Mistral3Model):
|
||||||
|
language_model.model.layers.0.input_layernorm.weight
|
||||||
|
language_model.model.norm.weight
|
||||||
|
|
||||||
|
Model keys (ErnieImageTextEncoder → self.model = Ministral3Model):
|
||||||
|
model.layers.0.input_layernorm.weight
|
||||||
|
model.norm.weight
|
||||||
|
|
||||||
|
Mapping: language_model. → model.
|
||||||
|
"""
|
||||||
|
new_state_dict = {}
|
||||||
|
for key in state_dict:
|
||||||
|
if key.startswith("language_model.model."):
|
||||||
|
new_key = key.replace("language_model.model.", "model.", 1)
|
||||||
|
new_state_dict[new_key] = state_dict[key]
|
||||||
|
return new_state_dict
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
def JoyAIImageTextEncoderStateDictConverter(state_dict):
|
||||||
|
"""Convert HuggingFace Qwen3VL checkpoint keys to DiffSynth wrapper keys.
|
||||||
|
|
||||||
|
Mapping (checkpoint -> wrapper):
|
||||||
|
- lm_head.weight -> model.lm_head.weight
|
||||||
|
- model.language_model.* -> model.model.language_model.*
|
||||||
|
- model.visual.* -> model.model.visual.*
|
||||||
|
"""
|
||||||
|
state_dict_ = {}
|
||||||
|
for key in state_dict:
|
||||||
|
if key == "lm_head.weight":
|
||||||
|
new_key = "model.lm_head.weight"
|
||||||
|
elif key.startswith("model.language_model."):
|
||||||
|
new_key = "model.model." + key[len("model."):]
|
||||||
|
elif key.startswith("model.visual."):
|
||||||
|
new_key = "model.model." + key[len("model."):]
|
||||||
|
else:
|
||||||
|
new_key = key
|
||||||
|
state_dict_[new_key] = state_dict[key]
|
||||||
|
return state_dict_
|
||||||
@@ -27,6 +27,6 @@ def LTX2VocoderStateDictConverter(state_dict):
|
|||||||
state_dict_ = {}
|
state_dict_ = {}
|
||||||
for name in state_dict:
|
for name in state_dict:
|
||||||
if name.startswith("vocoder."):
|
if name.startswith("vocoder."):
|
||||||
new_name = name.replace("vocoder.", "")
|
new_name = name[len("vocoder."):]
|
||||||
state_dict_[new_name] = state_dict[name]
|
state_dict_[new_name] = state_dict[name]
|
||||||
return state_dict_
|
return state_dict_
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ def LTX2VideoEncoderStateDictConverter(state_dict):
|
|||||||
state_dict_[new_name] = state_dict[name]
|
state_dict_[new_name] = state_dict[name]
|
||||||
elif name.startswith("vae.per_channel_statistics."):
|
elif name.startswith("vae.per_channel_statistics."):
|
||||||
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
|
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
|
||||||
|
if new_name not in ["per_channel_statistics.channel", "per_channel_statistics.mean-of-stds", "per_channel_statistics.mean-of-stds_over_std-of-means"]:
|
||||||
state_dict_[new_name] = state_dict[name]
|
state_dict_[new_name] = state_dict[name]
|
||||||
return state_dict_
|
return state_dict_
|
||||||
|
|
||||||
@@ -18,5 +19,6 @@ def LTX2VideoDecoderStateDictConverter(state_dict):
|
|||||||
state_dict_[new_name] = state_dict[name]
|
state_dict_[new_name] = state_dict[name]
|
||||||
elif name.startswith("vae.per_channel_statistics."):
|
elif name.startswith("vae.per_channel_statistics."):
|
||||||
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
|
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
|
||||||
|
if new_name not in ["per_channel_statistics.channel", "per_channel_statistics.mean-of-stds", "per_channel_statistics.mean-of-stds_over_std-of-means"]:
|
||||||
state_dict_[new_name] = state_dict[name]
|
state_dict_[new_name] = state_dict[name]
|
||||||
return state_dict_
|
return state_dict_
|
||||||
3
diffsynth/utils/state_dict_converters/z_image_dit.py
Normal file
3
diffsynth/utils/state_dict_converters/z_image_dit.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
def ZImageDiTStateDictConverter(state_dict):
|
||||||
|
state_dict_ = {name.replace("model.diffusion_model.", ""): state_dict[name] for name in state_dict}
|
||||||
|
return state_dict_
|
||||||
@@ -1 +1 @@
|
|||||||
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, get_sequence_parallel_world_size, initialize_usp
|
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, usp_vace_forward, get_sequence_parallel_world_size, initialize_usp, get_current_chunk, gather_all_chunks
|
||||||
|
|||||||
@@ -117,6 +117,39 @@ def usp_dit_forward(self,
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def usp_vace_forward(
|
||||||
|
self, x, vace_context, context, t_mod, freqs,
|
||||||
|
use_gradient_checkpointing: bool = False,
|
||||||
|
use_gradient_checkpointing_offload: bool = False,
|
||||||
|
):
|
||||||
|
# Compute full sequence length from the sharded x
|
||||||
|
full_seq_len = x.shape[1] * get_sequence_parallel_world_size()
|
||||||
|
|
||||||
|
# Embed vace_context via patch embedding
|
||||||
|
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||||
|
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||||
|
c = torch.cat([
|
||||||
|
torch.cat([u, u.new_zeros(1, full_seq_len - u.size(1), u.size(2))],
|
||||||
|
dim=1) for u in c
|
||||||
|
])
|
||||||
|
|
||||||
|
# Chunk VACE context along sequence dim BEFORE processing through blocks
|
||||||
|
c = torch.chunk(c, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||||
|
|
||||||
|
# Process through vace_blocks (self_attn already monkey-patched to usp_attn_forward)
|
||||||
|
for block in self.vace_blocks:
|
||||||
|
c = gradient_checkpoint_forward(
|
||||||
|
block,
|
||||||
|
use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload,
|
||||||
|
c, x, context, t_mod, freqs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Hints are already sharded per-rank
|
||||||
|
hints = torch.unbind(c)[:-1]
|
||||||
|
return hints
|
||||||
|
|
||||||
|
|
||||||
def usp_attn_forward(self, x, freqs):
|
def usp_attn_forward(self, x, freqs):
|
||||||
q = self.norm_q(self.q(x))
|
q = self.norm_q(self.q(x))
|
||||||
k = self.norm_k(self.k(x))
|
k = self.norm_k(self.k(x))
|
||||||
@@ -144,3 +177,30 @@ def usp_attn_forward(self, x, freqs):
|
|||||||
del q, k, v
|
del q, k, v
|
||||||
getattr(torch, parse_device_type(x.device)).empty_cache()
|
getattr(torch, parse_device_type(x.device)).empty_cache()
|
||||||
return self.o(x)
|
return self.o(x)
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_chunk(x, dim=1):
|
||||||
|
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=dim)
|
||||||
|
ndims = len(chunks[0].shape)
|
||||||
|
pad_list = [0] * (2 * ndims)
|
||||||
|
pad_end_index = 2 * (ndims - 1 - dim) + 1
|
||||||
|
max_size = chunks[0].size(dim)
|
||||||
|
chunks = [
|
||||||
|
torch.nn.functional.pad(
|
||||||
|
chunk,
|
||||||
|
tuple(pad_list[:pad_end_index] + [max_size - chunk.size(dim)] + pad_list[pad_end_index+1:]),
|
||||||
|
value=0
|
||||||
|
)
|
||||||
|
for chunk in chunks
|
||||||
|
]
|
||||||
|
x = chunks[get_sequence_parallel_rank()]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def gather_all_chunks(x, seq_len=None, dim=1):
|
||||||
|
x = get_sp_group().all_gather(x, dim=dim)
|
||||||
|
if seq_len is not None:
|
||||||
|
slices = [slice(None)] * x.ndim
|
||||||
|
slices[dim] = slice(0, seq_len)
|
||||||
|
x = x[tuple(slices)]
|
||||||
|
return x
|
||||||
|
|||||||
164
docs/en/Model_Details/ACE-Step.md
Normal file
164
docs/en/Model_Details/ACE-Step.md
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
# ACE-Step
|
||||||
|
|
||||||
|
ACE-Step 1.5 is an open-source music generation model based on DiT architecture, supporting text-to-music, audio cover, repainting and other functionalities, running efficiently on consumer-grade hardware.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Before performing model inference and training, please install DiffSynth-Studio first.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md).
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
Running the following code will load the [ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 3GB VRAM.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.audio import save_audio
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
pipe = AceStepPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||||
|
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||||
|
audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
lyrics=lyrics,
|
||||||
|
duration=160,
|
||||||
|
bpm=100,
|
||||||
|
keyscale="B minor",
|
||||||
|
timesignature="4",
|
||||||
|
vocal_language="zh",
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model Overview
|
||||||
|
|
||||||
|
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/Ace-Step1.5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/Ace-Step1.5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/Ace-Step1.5.py)|
|
||||||
|
|[ACE-Step/acestep-v15-turbo-shift1](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-shift1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-shift1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift1.py)|
|
||||||
|
|[ACE-Step/acestep-v15-turbo-shift3](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift3)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-shift3.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-shift3.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift3.py)|
|
||||||
|
|[ACE-Step/acestep-v15-turbo-continuous](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-continuous)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-continuous.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-continuous.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-continuous.py)|
|
||||||
|
|[ACE-Step/acestep-v15-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-base.py)|
|
||||||
|
|[ACE-Step/acestep-v15-base: CoverTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py)|—|—|—|—|
|
||||||
|
|[ACE-Step/acestep-v15-base: RepaintTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py)|—|—|—|—|
|
||||||
|
|[ACE-Step/acestep-v15-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-sft)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-sft.py)|
|
||||||
|
|[ACE-Step/acestep-v15-xl-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-base.py)|
|
||||||
|
|[ACE-Step/acestep-v15-xl-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-sft)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-sft.py)|
|
||||||
|
|[ACE-Step/acestep-v15-xl-turbo](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-turbo)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-turbo.py)|
|
||||||
|
|
||||||
|
## Model Inference
|
||||||
|
|
||||||
|
The model is loaded via `AceStepPipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
|
||||||
|
|
||||||
|
The input parameters for `AceStepPipeline` inference include:
|
||||||
|
|
||||||
|
* `prompt`: Text description of the music.
|
||||||
|
* `cfg_scale`: Classifier-free guidance scale, defaults to 1.0.
|
||||||
|
* `lyrics`: Lyrics text.
|
||||||
|
* `task_type`: Task type,可选 values include `"text2music"` (text-to-music), `"cover"` (audio cover), `"repaint"` (repainting), defaults to `"text2music"`.
|
||||||
|
* `reference_audios`: List of reference audio tensors for timbre reference.
|
||||||
|
* `src_audio`: Source audio tensor for cover or repaint tasks.
|
||||||
|
* `denoising_strength`: Denoising strength, controlling how much the output is influenced by source audio, defaults to 1.0.
|
||||||
|
* `audio_cover_strength`: Audio cover step ratio, controlling how many steps use cover condition in cover tasks, defaults to 1.0.
|
||||||
|
* `audio_code_string`: Input audio code string for cover tasks with discrete audio codes.
|
||||||
|
* `repainting_ranges`: List of repainting time ranges (tuples of floats, in seconds) for repaint tasks.
|
||||||
|
* `repainting_strength`: Repainting intensity, controlling the degree of change in repainted areas, defaults to 1.0.
|
||||||
|
* `duration`: Audio duration in seconds, defaults to 60.
|
||||||
|
* `bpm`: Beats per minute, defaults to 100.
|
||||||
|
* `keyscale`: Musical key scale, defaults to "B minor".
|
||||||
|
* `timesignature`: Time signature, defaults to "4".
|
||||||
|
* `vocal_language`: Vocal language, defaults to "unknown".
|
||||||
|
* `seed`: Random seed.
|
||||||
|
* `rand_device`: Device for noise generation, defaults to "cpu".
|
||||||
|
* `num_inference_steps`: Number of inference steps, defaults to 8.
|
||||||
|
* `shift`: Timestep shift parameter for the scheduler, defaults to 1.0.
|
||||||
|
|
||||||
|
## Model Training
|
||||||
|
|
||||||
|
Models in the ace_step series are trained uniformly via `examples/ace_step/model_training/train.py`. The script parameters include:
|
||||||
|
|
||||||
|
* General Training Parameters
|
||||||
|
* Dataset Configuration
|
||||||
|
* `--dataset_base_path`: Root directory of the dataset.
|
||||||
|
* `--dataset_metadata_path`: Path to the dataset metadata file.
|
||||||
|
* `--dataset_repeat`: Number of dataset repeats per epoch.
|
||||||
|
* `--dataset_num_workers`: Number of processes per DataLoader.
|
||||||
|
* `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`.
|
||||||
|
* Model Loading Configuration
|
||||||
|
* `--model_paths`: Paths to load models from, in JSON format.
|
||||||
|
* `--model_id_with_origin_paths`: Model IDs with original paths, separated by commas.
|
||||||
|
* `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`.
|
||||||
|
* `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients.
|
||||||
|
* Basic Training Configuration
|
||||||
|
* `--learning_rate`: Learning rate.
|
||||||
|
* `--num_epochs`: Number of epochs.
|
||||||
|
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
|
||||||
|
* `--find_unused_parameters`: Whether unused parameters exist in DDP training.
|
||||||
|
* `--weight_decay`: Weight decay magnitude.
|
||||||
|
* `--task`: Training task, defaults to `sft`.
|
||||||
|
* Output Configuration
|
||||||
|
* `--output_path`: Path to save the model.
|
||||||
|
* `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict.
|
||||||
|
* `--save_steps`: Interval in training steps to save the model.
|
||||||
|
* LoRA Configuration
|
||||||
|
* `--lora_base_model`: Which model to add LoRA to.
|
||||||
|
* `--lora_target_modules`: Which layers to add LoRA to.
|
||||||
|
* `--lora_rank`: Rank of LoRA.
|
||||||
|
* `--lora_checkpoint`: Path to LoRA checkpoint.
|
||||||
|
* `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training.
|
||||||
|
* `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`.
|
||||||
|
* Gradient Configuration
|
||||||
|
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
|
||||||
|
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
|
||||||
|
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
|
||||||
|
* Resolution Configuration
|
||||||
|
* `--height`: Height of the image/video. Leave empty to enable dynamic resolution.
|
||||||
|
* `--width`: Width of the image/video. Leave empty to enable dynamic resolution.
|
||||||
|
* `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution.
|
||||||
|
* `--num_frames`: Number of frames for video (video generation models only).
|
||||||
|
* ACE-Step Specific Parameters
|
||||||
|
* `--tokenizer_path`: Tokenizer path, in format model_id:origin_pattern.
|
||||||
|
* `--silence_latent_path`: Silence latent path, in format model_id:origin_pattern.
|
||||||
|
* `--initialize_model_on_cpu`: Whether to initialize models on CPU.
|
||||||
|
|
||||||
|
### Example Dataset
|
||||||
|
|
||||||
|
```shell
|
||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).
|
||||||
139
docs/en/Model_Details/Anima.md
Normal file
139
docs/en/Model_Details/Anima.md
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
# Anima
|
||||||
|
|
||||||
|
Anima is an image generation model trained and open-sourced by CircleStone Labs and Comfy Org.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Before using this project for model inference and training, please install DiffSynth-Studio first.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
For more installation information, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md).
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
The following code demonstrates how to quickly load the [circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima) model for inference. VRAM management is enabled by default, allowing the framework to automatically control model parameter loading based on available VRAM. Minimum 8GB VRAM required.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": "disk",
|
||||||
|
"offload_device": "disk",
|
||||||
|
"onload_dtype": "disk",
|
||||||
|
"onload_device": "disk",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = AnimaImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
|
||||||
|
tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
prompt = "Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait."
|
||||||
|
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||||
|
image = pipe(prompt, seed=0, num_inference_steps=50)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model Overview
|
||||||
|
|
||||||
|
|Model ID|Inference|Low VRAM Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference_low_vram/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/full/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_full/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/lora/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_lora/anima-preview.py)|
|
||||||
|
|
||||||
|
Special training scripts:
|
||||||
|
|
||||||
|
* Differential LoRA Training: [doc](../Training/Differential_LoRA.md)
|
||||||
|
* FP8 Precision Training: [doc](../Training/FP8_Precision.md)
|
||||||
|
* Two-Stage Split Training: [doc](../Training/Split_Training.md)
|
||||||
|
* End-to-End Direct Distillation: [doc](../Training/Direct_Distill.md)
|
||||||
|
|
||||||
|
## Model Inference
|
||||||
|
|
||||||
|
Models are loaded through `AnimaImagePipeline.from_pretrained`, see [Model Inference](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
|
||||||
|
|
||||||
|
Input parameters for `AnimaImagePipeline` inference include:
|
||||||
|
|
||||||
|
* `prompt`: Text description of the desired image content.
|
||||||
|
* `negative_prompt`: Content to exclude from the generated image (default: `""`).
|
||||||
|
* `cfg_scale`: Classifier-free guidance parameter (default: 4.0).
|
||||||
|
* `input_image`: Input image for image-to-image generation (default: `None`).
|
||||||
|
* `denoising_strength`: Controls similarity to input image (default: 1.0).
|
||||||
|
* `height`: Image height (must be multiple of 16, default: 1024).
|
||||||
|
* `width`: Image width (must be multiple of 16, default: 1024).
|
||||||
|
* `seed`: Random seed (default: `None`).
|
||||||
|
* `rand_device`: Device for random noise generation (default: `"cpu"`).
|
||||||
|
* `num_inference_steps`: Inference steps (default: 30).
|
||||||
|
* `sigma_shift`: Scheduler sigma offset (default: `None`).
|
||||||
|
* `progress_bar_cmd`: Progress bar implementation (default: `tqdm.tqdm`).
|
||||||
|
|
||||||
|
For VRAM constraints, enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). Recommended low-VRAM configurations are provided in the "Model Overview" table above.
|
||||||
|
|
||||||
|
## Model Training
|
||||||
|
|
||||||
|
Anima models are trained through [`examples/anima/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/train.py) with parameters including:
|
||||||
|
|
||||||
|
* General Training Parameters
|
||||||
|
* Dataset Configuration
|
||||||
|
* `--dataset_base_path`: Dataset root directory.
|
||||||
|
* `--dataset_metadata_path`: Metadata file path.
|
||||||
|
* `--dataset_repeat`: Dataset repetition per epoch.
|
||||||
|
* `--dataset_num_workers`: Dataloader worker count.
|
||||||
|
* `--data_file_keys`: Metadata fields to load (comma-separated).
|
||||||
|
* Model Loading
|
||||||
|
* `--model_paths`: Model paths (JSON format).
|
||||||
|
* `--model_id_with_origin_paths`: Model IDs with origin paths (e.g., `"anima-team/anima-1B:text_encoder/*.safetensors"`).
|
||||||
|
* `--extra_inputs`: Additional pipeline inputs (e.g., `controlnet_inputs` for ControlNet).
|
||||||
|
* `--fp8_models`: FP8-formatted models (same format as `--model_paths`).
|
||||||
|
* Training Configuration
|
||||||
|
* `--learning_rate`: Learning rate.
|
||||||
|
* `--num_epochs`: Training epochs.
|
||||||
|
* `--trainable_models`: Trainable components (e.g., `dit`, `vae`, `text_encoder`).
|
||||||
|
* `--find_unused_parameters`: Handle unused parameters in DDP training.
|
||||||
|
* `--weight_decay`: Weight decay value.
|
||||||
|
* `--task`: Training task (default: `sft`).
|
||||||
|
* Output Configuration
|
||||||
|
* `--output_path`: Model output directory.
|
||||||
|
* `--remove_prefix_in_ckpt`: Remove state dict prefixes.
|
||||||
|
* `--save_steps`: Model saving interval.
|
||||||
|
* LoRA Configuration
|
||||||
|
* `--lora_base_model`: Target model for LoRA.
|
||||||
|
* `--lora_target_modules`: Target modules for LoRA.
|
||||||
|
* `--lora_rank`: LoRA rank.
|
||||||
|
* `--lora_checkpoint`: LoRA checkpoint path.
|
||||||
|
* `--preset_lora_path`: Preloaded LoRA checkpoint path.
|
||||||
|
* `--preset_lora_model`: Model to merge LoRA with (e.g., `dit`).
|
||||||
|
* Gradient Configuration
|
||||||
|
* `--use_gradient_checkpointing`: Enable gradient checkpointing.
|
||||||
|
* `--use_gradient_checkpointing_offload`: Offload checkpointing to CPU.
|
||||||
|
* `--gradient_accumulation_steps`: Gradient accumulation steps.
|
||||||
|
* Image Resolution
|
||||||
|
* `--height`: Image height (empty for dynamic resolution).
|
||||||
|
* `--width`: Image width (empty for dynamic resolution).
|
||||||
|
* `--max_pixels`: Maximum pixel area for dynamic resolution.
|
||||||
|
* Anima-Specific Parameters
|
||||||
|
* `--tokenizer_path`: Tokenizer path for text-to-image models.
|
||||||
|
* `--tokenizer_t5xxl_path`: T5-XXL tokenizer path.
|
||||||
|
|
||||||
|
We provide a sample image dataset for testing:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
For training script details, refer to [Model Training](../Pipeline_Usage/Model_Training.md). For advanced training techniques, see [Training Framework Documentation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/).
|
||||||
134
docs/en/Model_Details/ERNIE-Image.md
Normal file
134
docs/en/Model_Details/ERNIE-Image.md
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
# ERNIE-Image
|
||||||
|
|
||||||
|
ERNIE-Image is a powerful image generation model with 8B parameters developed by Baidu, featuring a compact and efficient architecture with strong instruction-following capability. Based on an 8B DiT backbone, it delivers performance comparable to larger (20B+) models in certain scenarios while maintaining parameter efficiency. It offers reliable performance in instruction understanding and execution, text generation (English/Chinese/Japanese), and overall stability.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Before performing model inference and training, please install DiffSynth-Studio first.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md).
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
Running the following code will load the [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 3G VRAM.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = ErnieImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device='cuda',
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="一只黑白相间的中华田园犬",
|
||||||
|
negative_prompt="",
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
cfg_scale=4.0,
|
||||||
|
)
|
||||||
|
image.save("output.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model Overview
|
||||||
|
|
||||||
|
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
|
||||||
|
|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
|
||||||
|
|
||||||
|
## Model Inference
|
||||||
|
|
||||||
|
The model is loaded via `ErnieImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
|
||||||
|
|
||||||
|
The input parameters for `ErnieImagePipeline` inference include:
|
||||||
|
|
||||||
|
* `prompt`: The prompt describing the content to appear in the image.
|
||||||
|
* `negative_prompt`: The negative prompt describing what should not appear in the image, default value is `""`.
|
||||||
|
* `cfg_scale`: Classifier-free guidance parameter, default value is 4.0.
|
||||||
|
* `height`: Image height, must be a multiple of 16, default value is 1024.
|
||||||
|
* `width`: Image width, must be a multiple of 16, default value is 1024.
|
||||||
|
* `seed`: Random seed. Default is `None`, meaning completely random.
|
||||||
|
* `rand_device`: The computing device for generating random Gaussian noise matrices, default is `"cuda"`. When set to `cuda`, different GPUs will produce different results.
|
||||||
|
* `num_inference_steps`: Number of inference steps, default value is 50.
|
||||||
|
|
||||||
|
If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low-VRAM configurations for each model in the "Model Overview" table above.
|
||||||
|
|
||||||
|
## Model Training
|
||||||
|
|
||||||
|
ERNIE-Image series models are trained uniformly via [`examples/ernie_image/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/train.py). The script parameters include:
|
||||||
|
|
||||||
|
* General Training Parameters
|
||||||
|
* Dataset Configuration
|
||||||
|
* `--dataset_base_path`: Root directory of the dataset.
|
||||||
|
* `--dataset_metadata_path`: Path to the dataset metadata file.
|
||||||
|
* `--dataset_repeat`: Number of dataset repeats per epoch.
|
||||||
|
* `--dataset_num_workers`: Number of processes per DataLoader.
|
||||||
|
* `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`.
|
||||||
|
* Model Loading Configuration
|
||||||
|
* `--model_paths`: Paths to load models from, in JSON format.
|
||||||
|
* `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `"PaddlePaddle/ERNIE-Image:transformer/diffusion_pytorch_model*.safetensors"`, separated by commas.
|
||||||
|
* `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`.
|
||||||
|
* `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients.
|
||||||
|
* Basic Training Configuration
|
||||||
|
* `--learning_rate`: Learning rate.
|
||||||
|
* `--num_epochs`: Number of epochs.
|
||||||
|
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
|
||||||
|
* `--find_unused_parameters`: Whether unused parameters exist in DDP training.
|
||||||
|
* `--weight_decay`: Weight decay magnitude.
|
||||||
|
* `--task`: Training task, defaults to `sft`.
|
||||||
|
* Output Configuration
|
||||||
|
* `--output_path`: Path to save the model.
|
||||||
|
* `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict.
|
||||||
|
* `--save_steps`: Interval in training steps to save the model.
|
||||||
|
* LoRA Configuration
|
||||||
|
* `--lora_base_model`: Which model to add LoRA to.
|
||||||
|
* `--lora_target_modules`: Which layers to add LoRA to.
|
||||||
|
* `--lora_rank`: Rank of LoRA.
|
||||||
|
* `--lora_checkpoint`: Path to LoRA checkpoint.
|
||||||
|
* `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training.
|
||||||
|
* `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`.
|
||||||
|
* Gradient Configuration
|
||||||
|
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
|
||||||
|
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
|
||||||
|
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
|
||||||
|
* Resolution Configuration
|
||||||
|
* `--height`: Height of the image. Leave empty to enable dynamic resolution.
|
||||||
|
* `--width`: Width of the image. Leave empty to enable dynamic resolution.
|
||||||
|
* `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution.
|
||||||
|
* ERNIE-Image Specific Parameters
|
||||||
|
* `--tokenizer_path`: Path to the tokenizer, leave empty to auto-download from remote.
|
||||||
|
|
||||||
|
We provide an example image dataset for testing, which can be downloaded with the following command:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).
|
||||||
@@ -81,27 +81,27 @@ graph LR;
|
|||||||
|
|
||||||
| Model ID | Extra Parameters | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
| Model ID | Extra Parameters | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
||||||
| - | - | - | - | - | - | - | - |
|
| - | - | - | - | - | - | - | - |
|
||||||
| [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) | | [code](/examples/flux/model_inference/FLUX.1-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py) |
|
| [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev.py) |
|
||||||
| [black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) | | [code](/examples/flux/model_inference/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py) |
|
| [black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Krea-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py) |
|
||||||
| [black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) | `kontext_images` | [code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py) |
|
| [black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) | `kontext_images` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py) |
|
||||||
| [alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py) |
|
| [alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py) |
|
||||||
| [InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py) |
|
| [InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py) |
|
||||||
| [jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py) |
|
| [jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py) |
|
||||||
| [InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter) | `ipadapter_images`, `ipadapter_scale` | [code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py) |
|
| [InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter) | `ipadapter_images`, `ipadapter_scale` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py) |
|
||||||
| [ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou) | `infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py) |
|
| [ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou) | `infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py) |
|
||||||
| [DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) | `eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint` | [code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py) | - | - | [code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py) |
|
| [DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) | `eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-EliGen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py) |
|
||||||
| [DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev) | `lora_encoder_inputs`, `lora_encoder_scale` | [code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py) | - | - |
|
| [DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev) | `lora_encoder_inputs`, `lora_encoder_scale` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py) | - | - |
|
||||||
| [DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev) | | [code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py) | - | - | - | - | - |
|
| [DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py) | - | - | - | - | - |
|
||||||
| [stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit) | `step1x_reference_image` | [code](/examples/flux/model_inference/Step1X-Edit.py) | [code](/examples/flux/model_inference_low_vram/Step1X-Edit.py) | [code](/examples/flux/model_training/full/Step1X-Edit.sh) | [code](/examples/flux/model_training/validate_full/Step1X-Edit.py) | [code](/examples/flux/model_training/lora/Step1X-Edit.sh) | [code](/examples/flux/model_training/validate_lora/Step1X-Edit.py) |
|
| [stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit) | `step1x_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Step1X-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Step1X-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Step1X-Edit.py) |
|
||||||
| [ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview) | `flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop` | [code](/examples/flux/model_inference/FLEX.2-preview.py) | [code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py) | [code](/examples/flux/model_training/full/FLEX.2-preview.sh) | [code](/examples/flux/model_training/validate_full/FLEX.2-preview.py) | [code](/examples/flux/model_training/lora/FLEX.2-preview.sh) | [code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py) |
|
| [ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview) | `flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLEX.2-preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLEX.2-preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLEX.2-preview.py) |
|
||||||
| [DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2) | `nexus_gen_reference_image` | [code](/examples/flux/model_inference/Nexus-Gen-Editing.py) | [code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py) | [code](/examples/flux/model_training/full/Nexus-Gen.sh) | [code](/examples/flux/model_training/validate_full/Nexus-Gen.py) | [code](/examples/flux/model_training/lora/Nexus-Gen.sh) | [code](/examples/flux/model_training/validate_lora/Nexus-Gen.py) |
|
| [DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2) | `nexus_gen_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Nexus-Gen-Editing.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Nexus-Gen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Nexus-Gen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Nexus-Gen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Nexus-Gen.py) |
|
||||||
|
|
||||||
Special Training Scripts:
|
Special Training Scripts:
|
||||||
|
|
||||||
* Differential LoRA Training: [doc](../Training/Differential_LoRA.md), [code](/examples/flux/model_training/special/differential_training/)
|
* Differential LoRA Training: [doc](../Training/Differential_LoRA.md)
|
||||||
* FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](/examples/flux/model_training/special/fp8_training/)
|
* FP8 Precision Training: [doc](../Training/FP8_Precision.md)
|
||||||
* Two-stage Split Training: [doc](../Training/Split_Training.md), [code](/examples/flux/model_training/special/split_training/)
|
* Two-stage Split Training: [doc](../Training/Split_Training.md)
|
||||||
* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md), [code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh)
|
* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md)
|
||||||
|
|
||||||
## Model Inference
|
## Model Inference
|
||||||
|
|
||||||
@@ -147,7 +147,7 @@ If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_
|
|||||||
|
|
||||||
## Model Training
|
## Model Training
|
||||||
|
|
||||||
FLUX series models are uniformly trained through [`examples/flux/model_training/train.py`](/examples/flux/model_training/train.py), and the script parameters include:
|
FLUX series models are uniformly trained through [`examples/flux/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/train.py), and the script parameters include:
|
||||||
|
|
||||||
* General Training Parameters
|
* General Training Parameters
|
||||||
* Dataset Basic Configuration
|
* Dataset Basic Configuration
|
||||||
@@ -195,7 +195,7 @@ FLUX series models are uniformly trained through [`examples/flux/model_training/
|
|||||||
We have built a sample image dataset for your testing. You can download this dataset with the following command:
|
We have built a sample image dataset for your testing. You can download this dataset with the following command:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
|
||||||
```
|
```
|
||||||
|
|
||||||
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/).
|
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).
|
||||||
|
|||||||
@@ -61,11 +61,11 @@ image.save("image.jpg")
|
|||||||
|
|
||||||
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
||||||
| - | - | - | - | - | - | - |
|
| - | - | - | - | - | - | - |
|
||||||
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|
||||||
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|
||||||
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|
||||||
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|
||||||
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
|
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
|
||||||
|
|
||||||
Special Training Scripts:
|
Special Training Scripts:
|
||||||
|
|
||||||
@@ -99,7 +99,7 @@ If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_
|
|||||||
|
|
||||||
## Model Training
|
## Model Training
|
||||||
|
|
||||||
FLUX.2 series models are uniformly trained through [`examples/flux2/model_training/train.py`](/examples/flux2/model_training/train.py), and the script parameters include:
|
FLUX.2 series models are uniformly trained through [`examples/flux2/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/train.py), and the script parameters include:
|
||||||
|
|
||||||
* General Training Parameters
|
* General Training Parameters
|
||||||
* Dataset Basic Configuration
|
* Dataset Basic Configuration
|
||||||
@@ -145,7 +145,7 @@ FLUX.2 series models are uniformly trained through [`examples/flux2/model_traini
|
|||||||
We have built a sample image dataset for your testing. You can download this dataset with the following command:
|
We have built a sample image dataset for your testing. You can download this dataset with the following command:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
|
||||||
```
|
```
|
||||||
|
|
||||||
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/).
|
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).
|
||||||
|
|||||||
154
docs/en/Model_Details/JoyAI-Image.md
Normal file
154
docs/en/Model_Details/JoyAI-Image.md
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
# JoyAI-Image
|
||||||
|
|
||||||
|
JoyAI-Image is a unified multi-modal foundation model open-sourced by JD.com, supporting image understanding, text-to-image generation, and instruction-guided image editing.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Before performing model inference and training, please install DiffSynth-Studio first.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md).
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
Running the following code will load the [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 4GB VRAM.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
|
||||||
|
# Download dataset
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||||
|
local_dir="data/diffsynth_example_dataset",
|
||||||
|
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
|
||||||
|
)
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
|
||||||
|
pipe = JoyAIImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
|
||||||
|
],
|
||||||
|
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use first sample from dataset
|
||||||
|
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
|
||||||
|
prompt = "将裙子改为粉色"
|
||||||
|
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
|
||||||
|
|
||||||
|
output = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
edit_image=edit_image,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=0,
|
||||||
|
num_inference_steps=30,
|
||||||
|
cfg_scale=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
output.save("output_joyai_edit_low_vram.png")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model Overview
|
||||||
|
|
||||||
|
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)|
|
||||||
|
|
||||||
|
## Model Inference
|
||||||
|
|
||||||
|
The model is loaded via `JoyAIImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
|
||||||
|
|
||||||
|
The input parameters for `JoyAIImagePipeline` inference include:
|
||||||
|
|
||||||
|
* `prompt`: Text prompt describing the desired image editing effect.
|
||||||
|
* `negative_prompt`: Negative prompt specifying what should not appear in the result, defaults to empty string.
|
||||||
|
* `cfg_scale`: Classifier-free guidance scale factor, defaults to 5.0. Higher values make the output more closely follow the prompt.
|
||||||
|
* `edit_image`: Image to be edited.
|
||||||
|
* `denoising_strength`: Denoising strength controlling how much the input image is repainted, defaults to 1.0.
|
||||||
|
* `height`: Height of the output image, defaults to 1024. Must be divisible by 16.
|
||||||
|
* `width`: Width of the output image, defaults to 1024. Must be divisible by 16.
|
||||||
|
* `seed`: Random seed for reproducibility. Set to `None` for random seed.
|
||||||
|
* `max_sequence_length`: Maximum sequence length for the text encoder, defaults to 4096.
|
||||||
|
* `num_inference_steps`: Number of inference steps, defaults to 30. More steps typically yield better quality.
|
||||||
|
* `tiled`: Whether to enable tiling for reduced VRAM usage, defaults to False.
|
||||||
|
* `tile_size`: Tile size, defaults to (30, 52).
|
||||||
|
* `tile_stride`: Tile stride, defaults to (15, 26).
|
||||||
|
* `shift`: Shift parameter for the scheduler, controlling the Flow Match scheduling curve, defaults to 4.0.
|
||||||
|
* `progress_bar_cmd`: Progress bar display mode, defaults to tqdm.
|
||||||
|
|
||||||
|
## Model Training
|
||||||
|
|
||||||
|
Models in the joyai_image series are trained uniformly via `examples/joyai_image/model_training/train.py`. The script parameters include:
|
||||||
|
|
||||||
|
* General Training Parameters
|
||||||
|
* Dataset Configuration
|
||||||
|
* `--dataset_base_path`: Root directory of the dataset.
|
||||||
|
* `--dataset_metadata_path`: Path to the dataset metadata file.
|
||||||
|
* `--dataset_repeat`: Number of dataset repeats per epoch.
|
||||||
|
* `--dataset_num_workers`: Number of processes per DataLoader.
|
||||||
|
* `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`.
|
||||||
|
* Model Loading Configuration
|
||||||
|
* `--model_paths`: Paths to load models from, in JSON format.
|
||||||
|
* `--model_id_with_origin_paths`: Model IDs with original paths, separated by commas.
|
||||||
|
* `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`.
|
||||||
|
* `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients.
|
||||||
|
* Basic Training Configuration
|
||||||
|
* `--learning_rate`: Learning rate.
|
||||||
|
* `--num_epochs`: Number of epochs.
|
||||||
|
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
|
||||||
|
* `--find_unused_parameters`: Whether unused parameters exist in DDP training.
|
||||||
|
* `--weight_decay`: Weight decay magnitude.
|
||||||
|
* `--task`: Training task, defaults to `sft`.
|
||||||
|
* Output Configuration
|
||||||
|
* `--output_path`: Path to save the model.
|
||||||
|
* `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict.
|
||||||
|
* `--save_steps`: Interval in training steps to save the model.
|
||||||
|
* LoRA Configuration
|
||||||
|
* `--lora_base_model`: Which model to add LoRA to.
|
||||||
|
* `--lora_target_modules`: Which layers to add LoRA to.
|
||||||
|
* `--lora_rank`: Rank of LoRA.
|
||||||
|
* `--lora_checkpoint`: Path to LoRA checkpoint.
|
||||||
|
* `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training.
|
||||||
|
* `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`.
|
||||||
|
* Gradient Configuration
|
||||||
|
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
|
||||||
|
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
|
||||||
|
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
|
||||||
|
* Resolution Configuration
|
||||||
|
* `--height`: Height of the image/video. Leave empty to enable dynamic resolution.
|
||||||
|
* `--width`: Width of the image/video. Leave empty to enable dynamic resolution.
|
||||||
|
* `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution.
|
||||||
|
* `--num_frames`: Number of frames for video (video generation models only).
|
||||||
|
* JoyAI-Image Specific Parameters
|
||||||
|
* `--processor_path`: Path to the processor for processing text and image encoder inputs.
|
||||||
|
* `--initialize_model_on_cpu`: Whether to initialize models on CPU. By default, models are initialized on the accelerator device.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).
|
||||||
@@ -16,7 +16,7 @@ For more information about installation, please refer to [Installation Dependenc
|
|||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
Run the following code to quickly load the [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) model and perform inference. VRAM management has been enabled, and the framework will automatically control model parameter loading based on remaining VRAM. It can run with a minimum of 8GB VRAM.
|
Run the following code to quickly load the [Lightricks/LTX-2.3](https://www.modelscope.cn/models/Lightricks/LTX-2.3) model and perform inference. VRAM management has been enabled, and the framework will automatically control model parameter loading based on remaining VRAM. It can run with a minimum of 8GB VRAM.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import torch
|
import torch
|
||||||
@@ -24,11 +24,11 @@ from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelCo
|
|||||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
|
||||||
vram_config = {
|
vram_config = {
|
||||||
"offload_dtype": torch.float8_e5m2,
|
"offload_dtype": torch.bfloat16,
|
||||||
"offload_device": "cpu",
|
"offload_device": "cpu",
|
||||||
"onload_dtype": torch.float8_e5m2,
|
"onload_dtype": torch.bfloat16,
|
||||||
"onload_device": "cpu",
|
"onload_device": "cuda",
|
||||||
"preparing_dtype": torch.float8_e5m2,
|
"preparing_dtype": torch.bfloat16,
|
||||||
"preparing_device": "cuda",
|
"preparing_device": "cuda",
|
||||||
"computation_dtype": torch.bfloat16,
|
"computation_dtype": torch.bfloat16,
|
||||||
"computation_device": "cuda",
|
"computation_device": "cuda",
|
||||||
@@ -38,48 +38,52 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-distilled-lora-384.safetensors"),
|
||||||
)
|
)
|
||||||
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
|
prompt = "Two cute orange cats, wearing boxing gloves, stand in a boxing ring and fight each other. They are punching each other fast and yelling: 'I will win!'"
|
||||||
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
negative_prompt = pipe.default_negative_prompt["LTX-2.3"]
|
||||||
height, width, num_frames = 512, 768, 121
|
|
||||||
video, audio = pipe(
|
video, audio = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
seed=43,
|
seed=43,
|
||||||
height=height,
|
height=1024, width=1536, num_frames=121,
|
||||||
width=width,
|
tiled=True, use_two_stage_pipeline=True,
|
||||||
num_frames=num_frames,
|
|
||||||
tiled=True,
|
|
||||||
)
|
|
||||||
write_video_audio_ltx2(
|
|
||||||
video=video,
|
|
||||||
audio=audio,
|
|
||||||
output_path='ltx2_onestage.mp4',
|
|
||||||
fps=24,
|
|
||||||
audio_sample_rate=24000,
|
|
||||||
)
|
)
|
||||||
|
write_video_audio_ltx2(video=video, audio=audio, output_path='video.mp4', fps=24, audio_sample_rate=pipe.audio_vocoder.output_sampling_rate)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Model Overview
|
## Model Overview
|
||||||
|Model ID|Additional Parameters|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
|
|Model ID|Additional Parameters|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
|
||||||
|-|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|-|
|
||||||
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-|
|
|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2.3-I2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2.3-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-I2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-I2AV.py)|
|
||||||
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|
|[Lightricks/LTX-2.3: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-TwoStage.py)|-|-|-|-|
|
||||||
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|
|[Lightricks/LTX-2.3: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-DistilledPipeline.py)|-|-|-|-|
|
||||||
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|
|[Lightricks/LTX-2.3: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py)|
|
||||||
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|
|[Lightricks/LTX-2.3: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage.py)|-|-|-|-|
|
||||||
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
|
|[Lightricks/LTX-2.3: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-DistilledPipeline.py)|-|-|-|-|
|
||||||
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py)|-|-|-|-|
|
|[Lightricks/LTX-2.3: A2V](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`retake_audio`,`audio_sample_rate`,`retake_audio_regions`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-A2V-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-A2V-TwoStage.py)|-|-|-|-|
|
||||||
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|-|-|-|-|
|
|[Lightricks/LTX-2.3: Retake](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`retake_video`,`retake_video_regions`,`retake_audio`,`audio_sample_rate`,`retake_audio_regions`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage-Retake.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage-Retake.py)|-|-|-|-|
|
||||||
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|-|-|-|-|
|
|[Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py)|
|
||||||
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|-|-|-|-|
|
|[Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control](https://www.modelscope.cn/models/Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Motion-Track-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Motion-Track-Control.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py)|
|
||||||
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py)|-|-|-|-|
|
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)|
|
||||||
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py)|-|-|-|-|
|
|[Lightricks/LTX-2-19b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|
|
||||||
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Static](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Static)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py)|-|-|-|-|
|
|[Lightricks/LTX-2-19b-IC-LoRA-Detailer](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Detailer)|`in_context_videos`,`in_context_downsample_factor`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Detailer.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|
|
||||||
|
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Static](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Static)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py)|-|-|-|-|
|
||||||
|
|
||||||
## Model Inference
|
## Model Inference
|
||||||
|
|
||||||
@@ -113,4 +117,55 @@ If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_
|
|||||||
|
|
||||||
## Model Training
|
## Model Training
|
||||||
|
|
||||||
The LTX-2 series models currently do not support training functionality. We will add related support as soon as possible.
|
LTX-2 series models are uniformly trained through [`examples/ltx2/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/train.py), and the script parameters include:
|
||||||
|
|
||||||
|
* General Training Parameters
|
||||||
|
* Dataset Basic Configuration
|
||||||
|
* `--dataset_base_path`: Root directory of the dataset.
|
||||||
|
* `--dataset_metadata_path`: Metadata file path of the dataset.
|
||||||
|
* `--dataset_repeat`: Number of times the dataset is repeated in each epoch.
|
||||||
|
* `--dataset_num_workers`: Number of processes for each DataLoader.
|
||||||
|
* `--data_file_keys`: Field names to be loaded from metadata, usually image or video file paths, separated by `,`.
|
||||||
|
* Model Loading Configuration
|
||||||
|
* `--model_paths`: Paths of models to be loaded. JSON format.
|
||||||
|
* `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `"Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors"`. Separated by commas.
|
||||||
|
* `--extra_inputs`: Extra input parameters required by the model Pipeline, e.g., extra parameters when training image editing models, separated by `,`.
|
||||||
|
* `--fp8_models`: Models loaded in FP8 format, consistent with `--model_paths` or `--model_id_with_origin_paths` format. Currently only supports models whose parameters are not updated by gradients (no gradient backpropagation, or gradients only update their LoRA).
|
||||||
|
* Training Basic Configuration
|
||||||
|
* `--learning_rate`: Learning rate.
|
||||||
|
* `--num_epochs`: Number of epochs.
|
||||||
|
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
|
||||||
|
* `--find_unused_parameters`: Whether there are unused parameters in DDP training. Some models contain redundant parameters that do not participate in gradient calculation, and this setting needs to be enabled to avoid errors in multi-GPU training.
|
||||||
|
* `--weight_decay`: Weight decay size, see [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html).
|
||||||
|
* `--task`: Training task, default is `sft`. Some models support more training modes, please refer to the documentation of each specific model.
|
||||||
|
* Output Configuration
|
||||||
|
* `--output_path`: Model saving path.
|
||||||
|
* `--remove_prefix_in_ckpt`: Remove prefix in the state dict of the model file.
|
||||||
|
* `--save_steps`: Interval of training steps to save the model. If this parameter is left blank, the model is saved once per epoch.
|
||||||
|
* LoRA Configuration
|
||||||
|
* `--lora_base_model`: Which model to add LoRA to.
|
||||||
|
* `--lora_target_modules`: Which layers to add LoRA to.
|
||||||
|
* `--lora_rank`: Rank of LoRA.
|
||||||
|
* `--lora_checkpoint`: Path of the LoRA checkpoint. If this path is provided, LoRA will be loaded from this checkpoint.
|
||||||
|
* `--preset_lora_path`: Preset LoRA checkpoint path. If this path is provided, this LoRA will be loaded in the form of being merged into the base model. This parameter is used for LoRA differential training.
|
||||||
|
* `--preset_lora_model`: Model that the preset LoRA is merged into, e.g., `dit`.
|
||||||
|
* Gradient Configuration
|
||||||
|
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
|
||||||
|
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to memory.
|
||||||
|
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
|
||||||
|
* Video Width/Height Configuration
|
||||||
|
* `--height`: Height of the video. Leave `height` and `width` blank to enable dynamic resolution.
|
||||||
|
* `--width`: Width of the video. Leave `height` and `width` blank to enable dynamic resolution.
|
||||||
|
* `--max_pixels`: Maximum pixel area of video frames. When dynamic resolution is enabled, video frames with resolution larger than this value will be downscaled, and video frames with resolution smaller than this value will remain unchanged.
|
||||||
|
* `--num_frames`: Number of frames in the video.
|
||||||
|
* LTX-2 Series Specific Parameters
|
||||||
|
* `--tokenizer_path`: Path of the tokenizer, applicable to text-to-video models, leave blank to automatically download from remote.
|
||||||
|
* `--frame_rate`: frame rate of the training videos.
|
||||||
|
|
||||||
|
We have built a sample video dataset for your testing. You can download this dataset with the following command:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).
|
||||||
|
|||||||
@@ -69,19 +69,19 @@ graph LR;
|
|||||||
|
|
||||||
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
||||||
| - | - | - | - | - | - | - |
|
| - | - | - | - | - | - | - |
|
||||||
| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](/examples/qwen_image/model_inference/Qwen-Image.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) |
|
| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) |
|
||||||
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
|
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
|
||||||
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
|
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
|
||||||
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
|
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
|
||||||
| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
|
| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
|
||||||
| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) |
|
| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) |
|
||||||
| [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) | [code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py) |
|
| [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py) |
|
||||||
| [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) | [code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py) |
|
| [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py) |
|
||||||
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py) |
|
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py) |
|
||||||
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py) |
|
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py) |
|
||||||
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) |
|
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) |
|
||||||
| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) |
|
| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) |
|
||||||
| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - |
|
| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - |
|
||||||
|
|
||||||
## FLUX Series
|
## FLUX Series
|
||||||
|
|
||||||
@@ -149,20 +149,20 @@ graph LR;
|
|||||||
|
|
||||||
| Model ID | Extra Parameters | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
| Model ID | Extra Parameters | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
||||||
| - | - | - | - | - | - | - | - |
|
| - | - | - | - | - | - | - | - |
|
||||||
| [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) | | [code](/examples/flux/model_inference/FLUX.1-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py) |
|
| [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev.py) |
|
||||||
| [black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) | | [code](/examples/flux/model_inference/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py) |
|
| [black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Krea-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py) |
|
||||||
| [black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) | `kontext_images` | [code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py) |
|
| [black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) | `kontext_images` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py) |
|
||||||
| [alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py) |
|
| [alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py) |
|
||||||
| [InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py) |
|
| [InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py) |
|
||||||
| [jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py) |
|
| [jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py) |
|
||||||
| [InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter) | `ipadapter_images`, `ipadapter_scale` | [code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py) |
|
| [InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter) | `ipadapter_images`, `ipadapter_scale` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py) |
|
||||||
| [ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou) | `infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py) |
|
| [ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou) | `infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py) |
|
||||||
| [DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) | `eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint` | [code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py) | - | - | [code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py) |
|
| [DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) | `eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-EliGen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py) |
|
||||||
| [DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev) | `lora_encoder_inputs`, `lora_encoder_scale` | [code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py) | - | - |
|
| [DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev) | `lora_encoder_inputs`, `lora_encoder_scale` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py) | - | - |
|
||||||
| [DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev) | | [code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py) | - | - | - | - | - |
|
| [DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py) | - | - | - | - | - |
|
||||||
| [stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit) | `step1x_reference_image` | [code](/examples/flux/model_inference/Step1X-Edit.py) | [code](/examples/flux/model_inference_low_vram/Step1X-Edit.py) | [code](/examples/flux/model_training/full/Step1X-Edit.sh) | [code](/examples/flux/model_training/validate_full/Step1X-Edit.py) | [code](/examples/flux/model_training/lora/Step1X-Edit.sh) | [code](/examples/flux/model_training/validate_lora/Step1X-Edit.py) |
|
| [stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit) | `step1x_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Step1X-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Step1X-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Step1X-Edit.py) |
|
||||||
| [ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview) | `flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop` | [code](/examples/flux/model_inference/FLEX.2-preview.py) | [code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py) | [code](/examples/flux/model_training/full/FLEX.2-preview.sh) | [code](/examples/flux/model_training/validate_full/FLEX.2-preview.py) | [code](/examples/flux/model_training/lora/FLEX.2-preview.sh) | [code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py) |
|
| [ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview) | `flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLEX.2-preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLEX.2-preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLEX.2-preview.py) |
|
||||||
| [DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2) | `nexus_gen_reference_image` | [code](/examples/flux/model_inference/Nexus-Gen-Editing.py) | [code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py) | [code](/examples/flux/model_training/full/Nexus-Gen.sh) | [code](/examples/flux/model_training/validate_full/Nexus-Gen.py) | [code](/examples/flux/model_training/lora/Nexus-Gen.sh) | [code](/examples/flux/model_training/validate_lora/Nexus-Gen.py) |
|
| [DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2) | `nexus_gen_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Nexus-Gen-Editing.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Nexus-Gen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Nexus-Gen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Nexus-Gen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Nexus-Gen.py) |
|
||||||
|
|
||||||
## Wan Series
|
## Wan Series
|
||||||
|
|
||||||
@@ -254,38 +254,38 @@ graph LR;
|
|||||||
|
|
||||||
| Model ID | Extra Parameters | Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
| Model ID | Extra Parameters | Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
||||||
| - | - | - | - | - | - | - |
|
| - | - | - | - | - | - | - |
|
||||||
| [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | | [code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py) |
|
| [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py) |
|
||||||
| [Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | | [code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py) |
|
| [Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py) |
|
||||||
| [Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py) |
|
| [Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py) |
|
||||||
| [Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py) |
|
| [Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py) |
|
||||||
| [Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py) |
|
| [Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py) |
|
||||||
| [iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py) |
|
| [iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py) |
|
||||||
| [Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py) |
|
| [Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py) |
|
||||||
| [Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py) |
|
| [Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py) |
|
||||||
| [PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py) |
|
| [PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py) |
|
||||||
| [PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control) | `control_video` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py) |
|
| [PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control) | `control_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py) |
|
||||||
| [PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py) |
|
| [PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py) |
|
||||||
| [PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control) | `control_video` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py) |
|
| [PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control) | `control_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py) |
|
||||||
| [PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py) |
|
| [PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control) | `control_video`, `reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py) |
|
||||||
| [PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py) |
|
| [PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control) | `control_video`, `reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py) |
|
||||||
| [PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py) |
|
| [PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py) |
|
||||||
| [PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py) |
|
| [PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py) |
|
||||||
| [PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) |
|
| [PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) |
|
||||||
| [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py) |
|
| [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py) |
|
||||||
| [DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1) | `motion_bucket_id` | [code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py) |
|
| [DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1) | `motion_bucket_id` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py) |
|
||||||
| [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) | | [code](/examples/wanvideo/model_inference/krea-realtime-video.py) | [code](/examples/wanvideo/model_training/full/krea-realtime-video.sh) | [code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py) | [code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh) | [code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py) |
|
| [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/krea-realtime-video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/krea-realtime-video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/krea-realtime-video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/krea-realtime-video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py) |
|
||||||
| [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) | `longcat_video` | [code](/examples/wanvideo/model_inference/LongCat-Video.py) | [code](/examples/wanvideo/model_training/full/LongCat-Video.sh) | [code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py) | [code](/examples/wanvideo/model_training/lora/LongCat-Video.sh) | [code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py) |
|
| [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) | `longcat_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/LongCat-Video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/LongCat-Video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/LongCat-Video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/LongCat-Video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/LongCat-Video.py) |
|
||||||
| [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) | `vap_video`, `vap_prompt` | [code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py) | [code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py) | [code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py) |
|
| [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) | `vap_video`, `vap_prompt` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py) |
|
||||||
| [Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B) | | [code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py) |
|
| [Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py) |
|
||||||
| [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py) |
|
| [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py) |
|
||||||
| [Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py) |
|
| [Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py) |
|
||||||
| [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B) | `input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video` | [code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py) |
|
| [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B) | `input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py) |
|
||||||
| [Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B) | `input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video` | [code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py) |
|
| [Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B) | `input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py) |
|
||||||
| [PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py) |
|
| [PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py) |
|
||||||
| [PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py) |
|
| [PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py) |
|
||||||
| [PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py) |
|
| [PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control) | `control_video`, `reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py) |
|
||||||
| [PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py) |
|
| [PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py) |
|
||||||
|
|
||||||
* FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](/examples/wanvideo/model_training/special/fp8_training/)
|
* FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/fp8_training/)
|
||||||
* Two-stage Split Training: [doc](../Training/Split_Training.md), [code](/examples/wanvideo/model_training/special/split_training/)
|
* Two-stage Split Training: [doc](../Training/Split_Training.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/split_training/)
|
||||||
* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md), [code](/examples/wanvideo/model_training/special/direct_distill/)
|
* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/direct_distill/)
|
||||||
|
|||||||
@@ -80,32 +80,35 @@ graph LR;
|
|||||||
|
|
||||||
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
||||||
| - | - | - | - | - | - | - |
|
| - | - | - | - | - | - | - |
|
||||||
| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](/examples/qwen_image/model_inference/Qwen-Image.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) |
|
| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) |
|
||||||
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
|
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
|
||||||
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
|
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
|
||||||
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
|
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
|
||||||
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
||||||
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|
|[FireRedTeam/FireRed-Image-Edit-1.0](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.0)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/FireRed-Image-Edit-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.0.py)|
|
||||||
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
|[FireRedTeam/FireRed-Image-Edit-1.1](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/FireRed-Image-Edit-1.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.1.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|
||||||
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
||||||
| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
||||||
| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) |
|
|[DiffSynth-Studio/Qwen-Image-Layered-Control-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control-V2)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Layered-Control-V2.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control-V2.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control-V2.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control-V2.py)|
|
||||||
| [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) | [code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py) |
|
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
|
||||||
| [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) | [code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py) |
|
| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
|
||||||
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py) |
|
| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) |
|
||||||
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py) |
|
| [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py) |
|
||||||
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) |
|
| [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py) |
|
||||||
| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) |
|
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py) |
|
||||||
| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - |
|
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py) |
|
||||||
|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
|
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) |
|
||||||
|
| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) |
|
||||||
|
| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - |
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
|
||||||
|
|
||||||
Special Training Scripts:
|
Special Training Scripts:
|
||||||
|
|
||||||
* Differential LoRA Training: [doc](../Training/Differential_LoRA.md), [code](/examples/qwen_image/model_training/special/differential_training/)
|
* Differential LoRA Training: [doc](../Training/Differential_LoRA.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/qwen_image/model_training/special/differential_training/)
|
||||||
* FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](/examples/qwen_image/model_training/special/fp8_training/)
|
* FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/qwen_image/model_training/special/fp8_training/)
|
||||||
* Two-stage Split Training: [doc](../Training/Split_Training.md), [code](/examples/qwen_image/model_training/special/split_training/)
|
* Two-stage Split Training: [doc](../Training/Split_Training.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/qwen_image/model_training/special/split_training/)
|
||||||
* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md), [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)
|
* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md), [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)
|
||||||
|
|
||||||
DeepSpeed ZeRO Stage 3 Training: The Qwen-Image series models support DeepSpeed ZeRO Stage 3 training, which partitions the model across multiple GPUs. Taking full parameter training of the Qwen-Image model as an example, the following modifications are required:
|
DeepSpeed ZeRO Stage 3 Training: The Qwen-Image series models support DeepSpeed ZeRO Stage 3 training, which partitions the model across multiple GPUs. Taking full parameter training of the Qwen-Image model as an example, the following modifications are required:
|
||||||
|
|
||||||
@@ -149,7 +152,7 @@ If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_
|
|||||||
|
|
||||||
## Model Training
|
## Model Training
|
||||||
|
|
||||||
Qwen-Image series models are uniformly trained through [`examples/qwen_image/model_training/train.py`](/examples/qwen_image/model_training/train.py), and the script parameters include:
|
Qwen-Image series models are uniformly trained through [`examples/qwen_image/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/train.py), and the script parameters include:
|
||||||
|
|
||||||
* General Training Parameters
|
* General Training Parameters
|
||||||
* Dataset Basic Configuration
|
* Dataset Basic Configuration
|
||||||
@@ -196,7 +199,7 @@ Qwen-Image series models are uniformly trained through [`examples/qwen_image/mod
|
|||||||
We have built a sample image dataset for your testing. You can download this dataset with the following command:
|
We have built a sample image dataset for your testing. You can download this dataset with the following command:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
|
||||||
```
|
```
|
||||||
|
|
||||||
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/).
|
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).
|
||||||
|
|||||||
@@ -104,43 +104,47 @@ graph LR;
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
| Model ID | Extra Parameters | Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
| Model ID | Extra Inputs | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
||||||
| - | - | - | - | - | - | - |
|
|-|-|-|-|-|-|-|-|
|
||||||
| [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | | [code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py) |
|
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||||
| [Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | | [code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py) |
|
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||||
| [Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py) |
|
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||||
| [Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py) |
|
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|
||||||
| [Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py) |
|
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|
||||||
| [iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py) |
|
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|
||||||
| [Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py) |
|
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
||||||
| [Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py) |
|
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
||||||
| [PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py) |
|
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|
||||||
| [PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control) | `control_video` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py) |
|
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|
||||||
| [PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py) |
|
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|
||||||
| [PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control) | `control_video` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py) |
|
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|
||||||
| [PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py) |
|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|
||||||
| [PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py) |
|
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|
||||||
| [PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py) |
|
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|
||||||
| [PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py) |
|
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|
||||||
| [PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) |
|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|
||||||
| [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py) |
|
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|
||||||
| [DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1) | `motion_bucket_id` | [code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py) |
|
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
||||||
| [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) | | [code](/examples/wanvideo/model_inference/krea-realtime-video.py) | [code](/examples/wanvideo/model_training/full/krea-realtime-video.sh) | [code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py) | [code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh) | [code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py) |
|
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
|
||||||
| [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) | `longcat_video` | [code](/examples/wanvideo/model_inference/LongCat-Video.py) | [code](/examples/wanvideo/model_training/full/LongCat-Video.sh) | [code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py) | [code](/examples/wanvideo/model_training/lora/LongCat-Video.sh) | [code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py) |
|
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
|
||||||
| [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) | `vap_video`, `vap_prompt` | [code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py) | [code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py) | [code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py) |
|
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
|
||||||
| [Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B) | | [code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py) |
|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||||
| [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py) |
|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||||
| [Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py) |
|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||||
| [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B) | `input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video` | [code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py) |
|
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
||||||
| [Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B) | `input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video` | [code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py) |
|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B_multi_clips.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
|
||||||
| [PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py) |
|
|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|
||||||
| [PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py) |
|
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
||||||
| [PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py) |
|
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
||||||
| [PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py) |
|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
||||||
|
|[openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference_low_vram/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-360P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-360P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py)|
|
||||||
|
|[openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference_low_vram/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-720P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-720P-I2AV.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py)|
|
||||||
|
|[Wan-AI/WanToDance-14B (global model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py)|
|
||||||
|
|[Wan-AI/WanToDance-14B (local model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B)|`wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference_low_vram/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-local.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-local.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-local.py)|
|
||||||
|
|
||||||
* FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](/examples/wanvideo/model_training/special/fp8_training/)
|
* FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/fp8_training/)
|
||||||
* Two-stage Split Training: [doc](../Training/Split_Training.md), [code](/examples/wanvideo/model_training/special/split_training/)
|
* Two-stage Split Training: [doc](../Training/Split_Training.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/split_training/)
|
||||||
* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md), [code](/examples/wanvideo/model_training/special/direct_distill/)
|
* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/direct_distill/)
|
||||||
|
|
||||||
DeepSpeed ZeRO Stage 3 Training: The Wan series models support DeepSpeed ZeRO Stage 3 training, which partitions the model across multiple GPUs. Taking full parameter training of the Wan2.1-T2V-14B model as an example, the following modifications are required:
|
DeepSpeed ZeRO Stage 3 Training: The Wan series models support DeepSpeed ZeRO Stage 3 training, which partitions the model across multiple GPUs. Taking full parameter training of the Wan2.1-T2V-14B model as an example, the following modifications are required:
|
||||||
|
|
||||||
@@ -201,9 +205,53 @@ Input parameters for `WanVideoPipeline` inference include:
|
|||||||
|
|
||||||
If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
|
If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
|
||||||
|
|
||||||
|
### Multi-GPU Parallel Acceleration
|
||||||
|
|
||||||
|
To enable multi-GPU parallel acceleration, please install `flash_attn` and `xfuser`:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install flash-attn --no-build-isolation
|
||||||
|
pip install xfuser
|
||||||
|
```
|
||||||
|
|
||||||
|
Please modify your code as follows ([example code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/acceleration/unified_sequence_parallel.py)):
|
||||||
|
|
||||||
|
```diff
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from diffsynth.utils.data import save_video, VideoData
|
||||||
|
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||||
|
+ import torch.distributed as dist
|
||||||
|
|
||||||
|
pipe = WanVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
+ use_usp=True,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
||||||
|
)
|
||||||
|
video = pipe(
|
||||||
|
prompt="An astronaut in a spacesuit rides a mechanical horse across the Martian surface, facing the camera. The red, desolate terrain stretches into the distance, dotted with massive craters and unusual rock formations. The mechanical horse moves with steady strides, kicking up faint dust, embodying a perfect fusion of futuristic technology and primal exploration. The astronaut holds a control device, with a determined gaze, as if pioneering new frontiers for humanity. Against a backdrop of the deep cosmos and the blue Earth, the scene is both sci-fi and hopeful, evoking imagination about future interstellar life.",
|
||||||
|
negative_prompt="oversaturated colors, overexposed, static, blurry details, subtitles, style, artwork, painting, still image, overall gray tone, worst quality, low quality, JPEG compression artifacts, ugly, malformed, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, malformed limbs, fused fingers, frozen frame, cluttered background, three legs, crowd in background, walking backwards",
|
||||||
|
seed=0, tiled=True,
|
||||||
|
)
|
||||||
|
+ if dist.get_rank() == 0:
|
||||||
|
+ save_video(video, "video1.mp4", fps=15, quality=5)
|
||||||
|
```
|
||||||
|
|
||||||
|
When running multi-GPU parallel inference, please use `torchrun`, where `--nproc_per_node` specifies the number of GPUs:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
torchrun --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py
|
||||||
|
```
|
||||||
|
|
||||||
## Model Training
|
## Model Training
|
||||||
|
|
||||||
Wan series models are uniformly trained through [`examples/wanvideo/model_training/train.py`](/examples/wanvideo/model_training/train.py), and the script parameters include:
|
Wan series models are uniformly trained through [`examples/wanvideo/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/train.py), and the script parameters include:
|
||||||
|
|
||||||
* General Training Parameters
|
* General Training Parameters
|
||||||
* Dataset Basic Configuration
|
* Dataset Basic Configuration
|
||||||
@@ -251,7 +299,7 @@ Wan series models are uniformly trained through [`examples/wanvideo/model_traini
|
|||||||
We have built a sample video dataset for your testing. You can download this dataset with the following command:
|
We have built a sample video dataset for your testing. You can download this dataset with the following command:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
modelscope download --dataset DiffSynth-Studio/example_video_dataset --local_dir ./data/example_video_dataset
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
|
||||||
```
|
```
|
||||||
|
|
||||||
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/).
|
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).
|
||||||
|
|||||||
@@ -52,17 +52,17 @@ image.save("image.jpg")
|
|||||||
|
|
||||||
|Model ID|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
|
|Model ID|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
|
||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image.py)|
|
||||||
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-i2L.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|
||||||
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|
||||||
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|
||||||
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|
||||||
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
|
||||||
|
|
||||||
Special Training Scripts:
|
Special Training Scripts:
|
||||||
|
|
||||||
* Differential LoRA Training: [doc](../Training/Differential_LoRA.md), [code](/examples/z_image/model_training/special/differential_training/)
|
* Differential LoRA Training: [doc](../Training/Differential_LoRA.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/z_image/model_training/special/differential_training/)
|
||||||
* Trajectory Imitation Distillation Training (Experimental Feature): [code](/examples/z_image/model_training/special/trajectory_imitation/)
|
* Trajectory Imitation Distillation Training (Experimental Feature): [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/z_image/model_training/special/trajectory_imitation/)
|
||||||
|
|
||||||
## Model Inference
|
## Model Inference
|
||||||
|
|
||||||
@@ -88,7 +88,7 @@ If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_
|
|||||||
|
|
||||||
## Model Training
|
## Model Training
|
||||||
|
|
||||||
Z-Image series models are uniformly trained through [`examples/z_image/model_training/train.py`](/examples/z_image/model_training/train.py), and the script parameters include:
|
Z-Image series models are uniformly trained through [`examples/z_image/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/train.py), and the script parameters include:
|
||||||
|
|
||||||
* General Training Parameters
|
* General Training Parameters
|
||||||
* Dataset Basic Configuration
|
* Dataset Basic Configuration
|
||||||
@@ -134,16 +134,16 @@ Z-Image series models are uniformly trained through [`examples/z_image/model_tra
|
|||||||
We have built a sample image dataset for your testing. You can download this dataset with the following command:
|
We have built a sample image dataset for your testing. You can download this dataset with the following command:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
|
||||||
```
|
```
|
||||||
|
|
||||||
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/).
|
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).
|
||||||
|
|
||||||
Training Tips:
|
Training Tips:
|
||||||
|
|
||||||
* [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) is a distilled acceleration model. Therefore, direct training will quickly cause the model to lose its acceleration capability. The effect of inference with "acceleration configuration" (`num_inference_steps=8`, `cfg_scale=1`) becomes worse, while the effect of inference with "no acceleration configuration" (`num_inference_steps=30`, `cfg_scale=2`) becomes better. The following training and inference schemes can be adopted:
|
* [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) is a distilled acceleration model. Therefore, direct training will quickly cause the model to lose its acceleration capability. The effect of inference with "acceleration configuration" (`num_inference_steps=8`, `cfg_scale=1`) becomes worse, while the effect of inference with "no acceleration configuration" (`num_inference_steps=30`, `cfg_scale=2`) becomes better. The following training and inference schemes can be adopted:
|
||||||
* Standard SFT Training ([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + No Acceleration Configuration Inference
|
* Standard SFT Training ([code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + No Acceleration Configuration Inference
|
||||||
* Differential LoRA Training ([code](/examples/z_image/model_training/special/differential_training/)) + Acceleration Configuration Inference
|
* Differential LoRA Training ([code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/z_image/model_training/special/differential_training/)) + Acceleration Configuration Inference
|
||||||
* An additional LoRA needs to be loaded in differential LoRA training, e.g., [ostris/zimage_turbo_training_adapter](https://www.modelscope.cn/models/ostris/zimage_turbo_training_adapter)
|
* An additional LoRA needs to be loaded in differential LoRA training, e.g., [ostris/zimage_turbo_training_adapter](https://www.modelscope.cn/models/ostris/zimage_turbo_training_adapter)
|
||||||
* Standard SFT Training ([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Trajectory Imitation Distillation Training ([code](/examples/z_image/model_training/special/trajectory_imitation/)) + Acceleration Configuration Inference
|
* Standard SFT Training ([code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Trajectory Imitation Distillation Training ([code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/z_image/model_training/special/trajectory_imitation/)) + Acceleration Configuration Inference
|
||||||
* Standard SFT Training ([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Load Distillation Acceleration LoRA During Inference ([model](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-Turbo-DistillPatch)) + Acceleration Configuration Inference
|
* Standard SFT Training ([code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Load Distillation Acceleration LoRA During Inference ([model](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-Turbo-DistillPatch)) + Acceleration Configuration Inference
|
||||||
|
|||||||
94
docs/en/Pipeline_Usage/Accelerated_Inference.md
Normal file
94
docs/en/Pipeline_Usage/Accelerated_Inference.md
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
# Inference Acceleration
|
||||||
|
|
||||||
|
The denoising process of diffusion models is typically time-consuming. To improve inference speed, various acceleration techniques can be applied, including lossless acceleration solutions such as multi-GPU parallel inference and computation graph compilation, as well as lossy acceleration solutions like Cache and quantization.
|
||||||
|
|
||||||
|
Currently, most diffusion models are built on Diffusion Transformers (DiT), and efficient attention mechanisms are also a common acceleration method. DiffSynth-Studio currently supports certain lossless acceleration inference features. This section focuses on introducing acceleration methods from two dimensions: multi-GPU parallel inference and computation graph compilation.
|
||||||
|
|
||||||
|
## Efficient Attention Mechanisms
|
||||||
|
|
||||||
|
For details on the acceleration of attention mechanisms, please refer to [Attention Mechanism Implementation](../API_Reference/core/attention.md).
|
||||||
|
|
||||||
|
## Multi-GPU Parallel Inference
|
||||||
|
|
||||||
|
DiffSynth-Studio adopts a multi-GPU inference solution using Unified Sequence Parallel (USP). It splits the token sequence in the DiT across multiple GPUs for parallel processing. The underlying implementation is based on [xDiT](https://github.com/xdit-project/xDiT). Please note that unified sequence parallelism introduces additional communication overhead, so the actual speedup ratio is usually lower than the number of GPUs.
|
||||||
|
|
||||||
|
Currently, DiffSynth-Studio supports unified sequence parallel acceleration for the [Wan](../Model_Details/Wan.md) and [MOVA](../Model_Details/Wan.md) models.
|
||||||
|
|
||||||
|
First, install the `xDiT` dependency.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install "xfuser[flash-attn]>=0.4.3"
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, use `torchrun` to launch multi-GPU inference.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
torchrun --standalone --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py
|
||||||
|
```
|
||||||
|
|
||||||
|
When building the pipeline, simply configure `use_usp=True` to enable USP parallel inference. A code example is shown below.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from diffsynth.utils.data import save_video
|
||||||
|
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
pipe = WanVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
use_usp=True,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Text-to-video
|
||||||
|
video = pipe(
|
||||||
|
prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
|
||||||
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
seed=0, tiled=True,
|
||||||
|
)
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Computation Graph Compilation
|
||||||
|
|
||||||
|
PyTorch 2.0 provides an automatic computation graph compilation interface, [torch.compile](https://docs.pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), which can just-in-time (JIT) compile PyTorch code into optimized kernels, thereby improving execution speed. Since the inference time of diffusion models is concentrated in the multi-step denoising phase of the DiT, and the DiT is primarily stacked with basic blocks, DiffSynth's compile feature uses a [regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) strategy targeting only the basic Transformer blocks to reduce compilation time.
|
||||||
|
|
||||||
|
### Compile Usage Example
|
||||||
|
|
||||||
|
Compared to standard inference, you simply need to execute `pipe.compile_pipeline()` before calling the pipeline to enable compilation acceleration. For the specific function definition, please refer to the [source code](https://github.com/modelscope/DiffSynth-Studio/blob/166e6d2d38764209f66a74dd0fe468226536ad0f/diffsynth/diffusion/base_pipeline.py#L342).
|
||||||
|
|
||||||
|
The input parameters for `compile_pipeline` consist mainly of two types.
|
||||||
|
|
||||||
|
The first type is the compiled model parameters, `compile_models`. Taking the Qwen-Image Pipeline as an example, if you only want to compile the DiT model, you can keep this parameter empty. If you need to additionally compile models like the VAE, you can pass `compile_models=["vae", "dit"]`. Aside from DiT, all other models use a full-graph compilation strategy, meaning the model's forward function is completely compiled into a computation graph.
|
||||||
|
|
||||||
|
The second type is the compilation strategy parameters. This covers `mode`, `dynamic`, `fullgraph`, and other custom options. These parameters are directly passed to the `torch.compile` interface. If you are not deeply familiar with the specific mechanics of these parameters, it is recommended to keep the default settings.
|
||||||
|
|
||||||
|
* `mode` specifies the compilation mode, including `"default"`, `"reduce-overhead"`, `"max-autotune"`, and `"max-autotune-no-cudagraphs"`. Because cudagraph has stricter requirements on computation graphs (for example, it might need to be used in conjunction with `torch.compiler.cudagraph_mark_step_begin()`), the `"reduce-overhead"` and `"max-autotune"` modes might fail to compile.
|
||||||
|
* `dynamic` determines whether to enable dynamic shapes. For most generative models, modifying the prompt, enabling CFG, or adjusting the resolution will change the shape of the input tensors to the computation graph. Setting `dynamic=True` will increase the compilation time of the first run, but it supports dynamic shapes, meaning no recompilation is needed when shapes change. When set to `dynamic=False`, the first compilation is faster, but any operation that alters the input shape will trigger a recompilation. For most scenarios, setting it to `dynamic=True` is recommended.
|
||||||
|
* `fullgraph`, when set to `True`, makes the underlying system attempt to compile the target model into a single computation graph, throwing an error if it fails. When set to `False`, the underlying system will set breakpoints where connections cannot be made, compiling the model into multiple independent computation graphs. Developers can set it to `True` to optimize compilation performance, but regular users are advised to only use `False`.
|
||||||
|
* For other parameter configurations, please consult the [API documentation](https://docs.pytorch.org/docs/stable/generated/torch.compile.html).
|
||||||
|
|
||||||
|
### Compile Feature Developer Documentation
|
||||||
|
|
||||||
|
If you need to provide compile support for a newly integrated pipeline, you should configure the `compilable_models` attribute in the pipeline to specify the default models to compile. For the DiT model class of that pipeline, you also need to configure `_repeated_blocks` to specify the types of basic blocks that will participate in regional compilation.
|
||||||
|
|
||||||
|
Taking Qwen-Image as an example, its pipeline configuration is as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
self.compilable_models = ["dit"]
|
||||||
|
```
|
||||||
|
|
||||||
|
Its DiT configuration is as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class QwenImageDiT(torch.nn.Module):
|
||||||
|
_repeated_blocks = ["QwenImageTransformerBlock"]
|
||||||
|
```
|
||||||
@@ -91,3 +91,4 @@ Set 0 or not set: indicates not enabling the binding function
|
|||||||
|----------------|---------------------------|-------------------|
|
|----------------|---------------------------|-------------------|
|
||||||
| Wan 14B series | --initialize_model_on_cpu | The 14B model needs to be initialized on the CPU |
|
| Wan 14B series | --initialize_model_on_cpu | The 14B model needs to be initialized on the CPU |
|
||||||
| Qwen-Image series | --initialize_model_on_cpu | The model needs to be initialized on the CPU |
|
| Qwen-Image series | --initialize_model_on_cpu | The model needs to be initialized on the CPU |
|
||||||
|
| Z-Image series | --enable_npu_patch | Using NPU fusion operator to replace the corresponding operator in Z-image model to improve the performance of the model on NPU |
|
||||||
@@ -69,25 +69,11 @@ We have built sample datasets for your testing. To understand how the universal
|
|||||||
|
|
||||||
<details>
|
<details>
|
||||||
|
|
||||||
<summary>Sample Image Dataset</summary>
|
<summary>Sample Dataset</summary>
|
||||||
|
|
||||||
> ```shell
|
> ```shell
|
||||||
> modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
> modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
|
||||||
> ```
|
> ```
|
||||||
>
|
|
||||||
> Applicable to training of image generation models such as Qwen-Image and FLUX.
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
|
|
||||||
<summary>Sample Video Dataset</summary>
|
|
||||||
|
|
||||||
> ```shell
|
|
||||||
> modelscope download --dataset DiffSynth-Studio/example_video_dataset --local_dir ./data/example_video_dataset
|
|
||||||
> ```
|
|
||||||
>
|
|
||||||
> Applicable to training of video generation models such as Wan.
|
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@@ -123,7 +109,6 @@ Similar to [model loading during inference](../Pipeline_Usage/Model_Inference.md
|
|||||||
|
|
||||||
<details>
|
<details>
|
||||||
|
|
||||||
<details>
|
|
||||||
|
|
||||||
<summary>Load models from local file paths</summary>
|
<summary>Load models from local file paths</summary>
|
||||||
|
|
||||||
@@ -245,3 +230,118 @@ accelerate launch --config_file examples/qwen_image/model_training/full/accelera
|
|||||||
* Some models contain redundant parameters. For example, the text encoding part of the last layer of Qwen-Image's DiT part. When training these models, `--find_unused_parameters` needs to be set to avoid errors in multi-GPU training. For compatibility with community models, we do not intend to remove these redundant parameters.
|
* Some models contain redundant parameters. For example, the text encoding part of the last layer of Qwen-Image's DiT part. When training these models, `--find_unused_parameters` needs to be set to avoid errors in multi-GPU training. For compatibility with community models, we do not intend to remove these redundant parameters.
|
||||||
* The loss function value of Diffusion models has little relationship with actual effects. Therefore, we do not record loss function values during training. We recommend setting `--num_epochs` to a sufficiently large value, testing while training, and manually closing the training program after the effect converges.
|
* The loss function value of Diffusion models has little relationship with actual effects. Therefore, we do not record loss function values during training. We recommend setting `--num_epochs` to a sufficiently large value, testing while training, and manually closing the training program after the effect converges.
|
||||||
* `--use_gradient_checkpointing` is usually enabled unless GPU VRAM is sufficient; `--use_gradient_checkpointing_offload` is enabled as needed. See [`diffsynth.core.gradient`](../API_Reference/core/gradient.md) for details.
|
* `--use_gradient_checkpointing` is usually enabled unless GPU VRAM is sufficient; `--use_gradient_checkpointing_offload` is enabled as needed. See [`diffsynth.core.gradient`](../API_Reference/core/gradient.md) for details.
|
||||||
|
|
||||||
|
## Low VRAM Training
|
||||||
|
|
||||||
|
If you want to complete LoRA model training on GPU with low vram, you can combine [Two-Stage Split Training](../Training/Split_Training.md) with `deepspeed_zero3_offload` training. First, split the preprocessing steps into the first stage and store the computed results onto the hard disk. Second, read these results from the disk and train the denoising model. By using `deepspeed_zero3_offload`, the training parameters and optimizer states are offloaded to the CPU or disk. We provide examples for some models, primarily by specifying the `deepspeed` configuration via `--config_file`.
|
||||||
|
|
||||||
|
Please note that the `deepspeed_zero3_offload` mode is incompatible with PyTorch's native gradient checkpointing mechanism. To address this, we have adapted the `checkpointing` interface of `deepspeed`. Users need to fill the `activation_checkpointing` field in the `deepspeed` configuration to enable gradient checkpointing.
|
||||||
|
|
||||||
|
Below is the script for low VRAM model training for the Qwen-Image model:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
accelerate launch examples/qwen_image/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_image_dataset \
|
||||||
|
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||||
|
--max_pixels 1048576 \
|
||||||
|
--dataset_repeat 1 \
|
||||||
|
--model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/Qwen-Image_lora-splited-cache" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--task "sft:data_process" \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--dataset_num_workers 8 \
|
||||||
|
--find_unused_parameters
|
||||||
|
|
||||||
|
accelerate launch --config_file examples/qwen_image/model_training/special/low_vram_training/deepspeed_zero3_cpuoffload.yaml examples/qwen_image/model_training/train.py \
|
||||||
|
--dataset_base_path "./models/train/Qwen-Image_lora-splited-cache" \
|
||||||
|
--max_pixels 1048576 \
|
||||||
|
--dataset_repeat 50 \
|
||||||
|
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/Qwen-Image_lora" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--task "sft:train" \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--dataset_num_workers 8 \
|
||||||
|
--find_unused_parameters \
|
||||||
|
--initialize_model_on_cpu
|
||||||
|
```
|
||||||
|
|
||||||
|
The configurations for `accelerate` and `deepspeed` are as follows:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: true
|
||||||
|
deepspeed_config:
|
||||||
|
deepspeed_config_file: examples/qwen_image/model_training/special/low_vram_training/ds_z3_cpuoffload.json
|
||||||
|
zero3_init_flag: true
|
||||||
|
distributed_type: DEEPSPEED
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
enable_cpu_affinity: false
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 1
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
||||||
|
```
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"fp16": {
|
||||||
|
"enabled": "auto",
|
||||||
|
"loss_scale": 0,
|
||||||
|
"loss_scale_window": 1000,
|
||||||
|
"initial_scale_power": 16,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"min_loss_scale": 1
|
||||||
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": "auto"
|
||||||
|
},
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 3,
|
||||||
|
"offload_optimizer": {
|
||||||
|
"device": "cpu",
|
||||||
|
"pin_memory": true
|
||||||
|
},
|
||||||
|
"offload_param": {
|
||||||
|
"device": "cpu",
|
||||||
|
"pin_memory": true
|
||||||
|
},
|
||||||
|
"overlap_comm": false,
|
||||||
|
"contiguous_gradients": true,
|
||||||
|
"sub_group_size": 1e9,
|
||||||
|
"reduce_bucket_size": 5e7,
|
||||||
|
"stage3_prefetch_bucket_size": 5e7,
|
||||||
|
"stage3_param_persistence_threshold": 1e5,
|
||||||
|
"stage3_max_live_parameters": 1e8,
|
||||||
|
"stage3_max_reuse_distance": 1e8,
|
||||||
|
"stage3_gather_16bit_weights_on_model_save": true
|
||||||
|
},
|
||||||
|
"activation_checkpointing": {
|
||||||
|
"partition_activations": false,
|
||||||
|
"cpu_checkpointing": false,
|
||||||
|
"contiguous_memory_optimization": false
|
||||||
|
},
|
||||||
|
"gradient_accumulation_steps": "auto",
|
||||||
|
"gradient_clipping": "auto",
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
"wall_clock_breakdown": false
|
||||||
|
}
|
||||||
|
```
|
||||||
@@ -37,9 +37,9 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm6
|
|||||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
cd DiffSynth-Studio
|
cd DiffSynth-Studio
|
||||||
# aarch64/ARM
|
# aarch64/ARM
|
||||||
pip install -e .[npu_aarch64] --extra-index-url "https://download.pytorch.org/whl/cpu"
|
pip install -e .[npu_aarch64]
|
||||||
# x86
|
# x86
|
||||||
pip install -e .[npu]
|
pip install -e .[npu] --extra-index-url "https://download.pytorch.org/whl/cpu"
|
||||||
|
|
||||||
When using Ascend NPU, please replace `"cuda"` with `"npu"` in your Python code. For details, see [NPU Support](../Pipeline_Usage/GPU_support.md#ascend-npu).
|
When using Ascend NPU, please replace `"cuda"` with `"npu"` in your Python code. For details, see [NPU Support](../Pipeline_Usage/GPU_support.md#ascend-npu).
|
||||||
|
|
||||||
|
|||||||
@@ -42,6 +42,8 @@ This section introduces the Diffusion models supported by `DiffSynth-Studio`. So
|
|||||||
* [Qwen-Image](./Model_Details/Qwen-Image.md)
|
* [Qwen-Image](./Model_Details/Qwen-Image.md)
|
||||||
* [FLUX.2](./Model_Details/FLUX2.md)
|
* [FLUX.2](./Model_Details/FLUX2.md)
|
||||||
* [Z-Image](./Model_Details/Z-Image.md)
|
* [Z-Image](./Model_Details/Z-Image.md)
|
||||||
|
* [Anima](./Model_Details/Anima.md)
|
||||||
|
* [LTX-2](./Model_Details/LTX-2.md)
|
||||||
|
|
||||||
## Section 3: Training Framework
|
## Section 3: Training Framework
|
||||||
|
|
||||||
@@ -78,7 +80,7 @@ This section introduces the independent core module `diffsynth.core` in `DiffSyn
|
|||||||
This section introduces how to use `DiffSynth-Studio` to train new models, helping researchers explore new model technologies.
|
This section introduces how to use `DiffSynth-Studio` to train new models, helping researchers explore new model technologies.
|
||||||
|
|
||||||
* [Training models from scratch](./Research_Tutorial/train_from_scratch.md)
|
* [Training models from scratch](./Research_Tutorial/train_from_scratch.md)
|
||||||
* Inference improvement techniques 【coming soon】
|
* [Inference improvement techniques](./Research_Tutorial/inference_time_scaling.md)
|
||||||
* Designing controllable generation models 【coming soon】
|
* Designing controllable generation models 【coming soon】
|
||||||
* Creating new training paradigms 【coming soon】
|
* Creating new training paradigms 【coming soon】
|
||||||
|
|
||||||
|
|||||||
236
docs/en/Research_Tutorial/inference_time_scaling.ipynb
Normal file
236
docs/en/Research_Tutorial/inference_time_scaling.ipynb
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "8db54992",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Inference Optimization Techniques\n",
|
||||||
|
"\n",
|
||||||
|
"DiffSynth-Studio aims to drive technological innovation through its foundational framework. This article demonstrates how to build a training-free image generation enhancement solution using DiffSynth-Studio, taking Inference-time scaling as an example."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "0911cad4",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 1. Image Quality Quantification\n",
|
||||||
|
"\n",
|
||||||
|
"First, we need to find an indicator to quantify image quality from generation models. Manual scoring is the most straightforward solution but too costly for large-scale applications. However, after collecting manual scores, training an image classification model to predict human scoring is completely feasible. PickScore [[1]](https://arxiv.org/abs/2305.01569) is such a model. Running the following code will automatically download and load the [PickScore model](https://modelscope.cn/models/AI-ModelScope/PickScore_v1)."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "4faca4ca",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from modelscope import AutoProcessor, AutoModel\n",
|
||||||
|
"import torch\n",
|
||||||
|
"\n",
|
||||||
|
"class PickScore(torch.nn.Module):\n",
|
||||||
|
" def __init__(self):\n",
|
||||||
|
" super().__init__()\n",
|
||||||
|
" self.processor = AutoProcessor.from_pretrained(\"laion/CLIP-ViT-H-14-laion2B-s32B-b79K\")\n",
|
||||||
|
" self.model = AutoModel.from_pretrained(\"AI-ModelScope/PickScore_v1\").eval().to(\"cuda\")\n",
|
||||||
|
"\n",
|
||||||
|
" def forward(self, image, prompt):\n",
|
||||||
|
" image_inputs = self.processor(images=image, padding=True, truncation=True, max_length=77, return_tensors=\"pt\").to(\"cuda\")\n",
|
||||||
|
" text_inputs = self.processor(text=prompt, padding=True, truncation=True, max_length=77, return_tensors=\"pt\").to(\"cuda\")\n",
|
||||||
|
" with torch.inference_mode():\n",
|
||||||
|
" image_embs = self.model.get_image_features(**image_inputs).pooler_output\n",
|
||||||
|
" image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)\n",
|
||||||
|
" text_embs = self.model.get_text_features(**text_inputs).pooler_output\n",
|
||||||
|
" text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)\n",
|
||||||
|
" score = (text_embs @ image_embs.T).flatten().item()\n",
|
||||||
|
" return score\n",
|
||||||
|
"\n",
|
||||||
|
"reward_model = PickScore()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "5f807cec",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 2. Inference-time Scaling Techniques\n",
|
||||||
|
"\n",
|
||||||
|
"Inference-time Scaling [[2]](https://arxiv.org/abs/2504.00294) is an interesting technique aiming to improve generation quality by increasing computational costs during inference. For example, in language models, models like [Qwen/Qwen3.5-27B](https://modelscope.cn/models/Qwen/Qwen3.5-27B) and [deepseek-ai/DeepSeek-R1](deepseek-ai/DeepSeek-R1) use \"thinking mode\" to guide the model to spend more time considering results more carefully, producing more accurate answers. Next, we'll use the [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) model as an example to explore how to design Inference-time Scaling solutions for image generation models.\n",
|
||||||
|
"\n",
|
||||||
|
"> Before starting, we slightly modified the `Flux2ImagePipeline` code to allow initialization with specific Gaussian noise matrices for result reproducibility. See `Flux2Unit_NoiseInitializer` in [diffsynth/pipelines/flux2_image.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/pipelines/flux2_image.py).\n",
|
||||||
|
"\n",
|
||||||
|
"Run the following code to load the [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) model."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "c5818a87",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig\n",
|
||||||
|
"\n",
|
||||||
|
"pipe = Flux2ImagePipeline.from_pretrained(\n",
|
||||||
|
" torch_dtype=torch.bfloat16,\n",
|
||||||
|
" device=\"cuda\",\n",
|
||||||
|
" model_configs=[\n",
|
||||||
|
" ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"text_encoder/*.safetensors\"),\n",
|
||||||
|
" ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"transformer/*.safetensors\"),\n",
|
||||||
|
" ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"vae/diffusion_pytorch_model.safetensors\"),\n",
|
||||||
|
" ],\n",
|
||||||
|
" tokenizer_config=ModelConfig(model_id=\"black-forest-labs/FLUX.2-klein-4B\", origin_file_pattern=\"tokenizer/\"),\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "f58e9945",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Generate a sketch cat image using the prompt `\"sketch, a cat\"` and score it with the PickScore model."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6ea2d258",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def evaluate_noise(noise, pipe, reward_model, prompt):\n",
|
||||||
|
" # Generate an image and compute the score.\n",
|
||||||
|
" image = pipe(\n",
|
||||||
|
" prompt=prompt,\n",
|
||||||
|
" num_inference_steps=4,\n",
|
||||||
|
" initial_noise=noise,\n",
|
||||||
|
" progress_bar_cmd=lambda x: x,\n",
|
||||||
|
" )\n",
|
||||||
|
" score = reward_model(image, prompt)\n",
|
||||||
|
" return score\n",
|
||||||
|
"\n",
|
||||||
|
"torch.manual_seed(1)\n",
|
||||||
|
"prompt = \"sketch, a cat\"\n",
|
||||||
|
"noise = pipe.generate_noise((1, 128, 64, 64), rand_device=\"cuda\", rand_torch_dtype=pipe.torch_dtype)\n",
|
||||||
|
"\n",
|
||||||
|
"image_1 = pipe(prompt, num_inference_steps=4, initial_noise=noise)\n",
|
||||||
|
"print(\"Score:\", reward_model(image_1, prompt))\n",
|
||||||
|
"image_1"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "5e11694e",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### 2.1 Best-of-N Random Search\n",
|
||||||
|
"\n",
|
||||||
|
"Model generation results have inherent randomness. Different random seeds produce different images - sometimes high quality, sometimes low. This leads to a simple Inference-time scaling solution: generate images using multiple random seeds, score them with PickScore, and retain only the highest-scoring image."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "241f10d2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"\n",
|
||||||
|
"def random_search(base_latents, objective_reward_fn, total_eval_budget):\n",
|
||||||
|
" # Search for the noise randomly.\n",
|
||||||
|
" best_noise = base_latents\n",
|
||||||
|
" best_score = objective_reward_fn(base_latents)\n",
|
||||||
|
" for it in tqdm(range(total_eval_budget - 1)):\n",
|
||||||
|
" noise = pipe.generate_noise((1, 128, 64, 64), seed=None)\n",
|
||||||
|
" score = objective_reward_fn(noise)\n",
|
||||||
|
" if score > best_score:\n",
|
||||||
|
" best_score, best_noise = score, noise\n",
|
||||||
|
" return best_noise\n",
|
||||||
|
"\n",
|
||||||
|
"best_noise = random_search(\n",
|
||||||
|
" base_latents=noise,\n",
|
||||||
|
" objective_reward_fn=lambda noise: evaluate_noise(noise, pipe, reward_model, prompt),\n",
|
||||||
|
" total_eval_budget=50,\n",
|
||||||
|
")\n",
|
||||||
|
"image_2 = pipe(prompt, num_inference_steps=4, initial_noise=best_noise)\n",
|
||||||
|
"print(\"Score:\", reward_model(image_2, prompt))\n",
|
||||||
|
"image_2"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "8e9bf966",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"We can clearly see that after multiple random searches, the final selected cat image shows richer fur details and significantly improved PickScore. However, this brute-force random search is extremely inefficient - generation time multiplies while easily hitting quality limits. Therefore, we need a more efficient search method that achieves higher scores within the same computational budget."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "c9578349",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### 2.2 SES Search\n",
|
||||||
|
"\n",
|
||||||
|
"To overcome random search limitations, we introduce the Spectral Evolution Search (SES) algorithm [[3]](https://arxiv.org/abs/2602.03208). Detailed code is available at [diffsynth/utils/ses](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/utils/ses).\n",
|
||||||
|
"\n",
|
||||||
|
"Image generation in diffusion models is largely determined by low-frequency components in the initial noise. The SES algorithm decomposes Gaussian noise through wavelet transforms, fixes high-frequency details, and applies an evolution search using the cross-entropy method specifically on low-frequency components to find optimal initial noise with higher efficiency.\n",
|
||||||
|
"\n",
|
||||||
|
"Run the following code to perform efficient best Gaussian noise matrix search using SES."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "adeed2aa",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from diffsynth.utils.ses import ses_search\n",
|
||||||
|
"\n",
|
||||||
|
"best_noise = ses_search(\n",
|
||||||
|
" base_latents=noise,\n",
|
||||||
|
" objective_reward_fn=lambda noise: evaluate_noise(noise, pipe, reward_model, prompt),\n",
|
||||||
|
" total_eval_budget=50,\n",
|
||||||
|
")\n",
|
||||||
|
"image_3 = pipe(prompt, num_inference_steps=4, initial_noise=best_noise)\n",
|
||||||
|
"print(\"Score:\", reward_model(image_3, prompt))\n",
|
||||||
|
"image_3"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "940a97f1",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Observing the results, under the same computational budget, SES achieves significantly higher PickScore compared to random search. The \"sketch cat\" demonstrates more refined overall composition and more layered contrast between light and shadow.\n",
|
||||||
|
"\n",
|
||||||
|
"Inference-time scaling can achieve higher image quality at the cost of longer inference time. The generated image data can then be used to train the model itself through methods like DPO [[4]](https://arxiv.org/abs/2311.12908) or differential training [[5]](https://arxiv.org/abs/2412.12888), opening another interesting research direction."
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "dzj8",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.19"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
140
docs/en/Research_Tutorial/inference_time_scaling.md
Normal file
140
docs/en/Research_Tutorial/inference_time_scaling.md
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
# Inference Optimization Techniques
|
||||||
|
|
||||||
|
DiffSynth-Studio aims to drive technological innovation through its foundational framework. This article demonstrates how to build a training-free image generation enhancement solution using DiffSynth-Studio, taking Inference-time scaling as an example.
|
||||||
|
|
||||||
|
Notebook: https://github.com/modelscope/DiffSynth-Studio/blob/main/docs/en/Research_Tutorial/inference_time_scaling.ipynb
|
||||||
|
|
||||||
|
## 1. Image Quality Quantification
|
||||||
|
|
||||||
|
First, we need to find an indicator to quantify image quality from generation models. Manual scoring is the most straightforward solution but too costly for large-scale applications. However, after collecting manual scores, training an image classification model to predict human scoring is completely feasible. PickScore [[1]](https://arxiv.org/abs/2305.01569) is such a model. Running the following code will automatically download and load the [PickScore model](https://modelscope.cn/models/AI-ModelScope/PickScore_v1).
|
||||||
|
|
||||||
|
```python
|
||||||
|
from modelscope import AutoProcessor, AutoModel
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class PickScore(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
||||||
|
self.model = AutoModel.from_pretrained("AI-ModelScope/PickScore_v1").eval().to("cuda")
|
||||||
|
|
||||||
|
def forward(self, image, prompt):
|
||||||
|
image_inputs = self.processor(images=image, padding=True, truncation=True, max_length=77, return_tensors="pt").to("cuda")
|
||||||
|
text_inputs = self.processor(text=prompt, padding=True, truncation=True, max_length=77, return_tensors="pt").to("cuda")
|
||||||
|
with torch.inference_mode():
|
||||||
|
image_embs = self.model.get_image_features(**image_inputs).pooler_output
|
||||||
|
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
|
||||||
|
text_embs = self.model.get_text_features(**text_inputs).pooler_output
|
||||||
|
text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
|
||||||
|
score = (text_embs @ image_embs.T).flatten().item()
|
||||||
|
return score
|
||||||
|
|
||||||
|
reward_model = PickScore()
|
||||||
|
```
|
||||||
|
|
||||||
|
## 2. Inference-time Scaling Techniques
|
||||||
|
|
||||||
|
Inference-time Scaling [[2]](https://arxiv.org/abs/2504.00294) is an interesting technique aiming to improve generation quality by increasing computational costs during inference. For example, in language models, models like [Qwen/Qwen3.5-27B](https://modelscope.cn/models/Qwen/Qwen3.5-27B) and [deepseek-ai/DeepSeek-R1](deepseek-ai/DeepSeek-R1) use "thinking mode" to guide the model to spend more time considering results more carefully, producing more accurate answers. Next, we'll use the [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) model as an example to explore how to design Inference-time Scaling solutions for image generation models.
|
||||||
|
|
||||||
|
> Before starting, we slightly modified the `Flux2ImagePipeline` code to allow initialization with specific Gaussian noise matrices for result reproducibility. See `Flux2Unit_NoiseInitializer` in [diffsynth/pipelines/flux2_image.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/pipelines/flux2_image.py).
|
||||||
|
|
||||||
|
Run the following code to load the [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) model.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||||
|
|
||||||
|
pipe = Flux2ImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Generate a sketch cat image using the prompt `"sketch, a cat"` and score it with the PickScore model.
|
||||||
|
|
||||||
|
```python
|
||||||
|
def evaluate_noise(noise, pipe, reward_model, prompt):
|
||||||
|
# Generate an image and compute the score.
|
||||||
|
image = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
num_inference_steps=4,
|
||||||
|
initial_noise=noise,
|
||||||
|
progress_bar_cmd=lambda x: x,
|
||||||
|
)
|
||||||
|
score = reward_model(image, prompt)
|
||||||
|
return score
|
||||||
|
|
||||||
|
torch.manual_seed(1)
|
||||||
|
prompt = "sketch, a cat"
|
||||||
|
noise = pipe.generate_noise((1, 128, 64, 64), rand_device="cuda", rand_torch_dtype=pipe.torch_dtype)
|
||||||
|
|
||||||
|
image_1 = pipe(prompt, num_inference_steps=4, initial_noise=noise)
|
||||||
|
print("Score:", reward_model(image_1, prompt))
|
||||||
|
image_1
|
||||||
|
```
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
### 2.1 Best-of-N Random Search
|
||||||
|
|
||||||
|
Model generation results have inherent randomness. Different random seeds produce different images - sometimes high quality, sometimes low. This leads to a simple Inference-time scaling solution: generate images using multiple random seeds, score them with PickScore, and retain only the highest-scoring image.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
def random_search(base_latents, objective_reward_fn, total_eval_budget):
|
||||||
|
# Search for the noise randomly.
|
||||||
|
best_noise = base_latents
|
||||||
|
best_score = objective_reward_fn(base_latents)
|
||||||
|
for it in tqdm(range(total_eval_budget - 1)):
|
||||||
|
noise = pipe.generate_noise((1, 128, 64, 64), seed=None)
|
||||||
|
score = objective_reward_fn(noise)
|
||||||
|
if score > best_score:
|
||||||
|
best_score, best_noise = score, noise
|
||||||
|
return best_noise
|
||||||
|
|
||||||
|
best_noise = random_search(
|
||||||
|
base_latents=noise,
|
||||||
|
objective_reward_fn=lambda noise: evaluate_noise(noise, pipe, reward_model, prompt),
|
||||||
|
total_eval_budget=50,
|
||||||
|
)
|
||||||
|
image_2 = pipe(prompt, num_inference_steps=4, initial_noise=best_noise)
|
||||||
|
print("Score:", reward_model(image_2, prompt))
|
||||||
|
image_2
|
||||||
|
```
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
We can clearly see that after multiple random searches, the final selected cat image shows richer fur details and significantly improved PickScore. However, this brute-force random search is extremely inefficient - generation time multiplies while easily hitting quality limits. Therefore, we need a more efficient search method that achieves higher scores within the same computational budget.
|
||||||
|
|
||||||
|
### 2.2 SES Search
|
||||||
|
|
||||||
|
To overcome random search limitations, we introduce the Spectral Evolution Search (SES) algorithm [[3]](https://arxiv.org/abs/2602.03208). Detailed code is available at [diffsynth/utils/ses](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/utils/ses).
|
||||||
|
|
||||||
|
Image generation in diffusion models is largely determined by low-frequency components in the initial noise. The SES algorithm decomposes Gaussian noise through wavelet transforms, fixes high-frequency details, and applies an evolution search using the cross-entropy method specifically on low-frequency components to find optimal initial noise with higher efficiency.
|
||||||
|
|
||||||
|
Run the following code to perform efficient best Gaussian noise matrix search using SES.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.utils.ses import ses_search
|
||||||
|
|
||||||
|
best_noise = ses_search(
|
||||||
|
base_latents=noise,
|
||||||
|
objective_reward_fn=lambda noise: evaluate_noise(noise, pipe, reward_model, prompt),
|
||||||
|
total_eval_budget=50,
|
||||||
|
)
|
||||||
|
image_3 = pipe(prompt, num_inference_steps=4, initial_noise=best_noise)
|
||||||
|
print("Score:", reward_model(image_3, prompt))
|
||||||
|
image_3
|
||||||
|
```
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
Observing the results, under the same computational budget, SES achieves significantly higher PickScore compared to random search. The "sketch cat" demonstrates more refined overall composition and more layered contrast between light and shadow.
|
||||||
|
|
||||||
|
Inference-time scaling can achieve higher image quality at the cost of longer inference time. The generated image data can then be used to train the model itself through methods like DPO [[4]](https://arxiv.org/abs/2311.12908) or differential training [[5]](https://arxiv.org/abs/2412.12888), opening another interesting research direction.
|
||||||
@@ -137,7 +137,7 @@ Besides the Diffusion model used for denoising, we also need two other models:
|
|||||||
* Text Encoder: Used to encode text into tensors. We adopt the [Qwen/Qwen3-0.6B](https://modelscope.cn/models/Qwen/Qwen3-0.6B) model.
|
* Text Encoder: Used to encode text into tensors. We adopt the [Qwen/Qwen3-0.6B](https://modelscope.cn/models/Qwen/Qwen3-0.6B) model.
|
||||||
* VAE Encoder-Decoder: The encoder part is used to encode images into tensors, and the decoder part is used to decode image tensors into images. We adopt the VAE model from [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B).
|
* VAE Encoder-Decoder: The encoder part is used to encode images into tensors, and the decoder part is used to decode image tensors into images. We adopt the VAE model from [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B).
|
||||||
|
|
||||||
The architectures of these two models are already integrated in DiffSynth-Studio, located at [/diffsynth/models/z_image_text_encoder.py](/diffsynth/models/z_image_text_encoder.py) and [/diffsynth/models/flux2_vae.py](/diffsynth/models/flux2_vae.py), so we don't need to modify any code.
|
The architectures of these two models are already integrated in DiffSynth-Studio, located at [/diffsynth/models/z_image_text_encoder.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/models/z_image_text_encoder.py) and [/diffsynth/models/flux2_vae.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/models/flux2_vae.py), so we don't need to modify any code.
|
||||||
|
|
||||||
## 2. Building Pipeline
|
## 2. Building Pipeline
|
||||||
|
|
||||||
@@ -336,7 +336,7 @@ modelscope download --dataset DiffSynth-Studio/pokemon-gen1 --local_dir ./data
|
|||||||
|
|
||||||
### 4. Start Training
|
### 4. Start Training
|
||||||
|
|
||||||
The training process can be quickly implemented using Pipeline. We have placed the complete code at [../Research_Tutorial/train_from_scratch.py](../Research_Tutorial/train_from_scratch.py), which can be directly started with `python docs/en/Research_Tutorial/train_from_scratch.py` for single GPU training.
|
The training process can be quickly implemented using Pipeline. We have placed the complete code at [../Research_Tutorial/train_from_scratch.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/docs/en/Research_Tutorial/train_from_scratch.py), which can be directly started with `python docs/en/Research_Tutorial/train_from_scratch.py` for single GPU training.
|
||||||
|
|
||||||
To enable multi-GPU parallel training, please run `accelerate config` to set relevant parameters, then use the command `accelerate launch docs/en/Research_Tutorial/train_from_scratch.py` to start training.
|
To enable multi-GPU parallel training, please run `accelerate config` to set relevant parameters, then use the command `accelerate launch docs/en/Research_Tutorial/train_from_scratch.py` to start training.
|
||||||
|
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ distill_qwen/image.jpg,"精致肖像,水下少女,蓝裙飘逸,发丝轻
|
|||||||
This sample dataset can be downloaded directly:
|
This sample dataset can be downloaded directly:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
|
||||||
```
|
```
|
||||||
|
|
||||||
Then start LoRA distillation accelerated training:
|
Then start LoRA distillation accelerated training:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ FP8 precision is the only VRAM management strategy that can be enabled during tr
|
|||||||
|
|
||||||
## Enabling FP8
|
## Enabling FP8
|
||||||
|
|
||||||
In our provided training scripts, you can quickly set models to be stored in FP8 precision through the `--fp8_models` parameter. Taking Qwen-Image LoRA training as an example, we provide a script for enabling FP8 training located at [`/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh`](/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh). After training is completed, you can verify the training results with the script [`/examples/qwen_image/model_training/special/fp8_training/validate.py`](/examples/qwen_image/model_training/special/fp8_training/validate.py).
|
In our provided training scripts, you can quickly set models to be stored in FP8 precision through the `--fp8_models` parameter. Taking Qwen-Image LoRA training as an example, we provide a script for enabling FP8 training located at [`/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh). After training is completed, you can verify the training results with the script [`/examples/qwen_image/model_training/special/fp8_training/validate.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/special/fp8_training/validate.py).
|
||||||
|
|
||||||
Please note that this FP8 VRAM management strategy does not support gradient updates. When a model is set to be trainable, FP8 precision cannot be enabled for that model. Models that support FP8 include two types:
|
Please note that this FP8 VRAM management strategy does not support gradient updates. When a model is set to be trainable, FP8 precision cannot be enabled for that model. Models that support FP8 include two types:
|
||||||
|
|
||||||
|
|||||||
@@ -48,9 +48,10 @@ extensions = [
|
|||||||
'sphinx.ext.viewcode',
|
'sphinx.ext.viewcode',
|
||||||
'sphinx_markdown_tables',
|
'sphinx_markdown_tables',
|
||||||
'sphinx_copybutton',
|
'sphinx_copybutton',
|
||||||
|
"sphinx_rtd_theme",
|
||||||
|
'sphinx.ext.mathjax',
|
||||||
'myst_parser',
|
'myst_parser',
|
||||||
]
|
]
|
||||||
|
|
||||||
# build the templated autosummary files
|
# build the templated autosummary files
|
||||||
autosummary_generate = True
|
autosummary_generate = True
|
||||||
numpydoc_show_class_members = False
|
numpydoc_show_class_members = False
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ Welcome to DiffSynth-Studio's Documentation
|
|||||||
|
|
||||||
Pipeline_Usage/Setup
|
Pipeline_Usage/Setup
|
||||||
Pipeline_Usage/Model_Inference
|
Pipeline_Usage/Model_Inference
|
||||||
|
Pipeline_Usage/Accelerated_Inference
|
||||||
Pipeline_Usage/VRAM_management
|
Pipeline_Usage/VRAM_management
|
||||||
Pipeline_Usage/Model_Training
|
Pipeline_Usage/Model_Training
|
||||||
Pipeline_Usage/Environment_Variables
|
Pipeline_Usage/Environment_Variables
|
||||||
@@ -27,6 +28,11 @@ Welcome to DiffSynth-Studio's Documentation
|
|||||||
Model_Details/Qwen-Image
|
Model_Details/Qwen-Image
|
||||||
Model_Details/FLUX2
|
Model_Details/FLUX2
|
||||||
Model_Details/Z-Image
|
Model_Details/Z-Image
|
||||||
|
Model_Details/Anima
|
||||||
|
Model_Details/LTX-2
|
||||||
|
Model_Details/ERNIE-Image
|
||||||
|
Model_Details/JoyAI-Image
|
||||||
|
Model_Details/ACE-Step
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
@@ -63,6 +69,7 @@ Welcome to DiffSynth-Studio's Documentation
|
|||||||
:caption: Research Guide
|
:caption: Research Guide
|
||||||
|
|
||||||
Research_Tutorial/train_from_scratch
|
Research_Tutorial/train_from_scratch
|
||||||
|
Research_Tutorial/inference_time_scaling
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ recommonmark
|
|||||||
sphinx>=5.3.0
|
sphinx>=5.3.0
|
||||||
sphinx-book-theme
|
sphinx-book-theme
|
||||||
sphinx-copybutton
|
sphinx-copybutton
|
||||||
|
sphinx-autobuild
|
||||||
sphinx-rtd-theme
|
sphinx-rtd-theme
|
||||||
sphinx_markdown_tables
|
sphinx_markdown_tables
|
||||||
sphinxcontrib-mermaid
|
sphinxcontrib-mermaid
|
||||||
|
pymdown-extensions
|
||||||
164
docs/zh/Model_Details/ACE-Step.md
Normal file
164
docs/zh/Model_Details/ACE-Step.md
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
# ACE-Step
|
||||||
|
|
||||||
|
ACE-Step 1.5 是一个开源音乐生成模型,基于 DiT 架构,支持文生音乐、音频翻唱、局部重绘等多种功能,可在消费级硬件上高效运行。
|
||||||
|
|
||||||
|
## 安装
|
||||||
|
|
||||||
|
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.audio import save_audio
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
pipe = AceStepPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||||
|
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||||
|
audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
lyrics=lyrics,
|
||||||
|
duration=160,
|
||||||
|
bpm=100,
|
||||||
|
keyscale="B minor",
|
||||||
|
timesignature="4",
|
||||||
|
vocal_language="zh",
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
|
||||||
|
```
|
||||||
|
|
||||||
|
## 模型总览
|
||||||
|
|
||||||
|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/Ace-Step1.5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/Ace-Step1.5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/Ace-Step1.5.py)|
|
||||||
|
|[ACE-Step/acestep-v15-turbo-shift1](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-shift1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-shift1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift1.py)|
|
||||||
|
|[ACE-Step/acestep-v15-turbo-shift3](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift3)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-shift3.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-shift3.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift3.py)|
|
||||||
|
|[ACE-Step/acestep-v15-turbo-continuous](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-continuous)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-continuous.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-continuous.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-continuous.py)|
|
||||||
|
|[ACE-Step/acestep-v15-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-base.py)|
|
||||||
|
|[ACE-Step/acestep-v15-base: CoverTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py)|—|—|—|—|
|
||||||
|
|[ACE-Step/acestep-v15-base: RepaintTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py)|—|—|—|—|
|
||||||
|
|[ACE-Step/acestep-v15-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-sft)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-sft.py)|
|
||||||
|
|[ACE-Step/acestep-v15-xl-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-base.py)|
|
||||||
|
|[ACE-Step/acestep-v15-xl-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-sft)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-sft.py)|
|
||||||
|
|[ACE-Step/acestep-v15-xl-turbo](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-turbo)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-turbo.py)|
|
||||||
|
|
||||||
|
## 模型推理
|
||||||
|
|
||||||
|
模型通过 `AceStepPipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。
|
||||||
|
|
||||||
|
`AceStepPipeline` 推理的输入参数包括:
|
||||||
|
|
||||||
|
* `prompt`: 音乐文本描述。
|
||||||
|
* `cfg_scale`: 分类器无条件引导比例,默认为 1.0。
|
||||||
|
* `lyrics`: 歌词文本。
|
||||||
|
* `task_type`: 任务类型,可选值包括 `"text2music"`(文生音乐)、`"cover"`(音频翻唱)、`"repaint"`(局部重绘),默认为 `"text2music"`。
|
||||||
|
* `reference_audios`: 参考音频列表(Tensor 列表),用于提供音色参考。
|
||||||
|
* `src_audio`: 源音频(Tensor),用于 cover 或 repaint 任务。
|
||||||
|
* `denoising_strength`: 降噪强度,控制输出受源音频的影响程度,默认为 1.0。
|
||||||
|
* `audio_cover_strength`: 音频翻唱步数比例,控制 cover 任务中前多少步使用翻唱条件,默认为 1.0。
|
||||||
|
* `audio_code_string`: 输入音频码字符串,用于 cover 任务中直接传入离散音频码。
|
||||||
|
* `repainting_ranges`: 重绘时间区间(浮点元组列表,单位为秒),用于 repaint 任务。
|
||||||
|
* `repainting_strength`: 重绘强度,控制重绘区域的变化程度,默认为 1.0。
|
||||||
|
* `duration`: 音频时长(秒),默认为 60。
|
||||||
|
* `bpm`: 每分钟节拍数,默认为 100。
|
||||||
|
* `keyscale`: 音阶调式,默认为 "B minor"。
|
||||||
|
* `timesignature`: 拍号,默认为 "4"。
|
||||||
|
* `vocal_language`: 演唱语言,默认为 "unknown"。
|
||||||
|
* `seed`: 随机种子。
|
||||||
|
* `rand_device`: 噪声生成设备,默认为 "cpu"。
|
||||||
|
* `num_inference_steps`: 推理步数,默认为 8。
|
||||||
|
* `shift`: 调度器时间偏移参数,默认为 1.0。
|
||||||
|
|
||||||
|
## 模型训练
|
||||||
|
|
||||||
|
ace_step 系列模型统一通过 `examples/ace_step/model_training/train.py` 进行训练,脚本的参数包括:
|
||||||
|
|
||||||
|
* 通用训练参数
|
||||||
|
* 数据集基础配置
|
||||||
|
* `--dataset_base_path`: 数据集的根目录。
|
||||||
|
* `--dataset_metadata_path`: 数据集的元数据文件路径。
|
||||||
|
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||||
|
* `--dataset_num_workers`: 每个 Dataloader 的进程数量。
|
||||||
|
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
|
||||||
|
* 模型加载配置
|
||||||
|
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||||
|
* `--model_id_with_origin_paths`: 带原始路径的模型 ID。用逗号分隔。
|
||||||
|
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,以 `,` 分隔。
|
||||||
|
* `--fp8_models`: 以 FP8 格式加载的模型,目前仅支持参数不被梯度更新的模型。
|
||||||
|
* 训练基础配置
|
||||||
|
* `--learning_rate`: 学习率。
|
||||||
|
* `--num_epochs`: 轮数(Epoch)。
|
||||||
|
* `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。
|
||||||
|
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数。
|
||||||
|
* `--weight_decay`: 权重衰减大小。
|
||||||
|
* `--task`: 训练任务,默认为 `sft`。
|
||||||
|
* 输出配置
|
||||||
|
* `--output_path`: 模型保存路径。
|
||||||
|
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
|
||||||
|
* `--save_steps`: 保存模型的训练步数间隔。
|
||||||
|
* LoRA 配置
|
||||||
|
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||||
|
* `--lora_target_modules`: LoRA 添加到哪些层上。
|
||||||
|
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||||
|
* `--lora_checkpoint`: LoRA 检查点的路径。
|
||||||
|
* `--preset_lora_path`: 预置 LoRA 检查点路径,用于 LoRA 差分训练。
|
||||||
|
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。
|
||||||
|
* 梯度配置
|
||||||
|
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
|
||||||
|
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
|
||||||
|
* `--gradient_accumulation_steps`: 梯度累积步数。
|
||||||
|
* 分辨率配置
|
||||||
|
* `--height`: 图像/视频的高度。留空启用动态分辨率。
|
||||||
|
* `--width`: 图像/视频的宽度。留空启用动态分辨率。
|
||||||
|
* `--max_pixels`: 最大像素面积,动态分辨率时大于此值的图片会被缩小。
|
||||||
|
* `--num_frames`: 视频的帧数(仅视频生成模型)。
|
||||||
|
* ACE-Step 专有参数
|
||||||
|
* `--tokenizer_path`: Tokenizer 路径,格式为 model_id:origin_pattern。
|
||||||
|
* `--silence_latent_path`: 静音隐变量路径,格式为 model_id:origin_pattern。
|
||||||
|
* `--initialize_model_on_cpu`: 是否在 CPU 上初始化模型。
|
||||||
|
|
||||||
|
### 样例数据集
|
||||||
|
|
||||||
|
```shell
|
||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。
|
||||||
139
docs/zh/Model_Details/Anima.md
Normal file
139
docs/zh/Model_Details/Anima.md
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
# Anima
|
||||||
|
|
||||||
|
Anima 是由 CircleStone Labs 与 Comfy Org 训练并开源的图像生成模型。
|
||||||
|
|
||||||
|
## 安装
|
||||||
|
|
||||||
|
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": "disk",
|
||||||
|
"offload_device": "disk",
|
||||||
|
"onload_dtype": "disk",
|
||||||
|
"onload_device": "disk",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = AnimaImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
|
||||||
|
tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
prompt = "Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait."
|
||||||
|
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||||
|
image = pipe(prompt, seed=0, num_inference_steps=50)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
## 模型总览
|
||||||
|
|
||||||
|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference_low_vram/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/full/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_full/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/lora/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_lora/anima-preview.py)|
|
||||||
|
|
||||||
|
特殊训练脚本:
|
||||||
|
|
||||||
|
* 差分 LoRA 训练:[doc](../Training/Differential_LoRA.md)
|
||||||
|
* FP8 精度训练:[doc](../Training/FP8_Precision.md)
|
||||||
|
* 两阶段拆分训练:[doc](../Training/Split_Training.md)
|
||||||
|
* 端到端直接蒸馏:[doc](../Training/Direct_Distill.md)
|
||||||
|
|
||||||
|
## 模型推理
|
||||||
|
|
||||||
|
模型通过 `AnimaImagePipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。
|
||||||
|
|
||||||
|
`AnimaImagePipeline` 推理的输入参数包括:
|
||||||
|
|
||||||
|
* `prompt`: 提示词,描述画面中出现的内容。
|
||||||
|
* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`。
|
||||||
|
* `cfg_scale`: Classifier-free guidance 的参数,默认值为 4.0。
|
||||||
|
* `input_image`: 输入图像,用于图像到图像的生成。默认为 `None`。
|
||||||
|
* `denoising_strength`: 去噪强度,控制生成图像与输入图像的相似度,默认值为 1.0。
|
||||||
|
* `height`: 图像高度,需保证高度为 16 的倍数,默认值为 1024。
|
||||||
|
* `width`: 图像宽度,需保证宽度为 16 的倍数,默认值为 1024。
|
||||||
|
* `seed`: 随机种子。默认为 `None`,即完全随机。
|
||||||
|
* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
|
||||||
|
* `num_inference_steps`: 推理次数,默认值为 30。
|
||||||
|
* `sigma_shift`: 调度器的 sigma 偏移量,默认为 `None`。
|
||||||
|
* `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。
|
||||||
|
|
||||||
|
如果显存不足,请开启[显存管理](../Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。
|
||||||
|
|
||||||
|
## 模型训练
|
||||||
|
|
||||||
|
Anima 系列模型统一通过 [`examples/anima/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/train.py) 进行训练,脚本的参数包括:
|
||||||
|
|
||||||
|
* 通用训练参数
|
||||||
|
* 数据集基础配置
|
||||||
|
* `--dataset_base_path`: 数据集的根目录。
|
||||||
|
* `--dataset_metadata_path`: 数据集的元数据文件路径。
|
||||||
|
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||||
|
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
|
||||||
|
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
|
||||||
|
* 模型加载配置
|
||||||
|
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||||
|
* `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 `"anima-team/anima-1B:text_encoder/*.safetensors"`。用逗号分隔。
|
||||||
|
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,例如训练 ControlNet 模型时需要额外参数 `controlnet_inputs`,以 `,` 分隔。
|
||||||
|
* `--fp8_models`:以 FP8 格式加载的模型,格式与 `--model_paths` 或 `--model_id_with_origin_paths` 一致,目前仅支持参数不被梯度更新的模型(不需要梯度回传,或梯度仅更新其 LoRA)。
|
||||||
|
* 训练基础配置
|
||||||
|
* `--learning_rate`: 学习率。
|
||||||
|
* `--num_epochs`: 轮数(Epoch)。
|
||||||
|
* `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。
|
||||||
|
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数,少数模型包含不参与梯度计算的冗余参数,需开启这一设置避免在多 GPU 训练中报错。
|
||||||
|
* `--weight_decay`:权重衰减大小,详见 [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html)。
|
||||||
|
* `--task`: 训练任务,默认为 `sft`,部分模型支持更多训练模式,请参考每个特定模型的文档。
|
||||||
|
* 输出配置
|
||||||
|
* `--output_path`: 模型保存路径。
|
||||||
|
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
|
||||||
|
* `--save_steps`: 保存模型的训练步数间隔,若此参数留空,则每个 epoch 保存一次。
|
||||||
|
* LoRA 配置
|
||||||
|
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||||
|
* `--lora_target_modules`: LoRA 添加到哪些层上。
|
||||||
|
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||||
|
* `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。
|
||||||
|
* `--preset_lora_path`: 预置 LoRA 检查点路径,如果提供此路径,这一 LoRA 将会以融入基础模型的形式加载。此参数用于 LoRA 差分训练。
|
||||||
|
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。
|
||||||
|
* 梯度配置
|
||||||
|
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
|
||||||
|
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
|
||||||
|
* `--gradient_accumulation_steps`: 梯度累积步数。
|
||||||
|
* 图像宽高配置(适用于图像生成模型和视频生成模型)
|
||||||
|
* `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。
|
||||||
|
* `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。
|
||||||
|
* `--max_pixels`: 图像或视频帧的最大像素面积,当启用动态分辨率时,分辨率大于这个数值的图片都会被缩小,分辨率小于这个数值的图片保持不变。
|
||||||
|
* Anima 专有参数
|
||||||
|
* `--tokenizer_path`: tokenizer 的路径,适用于文生图模型,留空则自动从远程下载。
|
||||||
|
* `--tokenizer_t5xxl_path`: T5-XXL tokenizer 的路径,适用于文生图模型,留空则自动从远程下载。
|
||||||
|
|
||||||
|
我们构建了一个样例图像数据集,以方便您进行测试,通过以下命令可以下载这个数据集:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。
|
||||||
134
docs/zh/Model_Details/ERNIE-Image.md
Normal file
134
docs/zh/Model_Details/ERNIE-Image.md
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
# ERNIE-Image
|
||||||
|
|
||||||
|
ERNIE-Image 是百度推出的拥有 8B 参数的图像生成模型,具有紧凑高效的架构和出色的指令跟随能力。基于 8B DiT 主干网络,其在某些场景下的性能可与 20B 以上的更大模型相媲美,同时保持了良好的参数效率。该模型在指令理解与执行、文本生成(如英文/中文/日文)以及整体稳定性方面提供了较为可靠的表现。
|
||||||
|
|
||||||
|
## 安装
|
||||||
|
|
||||||
|
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = ErnieImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device='cuda',
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="一只黑白相间的中华田园犬",
|
||||||
|
negative_prompt="",
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
seed=42,
|
||||||
|
num_inference_steps=50,
|
||||||
|
cfg_scale=4.0,
|
||||||
|
)
|
||||||
|
image.save("output.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
## 模型总览
|
||||||
|
|
||||||
|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
|
||||||
|
|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
|
||||||
|
|
||||||
|
## 模型推理
|
||||||
|
|
||||||
|
模型通过 `ErnieImagePipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。
|
||||||
|
|
||||||
|
`ErnieImagePipeline` 推理的输入参数包括:
|
||||||
|
|
||||||
|
* `prompt`: 提示词,描述画面中出现的内容。
|
||||||
|
* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`。
|
||||||
|
* `cfg_scale`: Classifier-free guidance 的参数,默认值为 4.0。
|
||||||
|
* `height`: 图像高度,需保证高度为 16 的倍数,默认值为 1024。
|
||||||
|
* `width`: 图像宽度,需保证宽度为 16 的倍数,默认值为 1024。
|
||||||
|
* `seed`: 随机种子。默认为 `None`,即完全随机。
|
||||||
|
* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cuda"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
|
||||||
|
* `num_inference_steps`: 推理步数,默认值为 50。
|
||||||
|
|
||||||
|
如果显存不足,请开启[显存管理](../Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。
|
||||||
|
|
||||||
|
## 模型训练
|
||||||
|
|
||||||
|
ERNIE-Image 系列模型统一通过 [`examples/ernie_image/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/train.py) 进行训练,脚本的参数包括:
|
||||||
|
|
||||||
|
* 通用训练参数
|
||||||
|
* 数据集基础配置
|
||||||
|
* `--dataset_base_path`: 数据集的根目录。
|
||||||
|
* `--dataset_metadata_path`: 数据集的元数据文件路径。
|
||||||
|
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||||
|
* `--dataset_num_workers`: 每个 Dataloader 的进程数量。
|
||||||
|
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
|
||||||
|
* 模型加载配置
|
||||||
|
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||||
|
* `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 `"PaddlePaddle/ERNIE-Image:transformer/diffusion_pytorch_model*.safetensors"`。用逗号分隔。
|
||||||
|
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,以 `,` 分隔。
|
||||||
|
* `--fp8_models`:以 FP8 格式加载的模型,目前仅支持参数不被梯度更新的模型。
|
||||||
|
* 训练基础配置
|
||||||
|
* `--learning_rate`: 学习率。
|
||||||
|
* `--num_epochs`: 轮数(Epoch)。
|
||||||
|
* `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。
|
||||||
|
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数。
|
||||||
|
* `--weight_decay`:权重衰减大小。
|
||||||
|
* `--task`: 训练任务,默认为 `sft`。
|
||||||
|
* 输出配置
|
||||||
|
* `--output_path`: 模型保存路径。
|
||||||
|
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
|
||||||
|
* `--save_steps`: 保存模型的训练步数间隔。
|
||||||
|
* LoRA 配置
|
||||||
|
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||||
|
* `--lora_target_modules`: LoRA 添加到哪些层上。
|
||||||
|
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||||
|
* `--lora_checkpoint`: LoRA 检查点的路径。
|
||||||
|
* `--preset_lora_path`: 预置 LoRA 检查点路径,用于 LoRA 差分训练。
|
||||||
|
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。
|
||||||
|
* 梯度配置
|
||||||
|
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
|
||||||
|
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
|
||||||
|
* `--gradient_accumulation_steps`: 梯度累积步数。
|
||||||
|
* 分辨率配置
|
||||||
|
* `--height`: 图像的高度。留空启用动态分辨率。
|
||||||
|
* `--width`: 图像的宽度。留空启用动态分辨率。
|
||||||
|
* `--max_pixels`: 最大像素面积,动态分辨率时大于此值的图片会被缩小。
|
||||||
|
* ERNIE-Image 专有参数
|
||||||
|
* `--tokenizer_path`: tokenizer 的路径,留空则自动从远程下载。
|
||||||
|
|
||||||
|
我们构建了一个样例图像数据集,以方便您进行测试,通过以下命令可以下载这个数据集:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。
|
||||||
@@ -81,27 +81,27 @@ graph LR;
|
|||||||
|
|
||||||
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|-|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|-|
|
||||||
|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py)|
|
|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev.py)|
|
||||||
|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|
|
|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|
|
||||||
|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|
|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|
||||||
|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
|
|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
|
||||||
|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
|
|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
|
||||||
|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
|
|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
|
||||||
|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
|
|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
|
||||||
|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
|
|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
|
||||||
|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|
|
|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|
|
||||||
|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
|
|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
|
||||||
|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
|
|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
|
||||||
|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](/examples/flux/model_inference/Step1X-Edit.py)|[code](/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](/examples/flux/model_training/full/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_lora/Step1X-Edit.py)|
|
|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Step1X-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Step1X-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Step1X-Edit.py)|
|
||||||
|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](/examples/flux/model_inference/FLEX.2-preview.py)|[code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py)|
|
|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLEX.2-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLEX.2-preview.py)|
|
||||||
|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](/examples/flux/model_training/full/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_lora/Nexus-Gen.py)|
|
|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Nexus-Gen.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Nexus-Gen.py)|
|
||||||
|
|
||||||
特殊训练脚本:
|
特殊训练脚本:
|
||||||
|
|
||||||
* 差分 LoRA 训练:[doc](../Training/Differential_LoRA.md)、[code](/examples/flux/model_training/special/differential_training/)
|
* 差分 LoRA 训练:[doc](../Training/Differential_LoRA.md)
|
||||||
* FP8 精度训练:[doc](../Training/FP8_Precision.md)、[code](/examples/flux/model_training/special/fp8_training/)
|
* FP8 精度训练:[doc](../Training/FP8_Precision.md)
|
||||||
* 两阶段拆分训练:[doc](../Training/Split_Training.md)、[code](/examples/flux/model_training/special/split_training/)
|
* 两阶段拆分训练:[doc](../Training/Split_Training.md)
|
||||||
* 端到端直接蒸馏:[doc](../Training/Direct_Distill.md)、[code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh)
|
* 端到端直接蒸馏:[doc](../Training/Direct_Distill.md)
|
||||||
|
|
||||||
## 模型推理
|
## 模型推理
|
||||||
|
|
||||||
@@ -147,7 +147,7 @@ graph LR;
|
|||||||
|
|
||||||
## 模型训练
|
## 模型训练
|
||||||
|
|
||||||
FLUX 系列模型统一通过 [`examples/flux/model_training/train.py`](/examples/flux/model_training/train.py) 进行训练,脚本的参数包括:
|
FLUX 系列模型统一通过 [`examples/flux/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/train.py) 进行训练,脚本的参数包括:
|
||||||
|
|
||||||
* 通用训练参数
|
* 通用训练参数
|
||||||
* 数据集基础配置
|
* 数据集基础配置
|
||||||
@@ -195,7 +195,7 @@ FLUX 系列模型统一通过 [`examples/flux/model_training/train.py`](/example
|
|||||||
我们构建了一个样例图像数据集,以方便您进行测试,通过以下命令可以下载这个数据集:
|
我们构建了一个样例图像数据集,以方便您进行测试,通过以下命令可以下载这个数据集:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
|
||||||
```
|
```
|
||||||
|
|
||||||
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。
|
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user