diff --git a/README.md b/README.md index 4f3bb97..415eec0 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,8 @@ ## Introduction +> DiffSynth-Studio Documentation: [中文版](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/)、[English version](https://diffsynth-studio-doc.readthedocs.io/en/latest/) + Welcome to the magical world of Diffusion models! DiffSynth-Studio is an open-source Diffusion model engine developed and maintained by the [ModelScope Community](https://www.modelscope.cn/). We hope to foster technological innovation through framework construction, aggregate the power of the open-source community, and explore the boundaries of generative model technology! DiffSynth currently includes two open-source projects: @@ -23,8 +25,6 @@ DiffSynth currently includes two open-source projects: * ModelScope AIGC Zone (for Chinese users): https://modelscope.cn/aigc/home * ModelScope Civision (for global users): https://modelscope.ai/civision/home -> DiffSynth-Studio Documentation: [中文版](/docs/zh/README.md)、[English version](/docs/en/README.md) - We believe that a well-developed open-source code framework can lower the threshold for technical exploration. We have achieved many [interesting technologies](#innovative-achievements) based on this codebase. Perhaps you also have many wild ideas, and with DiffSynth-Studio, you can quickly realize these ideas. For this reason, we have prepared detailed documentation for developers. We hope that through these documents, developers can understand the principles of Diffusion models, and we look forward to expanding the boundaries of technology together with you. ## Update History @@ -32,6 +32,11 @@ We believe that a well-developed open-source code framework can lower the thresh > DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update. > Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand. +- **February 10, 2026** Added inference support for the LTX-2 audio-video generation model. See the documentation for details. Support for model training will be implemented in the future. + +- **February 2, 2026** The first document of the Research Tutorial series is now available, guiding you through training a small 0.1B text-to-image model from scratch. For details, see the [documentation](/docs/en/Research_Tutorial/train_from_scratch.md) and [model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel). We hope DiffSynth-Studio can evolve into a more powerful training framework for Diffusion models. + +- **January 27, 2026**: [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) is released, and our [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) model is released concurrently. You can use it in [ModelScope Studios](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L). For details, see the [documentation](/docs/zh/Model_Details/Z-Image.md). - **January 19, 2026**: Added support for [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) and [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/FLUX2.md) and [example code](/examples/flux2/) are now available. @@ -269,9 +274,14 @@ image.save("image.jpg") Example code for Z-Image is available at: [/examples/z_image/](/examples/z_image/) -| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | +|Model ID|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training| |-|-|-|-|-|-|-| +|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)| +|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-| |[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)| +|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)| +|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)| +|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)| @@ -410,6 +420,7 @@ Example code for Qwen-Image is available at: [/examples/qwen_image/](/examples/q |[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)| |[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)| |[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)| +|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-| |[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)| |[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)| |[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| @@ -522,6 +533,102 @@ Example code for FLUX.1 is available at: [/examples/flux/](/examples/flux/) https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314 +#### LTX-2: [/docs/en/Model_Details/LTX-2.md](/docs/en/Model_Details/LTX-2.md) + +
+ +Quick Start + +Running the following code will quickly load the [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8GB of VRAM. + +```python +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.float8_e5m2, + "offload_device": "cpu", + "onload_dtype": torch.float8_e5m2, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e5m2, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\"" +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_two_stage_pipeline=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_twostage.mp4', + fps=24, + audio_sample_rate=24000, +) +``` + +
+ +
+ +Examples + +Example code for LTX-2 is available at: [/examples/ltx2/](/examples/ltx2/) + +| Model ID | Extra Args | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | +|-|-|-|-|-|-|-|-| +|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-| +|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-| +|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-| +|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-| +|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-| +|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Static](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Static)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py)|-|-|-|-| + +
+ #### Wan: [/docs/en/Model_Details/Wan.md](/docs/en/Model_Details/Wan.md)
@@ -661,6 +768,37 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/) DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements. +
+ +Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation + +- Paper: [Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation +](https://arxiv.org/abs/2602.03208) +- Sample Code: coming soon + +|FLUX.1-dev|FLUX.1-dev + SES|Qwen-Image|Qwen-Image + SES| +|-|-|-|-| +|![Image](https://github.com/user-attachments/assets/5be15dc6-2805-4822-b04c-2573fc0f45f0)|![Image](https://github.com/user-attachments/assets/e71b8c20-1629-41d9-b0ff-185805c1da4e)|![Image](https://github.com/user-attachments/assets/7a73c968-133a-4545-9aa2-205533861cd4)|![Image](https://github.com/user-attachments/assets/c8390b22-14fe-48a0-a6e6-d6556d31235e)| + +
+ + +
+ +VIRAL: Visual In-Context Reasoning via Analogy in Diffusion Transformers + +- Paper: [VIRAL: Visual In-Context Reasoning via Analogy in Diffusion Transformers +](https://arxiv.org/abs/2602.03210) +- Sample code: [/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py) +- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA) + +|Example 1|Example 2|Query|Output| +|-|-|-|-| +|![Image](https://github.com/user-attachments/assets/380d2670-47bf-41cd-b5c9-37110cc4a943)|![Image](https://github.com/user-attachments/assets/7ceaf345-0992-46e6-b38f-394c2065b165)|![Image](https://github.com/user-attachments/assets/f7c26c21-6894-4d9e-b570-f1d44ca7c1de)|![Image](https://github.com/user-attachments/assets/c2bebe3b-5984-41ba-94bf-9509f6a8a990)| + +
+ +
AttriCtrl: Attribute Intensity Control for Image Generation Models @@ -671,7 +809,7 @@ DiffSynth-Studio is not just an engineered model framework, but also an incubato |brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9| |-|-|-|-|-| -|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.5.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.7.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.9.jpg)| +|![Image](https://github.com/user-attachments/assets/e74b32a5-5b2e-4c87-9df8-487c0f8366b7)|![Image](https://github.com/user-attachments/assets/bfe8bec2-9e55-493d-9a26-7e9cce28e03d)|![Image](https://github.com/user-attachments/assets/b099dfe3-ff1f-4b96-894c-d48bbe92db7a)|![Image](https://github.com/user-attachments/assets/0a6b2982-deab-4b0d-91ad-888782de01c9)|![Image](https://github.com/user-attachments/assets/fcecb755-7d03-4020-b83a-13ad2b38705c)|
@@ -686,10 +824,10 @@ DiffSynth-Studio is not just an engineered model framework, but also an incubato ||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)| |-|-|-|-|-| -|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_0.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)| -|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)| -|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)| -|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_3_3.jpg)| +|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |![Image](https://github.com/user-attachments/assets/01c54d5a-4f00-4c2e-982a-4ec0a4c6a6e3)|![Image](https://github.com/user-attachments/assets/e6621457-b9f1-437c-bcc8-3e12e41646de)|![Image](https://github.com/user-attachments/assets/4b7f721f-a2e5-416c-af2c-b53ef236c321)|![Image](https://github.com/user-attachments/assets/802d554e-0402-482c-9f28-87605f8fe318)| +|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |![Image](https://github.com/user-attachments/assets/e6621457-b9f1-437c-bcc8-3e12e41646de)|![Image](https://github.com/user-attachments/assets/43720a9f-aa27-4918-947d-545389375d46)|![Image](https://github.com/user-attachments/assets/418c725b-6d35-41f4-b18f-c7e3867cc142)|![Image](https://github.com/user-attachments/assets/8c8f22fa-9643-4019-b6d7-396d8b7fed9a)| +|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |![Image](https://github.com/user-attachments/assets/4b7f721f-a2e5-416c-af2c-b53ef236c321)|![Image](https://github.com/user-attachments/assets/418c725b-6d35-41f4-b18f-c7e3867cc142)|![Image](https://github.com/user-attachments/assets/041a3f9a-c7b4-4311-8582-cb71a7226d80)|![Image](https://github.com/user-attachments/assets/b54ebaa4-31a7-4536-a2c1-496adba0c013)| +|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |![Image](https://github.com/user-attachments/assets/802d554e-0402-482c-9f28-87605f8fe318)|![Image](https://github.com/user-attachments/assets/8c8f22fa-9643-4019-b6d7-396d8b7fed9a)|![Image](https://github.com/user-attachments/assets/b54ebaa4-31a7-4536-a2c1-496adba0c013)|![Image](https://github.com/user-attachments/assets/a640fd54-3192-49a0-9281-b43d9ba64f09)|
diff --git a/README_zh.md b/README_zh.md index a464dab..4639997 100644 --- a/README_zh.md +++ b/README_zh.md @@ -12,6 +12,8 @@ ## 简介 +> DiffSynth-Studio 文档:[中文版](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/)、[English version](https://diffsynth-studio-doc.readthedocs.io/en/latest/) + 欢迎来到 Diffusion 模型的魔法世界!DiffSynth-Studio 是由[魔搭社区](https://www.modelscope.cn/)团队开发和维护的开源 Diffusion 模型引擎。我们期望以框架建设孵化技术创新,凝聚开源社区的力量,探索生成式模型技术的边界! DiffSynth 目前包括两个开源项目: @@ -23,8 +25,6 @@ DiffSynth 目前包括两个开源项目: * 魔搭社区 AIGC 专区 (面向中国用户): https://modelscope.cn/aigc/home * ModelScope Civision (for global users): https://modelscope.ai/civision/home -> DiffSynth-Studio 文档:[中文版](/docs/zh/README.md)、[English version](/docs/en/README.md) - 我们相信,一个完善的开源代码框架能够降低技术探索的门槛,我们基于这个代码库搞出了不少[有意思的技术](#创新成果)。或许你也有许多天马行空的构想,借助 DiffSynth-Studio,你可以快速实现这些想法。为此,我们为开发者准备了详细的文档,我们希望通过这些文档,帮助开发者理解 Diffusion 模型的原理,更期待与你一同拓展技术的边界。 ## 更新历史 @@ -32,6 +32,11 @@ DiffSynth 目前包括两个开源项目: > DiffSynth-Studio 经历了大版本更新,部分旧功能已停止维护,如需使用旧版功能,请切换到大版本更新前的[最后一个历史版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3)。 > 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。 +- **2026年2月10日** 新增对[LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2)音视频生成模型的推理支持,详见[文档](docs/zh/Model_Details/LTX-2.md),后续将推进模型训练的支持。 + +- **2026年2月2日** Research Tutorial 的第一篇文档上线,带你从零开始训练一个 0.1B 的小型文生图模型,详见[文档](/docs/zh/Research_Tutorial/train_from_scratch.md)、[模型](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel),我们希望 DiffSynth-Studio 能够成为一个更强大的 Diffusion 模型训练框架。 + +- **2026年1月27日** [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) 发布,我们的 [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) 模型同步发布,在[魔搭创空间](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L)可直接体验,详见[文档](/docs/zh/Model_Details/Z-Image.md)。 - **2026年1月19日** 新增对 [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 和 [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/FLUX2.md)和[示例代码](/examples/flux2/)现已可用。 @@ -271,7 +276,12 @@ Z-Image 的示例代码位于:[/examples/z_image/](/examples/z_image/) |模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-| +|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)| +|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-| |[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)| +|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)| +|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)| +|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)| @@ -410,6 +420,7 @@ Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/ |[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)| |[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)| |[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)| +|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-| |[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)| |[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)| |[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| @@ -522,6 +533,102 @@ FLUX.1 的示例代码位于:[/examples/flux/](/examples/flux/) https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314 +#### LTX-2: [/docs/zh/Model_Details/LTX-2.md](/docs/zh/Model_Details/LTX-2.md) + +
+ +快速开始 + +运行以下代码可以快速加载 [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8GB 显存即可运行。 + +```python +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.float8_e5m2, + "offload_device": "cpu", + "onload_dtype": torch.float8_e5m2, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e5m2, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\"" +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_two_stage_pipeline=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_twostage.mp4', + fps=24, + audio_sample_rate=24000, +) +``` + +
+ +
+ +示例代码 + +LTX-2 的示例代码位于:[/examples/ltx2/](/examples/ltx2/) + +|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-|-| +|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-| +|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-| +|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-| +|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-| +|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-| +|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Static](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Static)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py)|-|-|-|-| + +
+ #### Wan: [/docs/zh/Model_Details/Wan.md](/docs/zh/Model_Details/Wan.md)
@@ -661,6 +768,37 @@ Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/) DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。 +
+ +Spectral Evolution Search: 用于奖励对齐图像生成的高效推理阶段缩放 + +- 论文:[Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation +](https://arxiv.org/abs/2602.03208) +- 代码样例:coming soon + +|FLUX.1-dev|FLUX.1-dev + SES|Qwen-Image|Qwen-Image + SES| +|-|-|-|-| +|![Image](https://github.com/user-attachments/assets/5be15dc6-2805-4822-b04c-2573fc0f45f0)|![Image](https://github.com/user-attachments/assets/e71b8c20-1629-41d9-b0ff-185805c1da4e)|![Image](https://github.com/user-attachments/assets/7a73c968-133a-4545-9aa2-205533861cd4)|![Image](https://github.com/user-attachments/assets/c8390b22-14fe-48a0-a6e6-d6556d31235e)| + +
+ + +
+ +VIRAL:基于DiT模型的类比视觉上下文推理 + +- 论文:[VIRAL: Visual In-Context Reasoning via Analogy in Diffusion Transformers +](https://arxiv.org/abs/2602.03210) +- 代码样例:[/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py) +- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA) + +|Example 1|Example 2|Query|Output| +|-|-|-|-| +|![Image](https://github.com/user-attachments/assets/380d2670-47bf-41cd-b5c9-37110cc4a943)|![Image](https://github.com/user-attachments/assets/7ceaf345-0992-46e6-b38f-394c2065b165)|![Image](https://github.com/user-attachments/assets/f7c26c21-6894-4d9e-b570-f1d44ca7c1de)|![Image](https://github.com/user-attachments/assets/c2bebe3b-5984-41ba-94bf-9509f6a8a990)| + +
+ +
AttriCtrl: 图像生成模型的属性强度控制 @@ -672,7 +810,7 @@ DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果 |brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9| |-|-|-|-|-| -|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.5.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.7.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.9.jpg)| +|![Image](https://github.com/user-attachments/assets/e74b32a5-5b2e-4c87-9df8-487c0f8366b7)|![Image](https://github.com/user-attachments/assets/bfe8bec2-9e55-493d-9a26-7e9cce28e03d)|![Image](https://github.com/user-attachments/assets/b099dfe3-ff1f-4b96-894c-d48bbe92db7a)|![Image](https://github.com/user-attachments/assets/0a6b2982-deab-4b0d-91ad-888782de01c9)|![Image](https://github.com/user-attachments/assets/fcecb755-7d03-4020-b83a-13ad2b38705c)|
@@ -688,10 +826,10 @@ DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果 ||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)| |-|-|-|-|-| -|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_0.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)| -|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)| -|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)| -|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_3_3.jpg)| +|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |![Image](https://github.com/user-attachments/assets/01c54d5a-4f00-4c2e-982a-4ec0a4c6a6e3)|![Image](https://github.com/user-attachments/assets/e6621457-b9f1-437c-bcc8-3e12e41646de)|![Image](https://github.com/user-attachments/assets/4b7f721f-a2e5-416c-af2c-b53ef236c321)|![Image](https://github.com/user-attachments/assets/802d554e-0402-482c-9f28-87605f8fe318)| +|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |![Image](https://github.com/user-attachments/assets/e6621457-b9f1-437c-bcc8-3e12e41646de)|![Image](https://github.com/user-attachments/assets/43720a9f-aa27-4918-947d-545389375d46)|![Image](https://github.com/user-attachments/assets/418c725b-6d35-41f4-b18f-c7e3867cc142)|![Image](https://github.com/user-attachments/assets/8c8f22fa-9643-4019-b6d7-396d8b7fed9a)| +|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |![Image](https://github.com/user-attachments/assets/4b7f721f-a2e5-416c-af2c-b53ef236c321)|![Image](https://github.com/user-attachments/assets/418c725b-6d35-41f4-b18f-c7e3867cc142)|![Image](https://github.com/user-attachments/assets/041a3f9a-c7b4-4311-8582-cb71a7226d80)|![Image](https://github.com/user-attachments/assets/b54ebaa4-31a7-4536-a2c1-496adba0c013)| +|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |![Image](https://github.com/user-attachments/assets/802d554e-0402-482c-9f28-87605f8fe318)|![Image](https://github.com/user-attachments/assets/8c8f22fa-9643-4019-b6d7-396d8b7fed9a)|![Image](https://github.com/user-attachments/assets/b54ebaa4-31a7-4536-a2c1-496adba0c013)|![Image](https://github.com/user-attachments/assets/a640fd54-3192-49a0-9281-b43d9ba64f09)|
diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index c93f5e9..9ff7ea6 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -589,6 +589,78 @@ z_image_series = [ "model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel", "extra_kwargs": {"compress_dim": 128}, }, + { + # Example: ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors") + "model_hash": "1392adecee344136041e70553f875f31", + "model_name": "z_image_text_encoder", + "model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder", + "extra_kwargs": {"model_size": "0.6B"}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter", + }, ] -MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series +ltx2_series = [ + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "aca7b0bbf8415e9c98360750268915fc", + "model_name": "ltx2_dit", + "model_class": "diffsynth.models.ltx2_dit.LTXModel", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "aca7b0bbf8415e9c98360750268915fc", + "model_name": "ltx2_video_vae_encoder", + "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "aca7b0bbf8415e9c98360750268915fc", + "model_name": "ltx2_video_vae_decoder", + "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "aca7b0bbf8415e9c98360750268915fc", + "model_name": "ltx2_audio_vae_decoder", + "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "aca7b0bbf8415e9c98360750268915fc", + "model_name": "ltx2_audio_vocoder", + "model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter", + }, + # { # not used currently + # # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + # "model_hash": "aca7b0bbf8415e9c98360750268915fc", + # "model_name": "ltx2_audio_vae_encoder", + # "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder", + # "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter", + # }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "aca7b0bbf8415e9c98360750268915fc", + "model_name": "ltx2_text_encoder_post_modules", + "model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter", + }, + { + # Example: ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors") + "model_hash": "33917f31c4a79196171154cca39f165e", + "model_name": "ltx2_text_encoder", + "model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "c79c458c6e99e0e14d47e676761732d2", + "model_name": "ltx2_latent_upsampler", + "model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler", + }, +] +MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series diff --git a/diffsynth/configs/vram_management_module_maps.py b/diffsynth/configs/vram_management_module_maps.py index a1813fb..0f360ef 100644 --- a/diffsynth/configs/vram_management_module_maps.py +++ b/diffsynth/configs/vram_management_module_maps.py @@ -210,4 +210,37 @@ VRAM_MANAGEMENT_MODULE_MAPS = { "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", }, + "diffsynth.models.ltx2_dit.LTXModel": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler": { + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder": { + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder": { + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder": { + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_audio_vae.LTX2Vocoder": { + "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.ltx2_text_encoder.Embeddings1DConnector": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_text_encoder.LTX2TextEncoder": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "transformers.models.gemma3.modeling_gemma3.Gemma3MultiModalProjector": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule", + }, } diff --git a/diffsynth/core/attention/attention.py b/diffsynth/core/attention/attention.py index 15b55a4..630d375 100644 --- a/diffsynth/core/attention/attention.py +++ b/diffsynth/core/attention/attention.py @@ -52,7 +52,7 @@ def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern=" if k_pattern != required_in_pattern: k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims) if v_pattern != required_in_pattern: - v = rearrange(v, f"{q_pattern} -> {required_in_pattern}", **dims) + v = rearrange(v, f"{v_pattern} -> {required_in_pattern}", **dims) return q, k, v diff --git a/diffsynth/core/loader/config.py b/diffsynth/core/loader/config.py index 88b46a0..d4ce83c 100644 --- a/diffsynth/core/loader/config.py +++ b/diffsynth/core/loader/config.py @@ -1,5 +1,5 @@ import torch, glob, os -from typing import Optional, Union +from typing import Optional, Union, Dict from dataclasses import dataclass from modelscope import snapshot_download from huggingface_hub import snapshot_download as hf_snapshot_download @@ -23,13 +23,14 @@ class ModelConfig: computation_device: Optional[Union[str, torch.device]] = None computation_dtype: Optional[torch.dtype] = None clear_parameters: bool = False + state_dict: Dict[str, torch.Tensor] = None def check_input(self): if self.path is None and self.model_id is None: raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""") def parse_original_file_pattern(self): - if self.origin_file_pattern is None or self.origin_file_pattern == "": + if self.origin_file_pattern in [None, "", "./"]: return "*" elif self.origin_file_pattern.endswith("/"): return self.origin_file_pattern + "*" @@ -98,7 +99,7 @@ class ModelConfig: if self.require_downloading(): self.download() if self.path is None: - if self.origin_file_pattern is None or self.origin_file_pattern == "": + if self.origin_file_pattern in [None, "", "./"]: self.path = os.path.join(self.local_model_path, self.model_id) else: self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern)) diff --git a/diffsynth/core/loader/file.py b/diffsynth/core/loader/file.py index 8f66961..67d8815 100644 --- a/diffsynth/core/loader/file.py +++ b/diffsynth/core/loader/file.py @@ -2,16 +2,25 @@ from safetensors import safe_open import torch, hashlib -def load_state_dict(file_path, torch_dtype=None, device="cpu"): +def load_state_dict(file_path, torch_dtype=None, device="cpu", pin_memory=False, verbose=0): if isinstance(file_path, list): state_dict = {} for file_path_ in file_path: - state_dict.update(load_state_dict(file_path_, torch_dtype, device)) - return state_dict - if file_path.endswith(".safetensors"): - return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device) + state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose)) else: - return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device) + if verbose >= 1: + print(f"Loading file [started]: {file_path}") + if file_path.endswith(".safetensors"): + state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device) + else: + state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device) + # If load state dict in CPU memory, `pin_memory=True` will make `model.to("cuda")` faster. + if pin_memory: + for i in state_dict: + state_dict[i] = state_dict[i].pin_memory() + if verbose >= 1: + print(f"Loading file [done]: {file_path}") + return state_dict def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"): diff --git a/diffsynth/core/loader/model.py b/diffsynth/core/loader/model.py index 56fa7d3..c244cd0 100644 --- a/diffsynth/core/loader/model.py +++ b/diffsynth/core/loader/model.py @@ -3,14 +3,14 @@ from ..vram.disk_map import DiskMap from ..vram.layers import enable_vram_management from .file import load_state_dict import torch +from contextlib import contextmanager +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.utils import ContextManagers -def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None): +def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None): config = {} if config is None else config - # Why do we use `skip_model_initialization`? - # It skips the random initialization of model parameters, - # thereby speeding up model loading and avoiding excessive memory usage. - with skip_model_initialization(): + with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)): model = model_class(**config) # What is `module_map`? # This is a module mapping table for VRAM management. @@ -20,7 +20,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]] dtype = [d for d in dtypes if d != "disk"][0] if vram_config["offload_device"] != "disk": - state_dict = DiskMap(path, device, torch_dtype=dtype) + if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype) if state_dict_converter is not None: state_dict = state_dict_converter(state_dict) else: @@ -35,7 +35,9 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic # Sometimes a model file contains multiple models, # and DiskMap can load only the parameters of a single model, # avoiding the need to load all parameters in the file. - if use_disk_map: + if state_dict is not None: + pass + elif use_disk_map: state_dict = DiskMap(path, device, torch_dtype=torch_dtype) else: state_dict = load_state_dict(path, torch_dtype, device) @@ -46,7 +48,14 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic state_dict = state_dict_converter(state_dict) else: state_dict = {i: state_dict[i] for i in state_dict} - model.load_state_dict(state_dict, assign=True) + # Why does DeepSpeed ZeRO Stage 3 need to be handled separately? + # Because at this stage, model parameters are partitioned across multiple GPUs. + # Loading them directly could lead to excessive GPU memory consumption. + if is_deepspeed_zero3_enabled(): + from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model + _load_state_dict_into_zero3_model(model, state_dict) + else: + model.load_state_dict(state_dict, assign=True) # Why do we call `to()`? # Because some models override the behavior of `to()`, # especially those from libraries like Transformers. @@ -77,3 +86,20 @@ def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=tor } enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80) return model + + +def get_init_context(torch_dtype, device): + if is_deepspeed_zero3_enabled(): + from transformers.modeling_utils import set_zero3_state + import deepspeed + # Why do we use "deepspeed.zero.Init"? + # Weight segmentation of the model can be performed on the CPU side + # and loading the segmented weights onto the computing card + init_contexts = [deepspeed.zero.Init(remote_device=device, dtype=torch_dtype), set_zero3_state()] + else: + # Why do we use `skip_model_initialization`? + # It skips the random initialization of model parameters, + # thereby speeding up model loading and avoiding excessive memory usage. + init_contexts = [skip_model_initialization()] + + return init_contexts diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py index d4731fd..7d41cac 100644 --- a/diffsynth/diffusion/base_pipeline.py +++ b/diffsynth/diffusion/base_pipeline.py @@ -296,6 +296,7 @@ class BasePipeline(torch.nn.Module): vram_config=vram_config, vram_limit=vram_limit, clear_parameters=model_config.clear_parameters, + state_dict=model_config.state_dict, ) return model_pool @@ -317,7 +318,14 @@ class BasePipeline(torch.nn.Module): if inputs_shared.get("positive_only_lora", None) is not None: self.clear_lora(verbose=0) noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others) - noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + if isinstance(noise_pred_posi, tuple): + # Separately handling different output types of latents, eg. video and audio latents. + noise_pred = tuple( + n_nega + cfg_scale * (n_posi - n_nega) + for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega) + ) + else: + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: noise_pred = noise_pred_posi return noise_pred diff --git a/diffsynth/diffusion/flow_match.py b/diffsynth/diffusion/flow_match.py index 2d6b367..208fb1e 100644 --- a/diffsynth/diffusion/flow_match.py +++ b/diffsynth/diffusion/flow_match.py @@ -4,13 +4,15 @@ from typing_extensions import Literal class FlowMatchScheduler(): - def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image"] = "FLUX.1"): + def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning"] = "FLUX.1"): self.set_timesteps_fn = { "FLUX.1": FlowMatchScheduler.set_timesteps_flux, "Wan": FlowMatchScheduler.set_timesteps_wan, "Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image, "FLUX.2": FlowMatchScheduler.set_timesteps_flux2, "Z-Image": FlowMatchScheduler.set_timesteps_z_image, + "LTX-2": FlowMatchScheduler.set_timesteps_ltx2, + "Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning, }.get(template, FlowMatchScheduler.set_timesteps_flux) self.num_train_timesteps = 1000 @@ -70,6 +72,28 @@ class FlowMatchScheduler(): timesteps = sigmas * num_train_timesteps return sigmas, timesteps + @staticmethod + def set_timesteps_qwen_image_lightning(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None): + sigma_min = 0.0 + sigma_max = 1.0 + num_train_timesteps = 1000 + base_shift = math.log(3) + max_shift = math.log(3) + # Sigmas + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] + # Mu + if exponential_shift_mu is not None: + mu = exponential_shift_mu + elif dynamic_shift_len is not None: + mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len, base_shift=base_shift, max_shift=max_shift) + else: + mu = 0.8 + sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1)) + # Timesteps + timesteps = sigmas * num_train_timesteps + return sigmas, timesteps + @staticmethod def compute_empirical_mu(image_seq_len, num_steps): a1, b1 = 8.73809524e-05, 1.89833333 @@ -121,7 +145,35 @@ class FlowMatchScheduler(): timestep_id = torch.argmin((timesteps - timestep).abs()) timesteps[timestep_id] = timestep return sigmas, timesteps - + + @staticmethod + def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None): + num_train_timesteps = 1000 + if special_case == "stage2": + sigmas = torch.Tensor([0.909375, 0.725, 0.421875]) + elif special_case == "ditilled_stage1": + sigmas = torch.Tensor([1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875]) + else: + dynamic_shift_len = dynamic_shift_len or 4096 + sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image( + image_seq_len=dynamic_shift_len, + base_seq_len=1024, + max_seq_len=4096, + base_shift=0.95, + max_shift=2.05, + ) + sigma_min = 0.0 + sigma_max = 1.0 + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] + sigmas = math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1)) + # Shift terminal + one_minus_z = 1.0 - sigmas + scale_factor = one_minus_z[-1] / (1 - terminal) + sigmas = 1.0 - (one_minus_z / scale_factor) + timesteps = sigmas * num_train_timesteps + return sigmas, timesteps + def set_training_weight(self): steps = 1000 x = self.timesteps diff --git a/diffsynth/diffusion/logger.py b/diffsynth/diffusion/logger.py index 6d2792f..ab6bdb9 100644 --- a/diffsynth/diffusion/logger.py +++ b/diffsynth/diffusion/logger.py @@ -18,8 +18,8 @@ class ModelLogger: def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id): accelerator.wait_for_everyone() + state_dict = accelerator.get_state_dict(model) if accelerator.is_main_process: - state_dict = accelerator.get_state_dict(model) state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) state_dict = self.state_dict_converter(state_dict) os.makedirs(self.output_path, exist_ok=True) @@ -34,8 +34,8 @@ class ModelLogger: def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name): accelerator.wait_for_everyone() + state_dict = accelerator.get_state_dict(model) if accelerator.is_main_process: - state_dict = accelerator.get_state_dict(model) state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) state_dict = self.state_dict_converter(state_dict) os.makedirs(self.output_path, exist_ok=True) diff --git a/diffsynth/diffusion/loss.py b/diffsynth/diffusion/loss.py index ae44bb6..14fdfd3 100644 --- a/diffsynth/diffusion/loss.py +++ b/diffsynth/diffusion/loss.py @@ -13,9 +13,16 @@ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs): inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep) training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep) + if "first_frame_latents" in inputs: + inputs["latents"][:, :, 0:1] = inputs["first_frame_latents"] + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep) + if "first_frame_latents" in inputs: + noise_pred = noise_pred[:, :, 1:] + training_target = training_target[:, :, 1:] + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) loss = loss * pipe.scheduler.training_weight(timestep) return loss diff --git a/diffsynth/diffusion/runner.py b/diffsynth/diffusion/runner.py index f6e2263..6e26035 100644 --- a/diffsynth/diffusion/runner.py +++ b/diffsynth/diffusion/runner.py @@ -27,7 +27,7 @@ def launch_training_task( optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) - + model.to(device=accelerator.device) model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) for epoch_id in range(num_epochs): @@ -59,6 +59,7 @@ def launch_data_process_task( num_workers = args.dataset_num_workers dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers) + model.to(device=accelerator.device) model, dataloader = accelerator.prepare(model, dataloader) for data_id, data in enumerate(tqdm(dataloader)): diff --git a/diffsynth/models/flux2_dit.py b/diffsynth/models/flux2_dit.py index 316cf08..a1bd02a 100644 --- a/diffsynth/models/flux2_dit.py +++ b/diffsynth/models/flux2_dit.py @@ -407,6 +407,7 @@ class Flux2AttnProcessor: query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype) hidden_states = attention_forward( query, key, @@ -536,6 +537,7 @@ class Flux2ParallelSelfAttnProcessor: query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype) hidden_states = attention_forward( query, key, diff --git a/diffsynth/models/ltx2_audio_vae.py b/diffsynth/models/ltx2_audio_vae.py new file mode 100644 index 0000000..708ded7 --- /dev/null +++ b/diffsynth/models/ltx2_audio_vae.py @@ -0,0 +1,1351 @@ +from typing import Set, Tuple, Optional, List +from enum import Enum +import math +import einops +import torch +import torch.nn as nn +import torch.nn.functional as F +from .ltx2_common import VideoLatentShape, AudioLatentShape, Patchifier, NormType, build_normalization_layer + +class AudioPatchifier(Patchifier): + def __init__( + self, + patch_size: int, + sample_rate: int = 16000, + hop_length: int = 160, + audio_latent_downsample_factor: int = 4, + is_causal: bool = True, + shift: int = 0, + ): + """ + Patchifier tailored for spectrogram/audio latents. + Args: + patch_size: Number of mel bins combined into a single patch. This + controls the resolution along the frequency axis. + sample_rate: Original waveform sampling rate. Used to map latent + indices back to seconds so downstream consumers can align audio + and video cues. + hop_length: Window hop length used for the spectrogram. Determines + how many real-time samples separate two consecutive latent frames. + audio_latent_downsample_factor: Ratio between spectrogram frames and + latent frames; compensates for additional downsampling inside the + VAE encoder. + is_causal: When True, timing is shifted to account for causal + receptive fields so timestamps do not peek into the future. + shift: Integer offset applied to the latent indices. Enables + constructing overlapping windows from the same latent sequence. + """ + self.hop_length = hop_length + self.sample_rate = sample_rate + self.audio_latent_downsample_factor = audio_latent_downsample_factor + self.is_causal = is_causal + self.shift = shift + self._patch_size = (1, patch_size, patch_size) + + @property + def patch_size(self) -> Tuple[int, int, int]: + return self._patch_size + + def get_token_count(self, tgt_shape: AudioLatentShape) -> int: + return tgt_shape.frames + + def _get_audio_latent_time_in_sec( + self, + start_latent: int, + end_latent: int, + dtype: torch.dtype, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """ + Converts latent indices into real-time seconds while honoring causal + offsets and the configured hop length. + Args: + start_latent: Inclusive start index inside the latent sequence. This + sets the first timestamp returned. + end_latent: Exclusive end index. Determines how many timestamps get + generated. + dtype: Floating-point dtype used for the returned tensor, allowing + callers to control precision. + device: Target device for the timestamp tensor. When omitted the + computation occurs on CPU to avoid surprising GPU allocations. + """ + if device is None: + device = torch.device("cpu") + + audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device) + + audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor + + if self.is_causal: + # Frame offset for causal alignment. + # The "+1" ensures the timestamp corresponds to the first sample that is fully available. + causal_offset = 1 + audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0) + + return audio_mel_frame * self.hop_length / self.sample_rate + + def _compute_audio_timings( + self, + batch_size: int, + num_steps: int, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """ + Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame. + This helper method underpins `get_patch_grid_bounds` for the audio patchifier. + Args: + batch_size: Number of sequences to broadcast the timings over. + num_steps: Number of latent frames (time steps) to convert into timestamps. + device: Device on which the resulting tensor should reside. + """ + resolved_device = device + if resolved_device is None: + resolved_device = torch.device("cpu") + + start_timings = self._get_audio_latent_time_in_sec( + self.shift, + num_steps + self.shift, + torch.float32, + resolved_device, + ) + start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) + + end_timings = self._get_audio_latent_time_in_sec( + self.shift + 1, + num_steps + self.shift + 1, + torch.float32, + resolved_device, + ) + end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) + + return torch.stack([start_timings, end_timings], dim=-1) + + def patchify( + self, + audio_latents: torch.Tensor, + ) -> torch.Tensor: + """ + Flattens the audio latent tensor along time. Use `get_patch_grid_bounds` + to derive timestamps for each latent frame based on the configured hop + length and downsampling. + Args: + audio_latents: Latent tensor to patchify. + Returns: + Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the + corresponding timing metadata when needed. + """ + audio_latents = einops.rearrange( + audio_latents, + "b c t f -> b t (c f)", + ) + + return audio_latents + + def unpatchify( + self, + audio_latents: torch.Tensor, + output_shape: AudioLatentShape, + ) -> torch.Tensor: + """ + Restores the `(B, C, T, F)` spectrogram tensor from flattened patches. + Use `get_patch_grid_bounds` to recompute the timestamps that describe each + frame's position in real time. + Args: + audio_latents: Latent tensor to unpatchify. + output_shape: Shape of the unpatched output tensor. + Returns: + Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing + metadata associated with the restored latents. + """ + # audio_latents shape: (batch, time, freq * channels) + audio_latents = einops.rearrange( + audio_latents, + "b t (c f) -> b c t f", + c=output_shape.channels, + f=output_shape.mel_bins, + ) + + return audio_latents + + def unpatchify_audio( + self, + audio_latents: torch.Tensor, + channels: int, + mel_bins: int + ) -> torch.Tensor: + audio_latents = einops.rearrange( + audio_latents, + "b t (c f) -> b c t f", + c=channels, + f=mel_bins, + ) + return audio_latents + + def get_patch_grid_bounds( + self, + output_shape: AudioLatentShape | VideoLatentShape, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """ + Return the temporal bounds `[inclusive start, exclusive end)` for every + patch emitted by `patchify`. For audio this corresponds to timestamps in + seconds aligned with the original spectrogram grid. + The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where: + - axis 1 (size 1) represents the temporal dimension + - axis 3 (size 2) stores the `[start, end)` timestamps per patch + Args: + output_shape: Audio grid specification describing the number of time steps. + device: Target device for the returned tensor. + """ + if not isinstance(output_shape, AudioLatentShape): + raise ValueError("AudioPatchifier expects AudioLatentShape when computing coordinates") + + return self._compute_audio_timings(output_shape.batch, output_shape.frames, device) + + +class AttentionType(Enum): + """Enum for specifying the attention mechanism type.""" + + VANILLA = "vanilla" + LINEAR = "linear" + NONE = "none" + + +class AttnBlock(torch.nn.Module): + def __init__( + self, + in_channels: int, + norm_type: NormType = NormType.GROUP, + ) -> None: + super().__init__() + self.in_channels = in_channels + + self.norm = build_normalization_layer(in_channels, normtype=norm_type) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w).contiguous() + q = q.permute(0, 2, 1).contiguous() # b,hw,c + k = k.reshape(b, c, h * w).contiguous() # b,c,hw + w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w).contiguous() + w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w).contiguous() + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn( + in_channels: int, + attn_type: AttentionType = AttentionType.VANILLA, + norm_type: NormType = NormType.GROUP, +) -> torch.nn.Module: + match attn_type: + case AttentionType.VANILLA: + return AttnBlock(in_channels, norm_type=norm_type) + case AttentionType.NONE: + return torch.nn.Identity() + case AttentionType.LINEAR: + raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.") + case _: + raise ValueError(f"Unknown attention type: {attn_type}") + + +class CausalityAxis(Enum): + """Enum for specifying the causality axis in causal convolutions.""" + + NONE = None + WIDTH = "width" + HEIGHT = "height" + WIDTH_COMPATIBILITY = "width-compatibility" + + +class CausalConv2d(torch.nn.Module): + """ + A causal 2D convolution. + This layer ensures that the output at time `t` only depends on inputs + at time `t` and earlier. It achieves this by applying asymmetric padding + to the time dimension (width) before the convolution. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int], + stride: int = 1, + dilation: int | tuple[int, int] = 1, + groups: int = 1, + bias: bool = True, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ) -> None: + super().__init__() + + self.causality_axis = causality_axis + + # Ensure kernel_size and dilation are tuples + kernel_size = torch.nn.modules.utils._pair(kernel_size) + dilation = torch.nn.modules.utils._pair(dilation) + + # Calculate padding dimensions + pad_h = (kernel_size[0] - 1) * dilation[0] + pad_w = (kernel_size[1] - 1) * dilation[1] + + # The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom) + match self.causality_axis: + case CausalityAxis.NONE: + self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY: + self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2) + case CausalityAxis.HEIGHT: + self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0) + case _: + raise ValueError(f"Invalid causality_axis: {causality_axis}") + + # The internal convolution layer uses no padding, as we handle it manually + self.conv = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Apply causal padding before convolution + x = F.pad(x, self.padding) + return self.conv(x) + + +def make_conv2d( + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int], + stride: int = 1, + padding: tuple[int, int, int, int] | None = None, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causality_axis: CausalityAxis | None = None, +) -> torch.nn.Module: + """ + Create a 2D convolution layer that can be either causal or non-causal. + Args: + in_channels: Number of input channels + out_channels: Number of output channels + kernel_size: Size of the convolution kernel + stride: Convolution stride + padding: Padding (if None, will be calculated based on causal flag) + dilation: Dilation rate + groups: Number of groups for grouped convolution + bias: Whether to use bias + causality_axis: Dimension along which to apply causality. + Returns: + Either a regular Conv2d or CausalConv2d layer + """ + if causality_axis is not None: + # For causal convolution, padding is handled internally by CausalConv2d + return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis) + else: + # For non-causal convolution, use symmetric padding if not specified + if padding is None: + padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size) + + return torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + + + +LRELU_SLOPE = 0.1 + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int, int] = (1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = torch.nn.ModuleList( + [ + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding="same", + ), + ] + ) + + self.convs2 = torch.nn.ModuleList( + [ + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding="same", + ), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for conv1, conv2 in zip(self.convs1, self.convs2, strict=True): + xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) + xt = conv1(xt) + xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE) + xt = conv2(xt) + x = xt + x + return x + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int] = (1, 3)): + super(ResBlock2, self).__init__() + self.convs = torch.nn.ModuleList( + [ + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding="same", + ), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for conv in self.convs: + xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) + xt = conv(xt) + x = xt + x + return x + + +class ResnetBlock(torch.nn.Module): + def __init__( + self, + *, + in_channels: int, + out_channels: int | None = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + norm_type: NormType = NormType.GROUP, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ) -> None: + super().__init__() + self.causality_axis = causality_axis + + if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP: + raise ValueError("Causal ResnetBlock with GroupNorm is not supported.") + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = build_normalization_layer(in_channels, normtype=norm_type) + self.non_linearity = torch.nn.SiLU() + self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = build_normalization_layer(out_channels, normtype=norm_type) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.nin_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis + ) + + def forward( + self, + x: torch.Tensor, + temb: torch.Tensor | None = None, + ) -> torch.Tensor: + h = x + h = self.norm1(h) + h = self.non_linearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = self.non_linearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x) + + return x + h + + +class Downsample(torch.nn.Module): + """ + A downsampling layer that can use either a strided convolution + or average pooling. Supports standard and causal padding for the + convolutional mode. + """ + + def __init__( + self, + in_channels: int, + with_conv: bool, + causality_axis: CausalityAxis = CausalityAxis.WIDTH, + ) -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + + if self.causality_axis != CausalityAxis.NONE and not self.with_conv: + raise ValueError("causality is only supported when `with_conv=True`.") + + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.with_conv: + # Padding tuple is in the order: (left, right, top, bottom). + match self.causality_axis: + case CausalityAxis.NONE: + pad = (0, 1, 0, 1) + case CausalityAxis.WIDTH: + pad = (2, 0, 0, 1) + case CausalityAxis.HEIGHT: + pad = (0, 1, 2, 0) + case CausalityAxis.WIDTH_COMPATIBILITY: + pad = (1, 0, 0, 1) + case _: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + # This branch is only taken if with_conv=False, which implies causality_axis is NONE. + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + + return x + + +def build_downsampling_path( # noqa: PLR0913 + *, + ch: int, + ch_mult: Tuple[int, ...], + num_resolutions: int, + num_res_blocks: int, + resolution: int, + temb_channels: int, + dropout: float, + norm_type: NormType, + causality_axis: CausalityAxis, + attn_type: AttentionType, + attn_resolutions: Set[int], + resamp_with_conv: bool, +) -> tuple[torch.nn.ModuleList, int]: + """Build the downsampling path with residual blocks, attention, and downsampling layers.""" + down_modules = torch.nn.ModuleList() + curr_res = resolution + in_ch_mult = (1, *tuple(ch_mult)) + block_in = ch + + for i_level in range(num_resolutions): + block = torch.nn.ModuleList() + attn = torch.nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + + for _ in range(num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=temb_channels, + dropout=dropout, + norm_type=norm_type, + causality_axis=causality_axis, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type)) + + down = torch.nn.Module() + down.block = block + down.attn = attn + if i_level != num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis) + curr_res = curr_res // 2 + down_modules.append(down) + + return down_modules, block_in + + +class Upsample(torch.nn.Module): + def __init__( + self, + in_channels: int, + with_conv: bool, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ) -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + if self.with_conv: + self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + # Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n. + # For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2]. + # The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2], + # So the output elements rely on the following windows: + # 0: [-,-,0] + # 1: [-,0,0] + # 2: [0,0,1] + # 3: [0,1,1] + # 4: [1,1,2] + # 5: [1,2,2] + # Notice that the first and second elements in the output rely only on the first element in the input, + # while all other elements rely on two elements in the input. + # So we can drop the first element to undo the padding (rather than the last element). + # This is a no-op for non-causal convolutions. + match self.causality_axis: + case CausalityAxis.NONE: + pass # x remains unchanged + case CausalityAxis.HEIGHT: + x = x[:, :, 1:, :] + case CausalityAxis.WIDTH: + x = x[:, :, :, 1:] + case CausalityAxis.WIDTH_COMPATIBILITY: + pass # x remains unchanged + case _: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + return x + + +def build_upsampling_path( # noqa: PLR0913 + *, + ch: int, + ch_mult: Tuple[int, ...], + num_resolutions: int, + num_res_blocks: int, + resolution: int, + temb_channels: int, + dropout: float, + norm_type: NormType, + causality_axis: CausalityAxis, + attn_type: AttentionType, + attn_resolutions: Set[int], + resamp_with_conv: bool, + initial_block_channels: int, +) -> tuple[torch.nn.ModuleList, int]: + """Build the upsampling path with residual blocks, attention, and upsampling layers.""" + up_modules = torch.nn.ModuleList() + block_in = initial_block_channels + curr_res = resolution // (2 ** (num_resolutions - 1)) + + for level in reversed(range(num_resolutions)): + stage = torch.nn.Module() + stage.block = torch.nn.ModuleList() + stage.attn = torch.nn.ModuleList() + block_out = ch * ch_mult[level] + + for _ in range(num_res_blocks + 1): + stage.block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=temb_channels, + dropout=dropout, + norm_type=norm_type, + causality_axis=causality_axis, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + stage.attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type)) + + if level != 0: + stage.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis) + curr_res *= 2 + + up_modules.insert(0, stage) + + return up_modules, block_in + + +class PerChannelStatistics(nn.Module): + """ + Per-channel statistics for normalizing and denormalizing the latent representation. + This statics is computed over the entire dataset and stored in model's checkpoint under AudioVAE state_dict. + """ + + def __init__(self, latent_channels: int = 128) -> None: + super().__init__() + self.register_buffer("std-of-means", torch.empty(latent_channels)) + self.register_buffer("mean-of-means", torch.empty(latent_channels)) + + def un_normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x) + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x) + + +LATENT_DOWNSAMPLE_FACTOR = 4 + + +def build_mid_block( + channels: int, + temb_channels: int, + dropout: float, + norm_type: NormType, + causality_axis: CausalityAxis, + attn_type: AttentionType, + add_attention: bool, +) -> torch.nn.Module: + """Build the middle block with two ResNet blocks and optional attention.""" + mid = torch.nn.Module() + mid.block_1 = ResnetBlock( + in_channels=channels, + out_channels=channels, + temb_channels=temb_channels, + dropout=dropout, + norm_type=norm_type, + causality_axis=causality_axis, + ) + mid.attn_1 = make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else torch.nn.Identity() + mid.block_2 = ResnetBlock( + in_channels=channels, + out_channels=channels, + temb_channels=temb_channels, + dropout=dropout, + norm_type=norm_type, + causality_axis=causality_axis, + ) + return mid + + +def run_mid_block(mid: torch.nn.Module, features: torch.Tensor) -> torch.Tensor: + """Run features through the middle block.""" + features = mid.block_1(features, temb=None) + features = mid.attn_1(features) + return mid.block_2(features, temb=None) + + +class LTX2AudioEncoder(torch.nn.Module): + """ + Encoder that compresses audio spectrograms into latent representations. + The encoder uses a series of downsampling blocks with residual connections, + attention mechanisms, and configurable causal convolutions. + """ + + def __init__( # noqa: PLR0913 + self, + *, + ch: int = 128, + ch_mult: Tuple[int, ...] = (1, 2, 4), + num_res_blocks: int = 2, + attn_resolutions: Set[int] = set(), + dropout: float = 0.0, + resamp_with_conv: bool = True, + in_channels: int = 2, + resolution: int = 256, + z_channels: int = 8, + double_z: bool = True, + attn_type: AttentionType = AttentionType.VANILLA, + mid_block_add_attention: bool = False, + norm_type: NormType = NormType.PIXEL, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + sample_rate: int = 16000, + mel_hop_length: int = 160, + n_fft: int = 1024, + is_causal: bool = True, + mel_bins: int = 64, + **_ignore_kwargs, + ) -> None: + """ + Initialize the Encoder. + Args: + Arguments are configuration parameters, loaded from the audio VAE checkpoint config + (audio_vae.model.params.ddconfig): + ch: Base number of feature channels used in the first convolution layer. + ch_mult: Multiplicative factors for the number of channels at each resolution level. + num_res_blocks: Number of residual blocks to use at each resolution level. + attn_resolutions: Spatial resolutions (e.g., in time/frequency) at which to apply attention. + resolution: Input spatial resolution of the spectrogram (height, width). + z_channels: Number of channels in the latent representation. + norm_type: Normalization layer type to use within the network (e.g., group, batch). + causality_axis: Axis along which convolutions should be causal (e.g., time axis). + sample_rate: Audio sample rate in Hz for the input signals. + mel_hop_length: Hop length used when computing the mel spectrogram. + n_fft: FFT size used to compute the spectrogram. + mel_bins: Number of mel-frequency bins in the input spectrogram. + in_channels: Number of channels in the input spectrogram tensor. + double_z: If True, predict both mean and log-variance (doubling latent channels). + is_causal: If True, use causal convolutions suitable for streaming setups. + dropout: Dropout probability used in residual and mid blocks. + attn_type: Type of attention mechanism to use in attention blocks. + resamp_with_conv: If True, perform resolution changes using strided convolutions. + mid_block_add_attention: If True, add an attention block in the mid-level of the encoder. + """ + super().__init__() + + self.per_channel_statistics = PerChannelStatistics(latent_channels=ch) + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.n_fft = n_fft + self.is_causal = is_causal + self.mel_bins = mel_bins + + self.patchifier = AudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=sample_rate, + hop_length=mel_hop_length, + is_causal=is_causal, + ) + + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.z_channels = z_channels + self.double_z = double_z + self.norm_type = norm_type + self.causality_axis = causality_axis + self.attn_type = attn_type + + # downsampling + self.conv_in = make_conv2d( + in_channels, + self.ch, + kernel_size=3, + stride=1, + causality_axis=self.causality_axis, + ) + + self.non_linearity = torch.nn.SiLU() + + self.down, block_in = build_downsampling_path( + ch=ch, + ch_mult=ch_mult, + num_resolutions=self.num_resolutions, + num_res_blocks=num_res_blocks, + resolution=resolution, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + attn_type=self.attn_type, + attn_resolutions=attn_resolutions, + resamp_with_conv=resamp_with_conv, + ) + + self.mid = build_mid_block( + channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + attn_type=self.attn_type, + add_attention=mid_block_add_attention, + ) + + self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type) + self.conv_out = make_conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + causality_axis=self.causality_axis, + ) + + def forward(self, spectrogram: torch.Tensor) -> torch.Tensor: + """ + Encode audio spectrogram into latent representations. + Args: + spectrogram: Input spectrogram of shape (batch, channels, time, frequency) + Returns: + Encoded latent representation of shape (batch, channels, frames, mel_bins) + """ + h = self.conv_in(spectrogram) + h = self._run_downsampling_path(h) + h = run_mid_block(self.mid, h) + h = self._finalize_output(h) + + return self._normalize_latents(h) + + def _run_downsampling_path(self, h: torch.Tensor) -> torch.Tensor: + for level in range(self.num_resolutions): + stage = self.down[level] + for block_idx in range(self.num_res_blocks): + h = stage.block[block_idx](h, temb=None) + if stage.attn: + h = stage.attn[block_idx](h) + + if level != self.num_resolutions - 1: + h = stage.downsample(h) + + return h + + def _finalize_output(self, h: torch.Tensor) -> torch.Tensor: + h = self.norm_out(h) + h = self.non_linearity(h) + return self.conv_out(h) + + def _normalize_latents(self, latent_output: torch.Tensor) -> torch.Tensor: + """ + Normalize encoder latents using per-channel statistics. + When the encoder is configured with ``double_z=True``, the final + convolution produces twice the number of latent channels, typically + interpreted as two concatenated tensors along the channel dimension + (e.g., mean and variance or other auxiliary parameters). + This method intentionally uses only the first half of the channels + (the "mean" component) as input to the patchifier and normalization + logic. The remaining channels are left unchanged by this method and + are expected to be consumed elsewhere in the VAE pipeline. + If ``double_z=False``, the encoder output already contains only the + mean latents and the chunking operation simply returns that tensor. + """ + means = torch.chunk(latent_output, 2, dim=1)[0] + latent_shape = AudioLatentShape( + batch=means.shape[0], + channels=means.shape[1], + frames=means.shape[2], + mel_bins=means.shape[3], + ) + latent_patched = self.patchifier.patchify(means) + latent_normalized = self.per_channel_statistics.normalize(latent_patched) + return self.patchifier.unpatchify(latent_normalized, latent_shape) + + +class LTX2AudioDecoder(torch.nn.Module): + """ + Symmetric decoder that reconstructs audio spectrograms from latent features. + The decoder mirrors the encoder structure with configurable channel multipliers, + attention resolutions, and causal convolutions. + """ + + def __init__( # noqa: PLR0913 + self, + *, + ch: int = 128, + out_ch: int = 2, + ch_mult: Tuple[int, ...] = (1, 2, 4), + num_res_blocks: int = 2, + attn_resolutions: Set[int] = set(), + resolution: int=256, + z_channels: int=8, + norm_type: NormType = NormType.PIXEL, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + dropout: float = 0.0, + mid_block_add_attention: bool = False, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: int | None = 64, + ) -> None: + """ + Initialize the Decoder. + Args: + Arguments are configuration parameters, loaded from the audio VAE checkpoint config + (audio_vae.model.params.ddconfig): + - ch, out_ch, ch_mult, num_res_blocks, attn_resolutions + - resolution, z_channels + - norm_type, causality_axis + """ + super().__init__() + + # Internal behavioural defaults that are not driven by the checkpoint. + resamp_with_conv = True + attn_type = AttentionType.VANILLA + + # Per-channel statistics for denormalizing latents + self.per_channel_statistics = PerChannelStatistics(latent_channels=ch) + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.is_causal = is_causal + self.mel_bins = mel_bins + self.patchifier = AudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=sample_rate, + hop_length=mel_hop_length, + is_causal=is_causal, + ) + + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.out_ch = out_ch + self.give_pre_end = False + self.tanh_out = False + self.norm_type = norm_type + self.z_channels = z_channels + self.channel_multipliers = ch_mult + self.attn_resolutions = attn_resolutions + self.causality_axis = causality_axis + self.attn_type = attn_type + + base_block_channels = ch * self.channel_multipliers[-1] + base_resolution = resolution // (2 ** (self.num_resolutions - 1)) + self.z_shape = (1, z_channels, base_resolution, base_resolution) + + self.conv_in = make_conv2d( + z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + self.non_linearity = torch.nn.SiLU() + self.mid = build_mid_block( + channels=base_block_channels, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + attn_type=self.attn_type, + add_attention=mid_block_add_attention, + ) + self.up, final_block_channels = build_upsampling_path( + ch=ch, + ch_mult=ch_mult, + num_resolutions=self.num_resolutions, + num_res_blocks=num_res_blocks, + resolution=resolution, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + attn_type=self.attn_type, + attn_resolutions=attn_resolutions, + resamp_with_conv=resamp_with_conv, + initial_block_channels=base_block_channels, + ) + + self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type) + self.conv_out = make_conv2d( + final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + """ + Decode latent features back to audio spectrograms. + Args: + sample: Encoded latent representation of shape (batch, channels, frames, mel_bins) + Returns: + Reconstructed audio spectrogram of shape (batch, channels, time, frequency) + """ + sample, target_shape = self._denormalize_latents(sample) + + h = self.conv_in(sample) + h = run_mid_block(self.mid, h) + h = self._run_upsampling_path(h) + h = self._finalize_output(h) + + return self._adjust_output_shape(h, target_shape) + + def _denormalize_latents(self, sample: torch.Tensor) -> tuple[torch.Tensor, AudioLatentShape]: + latent_shape = AudioLatentShape( + batch=sample.shape[0], + channels=sample.shape[1], + frames=sample.shape[2], + mel_bins=sample.shape[3], + ) + + sample_patched = self.patchifier.patchify(sample) + sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched) + sample = self.patchifier.unpatchify(sample_denormalized, latent_shape) + + target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR + if self.causality_axis != CausalityAxis.NONE: + target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1) + + target_shape = AudioLatentShape( + batch=latent_shape.batch, + channels=self.out_ch, + frames=target_frames, + mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins, + ) + + return sample, target_shape + + def _adjust_output_shape( + self, + decoded_output: torch.Tensor, + target_shape: AudioLatentShape, + ) -> torch.Tensor: + """ + Adjust output shape to match target dimensions for variable-length audio. + This function handles the common case where decoded audio spectrograms need to be + resized to match a specific target shape. + Args: + decoded_output: Tensor of shape (batch, channels, time, frequency) + target_shape: AudioLatentShape describing (batch, channels, time, mel bins) + Returns: + Tensor adjusted to match target_shape exactly + """ + # Current output shape: (batch, channels, time, frequency) + _, _, current_time, current_freq = decoded_output.shape + target_channels = target_shape.channels + target_time = target_shape.frames + target_freq = target_shape.mel_bins + + # Step 1: Crop first to avoid exceeding target dimensions + decoded_output = decoded_output[ + :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq) + ] + + # Step 2: Calculate padding needed for time and frequency dimensions + time_padding_needed = target_time - decoded_output.shape[2] + freq_padding_needed = target_freq - decoded_output.shape[3] + + # Step 3: Apply padding if needed + if time_padding_needed > 0 or freq_padding_needed > 0: + # PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom) + # For audio: pad_left/right = frequency, pad_top/bottom = time + padding = ( + 0, + max(freq_padding_needed, 0), # frequency padding (left, right) + 0, + max(time_padding_needed, 0), # time padding (top, bottom) + ) + decoded_output = F.pad(decoded_output, padding) + + # Step 4: Final safety crop to ensure exact target shape + decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] + + return decoded_output + + def _run_upsampling_path(self, h: torch.Tensor) -> torch.Tensor: + for level in reversed(range(self.num_resolutions)): + stage = self.up[level] + for block_idx, block in enumerate(stage.block): + h = block(h, temb=None) + if stage.attn: + h = stage.attn[block_idx](h) + + if level != 0 and hasattr(stage, "upsample"): + h = stage.upsample(h) + + return h + + def _finalize_output(self, h: torch.Tensor) -> torch.Tensor: + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = self.non_linearity(h) + h = self.conv_out(h) + return torch.tanh(h) if self.tanh_out else h + + +class LTX2Vocoder(torch.nn.Module): + """ + Vocoder model for synthesizing audio from Mel spectrograms. + Args: + resblock_kernel_sizes: List of kernel sizes for the residual blocks. + This value is read from the checkpoint at `config.vocoder.resblock_kernel_sizes`. + upsample_rates: List of upsampling rates. + This value is read from the checkpoint at `config.vocoder.upsample_rates`. + upsample_kernel_sizes: List of kernel sizes for the upsampling layers. + This value is read from the checkpoint at `config.vocoder.upsample_kernel_sizes`. + resblock_dilation_sizes: List of dilation sizes for the residual blocks. + This value is read from the checkpoint at `config.vocoder.resblock_dilation_sizes`. + upsample_initial_channel: Initial number of channels for the upsampling layers. + This value is read from the checkpoint at `config.vocoder.upsample_initial_channel`. + stereo: Whether to use stereo output. + This value is read from the checkpoint at `config.vocoder.stereo`. + resblock: Type of residual block to use. + This value is read from the checkpoint at `config.vocoder.resblock`. + output_sample_rate: Waveform sample rate. + This value is read from the checkpoint at `config.vocoder.output_sample_rate`. + """ + + def __init__( + self, + resblock_kernel_sizes: List[int] | None = [3, 7, 11], + upsample_rates: List[int] | None = [6, 5, 2, 2, 2], + upsample_kernel_sizes: List[int] | None = [16, 15, 8, 4, 4], + resblock_dilation_sizes: List[List[int]] | None = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_initial_channel: int = 1024, + stereo: bool = True, + resblock: str = "1", + output_sample_rate: int = 24000, + ): + super().__init__() + + # Initialize default values if not provided. Note that mutable default values are not supported. + if resblock_kernel_sizes is None: + resblock_kernel_sizes = [3, 7, 11] + if upsample_rates is None: + upsample_rates = [6, 5, 2, 2, 2] + if upsample_kernel_sizes is None: + upsample_kernel_sizes = [16, 15, 8, 4, 4] + if resblock_dilation_sizes is None: + resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + + self.output_sample_rate = output_sample_rate + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + in_channels = 128 if stereo else 64 + self.conv_pre = nn.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3) + resblock_class = ResBlock1 if resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True)): + self.ups.append( + nn.ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + kernel_size, + stride, + padding=(kernel_size - stride) // 2, + ) + ) + + self.resblocks = nn.ModuleList() + for i, _ in enumerate(self.ups): + ch = upsample_initial_channel // (2 ** (i + 1)) + for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes, strict=True): + self.resblocks.append(resblock_class(ch, kernel_size, dilations)) + + out_channels = 2 if stereo else 1 + final_channels = upsample_initial_channel // (2**self.num_upsamples) + self.conv_post = nn.Conv1d(final_channels, out_channels, 7, 1, padding=3) + + self.upsample_factor = math.prod(layer.stride[0] for layer in self.ups) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the vocoder. + Args: + x: Input Mel spectrogram tensor. Can be either: + - 3D: (batch_size, time, mel_bins) for mono + - 4D: (batch_size, 2, time, mel_bins) for stereo + Returns: + Audio waveform tensor of shape (batch_size, out_channels, audio_length) + """ + x = x.transpose(2, 3) # (batch, channels, time, mel_bins) -> (batch, channels, mel_bins, time) + + if x.dim() == 4: # stereo + assert x.shape[1] == 2, "Input must have 2 channels for stereo" + x = einops.rearrange(x, "b s c t -> b (s c) t") + + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + start = i * self.num_kernels + end = start + self.num_kernels + + # Evaluate all resblocks with the same input tensor so they can run + # independently (and thus in parallel on accelerator hardware) before + # aggregating their outputs via mean. + block_outputs = torch.stack( + [self.resblocks[idx](x) for idx in range(start, end)], + dim=0, + ) + + x = block_outputs.mean(dim=0) + + x = self.conv_post(F.leaky_relu(x)) + return torch.tanh(x) + + +def decode_audio(latent: torch.Tensor, audio_decoder: "LTX2AudioDecoder", vocoder: "LTX2Vocoder") -> torch.Tensor: + """ + Decode an audio latent representation using the provided audio decoder and vocoder. + Args: + latent: Input audio latent tensor. + audio_decoder: Model to decode the latent to waveform features. + vocoder: Model to convert decoded features to audio waveform. + Returns: + Decoded audio as a float tensor. + """ + decoded_audio = audio_decoder(latent) + decoded_audio = vocoder(decoded_audio).squeeze(0).float() + return decoded_audio diff --git a/diffsynth/models/ltx2_common.py b/diffsynth/models/ltx2_common.py new file mode 100644 index 0000000..a06ccd6 --- /dev/null +++ b/diffsynth/models/ltx2_common.py @@ -0,0 +1,371 @@ +from dataclasses import dataclass +from typing import NamedTuple, Protocol, Tuple +import torch +from torch import nn +from enum import Enum + + +class VideoPixelShape(NamedTuple): + """ + Shape of the tensor representing the video pixel array. Assumes BGR channel format. + """ + + batch: int + frames: int + height: int + width: int + fps: float + + +class SpatioTemporalScaleFactors(NamedTuple): + """ + Describes the spatiotemporal downscaling between decoded video space and + the corresponding VAE latent grid. + """ + + time: int + width: int + height: int + + @classmethod + def default(cls) -> "SpatioTemporalScaleFactors": + return cls(time=8, width=32, height=32) + + +VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default() + + +class VideoLatentShape(NamedTuple): + """ + Shape of the tensor representing video in VAE latent space. + The latent representation is a 5D tensor with dimensions ordered as + (batch, channels, frames, height, width). Spatial and temporal dimensions + are downscaled relative to pixel space according to the VAE's scale factors. + """ + + batch: int + channels: int + frames: int + height: int + width: int + + def to_torch_shape(self) -> torch.Size: + return torch.Size([self.batch, self.channels, self.frames, self.height, self.width]) + + @staticmethod + def from_torch_shape(shape: torch.Size) -> "VideoLatentShape": + return VideoLatentShape( + batch=shape[0], + channels=shape[1], + frames=shape[2], + height=shape[3], + width=shape[4], + ) + + def mask_shape(self) -> "VideoLatentShape": + return self._replace(channels=1) + + @staticmethod + def from_pixel_shape( + shape: VideoPixelShape, + latent_channels: int = 128, + scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS, + ) -> "VideoLatentShape": + frames = (shape.frames - 1) // scale_factors[0] + 1 + height = shape.height // scale_factors[1] + width = shape.width // scale_factors[2] + + return VideoLatentShape( + batch=shape.batch, + channels=latent_channels, + frames=frames, + height=height, + width=width, + ) + + def upscale(self, scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS) -> "VideoLatentShape": + return self._replace( + channels=3, + frames=(self.frames - 1) * scale_factors.time + 1, + height=self.height * scale_factors.height, + width=self.width * scale_factors.width, + ) + + +class AudioLatentShape(NamedTuple): + """ + Shape of audio in VAE latent space: (batch, channels, frames, mel_bins). + mel_bins is the number of frequency bins from the mel-spectrogram encoding. + """ + + batch: int + channels: int + frames: int + mel_bins: int + + def to_torch_shape(self) -> torch.Size: + return torch.Size([self.batch, self.channels, self.frames, self.mel_bins]) + + def mask_shape(self) -> "AudioLatentShape": + return self._replace(channels=1, mel_bins=1) + + @staticmethod + def from_torch_shape(shape: torch.Size) -> "AudioLatentShape": + return AudioLatentShape( + batch=shape[0], + channels=shape[1], + frames=shape[2], + mel_bins=shape[3], + ) + + @staticmethod + def from_duration( + batch: int, + duration: float, + channels: int = 8, + mel_bins: int = 16, + sample_rate: int = 16000, + hop_length: int = 160, + audio_latent_downsample_factor: int = 4, + ) -> "AudioLatentShape": + latents_per_second = float(sample_rate) / float(hop_length) / float(audio_latent_downsample_factor) + + return AudioLatentShape( + batch=batch, + channels=channels, + frames=round(duration * latents_per_second), + mel_bins=mel_bins, + ) + + @staticmethod + def from_video_pixel_shape( + shape: VideoPixelShape, + channels: int = 8, + mel_bins: int = 16, + sample_rate: int = 16000, + hop_length: int = 160, + audio_latent_downsample_factor: int = 4, + ) -> "AudioLatentShape": + return AudioLatentShape.from_duration( + batch=shape.batch, + duration=float(shape.frames) / float(shape.fps), + channels=channels, + mel_bins=mel_bins, + sample_rate=sample_rate, + hop_length=hop_length, + audio_latent_downsample_factor=audio_latent_downsample_factor, + ) + + +@dataclass(frozen=True) +class LatentState: + """ + State of latents during the diffusion denoising process. + Attributes: + latent: The current noisy latent tensor being denoised. + denoise_mask: Mask encoding the denoising strength for each token (1 = full denoising, 0 = no denoising). + positions: Positional indices for each latent element, used for positional embeddings. + clean_latent: Initial state of the latent before denoising, may include conditioning latents. + """ + + latent: torch.Tensor + denoise_mask: torch.Tensor + positions: torch.Tensor + clean_latent: torch.Tensor + + def clone(self) -> "LatentState": + return LatentState( + latent=self.latent.clone(), + denoise_mask=self.denoise_mask.clone(), + positions=self.positions.clone(), + clean_latent=self.clean_latent.clone(), + ) + + +class NormType(Enum): + """Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm).""" + + GROUP = "group" + PIXEL = "pixel" + + +class PixelNorm(nn.Module): + """ + Per-pixel (per-location) RMS normalization layer. + For each element along the chosen dimension, this layer normalizes the tensor + by the root-mean-square of its values across that dimension: + y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps) + """ + + def __init__(self, dim: int = 1, eps: float = 1e-8) -> None: + """ + Args: + dim: Dimension along which to compute the RMS (typically channels). + eps: Small constant added for numerical stability. + """ + super().__init__() + self.dim = dim + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply RMS normalization along the configured dimension. + """ + # Compute mean of squared values along `dim`, keep dimensions for broadcasting. + mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True) + # Normalize by the root-mean-square (RMS). + rms = torch.sqrt(mean_sq + self.eps) + return x / rms + + +def build_normalization_layer( + in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP +) -> nn.Module: + """ + Create a normalization layer based on the normalization type. + Args: + in_channels: Number of input channels + num_groups: Number of groups for group normalization + normtype: Type of normalization: "group" or "pixel" + Returns: + A normalization layer + """ + if normtype == NormType.GROUP: + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if normtype == NormType.PIXEL: + return PixelNorm(dim=1, eps=1e-6) + raise ValueError(f"Invalid normalization type: {normtype}") + + +def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor: + """Root-mean-square (RMS) normalize `x` over its last dimension. + Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized + shape and forwards `weight` and `eps`. + """ + return torch.nn.functional.rms_norm(x, (x.shape[-1],), weight=weight, eps=eps) + + +@dataclass(frozen=True) +class Modality: + """ + Input data for a single modality (video or audio) in the transformer. + Bundles the latent tokens, timestep embeddings, positional information, + and text conditioning context for processing by the diffusion transformer. + """ + + latent: ( + torch.Tensor + ) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension + timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps + positions: ( + torch.Tensor + ) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens + context: torch.Tensor + enabled: bool = True + context_mask: torch.Tensor | None = None + + +def to_denoised( + sample: torch.Tensor, + velocity: torch.Tensor, + sigma: float | torch.Tensor, + calc_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Convert the sample and its denoising velocity to denoised sample. + Returns: + Denoised sample + """ + if isinstance(sigma, torch.Tensor): + sigma = sigma.to(calc_dtype) + return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype) + + + +class Patchifier(Protocol): + """ + Protocol for patchifiers that convert latent tensors into patches and assemble them back. + """ + + def patchify( + self, + latents: torch.Tensor, + ) -> torch.Tensor: + ... + """ + Convert latent tensors into flattened patch tokens. + Args: + latents: Latent tensor to patchify. + Returns: + Flattened patch tokens tensor. + """ + + def unpatchify( + self, + latents: torch.Tensor, + output_shape: AudioLatentShape | VideoLatentShape, + ) -> torch.Tensor: + """ + Converts latent tensors between spatio-temporal formats and flattened sequence representations. + Args: + latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`. + output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or + VideoLatentShape. + Returns: + Dense latent tensor restored from the flattened representation. + """ + + @property + def patch_size(self) -> Tuple[int, int, int]: + ... + """ + Returns the patch size as a tuple of (temporal, height, width) dimensions + """ + + def get_patch_grid_bounds( + self, + output_shape: AudioLatentShape | VideoLatentShape, + device: torch.device | None = None, + ) -> torch.Tensor: + ... + """ + Compute metadata describing where each latent patch resides within the + grid specified by `output_shape`. + Args: + output_shape: Target grid layout for the patches. + device: Target device for the returned tensor. + Returns: + Tensor containing patch coordinate metadata such as spatial or temporal intervals. + """ + + +def get_pixel_coords( + latent_coords: torch.Tensor, + scale_factors: SpatioTemporalScaleFactors, + causal_fix: bool = False, +) -> torch.Tensor: + """ + Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling + each axis (frame/time, height, width) with the corresponding VAE downsampling factors. + Optionally compensate for causal encoding that keeps the first frame at unit temporal scale. + Args: + latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`. + scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied + per axis. + causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs + that treat frame zero differently still yield non-negative timestamps. + """ + # Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout. + broadcast_shape = [1] * latent_coords.ndim + broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width) + scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape) + + # Apply per-axis scaling to convert latent bounds into pixel-space coordinates. + pixel_coords = latent_coords * scale_tensor + + if causal_fix: + # VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`. + # Shift and clamp to keep the first-frame timestamps causal and non-negative. + pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0) + + return pixel_coords diff --git a/diffsynth/models/ltx2_dit.py b/diffsynth/models/ltx2_dit.py new file mode 100644 index 0000000..2e3c958 --- /dev/null +++ b/diffsynth/models/ltx2_dit.py @@ -0,0 +1,1451 @@ +import math +import functools +from dataclasses import dataclass, replace +from enum import Enum +from typing import Optional, Tuple, Callable +import numpy as np +import torch +from einops import rearrange +from .ltx2_common import rms_norm, Modality +from ..core.attention.attention import attention_forward + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> torch.Tensor: + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class TimestepEmbedding(torch.nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + out_dim: int | None = None, + post_act_fn: str | None = None, + cond_proj_dim: int | None = None, + sample_proj_bias: bool = True, + ): + super().__init__() + + self.linear_1 = torch.nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + + if cond_proj_dim is not None: + self.cond_proj = torch.nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = torch.nn.SiLU() + time_embed_dim_out = out_dim if out_dim is not None else time_embed_dim + + self.linear_2 = torch.nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + + if post_act_fn is None: + self.post_act = None + + def forward(self, sample: torch.Tensor, condition: torch.Tensor | None = None) -> torch.Tensor: + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Timesteps(torch.nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class PixArtAlphaCombinedTimestepSizeEmbeddings(torch.nn.Module): + """ + For PixArt-Alpha. + Reference: + https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 + """ + + def __init__( + self, + embedding_dim: int, + size_emb_dim: int, + ): + super().__init__() + + self.outdim = size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward( + self, + timestep: torch.Tensor, + hidden_dtype: torch.dtype, + ) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + return timesteps_emb + + +class PerturbationType(Enum): + """Types of attention perturbations for STG (Spatio-Temporal Guidance).""" + + SKIP_A2V_CROSS_ATTN = "skip_a2v_cross_attn" + SKIP_V2A_CROSS_ATTN = "skip_v2a_cross_attn" + SKIP_VIDEO_SELF_ATTN = "skip_video_self_attn" + SKIP_AUDIO_SELF_ATTN = "skip_audio_self_attn" + + +@dataclass(frozen=True) +class Perturbation: + """A single perturbation specifying which attention type to skip and in which blocks.""" + + type: PerturbationType + blocks: list[int] | None # None means all blocks + + def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool: + if self.type != perturbation_type: + return False + + if self.blocks is None: + return True + + return block in self.blocks + + +@dataclass(frozen=True) +class PerturbationConfig: + """Configuration holding a list of perturbations for a single sample.""" + + perturbations: list[Perturbation] | None + + def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool: + if self.perturbations is None: + return False + + return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations) + + @staticmethod + def empty() -> "PerturbationConfig": + return PerturbationConfig([]) + + +@dataclass(frozen=True) +class BatchedPerturbationConfig: + """Perturbation configurations for a batch, with utilities for generating attention masks.""" + + perturbations: list[PerturbationConfig] + + def mask( + self, perturbation_type: PerturbationType, block: int, device, dtype: torch.dtype + ) -> torch.Tensor: + mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype) + for batch_idx, perturbation in enumerate(self.perturbations): + if perturbation.is_perturbed(perturbation_type, block): + mask[batch_idx] = 0 + + return mask + + def mask_like(self, perturbation_type: PerturbationType, block: int, values: torch.Tensor) -> torch.Tensor: + mask = self.mask(perturbation_type, block, values.device, values.dtype) + return mask.view(mask.numel(), *([1] * len(values.shape[1:]))) + + def any_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool: + return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations) + + def all_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool: + return all(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations) + + @staticmethod + def empty(batch_size: int) -> "BatchedPerturbationConfig": + return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)]) + + +class AdaLayerNormSingle(torch.nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, embedding_coefficient: int = 6): + super().__init__() + + self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim, + size_emb_dim=embedding_dim // 3, + ) + + self.silu = torch.nn.SiLU() + self.linear = torch.nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + embedded_timestep = self.emb(timestep, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + +class LTXRopeType(Enum): + INTERLEAVED = "interleaved" + SPLIT = "split" + + +def apply_rotary_emb( + input_tensor: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor], + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, +) -> torch.Tensor: + if rope_type == LTXRopeType.INTERLEAVED: + return apply_interleaved_rotary_emb(input_tensor, *freqs_cis) + elif rope_type == LTXRopeType.SPLIT: + return apply_split_rotary_emb(input_tensor, *freqs_cis) + else: + raise ValueError(f"Invalid rope type: {rope_type}") + + + +def apply_interleaved_rotary_emb( + input_tensor: torch.Tensor, cos_freqs: torch.Tensor, sin_freqs: torch.Tensor +) -> torch.Tensor: + t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) + t1, t2 = t_dup.unbind(dim=-1) + t_dup = torch.stack((-t2, t1), dim=-1) + input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") + + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out + + +def apply_split_rotary_emb( + input_tensor: torch.Tensor, cos_freqs: torch.Tensor, sin_freqs: torch.Tensor +) -> torch.Tensor: + needs_reshape = False + if input_tensor.ndim != 4 and cos_freqs.ndim == 4: + b, h, t, _ = cos_freqs.shape + input_tensor = input_tensor.reshape(b, t, h, -1).swapaxes(1, 2) + needs_reshape = True + + split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2) + first_half_input = split_input[..., :1, :] + second_half_input = split_input[..., 1:, :] + + output = split_input * cos_freqs.unsqueeze(-2) + first_half_output = output[..., :1, :] + second_half_output = output[..., 1:, :] + + first_half_output.addcmul_(-sin_freqs.unsqueeze(-2), second_half_input) + second_half_output.addcmul_(sin_freqs.unsqueeze(-2), first_half_input) + + output = rearrange(output, "... d r -> ... (d r)") + if needs_reshape: + output = output.swapaxes(1, 2).reshape(b, t, -1) + + return output + + +@functools.lru_cache(maxsize=5) +def generate_freq_grid_np( + positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int +) -> torch.Tensor: + theta = positional_embedding_theta + start = 1 + end = theta + + n_elem = 2 * positional_embedding_max_pos_count + pow_indices = np.power( + theta, + np.linspace( + np.log(start) / np.log(theta), + np.log(end) / np.log(theta), + inner_dim // n_elem, + dtype=np.float64, + ), + ) + return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32) + + +@functools.lru_cache(maxsize=5) +def generate_freq_grid_pytorch( + positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int +) -> torch.Tensor: + theta = positional_embedding_theta + start = 1 + end = theta + n_elem = 2 * positional_embedding_max_pos_count + + indices = theta ** ( + torch.linspace( + math.log(start, theta), + math.log(end, theta), + inner_dim // n_elem, + dtype=torch.float32, + ) + ) + indices = indices.to(dtype=torch.float32) + + indices = indices * math.pi / 2 + + return indices + + +def get_fractional_positions(indices_grid: torch.Tensor, max_pos: list[int]) -> torch.Tensor: + n_pos_dims = indices_grid.shape[1] + assert n_pos_dims == len(max_pos), ( + f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})" + ) + fractional_positions = torch.stack( + [indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)], + dim=-1, + ) + return fractional_positions + + +def generate_freqs( + indices: torch.Tensor, indices_grid: torch.Tensor, max_pos: list[int], use_middle_indices_grid: bool +) -> torch.Tensor: + if use_middle_indices_grid: + assert len(indices_grid.shape) == 4 + assert indices_grid.shape[-1] == 2 + indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1] + indices_grid = (indices_grid_start + indices_grid_end) / 2.0 + elif len(indices_grid.shape) == 4: + indices_grid = indices_grid[..., 0] + + fractional_positions = get_fractional_positions(indices_grid, max_pos) + indices = indices.to(device=fractional_positions.device) + + freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2) + return freqs + + +def split_freqs_cis(freqs: torch.Tensor, pad_size: int, num_attention_heads: int) -> tuple[torch.Tensor, torch.Tensor]: + cos_freq = freqs.cos() + sin_freq = freqs.sin() + + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) + + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + + # Reshape freqs to be compatible with multi-head attention + b = cos_freq.shape[0] + t = cos_freq.shape[1] + + cos_freq = cos_freq.reshape(b, t, num_attention_heads, -1) + sin_freq = sin_freq.reshape(b, t, num_attention_heads, -1) + + cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) + sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) + return cos_freq, sin_freq + + +def interleaved_freqs_cis(freqs: torch.Tensor, pad_size: int) -> tuple[torch.Tensor, torch.Tensor]: + cos_freq = freqs.cos().repeat_interleave(2, dim=-1) + sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(cos_freq[:, :, :pad_size]) + cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) + sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) + return cos_freq, sin_freq + + +def precompute_freqs_cis( + indices_grid: torch.Tensor, + dim: int, + out_dtype: torch.dtype, + theta: float = 10000.0, + max_pos: list[int] | None = None, + use_middle_indices_grid: bool = False, + num_attention_heads: int = 32, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + freq_grid_generator: Callable[[float, int, int, torch.device], torch.Tensor] = generate_freq_grid_pytorch, +) -> tuple[torch.Tensor, torch.Tensor]: + if max_pos is None: + max_pos = [20, 2048, 2048] + + indices = freq_grid_generator(theta, indices_grid.shape[1], dim) + freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid) + + if rope_type == LTXRopeType.SPLIT: + expected_freqs = dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads) + else: + # 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only + n_elem = 2 * indices_grid.shape[1] + cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) + return cos_freq.to(out_dtype), sin_freq.to(out_dtype) + + +class Attention(torch.nn.Module): + def __init__( + self, + query_dim: int, + context_dim: int | None = None, + heads: int = 8, + dim_head: int = 64, + norm_eps: float = 1e-6, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + ) -> None: + super().__init__() + self.rope_type = rope_type + + inner_dim = dim_head * heads + context_dim = query_dim if context_dim is None else context_dim + + self.heads = heads + self.dim_head = dim_head + + self.q_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps) + self.k_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps) + + self.to_q = torch.nn.Linear(query_dim, inner_dim, bias=True) + self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True) + self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True) + + self.to_out = torch.nn.Sequential(torch.nn.Linear(inner_dim, query_dim, bias=True), torch.nn.Identity()) + + def forward( + self, + x: torch.Tensor, + context: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + pe: torch.Tensor | None = None, + k_pe: torch.Tensor | None = None, + ) -> torch.Tensor: + q = self.to_q(x) + context = x if context is None else context + k = self.to_k(context) + v = self.to_v(context) + + q = self.q_norm(q) + k = self.k_norm(k) + + if pe is not None: + q = apply_rotary_emb(q, pe, self.rope_type) + k = apply_rotary_emb(k, pe if k_pe is None else k_pe, self.rope_type) + + # Reshape for attention_forward using unflatten + q = q.unflatten(-1, (self.heads, self.dim_head)) + k = k.unflatten(-1, (self.heads, self.dim_head)) + v = v.unflatten(-1, (self.heads, self.dim_head)) + + out = attention_forward( + q=q, + k=k, + v=v, + q_pattern="b s n d", + k_pattern="b s n d", + v_pattern="b s n d", + out_pattern="b s n d", + attn_mask=mask + ) + + # Reshape back to original format + out = out.flatten(2, 3) + return self.to_out(out) + + +class PixArtAlphaTextProjection(torch.nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_features: int, hidden_size: int, out_features: int | None = None, act_fn: str = "gelu_tanh"): + super().__init__() + if out_features is None: + out_features = hidden_size + self.linear_1 = torch.nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) + if act_fn == "gelu_tanh": + self.act_1 = torch.nn.GELU(approximate="tanh") + elif act_fn == "silu": + self.act_1 = torch.nn.SiLU() + else: + raise ValueError(f"Unknown activation function: {act_fn}") + self.linear_2 = torch.nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) + + def forward(self, caption: torch.Tensor) -> torch.Tensor: + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +@dataclass(frozen=True) +class TransformerArgs: + x: torch.Tensor + context: torch.Tensor + context_mask: torch.Tensor + timesteps: torch.Tensor + embedded_timestep: torch.Tensor + positional_embeddings: torch.Tensor + cross_positional_embeddings: torch.Tensor | None + cross_scale_shift_timestep: torch.Tensor | None + cross_gate_timestep: torch.Tensor | None + enabled: bool + + + +class TransformerArgsPreprocessor: + def __init__( # noqa: PLR0913 + self, + patchify_proj: torch.nn.Linear, + adaln: AdaLayerNormSingle, + caption_projection: PixArtAlphaTextProjection, + inner_dim: int, + max_pos: list[int], + num_attention_heads: int, + use_middle_indices_grid: bool, + timestep_scale_multiplier: int, + double_precision_rope: bool, + positional_embedding_theta: float, + rope_type: LTXRopeType, + ) -> None: + self.patchify_proj = patchify_proj + self.adaln = adaln + self.caption_projection = caption_projection + self.inner_dim = inner_dim + self.max_pos = max_pos + self.num_attention_heads = num_attention_heads + self.use_middle_indices_grid = use_middle_indices_grid + self.timestep_scale_multiplier = timestep_scale_multiplier + self.double_precision_rope = double_precision_rope + self.positional_embedding_theta = positional_embedding_theta + self.rope_type = rope_type + + def _prepare_timestep( + self, timestep: torch.Tensor, batch_size: int, hidden_dtype: torch.dtype + ) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare timestep embeddings.""" + + timestep = timestep * self.timestep_scale_multiplier + timestep, embedded_timestep = self.adaln( + timestep.flatten(), + hidden_dtype=hidden_dtype, + ) + + # Second dimension is 1 or number of tokens (if timestep_per_token) + timestep = timestep.view(batch_size, -1, timestep.shape[-1]) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) + return timestep, embedded_timestep + + def _prepare_context( + self, + context: torch.Tensor, + x: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Prepare context for transformer blocks.""" + batch_size = x.shape[0] + context = self.caption_projection(context) + context = context.view(batch_size, -1, x.shape[-1]) + + return context, attention_mask + + def _prepare_attention_mask(self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype) -> torch.Tensor | None: + """Prepare attention mask.""" + if attention_mask is None or torch.is_floating_point(attention_mask): + return attention_mask + + return (attention_mask - 1).to(x_dtype).reshape( + (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + ) * torch.finfo(x_dtype).max + + def _prepare_positional_embeddings( + self, + positions: torch.Tensor, + inner_dim: int, + max_pos: list[int], + use_middle_indices_grid: bool, + num_attention_heads: int, + x_dtype: torch.dtype, + ) -> torch.Tensor: + """Prepare positional embeddings.""" + freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch + pe = precompute_freqs_cis( + positions, + dim=inner_dim, + out_dtype=x_dtype, + theta=self.positional_embedding_theta, + max_pos=max_pos, + use_middle_indices_grid=use_middle_indices_grid, + num_attention_heads=num_attention_heads, + rope_type=self.rope_type, + freq_grid_generator=freq_grid_generator, + ) + return pe + + def prepare( + self, + modality: Modality, + ) -> TransformerArgs: + x = self.patchify_proj(modality.latent) + timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], modality.latent.dtype) + context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask) + attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype) + pe = self._prepare_positional_embeddings( + positions=modality.positions, + inner_dim=self.inner_dim, + max_pos=self.max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + num_attention_heads=self.num_attention_heads, + x_dtype=modality.latent.dtype, + ) + return TransformerArgs( + x=x, + context=context, + context_mask=attention_mask, + timesteps=timestep, + embedded_timestep=embedded_timestep, + positional_embeddings=pe, + cross_positional_embeddings=None, + cross_scale_shift_timestep=None, + cross_gate_timestep=None, + enabled=modality.enabled, + ) + + +class MultiModalTransformerArgsPreprocessor: + def __init__( # noqa: PLR0913 + self, + patchify_proj: torch.nn.Linear, + adaln: AdaLayerNormSingle, + caption_projection: PixArtAlphaTextProjection, + cross_scale_shift_adaln: AdaLayerNormSingle, + cross_gate_adaln: AdaLayerNormSingle, + inner_dim: int, + max_pos: list[int], + num_attention_heads: int, + cross_pe_max_pos: int, + use_middle_indices_grid: bool, + audio_cross_attention_dim: int, + timestep_scale_multiplier: int, + double_precision_rope: bool, + positional_embedding_theta: float, + rope_type: LTXRopeType, + av_ca_timestep_scale_multiplier: int, + ) -> None: + self.simple_preprocessor = TransformerArgsPreprocessor( + patchify_proj=patchify_proj, + adaln=adaln, + caption_projection=caption_projection, + inner_dim=inner_dim, + max_pos=max_pos, + num_attention_heads=num_attention_heads, + use_middle_indices_grid=use_middle_indices_grid, + timestep_scale_multiplier=timestep_scale_multiplier, + double_precision_rope=double_precision_rope, + positional_embedding_theta=positional_embedding_theta, + rope_type=rope_type, + ) + self.cross_scale_shift_adaln = cross_scale_shift_adaln + self.cross_gate_adaln = cross_gate_adaln + self.cross_pe_max_pos = cross_pe_max_pos + self.audio_cross_attention_dim = audio_cross_attention_dim + self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier + + def prepare( + self, + modality: Modality, + ) -> TransformerArgs: + transformer_args = self.simple_preprocessor.prepare(modality) + cross_pe = self.simple_preprocessor._prepare_positional_embeddings( + positions=modality.positions[:, 0:1, :], + inner_dim=self.audio_cross_attention_dim, + max_pos=[self.cross_pe_max_pos], + use_middle_indices_grid=True, + num_attention_heads=self.simple_preprocessor.num_attention_heads, + x_dtype=modality.latent.dtype, + ) + + cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep( + timestep=modality.timesteps, + timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier, + batch_size=transformer_args.x.shape[0], + hidden_dtype=modality.latent.dtype, + ) + + return replace( + transformer_args, + cross_positional_embeddings=cross_pe, + cross_scale_shift_timestep=cross_scale_shift_timestep, + cross_gate_timestep=cross_gate_timestep, + ) + + def _prepare_cross_attention_timestep( + self, + timestep: torch.Tensor, + timestep_scale_multiplier: int, + batch_size: int, + hidden_dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare cross attention timestep embeddings.""" + timestep = timestep * timestep_scale_multiplier + + av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier + + scale_shift_timestep, _ = self.cross_scale_shift_adaln( + timestep.flatten(), + hidden_dtype=hidden_dtype, + ) + scale_shift_timestep = scale_shift_timestep.view(batch_size, -1, scale_shift_timestep.shape[-1]) + gate_noise_timestep, _ = self.cross_gate_adaln( + timestep.flatten() * av_ca_factor, + hidden_dtype=hidden_dtype, + ) + gate_noise_timestep = gate_noise_timestep.view(batch_size, -1, gate_noise_timestep.shape[-1]) + + return scale_shift_timestep, gate_noise_timestep + + +@dataclass +class TransformerConfig: + dim: int + heads: int + d_head: int + context_dim: int + + +class BasicAVTransformerBlock(torch.nn.Module): + def __init__( + self, + idx: int, + video: TransformerConfig | None = None, + audio: TransformerConfig | None = None, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + norm_eps: float = 1e-6, + ): + super().__init__() + + self.idx = idx + if video is not None: + self.attn1 = Attention( + query_dim=video.dim, + heads=video.heads, + dim_head=video.d_head, + context_dim=None, + rope_type=rope_type, + norm_eps=norm_eps, + ) + self.attn2 = Attention( + query_dim=video.dim, + context_dim=video.context_dim, + heads=video.heads, + dim_head=video.d_head, + rope_type=rope_type, + norm_eps=norm_eps, + ) + self.ff = FeedForward(video.dim, dim_out=video.dim) + self.scale_shift_table = torch.nn.Parameter(torch.empty(6, video.dim)) + + if audio is not None: + self.audio_attn1 = Attention( + query_dim=audio.dim, + heads=audio.heads, + dim_head=audio.d_head, + context_dim=None, + rope_type=rope_type, + norm_eps=norm_eps, + ) + self.audio_attn2 = Attention( + query_dim=audio.dim, + context_dim=audio.context_dim, + heads=audio.heads, + dim_head=audio.d_head, + rope_type=rope_type, + norm_eps=norm_eps, + ) + self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim) + self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(6, audio.dim)) + + if audio is not None and video is not None: + # Q: Video, K,V: Audio + self.audio_to_video_attn = Attention( + query_dim=video.dim, + context_dim=audio.dim, + heads=audio.heads, + dim_head=audio.d_head, + rope_type=rope_type, + norm_eps=norm_eps, + ) + + # Q: Audio, K,V: Video + self.video_to_audio_attn = Attention( + query_dim=audio.dim, + context_dim=video.dim, + heads=audio.heads, + dim_head=audio.d_head, + rope_type=rope_type, + norm_eps=norm_eps, + ) + + self.scale_shift_table_a2v_ca_audio = torch.nn.Parameter(torch.empty(5, audio.dim)) + self.scale_shift_table_a2v_ca_video = torch.nn.Parameter(torch.empty(5, video.dim)) + + self.norm_eps = norm_eps + + def get_ada_values( + self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice + ) -> tuple[torch.Tensor, ...]: + num_ada_params = scale_shift_table.shape[0] + + ada_values = ( + scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype) + + timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :] + ).unbind(dim=2) + return ada_values + + def get_av_ca_ada_values( + self, + scale_shift_table: torch.Tensor, + batch_size: int, + scale_shift_timestep: torch.Tensor, + gate_timestep: torch.Tensor, + num_scale_shift_values: int = 4, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + scale_shift_ada_values = self.get_ada_values( + scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, slice(None, None) + ) + gate_ada_values = self.get_ada_values( + scale_shift_table[num_scale_shift_values:, :], batch_size, gate_timestep, slice(None, None) + ) + + scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values] + gate_ada_values = [t.squeeze(2) for t in gate_ada_values] + + return (*scale_shift_chunks, *gate_ada_values) + + def forward( # noqa: PLR0915 + self, + video: TransformerArgs | None, + audio: TransformerArgs | None, + perturbations: BatchedPerturbationConfig | None = None, + ) -> tuple[TransformerArgs | None, TransformerArgs | None]: + batch_size = video.x.shape[0] + if perturbations is None: + perturbations = BatchedPerturbationConfig.empty(batch_size) + + vx = video.x if video is not None else None + ax = audio.x if audio is not None else None + + run_vx = video is not None and video.enabled and vx.numel() > 0 + run_ax = audio is not None and audio.enabled and ax.numel() > 0 + + run_a2v = run_vx and (audio is not None and ax.numel() > 0) + run_v2a = run_ax and (video is not None and vx.numel() > 0) + + if run_vx: + vshift_msa, vscale_msa, vgate_msa = self.get_ada_values( + self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3) + ) + if not perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx): + norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa + v_mask = perturbations.mask_like(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx, vx) + vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa * v_mask + + vx = vx + self.attn2(rms_norm(vx, eps=self.norm_eps), context=video.context, mask=video.context_mask) + + del vshift_msa, vscale_msa, vgate_msa + + if run_ax: + ashift_msa, ascale_msa, agate_msa = self.get_ada_values( + self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3) + ) + + if not perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx): + norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa + a_mask = perturbations.mask_like(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx, ax) + ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa * a_mask + + ax = ax + self.audio_attn2(rms_norm(ax, eps=self.norm_eps), context=audio.context, mask=audio.context_mask) + + del ashift_msa, ascale_msa, agate_msa + + # Audio - Video cross attention. + if run_a2v or run_v2a: + vx_norm3 = rms_norm(vx, eps=self.norm_eps) + ax_norm3 = rms_norm(ax, eps=self.norm_eps) + + ( + scale_ca_audio_hidden_states_a2v, + shift_ca_audio_hidden_states_a2v, + scale_ca_audio_hidden_states_v2a, + shift_ca_audio_hidden_states_v2a, + gate_out_v2a, + ) = self.get_av_ca_ada_values( + self.scale_shift_table_a2v_ca_audio, + ax.shape[0], + audio.cross_scale_shift_timestep, + audio.cross_gate_timestep, + ) + + ( + scale_ca_video_hidden_states_a2v, + shift_ca_video_hidden_states_a2v, + scale_ca_video_hidden_states_v2a, + shift_ca_video_hidden_states_v2a, + gate_out_a2v, + ) = self.get_av_ca_ada_values( + self.scale_shift_table_a2v_ca_video, + vx.shape[0], + video.cross_scale_shift_timestep, + video.cross_gate_timestep, + ) + + if run_a2v: + vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v) + shift_ca_video_hidden_states_a2v + ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v + a2v_mask = perturbations.mask_like(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx, vx) + vx = vx + ( + self.audio_to_video_attn( + vx_scaled, + context=ax_scaled, + pe=video.cross_positional_embeddings, + k_pe=audio.cross_positional_embeddings, + ) + * gate_out_a2v + * a2v_mask + ) + + if run_v2a: + ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a + vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a + v2a_mask = perturbations.mask_like(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx, ax) + ax = ax + ( + self.video_to_audio_attn( + ax_scaled, + context=vx_scaled, + pe=audio.cross_positional_embeddings, + k_pe=video.cross_positional_embeddings, + ) + * gate_out_v2a + * v2a_mask + ) + + del gate_out_a2v, gate_out_v2a + del ( + scale_ca_video_hidden_states_a2v, + shift_ca_video_hidden_states_a2v, + scale_ca_audio_hidden_states_a2v, + shift_ca_audio_hidden_states_a2v, + scale_ca_video_hidden_states_v2a, + shift_ca_video_hidden_states_v2a, + scale_ca_audio_hidden_states_v2a, + shift_ca_audio_hidden_states_v2a, + ) + + if run_vx: + vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values( + self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, None) + ) + vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp + vx = vx + self.ff(vx_scaled) * vgate_mlp + + del vshift_mlp, vscale_mlp, vgate_mlp + + if run_ax: + ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values( + self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, None) + ) + ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp + ax = ax + self.audio_ff(ax_scaled) * agate_mlp + + del ashift_mlp, ascale_mlp, agate_mlp + + return replace(video, x=vx) if video is not None else None, replace(audio, x=ax) if audio is not None else None + + +class GELUApprox(torch.nn.Module): + def __init__(self, dim_in: int, dim_out: int) -> None: + super().__init__() + self.proj = torch.nn.Linear(dim_in, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(self.proj(x), approximate="tanh") + + +class FeedForward(torch.nn.Module): + def __init__(self, dim: int, dim_out: int, mult: int = 4) -> None: + super().__init__() + inner_dim = int(dim * mult) + project_in = GELUApprox(dim, inner_dim) + + self.net = torch.nn.Sequential(project_in, torch.nn.Identity(), torch.nn.Linear(inner_dim, dim_out)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class LTXModelType(Enum): + AudioVideo = "ltx av model" + VideoOnly = "ltx video only model" + AudioOnly = "ltx audio only model" + + def is_video_enabled(self) -> bool: + return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly) + + def is_audio_enabled(self) -> bool: + return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly) + + +class LTXModel(torch.nn.Module): + """ + LTX model transformer implementation. + This class implements the transformer blocks for the LTX model. + """ + + def __init__( # noqa: PLR0913 + self, + *, + model_type: LTXModelType = LTXModelType.AudioVideo, + num_attention_heads: int = 32, + attention_head_dim: int = 128, + in_channels: int = 128, + out_channels: int = 128, + num_layers: int = 48, + cross_attention_dim: int = 4096, + norm_eps: float = 1e-06, + caption_channels: int = 3840, + positional_embedding_theta: float = 10000.0, + positional_embedding_max_pos: list[int] | None = [20, 2048, 2048], + timestep_scale_multiplier: int = 1000, + use_middle_indices_grid: bool = True, + audio_num_attention_heads: int = 32, + audio_attention_head_dim: int = 64, + audio_in_channels: int = 128, + audio_out_channels: int = 128, + audio_cross_attention_dim: int = 2048, + audio_positional_embedding_max_pos: list[int] | None = [20], + av_ca_timestep_scale_multiplier: int = 1000, + rope_type: LTXRopeType = LTXRopeType.SPLIT, + double_precision_rope: bool = True, + ): + super().__init__() + self._enable_gradient_checkpointing = False + self.use_middle_indices_grid = use_middle_indices_grid + self.rope_type = rope_type + self.double_precision_rope = double_precision_rope + self.timestep_scale_multiplier = timestep_scale_multiplier + self.positional_embedding_theta = positional_embedding_theta + self.model_type = model_type + cross_pe_max_pos = None + if model_type.is_video_enabled(): + if positional_embedding_max_pos is None: + positional_embedding_max_pos = [20, 2048, 2048] + self.positional_embedding_max_pos = positional_embedding_max_pos + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self._init_video( + in_channels=in_channels, + out_channels=out_channels, + caption_channels=caption_channels, + norm_eps=norm_eps, + ) + + if model_type.is_audio_enabled(): + if audio_positional_embedding_max_pos is None: + audio_positional_embedding_max_pos = [20] + self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos + self.audio_num_attention_heads = audio_num_attention_heads + self.audio_inner_dim = self.audio_num_attention_heads * audio_attention_head_dim + self._init_audio( + in_channels=audio_in_channels, + out_channels=audio_out_channels, + caption_channels=caption_channels, + norm_eps=norm_eps, + ) + + if model_type.is_video_enabled() and model_type.is_audio_enabled(): + cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]) + self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier + self.audio_cross_attention_dim = audio_cross_attention_dim + self._init_audio_video(num_scale_shift_values=4) + + self._init_preprocessors(cross_pe_max_pos) + # Initialize transformer blocks + self._init_transformer_blocks( + num_layers=num_layers, + attention_head_dim=attention_head_dim if model_type.is_video_enabled() else 0, + cross_attention_dim=cross_attention_dim, + audio_attention_head_dim=audio_attention_head_dim if model_type.is_audio_enabled() else 0, + audio_cross_attention_dim=audio_cross_attention_dim, + norm_eps=norm_eps, + ) + + def _init_video( + self, + in_channels: int, + out_channels: int, + caption_channels: int, + norm_eps: float, + ) -> None: + """Initialize video-specific components.""" + # Video input components + self.patchify_proj = torch.nn.Linear(in_channels, self.inner_dim, bias=True) + + self.adaln_single = AdaLayerNormSingle(self.inner_dim) + + # Video caption projection + self.caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, + hidden_size=self.inner_dim, + ) + + # Video output components + self.scale_shift_table = torch.nn.Parameter(torch.empty(2, self.inner_dim)) + self.norm_out = torch.nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=norm_eps) + self.proj_out = torch.nn.Linear(self.inner_dim, out_channels) + + def _init_audio( + self, + in_channels: int, + out_channels: int, + caption_channels: int, + norm_eps: float, + ) -> None: + """Initialize audio-specific components.""" + + # Audio input components + self.audio_patchify_proj = torch.nn.Linear(in_channels, self.audio_inner_dim, bias=True) + + self.audio_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + ) + + # Audio caption projection + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, + hidden_size=self.audio_inner_dim, + ) + + # Audio output components + self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(2, self.audio_inner_dim)) + self.audio_norm_out = torch.nn.LayerNorm(self.audio_inner_dim, elementwise_affine=False, eps=norm_eps) + self.audio_proj_out = torch.nn.Linear(self.audio_inner_dim, out_channels) + + def _init_audio_video( + self, + num_scale_shift_values: int, + ) -> None: + """Initialize audio-video cross-attention components.""" + self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=num_scale_shift_values, + ) + + self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + embedding_coefficient=num_scale_shift_values, + ) + + self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=1, + ) + + self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + embedding_coefficient=1, + ) + + def _init_preprocessors( + self, + cross_pe_max_pos: int | None = None, + ) -> None: + """Initialize preprocessors for LTX.""" + + if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled(): + self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor( + patchify_proj=self.patchify_proj, + adaln=self.adaln_single, + caption_projection=self.caption_projection, + cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single, + cross_gate_adaln=self.av_ca_a2v_gate_adaln_single, + inner_dim=self.inner_dim, + max_pos=self.positional_embedding_max_pos, + num_attention_heads=self.num_attention_heads, + cross_pe_max_pos=cross_pe_max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + audio_cross_attention_dim=self.audio_cross_attention_dim, + timestep_scale_multiplier=self.timestep_scale_multiplier, + double_precision_rope=self.double_precision_rope, + positional_embedding_theta=self.positional_embedding_theta, + rope_type=self.rope_type, + av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier, + ) + self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor( + patchify_proj=self.audio_patchify_proj, + adaln=self.audio_adaln_single, + caption_projection=self.audio_caption_projection, + cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single, + cross_gate_adaln=self.av_ca_v2a_gate_adaln_single, + inner_dim=self.audio_inner_dim, + max_pos=self.audio_positional_embedding_max_pos, + num_attention_heads=self.audio_num_attention_heads, + cross_pe_max_pos=cross_pe_max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + audio_cross_attention_dim=self.audio_cross_attention_dim, + timestep_scale_multiplier=self.timestep_scale_multiplier, + double_precision_rope=self.double_precision_rope, + positional_embedding_theta=self.positional_embedding_theta, + rope_type=self.rope_type, + av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier, + ) + elif self.model_type.is_video_enabled(): + self.video_args_preprocessor = TransformerArgsPreprocessor( + patchify_proj=self.patchify_proj, + adaln=self.adaln_single, + caption_projection=self.caption_projection, + inner_dim=self.inner_dim, + max_pos=self.positional_embedding_max_pos, + num_attention_heads=self.num_attention_heads, + use_middle_indices_grid=self.use_middle_indices_grid, + timestep_scale_multiplier=self.timestep_scale_multiplier, + double_precision_rope=self.double_precision_rope, + positional_embedding_theta=self.positional_embedding_theta, + rope_type=self.rope_type, + ) + elif self.model_type.is_audio_enabled(): + self.audio_args_preprocessor = TransformerArgsPreprocessor( + patchify_proj=self.audio_patchify_proj, + adaln=self.audio_adaln_single, + caption_projection=self.audio_caption_projection, + inner_dim=self.audio_inner_dim, + max_pos=self.audio_positional_embedding_max_pos, + num_attention_heads=self.audio_num_attention_heads, + use_middle_indices_grid=self.use_middle_indices_grid, + timestep_scale_multiplier=self.timestep_scale_multiplier, + double_precision_rope=self.double_precision_rope, + positional_embedding_theta=self.positional_embedding_theta, + rope_type=self.rope_type, + ) + + def _init_transformer_blocks( + self, + num_layers: int, + attention_head_dim: int, + cross_attention_dim: int, + audio_attention_head_dim: int, + audio_cross_attention_dim: int, + norm_eps: float, + ) -> None: + """Initialize transformer blocks for LTX.""" + video_config = ( + TransformerConfig( + dim=self.inner_dim, + heads=self.num_attention_heads, + d_head=attention_head_dim, + context_dim=cross_attention_dim, + ) + if self.model_type.is_video_enabled() + else None + ) + audio_config = ( + TransformerConfig( + dim=self.audio_inner_dim, + heads=self.audio_num_attention_heads, + d_head=audio_attention_head_dim, + context_dim=audio_cross_attention_dim, + ) + if self.model_type.is_audio_enabled() + else None + ) + self.transformer_blocks = torch.nn.ModuleList( + [ + BasicAVTransformerBlock( + idx=idx, + video=video_config, + audio=audio_config, + rope_type=self.rope_type, + norm_eps=norm_eps, + ) + for idx in range(num_layers) + ] + ) + + def set_gradient_checkpointing(self, enable: bool) -> None: + """Enable or disable gradient checkpointing for transformer blocks. + Gradient checkpointing trades compute for memory by recomputing activations + during the backward pass instead of storing them. This can significantly + reduce memory usage at the cost of ~20-30% slower training. + Args: + enable: Whether to enable gradient checkpointing + """ + self._enable_gradient_checkpointing = enable + + def _process_transformer_blocks( + self, + video: TransformerArgs | None, + audio: TransformerArgs | None, + perturbations: BatchedPerturbationConfig, + ) -> tuple[TransformerArgs, TransformerArgs]: + """Process transformer blocks for LTXAV.""" + + # Process transformer blocks + for block in self.transformer_blocks: + if self._enable_gradient_checkpointing and self.training: + # Use gradient checkpointing to save memory during training. + # With use_reentrant=False, we can pass dataclasses directly - + # PyTorch will track all tensor leaves in the computation graph. + video, audio = torch.utils.checkpoint.checkpoint( + block, + video, + audio, + perturbations, + use_reentrant=False, + ) + else: + video, audio = block( + video=video, + audio=audio, + perturbations=perturbations, + ) + + return video, audio + + def _process_output( + self, + scale_shift_table: torch.Tensor, + norm_out: torch.nn.LayerNorm, + proj_out: torch.nn.Linear, + x: torch.Tensor, + embedded_timestep: torch.Tensor, + ) -> torch.Tensor: + """Process output for LTXV.""" + # Apply scale-shift modulation + scale_shift_values = ( + scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + + x = norm_out(x) + x = x * (1 + scale) + shift + x = proj_out(x) + return x + + def _forward( + self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for LTX models. + Returns: + Processed output tensors + """ + if not self.model_type.is_video_enabled() and video is not None: + raise ValueError("Video is not enabled for this model") + if not self.model_type.is_audio_enabled() and audio is not None: + raise ValueError("Audio is not enabled for this model") + + video_args = self.video_args_preprocessor.prepare(video) if video is not None else None + audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None + # Process transformer blocks + video_out, audio_out = self._process_transformer_blocks( + video=video_args, + audio=audio_args, + perturbations=perturbations, + ) + + # Process output + vx = ( + self._process_output( + self.scale_shift_table, self.norm_out, self.proj_out, video_out.x, video_out.embedded_timestep + ) + if video_out is not None + else None + ) + ax = ( + self._process_output( + self.audio_scale_shift_table, + self.audio_norm_out, + self.audio_proj_out, + audio_out.x, + audio_out.embedded_timestep, + ) + if audio_out is not None + else None + ) + return vx, ax + + def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps): + cross_pe_max_pos = None + if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled(): + cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]) + self._init_preprocessors(cross_pe_max_pos) + video = Modality(video_latents, video_timesteps, video_positions, video_context) + audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context) + vx, ax = self._forward(video=video, audio=audio, perturbations=None) + return vx, ax diff --git a/diffsynth/models/ltx2_text_encoder.py b/diffsynth/models/ltx2_text_encoder.py new file mode 100644 index 0000000..7fb94a7 --- /dev/null +++ b/diffsynth/models/ltx2_text_encoder.py @@ -0,0 +1,366 @@ +import torch +from transformers import Gemma3ForConditionalGeneration, Gemma3Config, AutoTokenizer +from .ltx2_dit import (LTXRopeType, generate_freq_grid_np, generate_freq_grid_pytorch, precompute_freqs_cis, Attention, + FeedForward) +from .ltx2_common import rms_norm + + +class LTX2TextEncoder(Gemma3ForConditionalGeneration): + def __init__(self): + config = Gemma3Config( + **{ + "architectures": ["Gemma3ForConditionalGeneration"], + "boi_token_index": 255999, + "dtype": "bfloat16", + "eoi_token_index": 256000, + "eos_token_id": [1, 106], + "image_token_index": 262144, + "initializer_range": 0.02, + "mm_tokens_per_image": 256, + "model_type": "gemma3", + "text_config": { + "_sliding_window_pattern": 6, + "attention_bias": False, + "attention_dropout": 0.0, + "attn_logit_softcapping": None, + "cache_implementation": "hybrid", + "dtype": "bfloat16", + "final_logit_softcapping": None, + "head_dim": 256, + "hidden_activation": "gelu_pytorch_tanh", + "hidden_size": 3840, + "initializer_range": 0.02, + "intermediate_size": 15360, + "layer_types": [ + "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", + "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", + "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", + "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", + "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", + "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", + "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", + "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", + "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", + "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", + "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", + "sliding_attention", "sliding_attention", "sliding_attention", "full_attention" + ], + "max_position_embeddings": 131072, + "model_type": "gemma3_text", + "num_attention_heads": 16, + "num_hidden_layers": 48, + "num_key_value_heads": 8, + "query_pre_attn_scalar": 256, + "rms_norm_eps": 1e-06, + "rope_local_base_freq": 10000, + "rope_scaling": { + "factor": 8.0, + "rope_type": "linear" + }, + "rope_theta": 1000000, + "sliding_window": 1024, + "sliding_window_pattern": 6, + "use_bidirectional_attention": False, + "use_cache": True, + "vocab_size": 262208 + }, + "transformers_version": "4.57.3", + "vision_config": { + "attention_dropout": 0.0, + "dtype": "bfloat16", + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "image_size": 896, + "intermediate_size": 4304, + "layer_norm_eps": 1e-06, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 27, + "patch_size": 14, + "vision_use_head": False + } + }) + super().__init__(config) + + +class LTXVGemmaTokenizer: + """ + Tokenizer wrapper for Gemma models compatible with LTXV processes. + This class wraps HuggingFace's `AutoTokenizer` for use with Gemma text encoders, + ensuring correct settings and output formatting for downstream consumption. + """ + + def __init__(self, tokenizer_path: str, max_length: int = 1024): + """ + Initialize the tokenizer. + Args: + tokenizer_path (str): Path to the pretrained tokenizer files or model directory. + max_length (int, optional): Max sequence length for encoding. Defaults to 256. + """ + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, local_files_only=True, model_max_length=max_length + ) + # Gemma expects left padding for chat-style prompts; for plain text it doesn't matter much. + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.max_length = max_length + + def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> dict[str, list[tuple[int, int]]]: + """ + Tokenize the given text and return token IDs and attention weights. + Args: + text (str): The input string to tokenize. + return_word_ids (bool, optional): If True, includes the token's position (index) in the output tuples. + If False (default), omits the indices. + Returns: + dict[str, list[tuple[int, int]]] OR dict[str, list[tuple[int, int, int]]]: + A dictionary with a "gemma" key mapping to: + - a list of (token_id, attention_mask) tuples if return_word_ids is False; + - a list of (token_id, attention_mask, index) tuples if return_word_ids is True. + Example: + >>> tokenizer = LTXVGemmaTokenizer("path/to/tokenizer", max_length=8) + >>> tokenizer.tokenize_with_weights("hello world") + {'gemma': [(1234, 1), (5678, 1), (2, 0), ...]} + """ + text = text.strip() + encoded = self.tokenizer( + text, + padding="max_length", + max_length=self.max_length, + truncation=True, + return_tensors="pt", + ) + input_ids = encoded.input_ids + attention_mask = encoded.attention_mask + tuples = [ + (token_id, attn, i) for i, (token_id, attn) in enumerate(zip(input_ids[0], attention_mask[0], strict=True)) + ] + out = {"gemma": tuples} + + if not return_word_ids: + # Return only (token_id, attention_mask) pairs, omitting token position + out = {k: [(t, w) for t, w, _ in v] for k, v in out.items()} + + return out + + +class GemmaFeaturesExtractorProjLinear(torch.nn.Module): + """ + Feature extractor module for Gemma models. + This module applies a single linear projection to the input tensor. + It expects a flattened feature tensor of shape (batch_size, 3840*49). + The linear layer maps this to a (batch_size, 3840) embedding. + Attributes: + aggregate_embed (torch.nn.Linear): Linear projection layer. + """ + + def __init__(self) -> None: + """ + Initialize the GemmaFeaturesExtractorProjLinear module. + The input dimension is expected to be 3840 * 49, and the output is 3840. + """ + super().__init__() + self.aggregate_embed = torch.nn.Linear(3840 * 49, 3840, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the feature extractor. + Args: + x (torch.Tensor): Input tensor of shape (batch_size, 3840 * 49). + Returns: + torch.Tensor: Output tensor of shape (batch_size, 3840). + """ + return self.aggregate_embed(x) + + +class _BasicTransformerBlock1D(torch.nn.Module): + def __init__( + self, + dim: int, + heads: int, + dim_head: int, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + ): + super().__init__() + + self.attn1 = Attention( + query_dim=dim, + heads=heads, + dim_head=dim_head, + rope_type=rope_type, + ) + + self.ff = FeedForward( + dim, + dim_out=dim, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + pe: torch.Tensor | None = None, + ) -> torch.Tensor: + # Notice that normalization is always applied before the real computation in the following blocks. + + # 1. Normalization Before Self-Attention + norm_hidden_states = rms_norm(hidden_states) + + norm_hidden_states = norm_hidden_states.squeeze(1) + + # 2. Self-Attention + attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe) + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 3. Normalization before Feed-Forward + norm_hidden_states = rms_norm(hidden_states) + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class Embeddings1DConnector(torch.nn.Module): + """ + Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or + other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can + substitute padded positions with learnable registers. The module is highly configurable for head size, number of + layers, and register usage. + Args: + attention_head_dim (int): Dimension of each attention head (default=128). + num_attention_heads (int): Number of attention heads (default=30). + num_layers (int): Number of transformer layers (default=2). + positional_embedding_theta (float): Scaling factor for position embedding (default=10000.0). + positional_embedding_max_pos (list[int] | None): Max positions for positional embeddings (default=[1]). + causal_temporal_positioning (bool): If True, uses causal attention (default=False). + num_learnable_registers (int | None): Number of learnable registers to replace padded tokens. If None, disables + register replacement. (default=128) + rope_type (LTXRopeType): The RoPE variant to use (default=DEFAULT_ROPE_TYPE). + double_precision_rope (bool): Use double precision rope calculation (default=False). + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + attention_head_dim: int = 128, + num_attention_heads: int = 30, + num_layers: int = 2, + positional_embedding_theta: float = 10000.0, + positional_embedding_max_pos: list[int] | None = [4096], + causal_temporal_positioning: bool = False, + num_learnable_registers: int | None = 128, + rope_type: LTXRopeType = LTXRopeType.SPLIT, + double_precision_rope: bool = True, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self.causal_temporal_positioning = causal_temporal_positioning + self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = ( + positional_embedding_max_pos if positional_embedding_max_pos is not None else [1] + ) + self.rope_type = rope_type + self.double_precision_rope = double_precision_rope + self.transformer_1d_blocks = torch.nn.ModuleList( + [ + _BasicTransformerBlock1D( + dim=self.inner_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + rope_type=rope_type, + ) + for _ in range(num_layers) + ] + ) + + self.num_learnable_registers = num_learnable_registers + if self.num_learnable_registers: + self.learnable_registers = torch.nn.Parameter( + torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0 + ) + + def _replace_padded_with_learnable_registers( + self, hidden_states: torch.Tensor, attention_mask: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.shape[1] % self.num_learnable_registers == 0, ( + f"Hidden states sequence length {hidden_states.shape[1]} must be divisible by num_learnable_registers " + f"{self.num_learnable_registers}." + ) + + num_registers_duplications = hidden_states.shape[1] // self.num_learnable_registers + learnable_registers = torch.tile(self.learnable_registers, (num_registers_duplications, 1)) + attention_mask_binary = (attention_mask.squeeze(1).squeeze(1).unsqueeze(-1) >= -9000.0).int() + + non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :] + non_zero_nums = non_zero_hidden_states.shape[1] + pad_length = hidden_states.shape[1] - non_zero_nums + adjusted_hidden_states = torch.nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0) + flipped_mask = torch.flip(attention_mask_binary, dims=[1]) + hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers + + attention_mask = torch.full_like( + attention_mask, + 0.0, + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + return hidden_states, attention_mask + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass of Embeddings1DConnector. + Args: + hidden_states (torch.Tensor): Input tensor of embeddings (shape [batch, seq_len, feature_dim]). + attention_mask (torch.Tensor|None): Optional mask for valid tokens (shape compatible with hidden_states). + Returns: + tuple[torch.Tensor, torch.Tensor]: Processed features and the corresponding (possibly modified) mask. + """ + if self.num_learnable_registers: + hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask) + + indices_grid = torch.arange(hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device) + indices_grid = indices_grid[None, None, :] + freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch + freqs_cis = precompute_freqs_cis( + indices_grid=indices_grid, + dim=self.inner_dim, + out_dtype=hidden_states.dtype, + theta=self.positional_embedding_theta, + max_pos=self.positional_embedding_max_pos, + num_attention_heads=self.num_attention_heads, + rope_type=self.rope_type, + freq_grid_generator=freq_grid_generator, + ) + + for block in self.transformer_1d_blocks: + hidden_states = block(hidden_states, attention_mask=attention_mask, pe=freqs_cis) + + hidden_states = rms_norm(hidden_states) + + return hidden_states, attention_mask + + +class LTX2TextEncoderPostModules(torch.nn.Module): + def __init__(self,): + super().__init__() + self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear() + self.embeddings_connector = Embeddings1DConnector() + self.audio_embeddings_connector = Embeddings1DConnector() diff --git a/diffsynth/models/ltx2_upsampler.py b/diffsynth/models/ltx2_upsampler.py new file mode 100644 index 0000000..862ca14 --- /dev/null +++ b/diffsynth/models/ltx2_upsampler.py @@ -0,0 +1,313 @@ +import math +from typing import Optional, Tuple +import torch +from einops import rearrange +import torch.nn.functional as F +from .ltx2_video_vae import LTX2VideoEncoder + +class PixelShuffleND(torch.nn.Module): + """ + N-dimensional pixel shuffle operation for upsampling tensors. + Args: + dims (int): Number of dimensions to apply pixel shuffle to. + - 1: Temporal (e.g., frames) + - 2: Spatial (e.g., height and width) + - 3: Spatiotemporal (e.g., depth, height, width) + upscale_factors (tuple[int, int, int], optional): Upscaling factors for each dimension. + For dims=1, only the first value is used. + For dims=2, the first two values are used. + For dims=3, all three values are used. + The input tensor is rearranged so that the channel dimension is split into + smaller channels and upscaling factors, and the upscaling factors are moved + into the corresponding spatial/temporal dimensions. + Note: + This operation is equivalent to the patchifier operation in for the models. Consider + using this class instead. + """ + + def __init__(self, dims: int, upscale_factors: tuple[int, int, int] = (2, 2, 2)): + super().__init__() + assert dims in [1, 2, 3], "dims must be 1, 2, or 3" + self.dims = dims + self.upscale_factors = upscale_factors + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.dims == 3: + return rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + p3=self.upscale_factors[2], + ) + elif self.dims == 2: + return rearrange( + x, + "b (c p1 p2) h w -> b c (h p1) (w p2)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + ) + elif self.dims == 1: + return rearrange( + x, + "b (c p1) f h w -> b c (f p1) h w", + p1=self.upscale_factors[0], + ) + else: + raise ValueError(f"Unsupported dims: {self.dims}") + + +class ResBlock(torch.nn.Module): + """ + Residual block with two convolutional layers, group normalization, and SiLU activation. + Args: + channels (int): Number of input and output channels. + mid_channels (Optional[int]): Number of channels in the intermediate convolution layer. Defaults to `channels` + if not specified. + dims (int): Dimensionality of the convolution (2 for Conv2d, 3 for Conv3d). Defaults to 3. + """ + + def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3): + super().__init__() + if mid_channels is None: + mid_channels = channels + + conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.conv1 = conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = torch.nn.GroupNorm(32, mid_channels) + self.conv2 = conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = torch.nn.GroupNorm(32, channels) + self.activation = torch.nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = self.activation(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.activation(x + residual) + return x + + +class BlurDownsample(torch.nn.Module): + """ + Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. + Applies only on H,W. Works for dims=2 or dims=3 (per-frame). + """ + + def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None: + super().__init__() + assert dims in (2, 3) + assert isinstance(stride, int) + assert stride >= 1 + assert kernel_size >= 3 + assert kernel_size % 2 == 1 + self.dims = dims + self.stride = stride + self.kernel_size = kernel_size + + # 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from + # the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and + # provides a smooth approximation of a Gaussian filter (often called a "binomial filter"). + # The 2D kernel is constructed as the outer product and normalized. + k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)]) + k2d = k[:, None] @ k[None, :] + k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size) + self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.stride == 1: + return x + + if self.dims == 2: + return self._apply_2d(x) + else: + # dims == 3: apply per-frame on H,W + b, _, f, _, _ = x.shape + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self._apply_2d(x) + h2, w2 = x.shape[-2:] + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2) + return x + + def _apply_2d(self, x2d: torch.Tensor) -> torch.Tensor: + c = x2d.shape[1] + weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise + x2d = F.conv2d(x2d, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c) + return x2d + + +def _rational_for_scale(scale: float) -> Tuple[int, int]: + mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)} + if float(scale) not in mapping: + raise ValueError(f"Unsupported scale {scale}. Choose from {list(mapping.keys())}") + return mapping[float(scale)] + + +class SpatialRationalResampler(torch.nn.Module): + """ + Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased + downsample by 'den' using fixed blur + stride. Operates on H,W only. + For dims==3, work per-frame for spatial scaling (temporal axis untouched). + Args: + mid_channels (`int`): Number of intermediate channels for the convolution layer + scale (`float`): Spatial scaling factor. Supported values are: + - 0.75: Downsample by 3/4 (reduce spatial size) + - 1.5: Upsample by 3/2 (increase spatial size) + - 2.0: Upsample by 2x (double spatial size) + - 4.0: Upsample by 4x (quadruple spatial size) + Any other value will raise a ValueError. + """ + + def __init__(self, mid_channels: int, scale: float): + super().__init__() + self.scale = float(scale) + self.num, self.den = _rational_for_scale(self.scale) + self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1) + self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num)) + self.blur_down = BlurDownsample(dims=2, stride=self.den) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, _, f, _, _ = x.shape + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.conv(x) + x = self.pixel_shuffle(x) + x = self.blur_down(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + return x + + +class LTX2LatentUpsampler(torch.nn.Module): + """ + Model to upsample VAE latents spatially and/or temporally. + Args: + in_channels (`int`): Number of channels in the input latent + mid_channels (`int`): Number of channels in the middle layers + num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`): Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`): Whether to spatially upsample the latent + temporal_upsample (`bool`): Whether to temporally upsample the latent + spatial_scale (`float`): Scale factor for spatial upsampling + rational_resampler (`bool`): Whether to use a rational resampler for spatial upsampling + """ + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 1024, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + spatial_scale: float = 2.0, + rational_resampler: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + self.spatial_scale = float(spatial_scale) + self.rational_resampler = rational_resampler + + conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.initial_conv = conv(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = torch.nn.GroupNorm(32, mid_channels) + self.initial_activation = torch.nn.SiLU() + + self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]) + + if spatial_upsample and temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + if rational_resampler: + self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=self.spatial_scale) + else: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError("Either spatial_upsample or temporal_upsample must be True") + + self.post_upsample_res_blocks = torch.nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = conv(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, latent: torch.Tensor) -> torch.Tensor: + b, _, f, _, _ = latent.shape + + if self.dims == 2: + x = rearrange(latent, "b c f h w -> (b f) c h w") + x = self.initial_conv(x) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + x = self.upsampler(x) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + else: + x = self.initial_conv(latent) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + if self.temporal_upsample: + x = self.upsampler(x) + # remove the first frame after upsampling. + # This is done because the first frame encodes one pixel frame. + x = x[:, :, 1:, :, :] + elif isinstance(self.upsampler, SpatialRationalResampler): + x = self.upsampler(x) + else: + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.upsampler(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + + return x + + +def upsample_video(latent: torch.Tensor, video_encoder: LTX2VideoEncoder, upsampler: "LTX2LatentUpsampler") -> torch.Tensor: + """ + Apply upsampling to the latent representation using the provided upsampler, + with normalization and un-normalization based on the video encoder's per-channel statistics. + Args: + latent: Input latent tensor of shape [B, C, F, H, W]. + video_encoder: VideoEncoder with per_channel_statistics for normalization. + upsampler: LTX2LatentUpsampler module to perform upsampling. + Returns: + torch.Tensor: Upsampled and re-normalized latent tensor. + """ + latent = video_encoder.per_channel_statistics.un_normalize(latent) + latent = upsampler(latent) + latent = video_encoder.per_channel_statistics.normalize(latent) + return latent diff --git a/diffsynth/models/ltx2_video_vae.py b/diffsynth/models/ltx2_video_vae.py new file mode 100644 index 0000000..0c99432 --- /dev/null +++ b/diffsynth/models/ltx2_video_vae.py @@ -0,0 +1,2317 @@ +import itertools +import math +import einops +from dataclasses import replace, dataclass +from typing import Any, Callable, Iterator, List, NamedTuple, Tuple, Union, Optional +import torch +from einops import rearrange +from torch import nn +from torch.nn import functional as F +from enum import Enum +from .ltx2_common import PixelNorm, SpatioTemporalScaleFactors, VideoLatentShape, Patchifier, AudioLatentShape +from .ltx2_dit import PixArtAlphaCombinedTimestepSizeEmbeddings + +VAE_SPATIAL_FACTOR = 32 +VAE_TEMPORAL_FACTOR = 8 + + +class VideoLatentPatchifier(Patchifier): + def __init__(self, patch_size: int): + # Patch sizes for video latents. + self._patch_size = ( + 1, # temporal dimension + patch_size, # height dimension + patch_size, # width dimension + ) + + @property + def patch_size(self) -> Tuple[int, int, int]: + return self._patch_size + + def get_token_count(self, tgt_shape: VideoLatentShape) -> int: + return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size) + + def patchify( + self, + latents: torch.Tensor, + ) -> torch.Tensor: + latents = einops.rearrange( + latents, + "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", + p1=self._patch_size[0], + p2=self._patch_size[1], + p3=self._patch_size[2], + ) + + return latents + + def unpatchify( + self, + latents: torch.Tensor, + output_shape: VideoLatentShape, + ) -> torch.Tensor: + assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier" + + patch_grid_frames = output_shape.frames // self._patch_size[0] + patch_grid_height = output_shape.height // self._patch_size[1] + patch_grid_width = output_shape.width // self._patch_size[2] + + latents = einops.rearrange( + latents, + "b (f h w) (c p q) -> b c f (h p) (w q)", + f=patch_grid_frames, + h=patch_grid_height, + w=patch_grid_width, + p=self._patch_size[1], + q=self._patch_size[2], + ) + + return latents + + def unpatchify_video( + self, + latents: torch.Tensor, + frames: int, + height: int, + width: int, + ) -> torch.Tensor: + latents = einops.rearrange( + latents, + "b (f h w) (c p q) -> b c f (h p) (w q)", + f=frames, + h=height // self._patch_size[1], + w=width // self._patch_size[2], + p=self._patch_size[1], + q=self._patch_size[2], + ) + return latents + + def get_patch_grid_bounds( + self, + output_shape: AudioLatentShape | VideoLatentShape, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """ + Return the per-dimension bounds [inclusive start, exclusive end) for every + patch produced by `patchify`. The bounds are expressed in the original + video grid coordinates: frame/time, height, and width. + The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where: + - axis 1 (size 3) enumerates (frame/time, height, width) dimensions + - axis 3 (size 2) stores `[start, end)` indices within each dimension + Args: + output_shape: Video grid description containing frames, height, and width. + device: Device of the latent tensor. + """ + if not isinstance(output_shape, VideoLatentShape): + raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates") + + frames = output_shape.frames + height = output_shape.height + width = output_shape.width + batch_size = output_shape.batch + + # Validate inputs to ensure positive dimensions + assert frames > 0, f"frames must be positive, got {frames}" + assert height > 0, f"height must be positive, got {height}" + assert width > 0, f"width must be positive, got {width}" + assert batch_size > 0, f"batch_size must be positive, got {batch_size}" + + # Generate grid coordinates for each dimension (frame, height, width) + # We use torch.arange to create the starting coordinates for each patch. + # indexing='ij' ensures the dimensions are in the order (frame, height, width). + grid_coords = torch.meshgrid( + torch.arange(start=0, end=frames, step=self._patch_size[0], device=device), + torch.arange(start=0, end=height, step=self._patch_size[1], device=device), + torch.arange(start=0, end=width, step=self._patch_size[2], device=device), + indexing="ij", + ) + + # Stack the grid coordinates to create the start coordinates tensor. + # Shape becomes (3, grid_f, grid_h, grid_w) + patch_starts = torch.stack(grid_coords, dim=0) + + # Create a tensor containing the size of a single patch: + # (frame_patch_size, height_patch_size, width_patch_size). + # Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates. + patch_size_delta = torch.tensor( + self._patch_size, + device=patch_starts.device, + dtype=patch_starts.dtype, + ).view(3, 1, 1, 1) + + # Calculate end coordinates: start + patch_size + # Shape becomes (3, grid_f, grid_h, grid_w) + patch_ends = patch_starts + patch_size_delta + + # Stack start and end coordinates together along the last dimension + # Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end] + latent_coords = torch.stack((patch_starts, patch_ends), dim=-1) + + # Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence. + # Final Shape: (batch_size, 3, num_patches, 2) + latent_coords = einops.repeat( + latent_coords, + "c f h w bounds -> b c (f h w) bounds", + b=batch_size, + bounds=2, + ) + + return latent_coords + + +class NormLayerType(Enum): + GROUP_NORM = "group_norm" + PIXEL_NORM = "pixel_norm" + + +class LogVarianceType(Enum): + PER_CHANNEL = "per_channel" + UNIFORM = "uniform" + CONSTANT = "constant" + NONE = "none" + + +class PaddingModeType(Enum): + ZEROS = "zeros" + REFLECT = "reflect" + REPLICATE = "replicate" + CIRCULAR = "circular" + + +class DualConv3d(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + ) -> None: + super(DualConv3d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.padding_mode = padding_mode + # Ensure kernel_size, stride, padding, and dilation are tuples of length 3 + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + if kernel_size == (1, 1, 1): + raise ValueError("kernel_size must be greater than 1. Use make_linear_nd instead.") + if isinstance(stride, int): + stride = (stride, stride, stride) + if isinstance(padding, int): + padding = (padding, padding, padding) + if isinstance(dilation, int): + dilation = (dilation, dilation, dilation) + + # Set parameters for convolutions + self.groups = groups + self.bias = bias + + # Define the size of the channels after the first convolution + intermediate_channels = out_channels if in_channels < out_channels else in_channels + + # Define parameters for the first convolution + self.weight1 = nn.Parameter( + torch.Tensor( + intermediate_channels, + in_channels // groups, + 1, + kernel_size[1], + kernel_size[2], + )) + self.stride1 = (1, stride[1], stride[2]) + self.padding1 = (0, padding[1], padding[2]) + self.dilation1 = (1, dilation[1], dilation[2]) + if bias: + self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels)) + else: + self.register_parameter("bias1", None) + + # Define parameters for the second convolution + self.weight2 = nn.Parameter(torch.Tensor(out_channels, intermediate_channels // groups, kernel_size[0], 1, 1)) + self.stride2 = (stride[0], 1, 1) + self.padding2 = (padding[0], 0, 0) + self.dilation2 = (dilation[0], 1, 1) + if bias: + self.bias2 = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter("bias2", None) + + # Initialize weights and biases + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.kaiming_uniform_(self.weight1, a=torch.sqrt(5)) + nn.init.kaiming_uniform_(self.weight2, a=torch.sqrt(5)) + if self.bias: + fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1) + bound1 = 1 / torch.sqrt(fan_in1) + nn.init.uniform_(self.bias1, -bound1, bound1) + fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2) + bound2 = 1 / torch.sqrt(fan_in2) + nn.init.uniform_(self.bias2, -bound2, bound2) + + def forward( + self, + x: torch.Tensor, + use_conv3d: bool = False, + skip_time_conv: bool = False, + ) -> torch.Tensor: + if use_conv3d: + return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv) + else: + return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv) + + def forward_with_3d(self, x: torch.Tensor, skip_time_conv: bool = False) -> torch.Tensor: + # First convolution + x = F.conv3d( + x, + self.weight1, + self.bias1, + self.stride1, + self.padding1, + self.dilation1, + self.groups, + padding_mode=self.padding_mode, + ) + + if skip_time_conv: + return x + + # Second convolution + x = F.conv3d( + x, + self.weight2, + self.bias2, + self.stride2, + self.padding2, + self.dilation2, + self.groups, + padding_mode=self.padding_mode, + ) + + return x + + def forward_with_2d(self, x: torch.Tensor, skip_time_conv: bool = False) -> torch.Tensor: + b, _, _, h, w = x.shape + + # First 2D convolution + x = rearrange(x, "b c d h w -> (b d) c h w") + # Squeeze the depth dimension out of weight1 since it's 1 + weight1 = self.weight1.squeeze(2) + # Select stride, padding, and dilation for the 2D convolution + stride1 = (self.stride1[1], self.stride1[2]) + padding1 = (self.padding1[1], self.padding1[2]) + dilation1 = (self.dilation1[1], self.dilation1[2]) + x = F.conv2d( + x, + weight1, + self.bias1, + stride1, + padding1, + dilation1, + self.groups, + padding_mode=self.padding_mode, + ) + + _, _, h, w = x.shape + + if skip_time_conv: + x = rearrange(x, "(b d) c h w -> b c d h w", b=b) + return x + + # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension + x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b) + + # Reshape weight2 to match the expected dimensions for conv1d + weight2 = self.weight2.squeeze(-1).squeeze(-1) + # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution + stride2 = self.stride2[0] + padding2 = self.padding2[0] + dilation2 = self.dilation2[0] + x = F.conv1d( + x, + weight2, + self.bias2, + stride2, + padding2, + dilation2, + self.groups, + padding_mode=self.padding_mode, + ) + x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) + + return x + + @property + def weight(self) -> torch.Tensor: + return self.weight2 + + +class CausalConv3d(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: Union[int, Tuple[int]] = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ) -> None: + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + kernel_size = (kernel_size, kernel_size, kernel_size) + self.time_kernel_size = kernel_size[0] + + dilation = (dilation, 1, 1) + + height_pad = kernel_size[1] // 2 + width_pad = kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + padding_mode=spatial_padding_mode.value, + groups=groups, + bias=bias, + ) + + def forward(self, x: torch.Tensor, causal: bool = True) -> torch.Tensor: + if causal: + first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) + x = torch.concatenate((first_frame_pad, x), dim=2) + else: + first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)) + last_frame_pad = x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)) + x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) + x = self.conv(x) + return x + + @property + def weight(self) -> torch.Tensor: + return self.conv.weight + + +def make_conv_nd( # noqa: PLR0913 + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causal: bool = False, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + temporal_padding_mode: PaddingModeType = PaddingModeType.ZEROS, +) -> nn.Module: + if not (spatial_padding_mode == temporal_padding_mode or causal): + raise NotImplementedError("spatial and temporal padding modes must be equal") + if dims == 2: + return nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=spatial_padding_mode.value, + ) + elif dims == 3: + if causal: + return CausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + bias=bias, + spatial_padding_mode=spatial_padding_mode, + ) + return nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=spatial_padding_mode.value, + ) + elif dims == (2, 1): + return DualConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + padding_mode=spatial_padding_mode.value, + ) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +def make_linear_nd( + dims: int, + in_channels: int, + out_channels: int, + bias: bool = True, +) -> nn.Module: + if dims == 2: + return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias) + elif dims in (3, (2, 1)): + return nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +def patchify(x: torch.Tensor, patch_size_hw: int, patch_size_t: int = 1) -> torch.Tensor: + """ + Rearrange spatial dimensions into channels. Divides image into patch_size x patch_size blocks + and moves pixels from each block into separate channels (space-to-depth). + Args: + x: Input tensor (4D or 5D) + patch_size_hw: Spatial patch size for height and width. With patch_size_hw=4, divides HxW into 4x4 blocks. + patch_size_t: Temporal patch size for frames. Default=1 (no temporal patching). + For 5D: (B, C, F, H, W) -> (B, Cx(patch_size_hw^2)x(patch_size_t), F/patch_size_t, H/patch_size_hw, W/patch_size_hw) + Example: (B, 3, 33, 512, 512) with patch_size_hw=4, patch_size_t=1 -> (B, 48, 33, 128, 128) + """ + if patch_size_hw == 1 and patch_size_t == 1: + return x + if x.dim() == 4: + x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw) + elif x.dim() == 5: + x = rearrange( + x, + "b c (f p) (h q) (w r) -> b (c p r q) f h w", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x: torch.Tensor, patch_size_hw: int, patch_size_t: int = 1) -> torch.Tensor: + """ + Rearrange channels back into spatial dimensions. Inverse of patchify - moves pixels from + channels back into patch_size x patch_size blocks (depth-to-space). + Args: + x: Input tensor (4D or 5D) + patch_size_hw: Spatial patch size for height and width. With patch_size_hw=4, expands HxW by 4x. + patch_size_t: Temporal patch size for frames. Default=1 (no temporal expansion). + For 5D: (B, Cx(patch_size_hw^2)x(patch_size_t), F, H, W) -> (B, C, Fxpatch_size_t, Hxpatch_size_hw, Wxpatch_size_hw) + Example: (B, 48, 33, 128, 128) with patch_size_hw=4, patch_size_t=1 -> (B, 3, 33, 512, 512) + """ + if patch_size_hw == 1 and patch_size_t == 1: + return x + + if x.dim() == 4: + x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw) + elif x.dim() == 5: + x = rearrange( + x, + "b (c p r q) f h w -> b c (f p) (h q) (w r)", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + + return x + + +class PerChannelStatistics(nn.Module): + """ + Per-channel statistics for normalizing and denormalizing the latent representation. + This statics is computed over the entire dataset and stored in model's checkpoint under VAE state_dict. + """ + + def __init__(self, latent_channels: int = 128): + super().__init__() + self.register_buffer("std-of-means", torch.empty(latent_channels)) + self.register_buffer("mean-of-means", torch.empty(latent_channels)) + self.register_buffer("mean-of-stds", torch.empty(latent_channels)) + self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(latent_channels)) + self.register_buffer("channel", torch.empty(latent_channels)) + + def un_normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view( + 1, -1, 1, 1, 1).to(x) + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view( + 1, -1, 1, 1, 1).to(x) + + +class ResnetBlock3D(nn.Module): + r""" + A Resnet block. + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + norm_layer: NormLayerType = NormLayerType.PIXEL_NORM, + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.inject_noise = inject_noise + + if norm_layer == NormLayerType.GROUP_NORM: + self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + elif norm_layer == NormLayerType.PIXEL_NORM: + self.norm1 = PixelNorm() + + self.non_linearity = nn.SiLU() + + self.conv1 = make_conv_nd( + dims, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + if norm_layer == NormLayerType.GROUP_NORM: + self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) + elif norm_layer == NormLayerType.PIXEL_NORM: + self.norm2 = PixelNorm() + + self.dropout = torch.nn.Dropout(dropout) + + self.conv2 = make_conv_nd( + dims, + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + if inject_noise: + self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + self.conv_shortcut = (make_linear_nd(dims=dims, in_channels=in_channels, out_channels=out_channels) + if in_channels != out_channels else nn.Identity()) + + # Using GroupNorm with 1 group is equivalent to LayerNorm but works with (B, C, ...) layout + # avoiding the need for dimension rearrangement used in standard nn.LayerNorm + self.norm3 = (nn.GroupNorm(num_groups=1, num_channels=in_channels, eps=eps, affine=True) + if in_channels != out_channels else nn.Identity()) + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.scale_shift_table = nn.Parameter(torch.zeros(4, in_channels)) + + def _feed_spatial_noise( + self, + hidden_states: torch.Tensor, + per_channel_scale: torch.Tensor, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + spatial_shape = hidden_states.shape[-2:] + device = hidden_states.device + dtype = hidden_states.dtype + + # similar to the "explicit noise inputs" method in style-gan + spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype, generator=generator)[None] + scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...] + hidden_states = hidden_states + scaled_noise + + return hidden_states + + def forward( + self, + input_tensor: torch.Tensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + hidden_states = input_tensor + batch_size = hidden_states.shape[0] + + hidden_states = self.norm1(hidden_states) + if self.timestep_conditioning: + if timestep is None: + raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True") + ada_values = self.scale_shift_table[None, ..., None, None, None].to( + device=hidden_states.device, dtype=hidden_states.dtype) + timestep.reshape( + batch_size, + 4, + -1, + timestep.shape[-3], + timestep.shape[-2], + timestep.shape[-1], + ) + shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1) + + hidden_states = hidden_states * (1 + scale1) + shift1 + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.conv1(hidden_states, causal=causal) + + if self.inject_noise: + hidden_states = self._feed_spatial_noise( + hidden_states, + self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype), + generator=generator, + ) + + hidden_states = self.norm2(hidden_states) + + if self.timestep_conditioning: + hidden_states = hidden_states * (1 + scale2) + shift2 + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + + hidden_states = self.conv2(hidden_states, causal=causal) + + if self.inject_noise: + hidden_states = self._feed_spatial_noise( + hidden_states, + self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype), + generator=generator, + ) + + input_tensor = self.norm3(input_tensor) + + batch_size = input_tensor.shape[0] + + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +class UNetMidBlock3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + Args: + in_channels (`int`): The number of input channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + inject_noise (`bool`, *optional*, defaults to `False`): + Whether to inject noise into the hidden states. + timestep_conditioning (`bool`, *optional*, defaults to `False`): + Whether to condition the hidden states on the timestep. + Returns: + `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + norm_layer: NormLayerType = NormLayerType.GROUP_NORM, + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(embedding_dim=in_channels * 4, + size_emb_dim=0) + + self.res_blocks = nn.ModuleList([ + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) for _ in range(num_layers) + ]) + + def forward( + self, + hidden_states: torch.Tensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + timestep_embed = None + if self.timestep_conditioning: + if timestep is None: + raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True") + batch_size = hidden_states.shape[0] + timestep_embed = self.time_embedder( + timestep=timestep.flatten(), + hidden_dtype=hidden_states.dtype, + ) + timestep_embed = timestep_embed.view(batch_size, timestep_embed.shape[-1], 1, 1, 1) + + for resnet in self.res_blocks: + hidden_states = resnet( + hidden_states, + causal=causal, + timestep=timestep_embed, + generator=generator, + ) + + return hidden_states + + +class SpaceToDepthDownsample(nn.Module): + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + stride: Tuple[int, int, int], + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + self.stride = stride + self.group_size = in_channels * math.prod(stride) // out_channels + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=out_channels // math.prod(stride), + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + def forward( + self, + x: torch.Tensor, + causal: bool = True, + ) -> torch.Tensor: + if self.stride[0] == 2: + x = torch.cat([x[:, :, :1, :, :], x], dim=2) # duplicate first frames for padding + + # skip connection + x_in = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size) + x_in = x_in.mean(dim=2) + + # conv + x = self.conv(x, causal=causal) + x = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + + x = x + x_in + + return x + + +class DepthToSpaceUpsample(nn.Module): + + def __init__( + self, + dims: int | Tuple[int, int], + in_channels: int, + stride: Tuple[int, int, int], + residual: bool = False, + out_channels_reduction_factor: int = 1, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + self.stride = stride + self.out_channels = math.prod(stride) * in_channels // out_channels_reduction_factor + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + self.residual = residual + self.out_channels_reduction_factor = out_channels_reduction_factor + + def forward( + self, + x: torch.Tensor, + causal: bool = True, + ) -> torch.Tensor: + if self.residual: + # Reshape and duplicate the input to match the output shape + x_in = rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor + x_in = x_in.repeat(1, num_repeat, 1, 1, 1) + if self.stride[0] == 2: + x_in = x_in[:, :, 1:, :, :] + x = self.conv(x, causal=causal) + x = rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + if self.stride[0] == 2: + x = x[:, :, 1:, :, :] + if self.residual: + x = x + x_in + return x + + +def compute_trapezoidal_mask_1d( + length: int, + ramp_left: int, + ramp_right: int, + left_starts_from_0: bool = False, +) -> torch.Tensor: + """ + Generate a 1D trapezoidal blending mask with linear ramps. + Args: + length: Output length of the mask. + ramp_left: Fade-in length on the left. + ramp_right: Fade-out length on the right. + left_starts_from_0: Whether the ramp starts from 0 or first non-zero value. + Useful for temporal tiles where the first tile is causal. + Returns: + A 1D tensor of shape `(length,)` with values in [0, 1]. + """ + if length <= 0: + raise ValueError("Mask length must be positive.") + + ramp_left = max(0, min(ramp_left, length)) + ramp_right = max(0, min(ramp_right, length)) + + mask = torch.ones(length) + + if ramp_left > 0: + interval_length = ramp_left + 1 if left_starts_from_0 else ramp_left + 2 + fade_in = torch.linspace(0.0, 1.0, interval_length)[:-1] + if not left_starts_from_0: + fade_in = fade_in[1:] + mask[:ramp_left] *= fade_in + + if ramp_right > 0: + fade_out = torch.linspace(1.0, 0.0, steps=ramp_right + 2)[1:-1] + mask[-ramp_right:] *= fade_out + + return mask.clamp_(0, 1) + + +@dataclass(frozen=True) +class SpatialTilingConfig: + """Configuration for dividing each frame into spatial tiles with optional overlap. + Args: + tile_size_in_pixels (int): Size of each tile in pixels. Must be at least 64 and divisible by 32. + tile_overlap_in_pixels (int, optional): Overlap between tiles in pixels. Must be divisible by 32. Defaults to 0. + """ + + tile_size_in_pixels: int + tile_overlap_in_pixels: int = 0 + + def __post_init__(self) -> None: + if self.tile_size_in_pixels < 64: + raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}") + if self.tile_size_in_pixels % 32 != 0: + raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}") + if self.tile_overlap_in_pixels % 32 != 0: + raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}") + if self.tile_overlap_in_pixels >= self.tile_size_in_pixels: + raise ValueError( + f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}" + ) + + +@dataclass(frozen=True) +class TemporalTilingConfig: + """Configuration for dividing a video into temporal tiles (chunks of frames) with optional overlap. + Args: + tile_size_in_frames (int): Number of frames in each tile. Must be at least 16 and divisible by 8. + tile_overlap_in_frames (int, optional): Number of overlapping frames between consecutive tiles. + Must be divisible by 8. Defaults to 0. + """ + + tile_size_in_frames: int + tile_overlap_in_frames: int = 0 + + def __post_init__(self) -> None: + if self.tile_size_in_frames < 16: + raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}") + if self.tile_size_in_frames % 8 != 0: + raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}") + if self.tile_overlap_in_frames % 8 != 0: + raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}") + if self.tile_overlap_in_frames >= self.tile_size_in_frames: + raise ValueError( + f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}" + ) + + +@dataclass(frozen=True) +class TilingConfig: + """Configuration for splitting video into tiles with optional overlap. + Attributes: + spatial_config: Configuration for splitting spatial dimensions into tiles. + temporal_config: Configuration for splitting temporal dimension into tiles. + """ + + spatial_config: SpatialTilingConfig | None = None + temporal_config: TemporalTilingConfig | None = None + + @classmethod + def default(cls) -> "TilingConfig": + return cls( + spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64), + temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24), + ) + + +@dataclass(frozen=True) +class DimensionIntervals: + """Intervals which a single dimension of the latent space is split into. + Each interval is defined by its start, end, left ramp, and right ramp. + The start and end are the indices of the first and last element (exclusive) in the interval. + Ramps are regions of the interval where the value of the mask tensor is + interpolated between 0 and 1 for blending with neighboring intervals. + The left ramp and right ramp values are the lengths of the left and right ramps. + """ + + starts: List[int] + ends: List[int] + left_ramps: List[int] + right_ramps: List[int] + + +@dataclass(frozen=True) +class LatentIntervals: + """Intervals which the latent tensor of given shape is split into. + Each dimension of the latent space is split into intervals based on the length along said dimension. + """ + + original_shape: torch.Size + dimension_intervals: Tuple[DimensionIntervals, ...] + + +# Operation to split a single dimension of the tensor into intervals based on the length along the dimension. +SplitOperation = Callable[[int], DimensionIntervals] +# Operation to map the intervals in input dimension to slices and masks along a corresponding output dimension. +MappingOperation = Callable[[DimensionIntervals], tuple[list[slice], list[torch.Tensor | None]]] + + +def default_split_operation(length: int) -> DimensionIntervals: + return DimensionIntervals(starts=[0], ends=[length], left_ramps=[0], right_ramps=[0]) + + +DEFAULT_SPLIT_OPERATION: SplitOperation = default_split_operation + + +def default_mapping_operation(_intervals: DimensionIntervals,) -> tuple[list[slice], list[torch.Tensor | None]]: + return [slice(0, None)], [None] + + +DEFAULT_MAPPING_OPERATION: MappingOperation = default_mapping_operation + + +class Tile(NamedTuple): + """ + Represents a single tile. + Attributes: + in_coords: + Tuple of slices specifying where to cut the tile from the INPUT tensor. + out_coords: + Tuple of slices specifying where this tile's OUTPUT should be placed in the reconstructed OUTPUT tensor. + masks_1d: + Per-dimension masks in OUTPUT units. + These are used to create all-dimensional blending mask. + Methods: + blend_mask: + Create a single N-D mask from the per-dimension masks. + """ + + in_coords: Tuple[slice, ...] + out_coords: Tuple[slice, ...] + masks_1d: Tuple[Tuple[torch.Tensor, ...]] + + @property + def blend_mask(self) -> torch.Tensor: + num_dims = len(self.out_coords) + per_dimension_masks: List[torch.Tensor] = [] + + for dim_idx in range(num_dims): + mask_1d = self.masks_1d[dim_idx] + view_shape = [1] * num_dims + if mask_1d is None: + # Broadcast mask along this dimension (length 1). + one = torch.ones(1) + + view_shape[dim_idx] = 1 + per_dimension_masks.append(one.view(*view_shape)) + continue + + # Reshape (L,) -> (1, ..., L, ..., 1) so masks across dimensions broadcast-multiply. + view_shape[dim_idx] = mask_1d.shape[0] + per_dimension_masks.append(mask_1d.view(*view_shape)) + + # Multiply per-dimension masks to form the full N-D mask (separable blending window). + combined_mask = per_dimension_masks[0] + for mask in per_dimension_masks[1:]: + combined_mask = combined_mask * mask + + return combined_mask + + +def create_tiles_from_intervals_and_mappers( + intervals: LatentIntervals, + mappers: List[MappingOperation], +) -> List[Tile]: + full_dim_input_slices = [] + full_dim_output_slices = [] + full_dim_masks_1d = [] + for axis_index in range(len(intervals.original_shape)): + dimension_intervals = intervals.dimension_intervals[axis_index] + starts = dimension_intervals.starts + ends = dimension_intervals.ends + input_slices = [slice(s, e) for s, e in zip(starts, ends, strict=True)] + output_slices, masks_1d = mappers[axis_index](dimension_intervals) + full_dim_input_slices.append(input_slices) + full_dim_output_slices.append(output_slices) + full_dim_masks_1d.append(masks_1d) + + tiles = [] + tile_in_coords = list(itertools.product(*full_dim_input_slices)) + tile_out_coords = list(itertools.product(*full_dim_output_slices)) + tile_mask_1ds = list(itertools.product(*full_dim_masks_1d)) + for in_coord, out_coord, mask_1d in zip(tile_in_coords, tile_out_coords, tile_mask_1ds, strict=True): + tiles.append(Tile( + in_coords=in_coord, + out_coords=out_coord, + masks_1d=mask_1d, + )) + return tiles + + +def create_tiles( + latent_shape: torch.Size, + splitters: List[SplitOperation], + mappers: List[MappingOperation], +) -> List[Tile]: + if len(splitters) != len(latent_shape): + raise ValueError(f"Number of splitters must be equal to number of dimensions in latent shape, " + f"got {len(splitters)} and {len(latent_shape)}") + if len(mappers) != len(latent_shape): + raise ValueError(f"Number of mappers must be equal to number of dimensions in latent shape, " + f"got {len(mappers)} and {len(latent_shape)}") + intervals = [splitter(length) for splitter, length in zip(splitters, latent_shape, strict=True)] + latent_intervals = LatentIntervals(original_shape=latent_shape, dimension_intervals=tuple(intervals)) + return create_tiles_from_intervals_and_mappers(latent_intervals, mappers) + + +def _make_encoder_block( + block_name: str, + block_config: dict[str, Any], + in_channels: int, + convolution_dimensions: int, + norm_layer: NormLayerType, + norm_num_groups: int, + spatial_padding_mode: PaddingModeType, +) -> Tuple[nn.Module, int]: + out_channels = in_channels + + if block_name == "res_x": + block = UNetMidBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + num_layers=block_config["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "res_x_y": + out_channels = in_channels * block_config.get("multiplier", 2) + block = ResnetBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time": + block = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(2, 1, 1), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space": + block = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(1, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all": + block = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all_x_y": + out_channels = in_channels * block_config.get("multiplier", 2) + block = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all_res": + out_channels = in_channels * block_config.get("multiplier", 2) + block = SpaceToDepthDownsample( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space_res": + out_channels = in_channels * block_config.get("multiplier", 2) + block = SpaceToDepthDownsample( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time_res": + out_channels = in_channels * block_config.get("multiplier", 2) + block = SpaceToDepthDownsample( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unknown block: {block_name}") + + return block, out_channels + + +class LTX2VideoEncoder(nn.Module): + _DEFAULT_NORM_NUM_GROUPS = 32 + """ + Variational Autoencoder Encoder. Encodes video frames into a latent representation. + The encoder compresses the input video through a series of downsampling operations controlled by + patch_size and encoder_blocks. The output is a normalized latent tensor with shape (B, 128, F', H', W'). + Compression Behavior: + The total compression is determined by: + 1. Initial spatial compression via patchify: H -> H/4, W -> W/4 (patch_size=4) + 2. Sequential compression through encoder_blocks based on their stride patterns + Compression blocks apply 2x compression in specified dimensions: + - "compress_time" / "compress_time_res": temporal only + - "compress_space" / "compress_space_res": spatial only (H and W) + - "compress_all" / "compress_all_res": all dimensions (F, H, W) + - "res_x" / "res_x_y": no compression + Standard LTX Video configuration: + - patch_size=4 + - encoder_blocks: 1x compress_space_res, 1x compress_time_res, 2x compress_all_res + - Final dimensions: F' = 1 + (F-1)/8, H' = H/32, W' = W/32 + - Example: (B, 3, 33, 512, 512) -> (B, 128, 5, 16, 16) + - Note: Input must have 1 + 8*k frames (e.g., 1, 9, 17, 25, 33...) + Args: + convolution_dimensions: The number of dimensions to use in convolutions (2D or 3D). + in_channels: The number of input channels. For RGB images, this is 3. + out_channels: The number of output channels (latent channels). For latent channels, this is 128. + encoder_blocks: The list of blocks to construct the encoder. Each block is a tuple of (block_name, params) + where params is either an int (num_layers) or a dict with configuration. + patch_size: The patch size for initial spatial compression. Should be a power of 2. + norm_layer: The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + latent_log_var: The log variance mode. Can be either `per_channel`, `uniform`, `constant` or `none`. + """ + + def __init__( + self, + convolution_dimensions: int = 3, + in_channels: int = 3, + out_channels: int = 128, + patch_size: int = 4, + norm_layer: NormLayerType = NormLayerType.PIXEL_NORM, + latent_log_var: LogVarianceType = LogVarianceType.UNIFORM, + encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + encoder_blocks = [['res_x', { + 'num_layers': 4 + }], ['compress_space_res', { + 'multiplier': 2 + }], ['res_x', { + 'num_layers': 6 + }], ['compress_time_res', { + 'multiplier': 2 + }], ['res_x', { + 'num_layers': 6 + }], ['compress_all_res', { + 'multiplier': 2 + }], ['res_x', { + 'num_layers': 2 + }], ['compress_all_res', { + 'multiplier': 2 + }], ['res_x', { + 'num_layers': 2 + }]] + self.patch_size = patch_size + self.norm_layer = norm_layer + self.latent_channels = out_channels + self.latent_log_var = latent_log_var + self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS + + # Per-channel statistics for normalizing latents + self.per_channel_statistics = PerChannelStatistics(latent_channels=out_channels) + + in_channels = in_channels * patch_size**2 + feature_channels = out_channels + + self.conv_in = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=feature_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=encoder_spatial_padding_mode, + ) + + self.down_blocks = nn.ModuleList([]) + + for block_name, block_params in encoder_blocks: + # Convert int to dict format for uniform handling + block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params + + block, feature_channels = _make_encoder_block( + block_name=block_name, + block_config=block_config, + in_channels=feature_channels, + convolution_dimensions=convolution_dimensions, + norm_layer=norm_layer, + norm_num_groups=self._norm_num_groups, + spatial_padding_mode=encoder_spatial_padding_mode, + ) + + self.down_blocks.append(block) + + # out + if norm_layer == NormLayerType.GROUP_NORM: + self.conv_norm_out = nn.GroupNorm(num_channels=feature_channels, num_groups=self._norm_num_groups, eps=1e-6) + elif norm_layer == NormLayerType.PIXEL_NORM: + self.conv_norm_out = PixelNorm() + + self.conv_act = nn.SiLU() + + conv_out_channels = out_channels + if latent_log_var == LogVarianceType.PER_CHANNEL: + conv_out_channels *= 2 + elif latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}: + conv_out_channels += 1 + elif latent_log_var != LogVarianceType.NONE: + raise ValueError(f"Invalid latent_log_var: {latent_log_var}") + + self.conv_out = make_conv_nd( + dims=convolution_dimensions, + in_channels=feature_channels, + out_channels=conv_out_channels, + kernel_size=3, + padding=1, + causal=True, + spatial_padding_mode=encoder_spatial_padding_mode, + ) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + r""" + Encode video frames into normalized latent representation. + Args: + sample: Input video (B, C, F, H, W). F must be 1 + 8*k (e.g., 1, 9, 17, 25, 33...). + Returns: + Normalized latent means (B, 128, F', H', W') where F' = 1+(F-1)/8, H' = H/32, W' = W/32. + Example: (B, 3, 33, 512, 512) -> (B, 128, 5, 16, 16). + """ + # Validate frame count + frames_count = sample.shape[2] + if ((frames_count - 1) % 8) != 0: + raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames " + "(e.g., 1, 9, 17, ...). Please check your input.") + + # Initial spatial compression: trade spatial resolution for channel depth + # This reduces H,W by patch_size and increases channels, making convolutions more efficient + # Example: (B, 3, F, 512, 512) -> (B, 48, F, 128, 128) with patch_size=4 + sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + sample = self.conv_in(sample) + + for down_block in self.down_blocks: + sample = down_block(sample) + + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if self.latent_log_var == LogVarianceType.UNIFORM: + # Uniform Variance: model outputs N means and 1 shared log-variance channel. + # We need to expand the single logvar to match the number of means channels + # to create a format compatible with PER_CHANNEL (means + logvar, each with N channels). + # Sample shape: (B, N+1, ...) where N = latent_channels (e.g., 128 means + 1 logvar = 129) + # Target shape: (B, 2*N, ...) where first N are means, last N are logvar + + if sample.shape[1] < 2: + raise ValueError(f"Invalid channel count for UNIFORM mode: expected at least 2 channels " + f"(N means + 1 logvar), got {sample.shape[1]}") + + # Extract means (first N channels) and logvar (last 1 channel) + means = sample[:, :-1, ...] # (B, N, ...) + logvar = sample[:, -1:, ...] # (B, 1, ...) + + # Repeat logvar N times to match means channels + # Use expand/repeat pattern that works for both 4D and 5D tensors + num_channels = means.shape[1] + repeat_shape = [1, num_channels] + [1] * (sample.ndim - 2) + repeated_logvar = logvar.repeat(*repeat_shape) # (B, N, ...) + + # Concatenate to create (B, 2*N, ...) format: [means, repeated_logvar] + sample = torch.cat([means, repeated_logvar], dim=1) + elif self.latent_log_var == LogVarianceType.CONSTANT: + sample = sample[:, :-1, ...] + approx_ln_0 = -30 # this is the minimal clamp value in DiagonalGaussianDistribution objects + sample = torch.cat( + [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0], + dim=1, + ) + + # Split into means and logvar, then normalize means + means, _ = torch.chunk(sample, 2, dim=1) + return self.per_channel_statistics.normalize(means) + + + def tiled_encode_video( + self, + video: torch.Tensor, + tile_size: int = 512, + tile_overlap: int = 128, + ) -> torch.Tensor: + """Encode video using spatial tiling for memory efficiency. + Splits the video into overlapping spatial tiles, encodes each tile separately, + and blends the results using linear feathering in the overlap regions. + Args: + video: Input tensor of shape [B, C, F, H, W] + tile_size: Tile size in pixels (must be divisible by 32) + tile_overlap: Overlap between tiles in pixels (must be divisible by 32) + Returns: + Encoded latent tensor [B, C_latent, F_latent, H_latent, W_latent] + """ + batch, _channels, frames, height, width = video.shape + device = video.device + dtype = video.dtype + + # Validate tile parameters + if tile_size % VAE_SPATIAL_FACTOR != 0: + raise ValueError(f"tile_size must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_size}") + if tile_overlap % VAE_SPATIAL_FACTOR != 0: + raise ValueError(f"tile_overlap must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_overlap}") + if tile_overlap >= tile_size: + raise ValueError(f"tile_overlap ({tile_overlap}) must be less than tile_size ({tile_size})") + + # If video fits in a single tile, use regular encoding + if height <= tile_size and width <= tile_size: + return self.forward(video) + + # Calculate output dimensions + # VAE compresses: H -> H/32, W -> W/32, F -> 1 + (F-1)/8 + output_height = height // VAE_SPATIAL_FACTOR + output_width = width // VAE_SPATIAL_FACTOR + output_frames = 1 + (frames - 1) // VAE_TEMPORAL_FACTOR + + # Latent channels (128 for LTX-2) + # Get from a small test encode or assume 128 + latent_channels = 128 + + # Initialize output and weight tensors + output = torch.zeros( + (batch, latent_channels, output_frames, output_height, output_width), + device=device, + dtype=dtype, + ) + weights = torch.zeros( + (batch, 1, output_frames, output_height, output_width), + device=device, + dtype=dtype, + ) + + # Calculate tile positions with overlap + # Step size is tile_size - tile_overlap + step_h = tile_size - tile_overlap + step_w = tile_size - tile_overlap + + h_positions = list(range(0, max(1, height - tile_overlap), step_h)) + w_positions = list(range(0, max(1, width - tile_overlap), step_w)) + + # Ensure last tile covers the edge + if h_positions[-1] + tile_size < height: + h_positions.append(height - tile_size) + if w_positions[-1] + tile_size < width: + w_positions.append(width - tile_size) + + # Remove duplicates and sort + h_positions = sorted(set(h_positions)) + w_positions = sorted(set(w_positions)) + + # Overlap in latent space + overlap_out_h = tile_overlap // VAE_SPATIAL_FACTOR + overlap_out_w = tile_overlap // VAE_SPATIAL_FACTOR + + # Process each tile + for h_pos in h_positions: + for w_pos in w_positions: + # Calculate tile boundaries in input space + h_start = max(0, h_pos) + w_start = max(0, w_pos) + h_end = min(h_start + tile_size, height) + w_end = min(w_start + tile_size, width) + + # Ensure tile dimensions are divisible by VAE_SPATIAL_FACTOR + tile_h = ((h_end - h_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR + tile_w = ((w_end - w_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR + + if tile_h < VAE_SPATIAL_FACTOR or tile_w < VAE_SPATIAL_FACTOR: + continue + + # Adjust end positions + h_end = h_start + tile_h + w_end = w_start + tile_w + + # Extract tile + tile = video[:, :, :, h_start:h_end, w_start:w_end] + + # Encode tile + encoded_tile = self.forward(tile) + + # Get actual encoded dimensions + _, _, tile_out_frames, tile_out_height, tile_out_width = encoded_tile.shape + + # Calculate output positions + out_h_start = h_start // VAE_SPATIAL_FACTOR + out_w_start = w_start // VAE_SPATIAL_FACTOR + out_h_end = min(out_h_start + tile_out_height, output_height) + out_w_end = min(out_w_start + tile_out_width, output_width) + + # Trim encoded tile if necessary + actual_tile_h = out_h_end - out_h_start + actual_tile_w = out_w_end - out_w_start + encoded_tile = encoded_tile[:, :, :, :actual_tile_h, :actual_tile_w] + + # Create blending mask with linear feathering at edges + mask = torch.ones( + (1, 1, tile_out_frames, actual_tile_h, actual_tile_w), + device=device, + dtype=dtype, + ) + + # Apply feathering at edges (linear blend in overlap regions) + # Left edge + if h_pos > 0 and overlap_out_h > 0 and overlap_out_h < actual_tile_h: + fade_in = torch.linspace(0.0, 1.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1] + mask[:, :, :, :overlap_out_h, :] *= fade_in.view(1, 1, 1, -1, 1) + + # Right edge (bottom in height dimension) + if h_end < height and overlap_out_h > 0 and overlap_out_h < actual_tile_h: + fade_out = torch.linspace(1.0, 0.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1] + mask[:, :, :, -overlap_out_h:, :] *= fade_out.view(1, 1, 1, -1, 1) + + # Top edge (left in width dimension) + if w_pos > 0 and overlap_out_w > 0 and overlap_out_w < actual_tile_w: + fade_in = torch.linspace(0.0, 1.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1] + mask[:, :, :, :, :overlap_out_w] *= fade_in.view(1, 1, 1, 1, -1) + + # Bottom edge (right in width dimension) + if w_end < width and overlap_out_w > 0 and overlap_out_w < actual_tile_w: + fade_out = torch.linspace(1.0, 0.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1] + mask[:, :, :, :, -overlap_out_w:] *= fade_out.view(1, 1, 1, 1, -1) + + # Accumulate weighted results + output[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += encoded_tile * mask + weights[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += mask + + # Normalize by weights (avoid division by zero) + output = output / (weights + 1e-8) + + return output + + def encode( + self, + video: torch.Tensor, + tiled=False, + tile_size_in_pixels: Optional[int] = 512, + tile_overlap_in_pixels: Optional[int] = 128, + **kwargs, + ) -> torch.Tensor: + if video.ndim == 4: + video = video.unsqueeze(0) # [C, F, H, W] -> [B, C, F, H, W] + # Choose encoding method based on tiling flag + if tiled: + latents = self.tiled_encode_video( + video=video, + tile_size=tile_size_in_pixels, + tile_overlap=tile_overlap_in_pixels, + ) + else: + # Encode video - VAE expects [B, C, F, H, W], returns [B, C, F', H', W'] + latents = self.forward(video) + return latents + + +def _make_decoder_block( + block_name: str, + block_config: dict[str, Any], + in_channels: int, + convolution_dimensions: int, + norm_layer: NormLayerType, + timestep_conditioning: bool, + norm_num_groups: int, + spatial_padding_mode: PaddingModeType, +) -> Tuple[nn.Module, int]: + out_channels = in_channels + if block_name == "res_x": + block = UNetMidBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + num_layers=block_config["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_config.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "attn_res_x": + block = UNetMidBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + num_layers=block_config["num_layers"], + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_config.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + attention_head_dim=block_config["attention_head_dim"], + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "res_x_y": + out_channels = in_channels // block_config.get("multiplier", 2) + block = ResnetBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_config.get("inject_noise", False), + timestep_conditioning=False, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time": + block = DepthToSpaceUpsample( + dims=convolution_dimensions, + in_channels=in_channels, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space": + block = DepthToSpaceUpsample( + dims=convolution_dimensions, + in_channels=in_channels, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all": + out_channels = in_channels // block_config.get("multiplier", 1) + block = DepthToSpaceUpsample( + dims=convolution_dimensions, + in_channels=in_channels, + stride=(2, 2, 2), + residual=block_config.get("residual", False), + out_channels_reduction_factor=block_config.get("multiplier", 1), + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unknown layer: {block_name}") + + return block, out_channels + + +class LTX2VideoDecoder(nn.Module): + _DEFAULT_NORM_NUM_GROUPS = 32 + """ + Variational Autoencoder Decoder. Decodes latent representation into video frames. + The decoder upsamples latents through a series of upsampling operations (inverse of encoder). + Output dimensions: F = 8x(F'-1) + 1, H = 32xH', W = 32xW' for standard LTX Video configuration. + Upsampling blocks expand dimensions by 2x in specified dimensions: + - "compress_time": temporal only + - "compress_space": spatial only (H and W) + - "compress_all": all dimensions (F, H, W) + - "res_x" / "res_x_y" / "attn_res_x": no upsampling + Causal Mode: + causal=False (standard): Symmetric padding, allows future frame dependencies. + causal=True: Causal padding, each frame depends only on past/current frames. + First frame removed after temporal upsampling in both modes. Output shape unchanged. + Example: (B, 128, 5, 16, 16) -> (B, 3, 33, 512, 512) for both modes. + Args: + convolution_dimensions: The number of dimensions to use in convolutions (2D or 3D). + in_channels: The number of input channels (latent channels). Default is 128. + out_channels: The number of output channels. For RGB images, this is 3. + decoder_blocks: The list of blocks to construct the decoder. Each block is a tuple of (block_name, params) + where params is either an int (num_layers) or a dict with configuration. + patch_size: Final spatial expansion factor. For standard LTX Video, use 4 for 4x spatial expansion: + H -> Hx4, W -> Wx4. Should be a power of 2. + norm_layer: The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + causal: Whether to use causal convolutions. For standard LTX Video, use False for symmetric padding. + When True, uses causal padding (past/current frames only). + timestep_conditioning: Whether to condition the decoder on timestep for denoising. + """ + + def __init__( + self, + convolution_dimensions: int = 3, + in_channels: int = 128, + out_channels: int = 3, + decoder_blocks: List[Tuple[str, int | dict]] = [], # noqa: B006 + patch_size: int = 4, + norm_layer: NormLayerType = NormLayerType.PIXEL_NORM, + causal: bool = False, + timestep_conditioning: bool = False, + decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT, + ): + super().__init__() + + # Spatiotemporal downscaling between decoded video space and VAE latents. + # According to the LTXV paper, the standard configuration downsamples + # video inputs by a factor of 8 in the temporal dimension and 32 in + # each spatial dimension (height and width). This parameter determines how + # many video frames and pixels correspond to a single latent cell. + decoder_blocks = [['res_x', { + 'num_layers': 5, + 'inject_noise': False + }], ['compress_all', { + 'residual': True, + 'multiplier': 2 + }], ['res_x', { + 'num_layers': 5, + 'inject_noise': False + }], ['compress_all', { + 'residual': True, + 'multiplier': 2 + }], ['res_x', { + 'num_layers': 5, + 'inject_noise': False + }], ['compress_all', { + 'residual': True, + 'multiplier': 2 + }], ['res_x', { + 'num_layers': 5, + 'inject_noise': False + }]] + self.video_downscale_factors = SpatioTemporalScaleFactors( + time=8, + width=32, + height=32, + ) + + self.patch_size = patch_size + out_channels = out_channels * patch_size**2 + self.causal = causal + self.timestep_conditioning = timestep_conditioning + self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS + + # Per-channel statistics for denormalizing latents + self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels) + + # Noise and timestep parameters for decoder conditioning + self.decode_noise_scale = 0.025 + self.decode_timestep = 0.05 + + # Compute initial feature_channels by going through blocks in reverse + # This determines the channel width at the start of the decoder + feature_channels = in_channels + for block_name, block_params in list(reversed(decoder_blocks)): + block_config = block_params if isinstance(block_params, dict) else {} + if block_name == "res_x_y": + feature_channels = feature_channels * block_config.get("multiplier", 2) + if block_name == "compress_all": + feature_channels = feature_channels * block_config.get("multiplier", 1) + + self.conv_in = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=feature_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=decoder_spatial_padding_mode, + ) + + self.up_blocks = nn.ModuleList([]) + + for block_name, block_params in list(reversed(decoder_blocks)): + # Convert int to dict format for uniform handling + block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params + + block, feature_channels = _make_decoder_block( + block_name=block_name, + block_config=block_config, + in_channels=feature_channels, + convolution_dimensions=convolution_dimensions, + norm_layer=norm_layer, + timestep_conditioning=timestep_conditioning, + norm_num_groups=self._norm_num_groups, + spatial_padding_mode=decoder_spatial_padding_mode, + ) + + self.up_blocks.append(block) + + if norm_layer == NormLayerType.GROUP_NORM: + self.conv_norm_out = nn.GroupNorm(num_channels=feature_channels, num_groups=self._norm_num_groups, eps=1e-6) + elif norm_layer == NormLayerType.PIXEL_NORM: + self.conv_norm_out = PixelNorm() + + self.conv_act = nn.SiLU() + self.conv_out = make_conv_nd( + dims=convolution_dimensions, + in_channels=feature_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + causal=True, + spatial_padding_mode=decoder_spatial_padding_mode, + ) + + if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0)) + self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(embedding_dim=feature_channels * 2, + size_emb_dim=0) + self.last_scale_shift_table = nn.Parameter(torch.empty(2, feature_channels)) + + def forward( + self, + sample: torch.Tensor, + timestep: torch.Tensor | None = None, + generator: torch.Generator | None = None, + ) -> torch.Tensor: + r""" + Decode latent representation into video frames. + Args: + sample: Latent tensor (B, 128, F', H', W'). + timestep: Timestep for conditioning (if timestep_conditioning=True). Uses default 0.05 if None. + generator: Random generator for deterministic noise injection (if inject_noise=True in blocks). + Returns: + Decoded video (B, 3, F, H, W) where F = 8x(F'-1) + 1, H = 32xH', W = 32xW'. + Example: (B, 128, 5, 16, 16) -> (B, 3, 33, 512, 512). + Note: First frame is removed after temporal upsampling regardless of causal mode. + When causal=False, allows future frame dependencies in convolutions but maintains same output shape. + """ + batch_size = sample.shape[0] + + # Add noise if timestep conditioning is enabled + if self.timestep_conditioning: + noise = (torch.randn( + sample.size(), + generator=generator, + dtype=sample.dtype, + device=sample.device, + ) * self.decode_noise_scale) + + sample = noise + (1.0 - self.decode_noise_scale) * sample + + # Denormalize latents + sample = self.per_channel_statistics.un_normalize(sample) + + # Use default decode_timestep if timestep not provided + if timestep is None and self.timestep_conditioning: + timestep = torch.full((batch_size,), self.decode_timestep, device=sample.device, dtype=sample.dtype) + + sample = self.conv_in(sample, causal=self.causal) + + scaled_timestep = None + if self.timestep_conditioning: + if timestep is None: + raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True") + scaled_timestep = timestep * self.timestep_scale_multiplier.to(sample) + + for up_block in self.up_blocks: + if isinstance(up_block, UNetMidBlock3D): + block_kwargs = { + "causal": self.causal, + "timestep": scaled_timestep if self.timestep_conditioning else None, + "generator": generator, + } + sample = up_block(sample, **block_kwargs) + elif isinstance(up_block, ResnetBlock3D): + sample = up_block(sample, causal=self.causal, generator=generator) + else: + sample = up_block(sample, causal=self.causal) + + sample = self.conv_norm_out(sample) + + if self.timestep_conditioning: + embedded_timestep = self.last_time_embedder( + timestep=scaled_timestep.flatten(), + hidden_dtype=sample.dtype, + ) + embedded_timestep = embedded_timestep.view(batch_size, embedded_timestep.shape[-1], 1, 1, 1) + ada_values = self.last_scale_shift_table[None, ..., None, None, None].to( + device=sample.device, dtype=sample.dtype) + embedded_timestep.reshape( + batch_size, + 2, + -1, + embedded_timestep.shape[-3], + embedded_timestep.shape[-2], + embedded_timestep.shape[-1], + ) + shift, scale = ada_values.unbind(dim=1) + sample = sample * (1 + scale) + shift + + sample = self.conv_act(sample) + sample = self.conv_out(sample, causal=self.causal) + + # Final spatial expansion: reverse the initial patchify from encoder + # Moves pixels from channels back to spatial dimensions + # Example: (B, 48, F, 128, 128) -> (B, 3, F, 512, 512) with patch_size=4 + sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + + return sample + + def _prepare_tiles( + self, + latent: torch.Tensor, + tiling_config: TilingConfig | None = None, + ) -> List[Tile]: + splitters = [DEFAULT_SPLIT_OPERATION] * len(latent.shape) + mappers = [DEFAULT_MAPPING_OPERATION] * len(latent.shape) + if tiling_config is not None and tiling_config.spatial_config is not None: + cfg = tiling_config.spatial_config + long_side = max(latent.shape[3], latent.shape[4]) + + def enable_on_axis(axis_idx: int, factor: int) -> None: + size = cfg.tile_size_in_pixels // factor + overlap = cfg.tile_overlap_in_pixels // factor + axis_length = latent.shape[axis_idx] + lower_threshold = max(2, overlap + 1) + tile_size = max(lower_threshold, round(size * axis_length / long_side)) + splitters[axis_idx] = split_in_spatial(tile_size, overlap) + mappers[axis_idx] = to_mapping_operation(map_spatial_slice, factor) + + enable_on_axis(3, self.video_downscale_factors.height) + enable_on_axis(4, self.video_downscale_factors.width) + + if tiling_config is not None and tiling_config.temporal_config is not None: + cfg = tiling_config.temporal_config + tile_size = cfg.tile_size_in_frames // self.video_downscale_factors.time + overlap = cfg.tile_overlap_in_frames // self.video_downscale_factors.time + splitters[2] = split_in_temporal(tile_size, overlap) + mappers[2] = to_mapping_operation(map_temporal_slice, self.video_downscale_factors.time) + + return create_tiles(latent.shape, splitters, mappers) + + def tiled_decode( + self, + latent: torch.Tensor, + tiling_config: TilingConfig | None = None, + timestep: torch.Tensor | None = None, + generator: torch.Generator | None = None, + ) -> Iterator[torch.Tensor]: + """ + Decode a latent tensor into video frames using tiled processing. + Splits the latent tensor into tiles, decodes each tile individually, + and yields video chunks as they become available. + Args: + latent: Input latent tensor (B, C, F', H', W'). + tiling_config: Tiling configuration for the latent tensor. + timestep: Optional timestep for decoder conditioning. + generator: Optional random generator for deterministic decoding. + Yields: + Video chunks (B, C, T, H, W) by temporal slices; + """ + + # Calculate full video shape from latent shape to get spatial dimensions + full_video_shape = VideoLatentShape.from_torch_shape(latent.shape).upscale(self.video_downscale_factors) + tiles = self._prepare_tiles(latent, tiling_config) + + temporal_groups = self._group_tiles_by_temporal_slice(tiles) + + # State for temporal overlap handling + previous_chunk = None + previous_weights = None + previous_temporal_slice = None + + for temporal_group_tiles in temporal_groups: + curr_temporal_slice = temporal_group_tiles[0].out_coords[2] + + # Calculate the shape of the temporal buffer for this group of tiles. + # The temporal length depends on whether this is the first tile (starts at 0) or not. + # - First tile: (frames - 1) * scale + 1 + # - Subsequent tiles: frames * scale + # This logic is handled by TemporalAxisMapping and reflected in out_coords. + temporal_tile_buffer_shape = full_video_shape._replace(frames=curr_temporal_slice.stop - + curr_temporal_slice.start,) + + buffer = torch.zeros( + temporal_tile_buffer_shape.to_torch_shape(), + device=latent.device, + dtype=latent.dtype, + ) + + curr_weights = self._accumulate_temporal_group_into_buffer( + group_tiles=temporal_group_tiles, + buffer=buffer, + latent=latent, + timestep=timestep, + generator=generator, + ) + + # Blend with previous temporal chunk if it exists + if previous_chunk is not None: + # Check if current temporal slice overlaps with previous temporal slice + if previous_temporal_slice.stop > curr_temporal_slice.start: + overlap_len = previous_temporal_slice.stop - curr_temporal_slice.start + temporal_overlap_slice = slice(curr_temporal_slice.start - previous_temporal_slice.start, None) + + # The overlap is already masked before it reaches this step. Each tile is accumulated into buffer + # with its trapezoidal mask, and curr_weights accumulates the same mask. In the overlap blend we add + # the masked values (buffer[...]) and the corresponding weights (curr_weights[...]) into the + # previous buffers, then later normalize by weights. + previous_chunk[:, :, temporal_overlap_slice, :, :] += buffer[:, :, slice(0, overlap_len), :, :] + previous_weights[:, :, temporal_overlap_slice, :, :] += curr_weights[:, :, + slice(0, overlap_len), :, :] + + buffer[:, :, slice(0, overlap_len), :, :] = previous_chunk[:, :, temporal_overlap_slice, :, :] + curr_weights[:, :, slice(0, overlap_len), :, :] = previous_weights[:, :, + temporal_overlap_slice, :, :] + + # Yield the non-overlapping part of the previous chunk + previous_weights = previous_weights.clamp(min=1e-8) + yield_len = curr_temporal_slice.start - previous_temporal_slice.start + yield (previous_chunk / previous_weights)[:, :, :yield_len, :, :] + + # Update state for next iteration + previous_chunk = buffer + previous_weights = curr_weights + previous_temporal_slice = curr_temporal_slice + + # Yield any remaining chunk + if previous_chunk is not None: + previous_weights = previous_weights.clamp(min=1e-8) + yield previous_chunk / previous_weights + + def _group_tiles_by_temporal_slice(self, tiles: List[Tile]) -> List[List[Tile]]: + """Group tiles by their temporal output slice.""" + if not tiles: + return [] + + groups = [] + current_slice = tiles[0].out_coords[2] + current_group = [] + + for tile in tiles: + tile_slice = tile.out_coords[2] + if tile_slice == current_slice: + current_group.append(tile) + else: + groups.append(current_group) + current_slice = tile_slice + current_group = [tile] + + # Add the final group + if current_group: + groups.append(current_group) + + return groups + + def _accumulate_temporal_group_into_buffer( + self, + group_tiles: List[Tile], + buffer: torch.Tensor, + latent: torch.Tensor, + timestep: torch.Tensor | None, + generator: torch.Generator | None, + ) -> torch.Tensor: + """ + Decode and accumulate all tiles of a temporal group into a local buffer. + The buffer is local to the group and always starts at time 0; temporal coordinates + are rebased by subtracting temporal_slice.start. + """ + temporal_slice = group_tiles[0].out_coords[2] + + weights = torch.zeros_like(buffer) + + for tile in group_tiles: + decoded_tile = self.forward(latent[tile.in_coords], timestep, generator) + mask = tile.blend_mask.to(device=buffer.device, dtype=buffer.dtype) + temporal_offset = tile.out_coords[2].start - temporal_slice.start + # Use the tile's output coordinate length, not the decoded tile's length, + # as the decoder may produce a different number of frames than expected + expected_temporal_len = tile.out_coords[2].stop - tile.out_coords[2].start + decoded_temporal_len = decoded_tile.shape[2] + + # Ensure we don't exceed the buffer or decoded tile bounds + actual_temporal_len = min(expected_temporal_len, decoded_temporal_len, buffer.shape[2] - temporal_offset) + + chunk_coords = ( + slice(None), # batch + slice(None), # channels + slice(temporal_offset, temporal_offset + actual_temporal_len), + tile.out_coords[3], # height + tile.out_coords[4], # width + ) + + # Slice decoded_tile and mask to match the actual length we're writing + decoded_slice = decoded_tile[:, :, :actual_temporal_len, :, :] + mask_slice = mask[:, :, :actual_temporal_len, :, :] if mask.shape[2] > 1 else mask + + buffer[chunk_coords] += decoded_slice * mask_slice + weights[chunk_coords] += mask_slice + + return weights + + def decode( + self, + latent: torch.Tensor, + tiled=False, + tile_size_in_pixels: Optional[int] = 512, + tile_overlap_in_pixels: Optional[int] = 128, + tile_size_in_frames: Optional[int] = 128, + tile_overlap_in_frames: Optional[int] = 24, + ) -> torch.Tensor: + if tiled: + tiling_config = TilingConfig( + spatial_config=SpatialTilingConfig( + tile_size_in_pixels=tile_size_in_pixels, + tile_overlap_in_pixels=tile_overlap_in_pixels, + ), + temporal_config=TemporalTilingConfig( + tile_size_in_frames=tile_size_in_frames, + tile_overlap_in_frames=tile_overlap_in_frames, + ), + ) + tiles = self.tiled_decode(latent, tiling_config) + return torch.cat(list(tiles), dim=2) + else: + return self.forward(latent) + +def decode_video( + latent: torch.Tensor, + video_decoder: LTX2VideoDecoder, + tiling_config: TilingConfig | None = None, + generator: torch.Generator | None = None, +) -> Iterator[torch.Tensor]: + """ + Decode a video latent tensor with the given decoder. + Args: + latent: Tensor [c, f, h, w] + video_decoder: Decoder module. + tiling_config: Optional tiling settings. + generator: Optional random generator for deterministic decoding. + Yields: + Decoded chunk [f, h, w, c], uint8 in [0, 255]. + """ + + def convert_to_uint8(frames: torch.Tensor) -> torch.Tensor: + frames = (((frames + 1.0) / 2.0).clamp(0.0, 1.0) * 255.0).to(torch.uint8) + frames = rearrange(frames[0], "c f h w -> f h w c") + return frames + + if tiling_config is not None: + for frames in video_decoder.tiled_decode(latent, tiling_config, generator=generator): + return convert_to_uint8(frames) + else: + decoded_video = video_decoder(latent, generator=generator) + return convert_to_uint8(decoded_video) + + +def get_video_chunks_number(num_frames: int, tiling_config: TilingConfig | None = None) -> int: + """ + Get the number of video chunks for a given number of frames and tiling configuration. + Args: + num_frames: Number of frames in the video. + tiling_config: Tiling configuration. + Returns: + Number of video chunks. + """ + if not tiling_config or not tiling_config.temporal_config: + return 1 + cfg = tiling_config.temporal_config + frame_stride = cfg.tile_size_in_frames - cfg.tile_overlap_in_frames + return (num_frames - 1 + frame_stride - 1) // frame_stride + + +def split_in_spatial(size: int, overlap: int) -> SplitOperation: + + def split(dimension_size: int) -> DimensionIntervals: + if dimension_size <= size: + return DEFAULT_SPLIT_OPERATION(dimension_size) + amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap) + starts = [i * (size - overlap) for i in range(amount)] + ends = [start + size for start in starts] + ends[-1] = dimension_size + left_ramps = [0] + [overlap] * (amount - 1) + right_ramps = [overlap] * (amount - 1) + [0] + return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps) + + return split + + +def split_in_temporal(size: int, overlap: int) -> SplitOperation: + non_causal_split = split_in_spatial(size, overlap) + + def split(dimension_size: int) -> DimensionIntervals: + if dimension_size <= size: + return DEFAULT_SPLIT_OPERATION(dimension_size) + intervals = non_causal_split(dimension_size) + starts = intervals.starts + starts[1:] = [s - 1 for s in starts[1:]] + left_ramps = intervals.left_ramps + left_ramps[1:] = [r + 1 for r in left_ramps[1:]] + return replace(intervals, starts=starts, left_ramps=left_ramps) + + return split + + +def to_mapping_operation( + map_func: Callable[[int, int, int, int, int], Tuple[slice, torch.Tensor]], + scale: int, +) -> MappingOperation: + + def map_op(intervals: DimensionIntervals) -> tuple[list[slice], list[torch.Tensor | None]]: + output_slices: list[slice] = [] + masks_1d: list[torch.Tensor | None] = [] + number_of_slices = len(intervals.starts) + for i in range(number_of_slices): + start = intervals.starts[i] + end = intervals.ends[i] + left_ramp = intervals.left_ramps[i] + right_ramp = intervals.right_ramps[i] + output_slice, mask_1d = map_func(start, end, left_ramp, right_ramp, scale) + output_slices.append(output_slice) + masks_1d.append(mask_1d) + return output_slices, masks_1d + + return map_op + + +def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, torch.Tensor]: + start = begin * scale + stop = 1 + (end - 1) * scale + left_ramp = 1 + (left_ramp - 1) * scale + right_ramp = right_ramp * scale + + return slice(start, stop), compute_trapezoidal_mask_1d(stop - start, left_ramp, right_ramp, True) + + +def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, torch.Tensor]: + start = begin * scale + stop = end * scale + left_ramp = left_ramp * scale + right_ramp = right_ramp * scale + + return slice(start, stop), compute_trapezoidal_mask_1d(stop - start, left_ramp, right_ramp, False) diff --git a/diffsynth/models/model_loader.py b/diffsynth/models/model_loader.py index 16d72dd..6a58c89 100644 --- a/diffsynth/models/model_loader.py +++ b/diffsynth/models/model_loader.py @@ -29,7 +29,7 @@ class ModelPool: module_map = None return module_map - def load_model_file(self, config, path, vram_config, vram_limit=None): + def load_model_file(self, config, path, vram_config, vram_limit=None, state_dict=None): model_class = self.import_model_class(config["model_class"]) model_config = config.get("extra_kwargs", {}) if "state_dict_converter" in config: @@ -43,6 +43,7 @@ class ModelPool: state_dict_converter, use_disk_map=True, vram_config=vram_config, module_map=module_map, vram_limit=vram_limit, + state_dict=state_dict, ) return model @@ -59,7 +60,7 @@ class ModelPool: } return vram_config - def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False): + def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False, state_dict=None): print(f"Loading models from: {json.dumps(path, indent=4)}") if vram_config is None: vram_config = self.default_vram_config() @@ -67,7 +68,7 @@ class ModelPool: loaded = False for config in MODEL_CONFIGS: if config["model_hash"] == model_hash: - model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit) + model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit, state_dict=state_dict) if clear_parameters: self.clear_parameters(model) self.model.append(model) model_name = config["model_name"] diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 7386223..d957717 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -5,6 +5,7 @@ import math from typing import Tuple, Optional from einops import rearrange from .wan_video_camera_controller import SimpleAdapter +from ..core.gradient import gradient_checkpoint_forward try: import flash_attn_interface @@ -93,7 +94,7 @@ def rope_apply(x, freqs, num_heads): x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) x_out = torch.view_as_complex(x.to(torch.float64).reshape( x.shape[0], x.shape[1], x.shape[2], -1, 2)) - freqs = freqs.to(torch.complex64) if freqs.device == "npu" else freqs + freqs = freqs.to(torch.complex64) if freqs.device.type == "npu" else freqs x_out = torch.view_as_real(x_out * freqs).flatten(2) return x_out.to(x.dtype) @@ -379,27 +380,15 @@ class WanModel(torch.nn.Module): self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - return custom_forward for block in self.blocks: - if self.training and use_gradient_checkpointing: - if use_gradient_checkpointing_offload: - with torch.autograd.graph.save_on_cpu(): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, freqs, - use_reentrant=False, - ) - else: - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, freqs, - use_reentrant=False, - ) + if self.training: + x = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x, context, t_mod, freqs + ) else: x = block(x, context, t_mod, freqs) diff --git a/diffsynth/models/wan_video_dit_s2v.py b/diffsynth/models/wan_video_dit_s2v.py index 8fbed8c..f4d1abe 100644 --- a/diffsynth/models/wan_video_dit_s2v.py +++ b/diffsynth/models/wan_video_dit_s2v.py @@ -4,6 +4,7 @@ import torch.nn as nn import torch.nn.functional as F from typing import Tuple from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d +from ..core.gradient import gradient_checkpoint_forward def torch_dfs(model: nn.Module, parent_name='root'): @@ -545,46 +546,19 @@ class WanS2VModel(torch.nn.Module): t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2) - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - return custom_forward - for block_id, block in enumerate(self.blocks): - if use_gradient_checkpointing_offload: - with torch.autograd.graph.save_on_cpu(): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, - context, - t_mod, - seq_len_x, - pre_compute_freqs[0], - use_reentrant=False, - ) - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), - x, - use_reentrant=False, - ) - elif use_gradient_checkpointing: - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, - context, - t_mod, - seq_len_x, - pre_compute_freqs[0], - use_reentrant=False, - ) - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), - x, - use_reentrant=False, - ) - else: - x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0]) - x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x) + x = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x, context, t_mod, seq_len_x, pre_compute_freqs[0] + ) + x = gradient_checkpoint_forward( + lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x), + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x + ) x = x[:, :seq_len_x] x = self.head(x, t[:-1]) diff --git a/diffsynth/models/wan_video_vace.py b/diffsynth/models/wan_video_vace.py index f3367f7..0e13183 100644 --- a/diffsynth/models/wan_video_vace.py +++ b/diffsynth/models/wan_video_vace.py @@ -1,6 +1,6 @@ import torch from .wan_video_dit import DiTBlock - +from ..core.gradient import gradient_checkpoint_forward class VaceWanAttentionBlock(DiTBlock): def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0): @@ -62,26 +62,13 @@ class VaceWanModel(torch.nn.Module): dim=1) for u in c ]) - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - return custom_forward - for block in self.vace_blocks: - if use_gradient_checkpointing_offload: - with torch.autograd.graph.save_on_cpu(): - c = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - c, x, context, t_mod, freqs, - use_reentrant=False, - ) - elif use_gradient_checkpointing: - c = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - c, x, context, t_mod, freqs, - use_reentrant=False, - ) - else: - c = block(c, x, context, t_mod, freqs) + c = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + c, x, context, t_mod, freqs + ) + hints = torch.unbind(c)[:-1] return hints diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py index d24e29d..3c2181a 100644 --- a/diffsynth/models/wan_video_vae.py +++ b/diffsynth/models/wan_video_vae.py @@ -171,7 +171,7 @@ class Resample(nn.Module): torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) feat_cache[idx] = cache_x feat_idx[0] += 1 - return x + return x, feat_cache, feat_idx def init_weight(self, conv): conv_weight = conv.weight @@ -298,7 +298,7 @@ class ResidualBlock(nn.Module): feat_idx[0] += 1 else: x = layer(x) - return x + h + return x + h, feat_cache, feat_idx class AttentionBlock(nn.Module): @@ -471,7 +471,7 @@ class Down_ResidualBlock(nn.Module): for module in self.downsamples: x = module(x, feat_cache, feat_idx) - return x + self.avg_shortcut(x_copy) + return x + self.avg_shortcut(x_copy), feat_cache, feat_idx class Up_ResidualBlock(nn.Module): @@ -511,7 +511,7 @@ class Up_ResidualBlock(nn.Module): x_shortcut = self.avg_shortcut(x, first_chunk) return x_main + x_shortcut else: - return x_main + return x_main, feat_cache, feat_idx class Encoder3d(nn.Module): @@ -586,14 +586,14 @@ class Encoder3d(nn.Module): ## downsamples for layer in self.downsamples: if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) else: x = layer(x) ## middle for layer in self.middle: if check_is_instance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) else: x = layer(x) @@ -614,7 +614,7 @@ class Encoder3d(nn.Module): feat_idx[0] += 1 else: x = layer(x) - return x + return x, feat_cache, feat_idx class Encoder3d_38(nn.Module): @@ -698,14 +698,14 @@ class Encoder3d_38(nn.Module): ## downsamples for layer in self.downsamples: if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) else: x = layer(x) ## middle for layer in self.middle: if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) else: x = layer(x) @@ -730,7 +730,7 @@ class Encoder3d_38(nn.Module): else: x = layer(x) - return x + return x, feat_cache, feat_idx class Decoder3d(nn.Module): @@ -807,14 +807,14 @@ class Decoder3d(nn.Module): ## middle for layer in self.middle: if check_is_instance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) else: x = layer(x) ## upsamples for layer in self.upsamples: if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) else: x = layer(x) @@ -835,7 +835,7 @@ class Decoder3d(nn.Module): feat_idx[0] += 1 else: x = layer(x) - return x + return x, feat_cache, feat_idx @@ -906,14 +906,14 @@ class Decoder3d_38(nn.Module): for layer in self.middle: if check_is_instance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) else: x = layer(x) ## upsamples for layer in self.upsamples: if feat_cache is not None: - x = layer(x, feat_cache, feat_idx, first_chunk) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx, first_chunk) else: x = layer(x) @@ -937,7 +937,7 @@ class Decoder3d_38(nn.Module): feat_idx[0] += 1 else: x = layer(x) - return x + return x, feat_cache, feat_idx def count_conv3d(model): @@ -990,11 +990,11 @@ class VideoVAE_(nn.Module): for i in range(iter_): self._enc_conv_idx = [0] if i == 0: - out = self.encoder(x[:, :, :1, :, :], + out, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) else: - out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + out_, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) out = torch.cat([out, out_], 2) @@ -1023,11 +1023,11 @@ class VideoVAE_(nn.Module): for i in range(iter_): self._conv_idx = [0] if i == 0: - out = self.decoder(x[:, :, i:i + 1, :, :], + out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) else: - out_ = self.decoder(x[:, :, i:i + 1, :, :], + out_, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) out = torch.cat([out, out_], 2) # may add tensor offload @@ -1303,11 +1303,11 @@ class VideoVAE38_(VideoVAE_): for i in range(iter_): self._enc_conv_idx = [0] if i == 0: - out = self.encoder(x[:, :, :1, :, :], + out, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) else: - out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + out_, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) out = torch.cat([out, out_], 2) @@ -1337,12 +1337,12 @@ class VideoVAE38_(VideoVAE_): for i in range(iter_): self._conv_idx = [0] if i == 0: - out = self.decoder(x[:, :, i:i + 1, :, :], + out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True) else: - out_ = self.decoder(x[:, :, i:i + 1, :, :], + out_, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) out = torch.cat([out, out_], 2) diff --git a/diffsynth/models/z_image_text_encoder.py b/diffsynth/models/z_image_text_encoder.py index 4d6271d..6f3e6c0 100644 --- a/diffsynth/models/z_image_text_encoder.py +++ b/diffsynth/models/z_image_text_encoder.py @@ -6,6 +6,36 @@ class ZImageTextEncoder(torch.nn.Module): def __init__(self, model_size="4B"): super().__init__() config_dict = { + "0.6B": Qwen3Config(**{ + "architectures": [ + "Qwen3ForCausalLM" + ], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 3072, + "max_position_embeddings": 40960, + "max_window_layers": 28, + "model_type": "qwen3", + "num_attention_heads": 16, + "num_hidden_layers": 28, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": None, + "rope_theta": 1000000, + "sliding_window": None, + "tie_word_embeddings": True, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 151936 + }), "4B": Qwen3Config(**{ "architectures": [ "Qwen3ForCausalLM" diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index d5dc35b..bea6b7c 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -348,13 +348,12 @@ class Flux2Unit_Qwen3PromptEmbedder(PipelineUnit): attention_mask = torch.cat(all_attention_masks, dim=0).to(device) # Forward pass through the model - with torch.inference_mode(): - output = text_encoder( - input_ids=input_ids, - attention_mask=attention_mask, - output_hidden_states=True, - use_cache=False, - ) + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) # Only use outputs from intermediate layers and stack them out = torch.stack([output.hidden_states[k] for k in self.hidden_states_layers], dim=1) diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py new file mode 100644 index 0000000..9ed48aa --- /dev/null +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -0,0 +1,550 @@ +import torch, types +import numpy as np +from PIL import Image +from einops import repeat +from typing import Optional, Union +from einops import rearrange +import numpy as np +from PIL import Image +from tqdm import tqdm +from typing import Optional +from transformers import AutoImageProcessor, Gemma3Processor + +from ..core.device.npu_compatible_device import get_device_type +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit + +from ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer +from ..models.ltx2_dit import LTXModel +from ..models.ltx2_video_vae import LTX2VideoEncoder, LTX2VideoDecoder, VideoLatentPatchifier +from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier +from ..models.ltx2_upsampler import LTX2LatentUpsampler +from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS +from ..utils.data.media_io_ltx2 import ltx2_preprocess + + +class LTX2AudioVideoPipeline(BasePipeline): + + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): + super().__init__( + device=device, + torch_dtype=torch_dtype, + height_division_factor=32, + width_division_factor=32, + time_division_factor=8, + time_division_remainder=1, + ) + self.scheduler = FlowMatchScheduler("LTX-2") + self.text_encoder: LTX2TextEncoder = None + self.tokenizer: LTXVGemmaTokenizer = None + self.processor: Gemma3Processor = None + self.text_encoder_post_modules: LTX2TextEncoderPostModules = None + self.dit: LTXModel = None + self.video_vae_encoder: LTX2VideoEncoder = None + self.video_vae_decoder: LTX2VideoDecoder = None + self.audio_vae_encoder: LTX2AudioEncoder = None + self.audio_vae_decoder: LTX2AudioDecoder = None + self.audio_vocoder: LTX2Vocoder = None + self.upsampler: LTX2LatentUpsampler = None + + self.video_patchifier: VideoLatentPatchifier = VideoLatentPatchifier(patch_size=1) + self.audio_patchifier: AudioPatchifier = AudioPatchifier(patch_size=1) + + self.in_iteration_models = ("dit",) + self.units = [ + LTX2AudioVideoUnit_PipelineChecker(), + LTX2AudioVideoUnit_ShapeChecker(), + LTX2AudioVideoUnit_PromptEmbedder(), + LTX2AudioVideoUnit_NoiseInitializer(), + LTX2AudioVideoUnit_InputVideoEmbedder(), + LTX2AudioVideoUnit_InputImagesEmbedder(), + ] + self.model_fn = model_fn_ltx2 + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = get_device_type(), + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config: Optional[ModelConfig] = None, + vram_limit: float = None, + ): + # Initialize pipeline + pipe = LTX2AudioVideoPipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("ltx2_text_encoder") + tokenizer_config.download_if_necessary() + pipe.tokenizer = LTXVGemmaTokenizer(tokenizer_path=tokenizer_config.path) + image_processor = AutoImageProcessor.from_pretrained(tokenizer_config.path, local_files_only=True) + pipe.processor = Gemma3Processor(image_processor=image_processor, tokenizer=pipe.tokenizer.tokenizer) + + pipe.text_encoder_post_modules = model_pool.fetch_model("ltx2_text_encoder_post_modules") + pipe.dit = model_pool.fetch_model("ltx2_dit") + pipe.video_vae_encoder = model_pool.fetch_model("ltx2_video_vae_encoder") + pipe.video_vae_decoder = model_pool.fetch_model("ltx2_video_vae_decoder") + pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder") + pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder") + pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler") + + # Stage 2 + if stage2_lora_config is not None: + stage2_lora_config.download_if_necessary() + pipe.stage2_lora_path = stage2_lora_config.path + # Optional, currently not used + # pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder") + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd=tqdm): + if inputs_shared["use_two_stage_pipeline"]: + latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"]) + self.load_models_to_device('upsampler',) + latent = self.upsampler(latent) + latent = self.video_vae_encoder.per_channel_statistics.normalize(latent) + self.scheduler.set_timesteps(special_case="stage2") + inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")}) + denoise_mask_video = 1.0 + if inputs_shared.get("input_images", None) is not None: + latent, denoise_mask_video, initial_latents = self.apply_input_images_to_latents( + latent, inputs_shared.pop("input_latents"), inputs_shared["input_images_indexes"], + inputs_shared["input_images_strength"], latent.clone()) + inputs_shared.update({"input_latents_video": initial_latents, "denoise_mask_video": denoise_mask_video}) + inputs_shared["video_latents"] = self.scheduler.sigmas[0] * denoise_mask_video * inputs_shared[ + "video_noise"] + (1 - self.scheduler.sigmas[0] * denoise_mask_video) * latent + inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + ( + 1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"] + + self.load_models_to_device(self.in_iteration_models) + if not inputs_shared["use_distilled_pipeline"]: + self.load_lora(self.dit, self.stage2_lora_path, alpha=0.8) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn( + self.model_fn, 1.0, inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, + noise_pred=noise_pred_video, inpaint_mask=inputs_shared.get("denoise_mask_video", None), + input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared) + inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, + noise_pred=noise_pred_audio, **inputs_shared) + return inputs_shared + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: Optional[str] = "", + # Image-to-video + denoising_strength: float = 1.0, + input_images: Optional[list[Image.Image]] = None, + input_images_indexes: Optional[list[int]] = None, + input_images_strength: Optional[float] = 1.0, + # Randomness + seed: Optional[int] = None, + rand_device: Optional[str] = "cpu", + # Shape + height: Optional[int] = 512, + width: Optional[int] = 768, + num_frames=121, + # Classifier-free guidance + cfg_scale: Optional[float] = 3.0, + cfg_merge: Optional[bool] = False, + # Scheduler + num_inference_steps: Optional[int] = 40, + # VAE tiling + tiled: Optional[bool] = True, + tile_size_in_pixels: Optional[int] = 512, + tile_overlap_in_pixels: Optional[int] = 128, + tile_size_in_frames: Optional[int] = 128, + tile_overlap_in_frames: Optional[int] = 24, + # Special Pipelines + use_two_stage_pipeline: Optional[bool] = False, + use_distilled_pipeline: Optional[bool] = False, + # progress_bar + progress_bar_cmd=tqdm, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, + special_case="ditilled_stage1" if use_distilled_pipeline else None) + # Inputs + inputs_posi = { + "prompt": prompt, + } + inputs_nega = { + "negative_prompt": negative_prompt, + } + inputs_shared = { + "input_images": input_images, "input_images_indexes": input_images_indexes, "input_images_strength": input_images_strength, + "seed": seed, "rand_device": rand_device, + "height": height, "width": width, "num_frames": num_frames, + "cfg_scale": cfg_scale, "cfg_merge": cfg_merge, + "tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels, + "tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames, + "use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline, + "video_patchifier": self.video_patchifier, "audio_patchifier": self.audio_patchifier, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise Stage 1 + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video, + inpaint_mask=inputs_shared.get("denoise_mask_video", None), input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared) + inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, + noise_pred=noise_pred_audio, **inputs_shared) + + # Denoise Stage 2 + inputs_shared = self.stage2_denoise(inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd) + + # Decode + self.load_models_to_device(['video_vae_decoder']) + video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels, + tile_overlap_in_pixels, tile_size_in_frames, tile_overlap_in_frames) + video = self.vae_output_to_video(video) + self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder']) + decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"]) + decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float() + return video, decoded_audio + + def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength, initial_latents=None, num_frames=121): + b, _, f, h, w = latents.shape + denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device) + initial_latents = torch.zeros_like(latents) if initial_latents is None else initial_latents + for idx, input_latent in zip(input_indexes, input_latents): + idx = min(max(1 + (idx-1) // 8, 0), f - 1) + input_latent = input_latent.to(dtype=latents.dtype, device=latents.device) + initial_latents[:, :, idx:idx + input_latent.shape[2], :, :] = input_latent + denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength + latents = latents * denoise_mask + initial_latents * (1.0 - denoise_mask) + return latents, denoise_mask, initial_latents + + +class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("use_distilled_pipeline", "use_two_stage_pipeline"), + output_params=("use_two_stage_pipeline", "cfg_scale") + ) + + def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if inputs_shared.get("use_distilled_pipeline", False): + inputs_shared["use_two_stage_pipeline"] = True + inputs_shared["cfg_scale"] = 1.0 + print(f"Distilled pipeline requested, setting use_two_stage_pipeline to True, disable CFG by setting cfg_scale to 1.0.") + if inputs_shared.get("use_two_stage_pipeline", False): + # distill pipeline also uses two-stage, but it does not needs lora + if not inputs_shared.get("use_distilled_pipeline", False): + if not (hasattr(pipe, "stage2_lora_path") and pipe.stage2_lora_path is not None): + raise ValueError("Two-stage pipeline requested, but stage2_lora_path is not set in the pipeline.") + if not (hasattr(pipe, "upsampler") and pipe.upsampler is not None): + raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.") + return inputs_shared, inputs_posi, inputs_nega + + +class LTX2AudioVideoUnit_ShapeChecker(PipelineUnit): + """ + For two-stage pipelines, the resolution must be divisible by 64. + For one-stage pipelines, the resolution must be divisible by 32. + """ + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames"), + output_params=("height", "width", "num_frames"), + ) + + def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, use_two_stage_pipeline=False): + if use_two_stage_pipeline: + self.width_division_factor = 64 + self.height_division_factor = 64 + height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) + if use_two_stage_pipeline: + self.width_division_factor = 32 + self.height_division_factor = 32 + return {"height": height, "width": width, "num_frames": num_frames} + + +class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit): + + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("video_context", "audio_context"), + onload_model_names=("text_encoder", "text_encoder_post_modules"), + ) + + def _convert_to_additive_mask(self, attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + return (attention_mask - 1).to(dtype).reshape( + (attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(dtype).max + + def _run_connectors(self, pipe, encoded_input: torch.Tensor, + attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + connector_attention_mask = self._convert_to_additive_mask(attention_mask, encoded_input.dtype) + + encoded, encoded_connector_attention_mask = pipe.text_encoder_post_modules.embeddings_connector( + encoded_input, + connector_attention_mask, + ) + + # restore the mask values to int64 + attention_mask = (encoded_connector_attention_mask < 0.000001).to(torch.int64) + attention_mask = attention_mask.reshape([encoded.shape[0], encoded.shape[1], 1]) + encoded = encoded * attention_mask + + encoded_for_audio, _ = pipe.text_encoder_post_modules.audio_embeddings_connector( + encoded_input, connector_attention_mask) + + return encoded, encoded_for_audio, attention_mask.squeeze(-1) + + def _norm_and_concat_padded_batch( + self, + encoded_text: torch.Tensor, + sequence_lengths: torch.Tensor, + padding_side: str = "right", + ) -> torch.Tensor: + """Normalize and flatten multi-layer hidden states, respecting padding. + Performs per-batch, per-layer normalization using masked mean and range, + then concatenates across the layer dimension. + Args: + encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers]. + sequence_lengths: Number of valid (non-padded) tokens per batch item. + padding_side: Whether padding is on "left" or "right". + Returns: + Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers], + with padded positions zeroed out. + """ + b, t, d, l = encoded_text.shape # noqa: E741 + device = encoded_text.device + # Build mask: [B, T, 1, 1] + token_indices = torch.arange(t, device=device)[None, :] # [1, T] + if padding_side == "right": + # For right padding, valid tokens are from 0 to sequence_length-1 + mask = token_indices < sequence_lengths[:, None] # [B, T] + elif padding_side == "left": + # For left padding, valid tokens are from (T - sequence_length) to T-1 + start_indices = t - sequence_lengths[:, None] # [B, 1] + mask = token_indices >= start_indices # [B, T] + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = rearrange(mask, "b t -> b t 1 1") + eps = 1e-6 + # Compute masked mean: [B, 1, 1, L] + masked = encoded_text.masked_fill(~mask, 0.0) + denom = (sequence_lengths * d).view(b, 1, 1, 1) + mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps) + # Compute masked min/max: [B, 1, 1, L] + x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + range_ = x_max - x_min + # Normalize only the valid tokens + normed = 8 * (encoded_text - mean) / (range_ + eps) + # concat to be [Batch, T, D * L] - this preserves the original structure + normed = normed.reshape(b, t, -1) # [B, T, D * L] + # Apply mask to preserve original padding (set padded positions to 0) + mask_flattened = rearrange(mask, "b t 1 1 -> b t 1").expand(-1, -1, d * l) + normed = normed.masked_fill(~mask_flattened, 0.0) + + return normed + + def _run_feature_extractor(self, + pipe, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + padding_side: str = "right") -> torch.Tensor: + encoded_text_features = torch.stack(hidden_states, dim=-1) + encoded_text_features_dtype = encoded_text_features.dtype + sequence_lengths = attention_mask.sum(dim=-1) + normed_concated_encoded_text_features = self._norm_and_concat_padded_batch(encoded_text_features, + sequence_lengths, + padding_side=padding_side) + + return pipe.text_encoder_post_modules.feature_extractor_linear( + normed_concated_encoded_text_features.to(encoded_text_features_dtype)) + + def _preprocess_text( + self, + pipe, + text: str, + padding_side: str = "left", + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + """ + Encode a given string into feature tensors suitable for downstream tasks. + Args: + text (str): Input string to encode. + Returns: + tuple[torch.Tensor, dict[str, torch.Tensor]]: Encoded features and a dictionary with attention mask. + """ + token_pairs = pipe.tokenizer.tokenize_with_weights(text)["gemma"] + input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.device) + attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.device) + outputs = pipe.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) + projected = self._run_feature_extractor(pipe, + hidden_states=outputs.hidden_states, + attention_mask=attention_mask, + padding_side=padding_side) + return projected, attention_mask + + def encode_prompt(self, pipe, text, padding_side="left"): + encoded_inputs, attention_mask = self._preprocess_text(pipe, text, padding_side) + video_encoding, audio_encoding, attention_mask = self._run_connectors(pipe, encoded_inputs, attention_mask) + return video_encoding, audio_encoding, attention_mask + + def process(self, pipe: LTX2AudioVideoPipeline, prompt: str): + pipe.load_models_to_device(self.onload_model_names) + video_context, audio_context, _ = self.encode_prompt(pipe, prompt) + return {"video_context": video_context, "audio_context": audio_context} + + +class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames", "seed", "rand_device", "use_two_stage_pipeline"), + output_params=("video_noise", "audio_noise",), + ) + + def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0): + video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate) + video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=pipe.video_vae_encoder.latent_channels) + video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device) + + latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device) + video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float() + video_positions[:, 0, ...] = video_positions[:, 0, ...] / frame_rate + video_positions = video_positions.to(pipe.torch_dtype) + + audio_latent_shape = AudioLatentShape.from_video_pixel_shape(video_pixel_shape) + audio_noise = pipe.generate_noise(audio_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device) + audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device) + return { + "video_noise": video_noise, + "audio_noise": audio_noise, + "video_positions": video_positions, + "audio_positions": audio_positions, + "video_latent_shape": video_latent_shape, + "audio_latent_shape": audio_latent_shape + } + + def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0, use_two_stage_pipeline=False): + if use_two_stage_pipeline: + stage1_dict = self.process_stage(pipe, height // 2, width // 2, num_frames, seed, rand_device, frame_rate) + stage2_dict = self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate) + initial_dict = stage1_dict + initial_dict.update({"stage2_" + k: v for k, v in stage2_dict.items()}) + return initial_dict + else: + return self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate) + +class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_video", "video_noise", "audio_noise", "tiled", "tile_size", "tile_stride"), + output_params=("video_latents", "audio_latents"), + onload_model_names=("video_vae_encoder") + ) + + def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, audio_noise, tiled, tile_size, tile_stride): + if input_video is None: + return {"video_latents": video_noise, "audio_latents": audio_noise} + else: + # TODO: implement video-to-video + raise NotImplementedError("Video-to-video not implemented yet.") + +class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "num_frames", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"), + output_params=("video_latents"), + onload_model_names=("video_vae_encoder") + ) + + def get_image_latent(self, pipe, input_image, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels): + image = ltx2_preprocess(np.array(input_image.resize((width, height)))) + image = torch.Tensor(np.array(image, dtype=np.float32)).to(dtype=pipe.torch_dtype, device=pipe.device) + image = image / 127.5 - 1.0 + image = repeat(image, f"H W C -> B C F H W", B=1, F=1) + latent = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device) + return latent + + def process(self, pipe: LTX2AudioVideoPipeline, input_images, input_images_indexes, input_images_strength, video_latents, height, width, num_frames, tiled, tile_size_in_pixels, tile_overlap_in_pixels, use_two_stage_pipeline=False): + if input_images is None or len(input_images) == 0: + return {"video_latents": video_latents} + else: + pipe.load_models_to_device(self.onload_model_names) + output_dicts = {} + stage1_height = height // 2 if use_two_stage_pipeline else height + stage1_width = width // 2 if use_two_stage_pipeline else width + stage1_latents = [ + self.get_image_latent(pipe, img, stage1_height, stage1_width, tiled, tile_size_in_pixels, + tile_overlap_in_pixels) for img in input_images + ] + video_latents, denoise_mask_video, initial_latents = pipe.apply_input_images_to_latents(video_latents, stage1_latents, input_images_indexes, input_images_strength, num_frames=num_frames) + output_dicts.update({"video_latents": video_latents, "denoise_mask_video": denoise_mask_video, "input_latents_video": initial_latents}) + if use_two_stage_pipeline: + stage2_latents = [ + self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels, + tile_overlap_in_pixels) for img in input_images + ] + output_dicts.update({"stage2_input_latents": stage2_latents}) + return output_dicts + + +def model_fn_ltx2( + dit: LTXModel, + video_latents=None, + video_context=None, + video_positions=None, + video_patchifier=None, + audio_latents=None, + audio_context=None, + audio_positions=None, + audio_patchifier=None, + timestep=None, + denoise_mask_video=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + timestep = timestep.float() / 1000. + + # patchify + b, c_v, f, h, w = video_latents.shape + video_latents = video_patchifier.patchify(video_latents) + video_timesteps = timestep.repeat(1, video_latents.shape[1], 1) + if denoise_mask_video is not None: + video_timesteps = video_patchifier.patchify(denoise_mask_video) * video_timesteps + _, c_a, _, mel_bins = audio_latents.shape + audio_latents = audio_patchifier.patchify(audio_latents) + audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1) + #TODO: support gradient checkpointing in training + vx, ax = dit( + video_latents=video_latents, + video_positions=video_positions, + video_context=video_context, + video_timesteps=video_timesteps, + audio_latents=audio_latents, + audio_positions=audio_positions, + audio_context=audio_context, + audio_timesteps=audio_timesteps, + ) + # unpatchify + vx = video_patchifier.unpatchify_video(vx, f, h, w) + ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins) + return vx, ax diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index edd6dff..bbc479e 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -1321,11 +1321,6 @@ def model_fn_wan_video( if tea_cache_update: x = tea_cache.update(x) else: - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - return custom_forward - def create_custom_forward_vap(block, vap): def custom_forward(*inputs): return vap(block, *inputs) @@ -1339,32 +1334,24 @@ def model_fn_wan_video( x, x_vap = torch.utils.checkpoint.checkpoint( create_custom_forward_vap(block, vap), x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, - use_reentrant=False, + use_reentrant=False ) elif use_gradient_checkpointing: x, x_vap = torch.utils.checkpoint.checkpoint( create_custom_forward_vap(block, vap), x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, - use_reentrant=False, + use_reentrant=False ) else: x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id) else: - if use_gradient_checkpointing_offload: - with torch.autograd.graph.save_on_cpu(): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, freqs, - use_reentrant=False, - ) - elif use_gradient_checkpointing: - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, freqs, - use_reentrant=False, - ) - else: - x = block(x, context, t_mod, freqs) + x = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x, context, t_mod, freqs + ) + # VACE if vace_context is not None and block_id in vace.vace_layers_mapping: @@ -1487,32 +1474,18 @@ def model_fn_wans2v( return custom_forward for block_id, block in enumerate(dit.blocks): - if use_gradient_checkpointing_offload: - with torch.autograd.graph.save_on_cpu(): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, seq_len_x, pre_compute_freqs[0], - use_reentrant=False, - ) - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), - x, - use_reentrant=False, - ) - elif use_gradient_checkpointing: - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, seq_len_x, pre_compute_freqs[0], - use_reentrant=False, + x = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x, context, t_mod, seq_len_x, pre_compute_freqs[0] ) - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), - x, - use_reentrant=False, - ) - else: - x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0]) - x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel) + x = gradient_checkpoint_forward( + lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x), + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x + ) if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: x = get_sp_group().all_gather(x, dim=1) diff --git a/diffsynth/utils/data/media_io_ltx2.py b/diffsynth/utils/data/media_io_ltx2.py new file mode 100644 index 0000000..5526ca9 --- /dev/null +++ b/diffsynth/utils/data/media_io_ltx2.py @@ -0,0 +1,149 @@ + +from fractions import Fraction +import torch +import av +from tqdm import tqdm +from PIL import Image +import numpy as np +from io import BytesIO +from collections.abc import Generator, Iterator + + +def _resample_audio( + container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame +) -> None: + cc = audio_stream.codec_context + + # Use the encoder's format/layout/rate as the *target* + target_format = cc.format or "fltp" # AAC → usually fltp + target_layout = cc.layout or "stereo" + target_rate = cc.sample_rate or frame_in.sample_rate + + audio_resampler = av.audio.resampler.AudioResampler( + format=target_format, + layout=target_layout, + rate=target_rate, + ) + + audio_next_pts = 0 + for rframe in audio_resampler.resample(frame_in): + if rframe.pts is None: + rframe.pts = audio_next_pts + audio_next_pts += rframe.samples + rframe.sample_rate = frame_in.sample_rate + container.mux(audio_stream.encode(rframe)) + + # flush audio encoder + for packet in audio_stream.encode(): + container.mux(packet) + + +def _write_audio( + container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int +) -> None: + if samples.ndim == 1: + samples = samples[:, None] + + if samples.shape[1] != 2 and samples.shape[0] == 2: + samples = samples.T + + if samples.shape[1] != 2: + raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.") + + # Convert to int16 packed for ingestion; resampler converts to encoder fmt. + if samples.dtype != torch.int16: + samples = torch.clip(samples, -1.0, 1.0) + samples = (samples * 32767.0).to(torch.int16) + + frame_in = av.AudioFrame.from_ndarray( + samples.contiguous().reshape(1, -1).cpu().numpy(), + format="s16", + layout="stereo", + ) + frame_in.sample_rate = audio_sample_rate + + _resample_audio(container, audio_stream, frame_in) + + +def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream: + """ + Prepare the audio stream for writing. + """ + audio_stream = container.add_stream("aac", rate=audio_sample_rate) + audio_stream.codec_context.sample_rate = audio_sample_rate + audio_stream.codec_context.layout = "stereo" + audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate) + return audio_stream + +def write_video_audio_ltx2( + video: list[Image.Image], + audio: torch.Tensor | None, + output_path: str, + fps: int = 24, + audio_sample_rate: int | None = 24000, +) -> None: + + width, height = video[0].size + container = av.open(output_path, mode="w") + stream = container.add_stream("libx264", rate=int(fps)) + stream.width = width + stream.height = height + stream.pix_fmt = "yuv420p" + + if audio is not None: + if audio_sample_rate is None: + raise ValueError("audio_sample_rate is required when audio is provided") + audio_stream = _prepare_audio_stream(container, audio_sample_rate) + + for frame in tqdm(video, total=len(video)): + frame = av.VideoFrame.from_image(frame) + for packet in stream.encode(frame): + container.mux(packet) + + # Flush encoder + for packet in stream.encode(): + container.mux(packet) + + if audio is not None: + _write_audio(container, audio_stream, audio, audio_sample_rate) + + container.close() + + +def encode_single_frame(output_file: str, image_array: np.ndarray, crf: float) -> None: + container = av.open(output_file, "w", format="mp4") + try: + stream = container.add_stream("libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"}) + # Round to nearest multiple of 2 for compatibility with video codecs + height = image_array.shape[0] // 2 * 2 + width = image_array.shape[1] // 2 * 2 + image_array = image_array[:height, :width] + stream.height = height + stream.width = width + av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(format="yuv420p") + container.mux(stream.encode(av_frame)) + container.mux(stream.encode()) + finally: + container.close() + + +def decode_single_frame(video_file: str) -> np.array: + container = av.open(video_file) + try: + stream = next(s for s in container.streams if s.type == "video") + frame = next(container.decode(stream)) + finally: + container.close() + return frame.to_ndarray(format="rgb24") + + +def ltx2_preprocess(image: np.array, crf: float = 33) -> np.array: + if crf == 0: + return image + + with BytesIO() as output_file: + encode_single_frame(output_file, image, crf) + video_bytes = output_file.getvalue() + with BytesIO(video_bytes) as video_file: + image_array = decode_single_frame(video_file) + return image_array diff --git a/diffsynth/utils/state_dict_converters/ltx2_audio_vae.py b/diffsynth/utils/state_dict_converters/ltx2_audio_vae.py new file mode 100644 index 0000000..c9bb66d --- /dev/null +++ b/diffsynth/utils/state_dict_converters/ltx2_audio_vae.py @@ -0,0 +1,32 @@ +def LTX2AudioEncoderStateDictConverter(state_dict): + # Not used + state_dict_ = {} + for name in state_dict: + if name.startswith("audio_vae.encoder."): + new_name = name.replace("audio_vae.encoder.", "") + state_dict_[new_name] = state_dict[name] + elif name.startswith("audio_vae.per_channel_statistics."): + new_name = name.replace("audio_vae.per_channel_statistics.", "per_channel_statistics.") + state_dict_[new_name] = state_dict[name] + return state_dict_ + + +def LTX2AudioDecoderStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("audio_vae.decoder."): + new_name = name.replace("audio_vae.decoder.", "") + state_dict_[new_name] = state_dict[name] + elif name.startswith("audio_vae.per_channel_statistics."): + new_name = name.replace("audio_vae.per_channel_statistics.", "per_channel_statistics.") + state_dict_[new_name] = state_dict[name] + return state_dict_ + + +def LTX2VocoderStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("vocoder."): + new_name = name.replace("vocoder.", "") + state_dict_[new_name] = state_dict[name] + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/ltx2_dit.py b/diffsynth/utils/state_dict_converters/ltx2_dit.py new file mode 100644 index 0000000..baffb9a --- /dev/null +++ b/diffsynth/utils/state_dict_converters/ltx2_dit.py @@ -0,0 +1,9 @@ +def LTXModelStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("model.diffusion_model."): + new_name = name.replace("model.diffusion_model.", "") + if new_name.startswith("audio_embeddings_connector.") or new_name.startswith("video_embeddings_connector."): + continue + state_dict_[new_name] = state_dict[name] + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/ltx2_text_encoder.py b/diffsynth/utils/state_dict_converters/ltx2_text_encoder.py new file mode 100644 index 0000000..b7e528f --- /dev/null +++ b/diffsynth/utils/state_dict_converters/ltx2_text_encoder.py @@ -0,0 +1,31 @@ +def LTX2TextEncoderStateDictConverter(state_dict): + state_dict_ = {} + for key in state_dict: + if key.startswith("language_model.model."): + new_key = key.replace("language_model.model.", "model.language_model.") + elif key.startswith("vision_tower."): + new_key = key.replace("vision_tower.", "model.vision_tower.") + elif key.startswith("multi_modal_projector."): + new_key = key.replace("multi_modal_projector.", "model.multi_modal_projector.") + elif key.startswith("language_model.lm_head."): + new_key = key.replace("language_model.lm_head.", "lm_head.") + else: + continue + state_dict_[new_key] = state_dict[key] + state_dict_["lm_head.weight"] = state_dict_.get("model.language_model.embed_tokens.weight") + return state_dict_ + + +def LTX2TextEncoderPostModulesStateDictConverter(state_dict): + state_dict_ = {} + for key in state_dict: + if key.startswith("text_embedding_projection."): + new_key = key.replace("text_embedding_projection.", "feature_extractor_linear.") + elif key.startswith("model.diffusion_model.video_embeddings_connector."): + new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "embeddings_connector.") + elif key.startswith("model.diffusion_model.audio_embeddings_connector."): + new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "audio_embeddings_connector.") + else: + continue + state_dict_[new_key] = state_dict[key] + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/ltx2_video_vae.py b/diffsynth/utils/state_dict_converters/ltx2_video_vae.py new file mode 100644 index 0000000..132897d --- /dev/null +++ b/diffsynth/utils/state_dict_converters/ltx2_video_vae.py @@ -0,0 +1,22 @@ +def LTX2VideoEncoderStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("vae.encoder."): + new_name = name.replace("vae.encoder.", "") + state_dict_[new_name] = state_dict[name] + elif name.startswith("vae.per_channel_statistics."): + new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.") + state_dict_[new_name] = state_dict[name] + return state_dict_ + + +def LTX2VideoDecoderStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("vae.decoder."): + new_name = name.replace("vae.decoder.", "") + state_dict_[new_name] = state_dict[name] + elif name.startswith("vae.per_channel_statistics."): + new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.") + state_dict_[new_name] = state_dict[name] + return state_dict_ \ No newline at end of file diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index 21dc3b3..228e7b8 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -1,11 +1,15 @@ import torch from typing import Optional from einops import rearrange +from yunchang.kernels import AttnType from xfuser.core.distributed import (get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group) from xfuser.core.long_ctx_attention import xFuserLongContextAttention + +from ... import IS_NPU_AVAILABLE from ...core.device import parse_nccl_backend, parse_device_type +from ...core.gradient import gradient_checkpoint_forward def initialize_usp(device_type): @@ -30,13 +34,16 @@ def sinusoidal_embedding_1d(dim, position): def pad_freqs(original_tensor, target_len): seq_len, s1, s2 = original_tensor.shape pad_size = target_len - seq_len + original_tensor_device = original_tensor.device + if original_tensor.device == "npu": + original_tensor = original_tensor.cpu() padding_tensor = torch.ones( pad_size, s1, s2, dtype=original_tensor.dtype, device=original_tensor.device) - padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) + padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0).to(device=original_tensor_device) return padded_tensor def rope_apply(x, freqs, num_heads): @@ -50,7 +57,7 @@ def rope_apply(x, freqs, num_heads): sp_rank = get_sequence_parallel_rank() freqs = pad_freqs(freqs, s_per_rank * sp_size) freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] - freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device == "npu" else freqs_rank + freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device.type == "npu" else freqs_rank x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) return x_out.to(x.dtype) @@ -81,11 +88,6 @@ def usp_dit_forward(self, self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - return custom_forward # Context Parallel chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) @@ -94,20 +96,13 @@ def usp_dit_forward(self, x = chunks[get_sequence_parallel_rank()] for block in self.blocks: - if self.training and use_gradient_checkpointing: - if use_gradient_checkpointing_offload: - with torch.autograd.graph.save_on_cpu(): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, freqs, - use_reentrant=False, - ) - else: - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, freqs, - use_reentrant=False, - ) + if self.training: + x = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x, context, t_mod, freqs + ) else: x = block(x, context, t_mod, freqs) @@ -133,7 +128,12 @@ def usp_attn_forward(self, x, freqs): k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads) v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads) - x = xFuserLongContextAttention()( + attn_type = AttnType.FA + ring_impl_type = "basic" + if IS_NPU_AVAILABLE: + attn_type = AttnType.NPU + ring_impl_type = "basic_npu" + x = xFuserLongContextAttention(attn_type=attn_type, ring_impl_type=ring_impl_type)( None, query=q, key=k, diff --git a/diffsynth/version.py b/diffsynth/version.py new file mode 100644 index 0000000..6fcae7a --- /dev/null +++ b/diffsynth/version.py @@ -0,0 +1,5 @@ +# Make sure to modify __release_datetime__ to release time when making official release. +__version__ = '2.0.0' +# default release datetime for branches under active development is set +# to be a time far-far-away-into-the-future +__release_datetime__ = '2099-10-13 08:56:12' \ No newline at end of file diff --git a/docs/en/.readthedocs.yaml b/docs/en/.readthedocs.yaml new file mode 100644 index 0000000..6197276 --- /dev/null +++ b/docs/en/.readthedocs.yaml @@ -0,0 +1,28 @@ +# .readthedocs.yaml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the OS, Python version and other tools you might need +build: + os: ubuntu-22.04 + tools: + python: "3.10" + +# Build documentation in the "docs/" directory with Sphinx +sphinx: + configuration: docs/en/conf.py + +# Optionally build your docs in additional formats such as PDF and ePub +# formats: +# - pdf +# - epub + +# Optional but recommended, declare the Python requirements required +# to build your documentation +# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html +python: + install: + - requirements: docs/requirements.txt diff --git a/docs/en/API_Reference/core/attention.md b/docs/en/API_Reference/core/attention.md index 9ec3123..a9c9a83 100644 --- a/docs/en/API_Reference/core/attention.md +++ b/docs/en/API_Reference/core/attention.md @@ -1,6 +1,6 @@ # `diffsynth.core.attention`: Attention Mechanism Implementation -`diffsynth.core.attention` provides routing mechanisms for attention mechanism implementations, automatically selecting efficient attention implementations based on available packages in the `Python` environment and [environment variables](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation). +`diffsynth.core.attention` provides routing mechanisms for attention mechanism implementations, automatically selecting efficient attention implementations based on available packages in the `Python` environment and [environment variables](../../Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation). ## Attention Mechanism @@ -46,7 +46,7 @@ Note that the dimension of the Attention Score in the attention mechanism ( $\te * xFormers: [GitHub](https://github.com/facebookresearch/xformers), [Documentation](https://facebookresearch.github.io/xformers/components/ops.html#module-xformers.ops) * PyTorch: [GitHub](https://github.com/pytorch/pytorch), [Documentation](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) -To call attention implementations other than `PyTorch`, please follow the instructions on their GitHub pages to install the corresponding packages. `DiffSynth-Studio` will automatically route to the corresponding implementation based on available packages in the Python environment, or can be controlled through [environment variables](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation). +To call attention implementations other than `PyTorch`, please follow the instructions on their GitHub pages to install the corresponding packages. `DiffSynth-Studio` will automatically route to the corresponding implementation based on available packages in the Python environment, or can be controlled through [environment variables](../../Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation). ```python from diffsynth.core.attention import attention_forward diff --git a/docs/en/API_Reference/core/loader.md b/docs/en/API_Reference/core/loader.md index 1dccf5f..7f1018a 100644 --- a/docs/en/API_Reference/core/loader.md +++ b/docs/en/API_Reference/core/loader.md @@ -8,9 +8,9 @@ This document introduces the model download and loading functionalities in `diff ### Downloading and Loading Models from Remote Sources -Taking the model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) as an example, after filling in `model_id` and `origin_file_pattern` in `ModelConfig`, the model can be automatically downloaded. By default, it downloads to the `./models` path, which can be modified through the [environment variable DIFFSYNTH_MODEL_BASE_PATH](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path). +Taking the model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) as an example, after filling in `model_id` and `origin_file_pattern` in `ModelConfig`, the model can be automatically downloaded. By default, it downloads to the `./models` path, which can be modified through the [environment variable DIFFSYNTH_MODEL_BASE_PATH](../../Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path). -By default, even if the model has already been downloaded, the program will still query the remote for any missing files. To completely disable remote requests, set the [environment variable DIFFSYNTH_SKIP_DOWNLOAD](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) to `True`. +By default, even if the model has already been downloaded, the program will still query the remote for any missing files. To completely disable remote requests, set the [environment variable DIFFSYNTH_SKIP_DOWNLOAD](../../Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) to `True`. ```python from diffsynth.core import ModelConfig @@ -51,7 +51,7 @@ config = ModelConfig(path=[ ### VRAM Management Configuration -`ModelConfig` also contains VRAM management configuration information. See [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md#more-usage-methods) for details. +`ModelConfig` also contains VRAM management configuration information. See [VRAM Management](../../Pipeline_Usage/VRAM_management.md#more-usage-methods) for details. ## Model File Loading @@ -103,11 +103,11 @@ print(hash_model_file([ The model hash value is only related to the keys and tensor shapes in the state dict of the model file, and is unrelated to the numerical values of the model parameters, file saving time, and other information. When calculating the model hash value of `.safetensors` format files, `hash_model_file` is almost instantly completed without reading the model parameters. However, when calculating the model hash value of `.bin`, `.pth`, `.ckpt`, and other binary files, all model parameters need to be read, so **we do not recommend developers to continue using these formats of files.** -By [writing model Config](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-3-writing-model-config) and filling in model hash value and other information into `diffsynth/configs/model_configs.py`, developers can let `DiffSynth-Studio` automatically identify the model type and load it. +By [writing model Config](../../Developer_Guide/Integrating_Your_Model.md#step-3-writing-model-config) and filling in model hash value and other information into `diffsynth/configs/model_configs.py`, developers can let `DiffSynth-Studio` automatically identify the model type and load it. ## Model Loading -`load_model` is the external entry for loading models in `diffsynth.core.loader`. It will call [skip_model_initialization](/docs/en/API_Reference/core/vram.md#skipping-model-parameter-initialization) to skip model parameter initialization. If [Disk Offload](/docs/en/Pipeline_Usage/VRAM_management.md#disk-offload) is enabled, it calls [DiskMap](/docs/en/API_Reference/core/vram.md#state-dict-disk-mapping) for lazy loading. If Disk Offload is not enabled, it calls [load_state_dict](#model-file-loading) to load model parameters. If necessary, it will also call [state dict converter](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) for model format conversion. Finally, it calls `model.eval()` to switch to inference mode. +`load_model` is the external entry for loading models in `diffsynth.core.loader`. It will call [skip_model_initialization](../../API_Reference/core/vram.md#skipping-model-parameter-initialization) to skip model parameter initialization. If [Disk Offload](../../Pipeline_Usage/VRAM_management.md#disk-offload) is enabled, it calls [DiskMap](../../API_Reference/core/vram.md#state-dict-disk-mapping) for lazy loading. If Disk Offload is not enabled, it calls [load_state_dict](#model-file-loading) to load model parameters. If necessary, it will also call [state dict converter](../../Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) for model format conversion. Finally, it calls `model.eval()` to switch to inference mode. Here is a usage example with Disk Offload enabled: diff --git a/docs/en/API_Reference/core/vram.md b/docs/en/API_Reference/core/vram.md index 79e51fc..6b4878f 100644 --- a/docs/en/API_Reference/core/vram.md +++ b/docs/en/API_Reference/core/vram.md @@ -31,7 +31,7 @@ state_dict = load_state_dict(path, device="cpu") model.load_state_dict(state_dict, assign=True) ``` -In `DiffSynth-Studio`, all pretrained models follow this loading logic. After developers [integrate models](/docs/en/Developer_Guide/Integrating_Your_Model.md), they can directly load models quickly using this approach. +In `DiffSynth-Studio`, all pretrained models follow this loading logic. After developers [integrate models](../../Developer_Guide/Integrating_Your_Model.md), they can directly load models quickly using this approach. ## State Dict Disk Mapping @@ -57,10 +57,10 @@ state_dict = DiskMap(path, device="cpu") # Fast print(state_dict["img_in.weight"]) ``` -`DiskMap` is the basic component of Disk Offload in `DiffSynth-Studio`. After developers [configure fine-grained VRAM management schemes](/docs/en/Developer_Guide/Enabling_VRAM_management.md), they can directly enable Disk Offload. +`DiskMap` is the basic component of Disk Offload in `DiffSynth-Studio`. After developers [configure fine-grained VRAM management schemes](../../Developer_Guide/Enabling_VRAM_management.md), they can directly enable Disk Offload. `DiskMap` is a functionality implemented using the characteristics of `.safetensors` files. Therefore, when using `.bin`, `.pth`, `.ckpt`, and other binary files, model parameters are fully loaded, which causes Disk Offload to not support these formats of files. **We do not recommend developers to continue using these formats of files.** ## Replacable Modules for VRAM Management -When `DiffSynth-Studio`'s VRAM management is enabled, the modules inside the model will be replaced with replacable modules in `diffsynth.core.vram.layers`. For usage, see [Fine-grained VRAM Management Scheme](/docs/en/Developer_Guide/Enabling_VRAM_management.md#writing-fine-grained-vram-management-schemes). \ No newline at end of file +When `DiffSynth-Studio`'s VRAM management is enabled, the modules inside the model will be replaced with replacable modules in `diffsynth.core.vram.layers`. For usage, see [Fine-grained VRAM Management Scheme](../../Developer_Guide/Enabling_VRAM_management.md#writing-fine-grained-vram-management-schemes). \ No newline at end of file diff --git a/docs/en/Developer_Guide/Building_a_Pipeline.md b/docs/en/Developer_Guide/Building_a_Pipeline.md index 7d5e785..7827c80 100644 --- a/docs/en/Developer_Guide/Building_a_Pipeline.md +++ b/docs/en/Developer_Guide/Building_a_Pipeline.md @@ -1,6 +1,6 @@ # Building a Pipeline -After [integrating the required models for the Pipeline](/docs/en/Developer_Guide/Integrating_Your_Model.md), you also need to build a `Pipeline` for model inference. This document provides a standardized process for building a `Pipeline`. Developers can also refer to existing `Pipeline` implementations for construction. +After [integrating the required models for the Pipeline](../Developer_Guide/Integrating_Your_Model.md), you also need to build a `Pipeline` for model inference. This document provides a standardized process for building a `Pipeline`. Developers can also refer to existing `Pipeline` implementations for construction. The `Pipeline` implementation is located in `diffsynth/pipelines`. Each `Pipeline` contains the following essential key components: @@ -79,7 +79,7 @@ This includes the following parts: return pipe ``` -Developers need to implement the logic for fetching models. The corresponding model names are the `"model_name"` in the [model Config filled in during model integration](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-3-writing-model-config). +Developers need to implement the logic for fetching models. The corresponding model names are the `"model_name"` in the [model Config filled in during model integration](../Developer_Guide/Integrating_Your_Model.md#step-3-writing-model-config). Some models also need to load `tokenizer`. Extra `tokenizer_config` parameters can be added to `from_pretrained` as needed, and this part can be implemented after fetching the models. diff --git a/docs/en/Developer_Guide/Enabling_VRAM_management.md b/docs/en/Developer_Guide/Enabling_VRAM_management.md index 9bdd49f..ef4ee58 100644 --- a/docs/en/Developer_Guide/Enabling_VRAM_management.md +++ b/docs/en/Developer_Guide/Enabling_VRAM_management.md @@ -1,6 +1,6 @@ # Fine-Grained VRAM Management Scheme -This document introduces how to write reasonable fine-grained VRAM management schemes for models, and how to use the VRAM management functions in `DiffSynth-Studio` for other external code libraries. Before reading this document, please read the document [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). +This document introduces how to write reasonable fine-grained VRAM management schemes for models, and how to use the VRAM management functions in `DiffSynth-Studio` for other external code libraries. Before reading this document, please read the document [VRAM Management](../Pipeline_Usage/VRAM_management.md). ## How Much VRAM Does a 20B Model Need? @@ -124,7 +124,7 @@ module_map={ } ``` -In addition, `vram_config` and `vram_limit` are also required, which have been introduced in [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md#more-usage-methods). +In addition, `vram_config` and `vram_limit` are also required, which have been introduced in [VRAM Management](../Pipeline_Usage/VRAM_management.md#more-usage-methods). Call `enable_vram_management` to enable VRAM management. Note that the `device` when loading the model is `cpu`, consistent with `offload_device`: @@ -171,7 +171,7 @@ The above code only requires 2G VRAM to run the `forward` of a 20B model. ## Disk Offload -[Disk Offload](/docs/en/Pipeline_Usage/VRAM_management.md#disk-offload) is a special VRAM management scheme that needs to be enabled during the model loading process, not after the model is loaded. Usually, when the above code can run smoothly, Disk Offload can be directly enabled: +[Disk Offload](../Pipeline_Usage/VRAM_management.md#disk-offload) is a special VRAM management scheme that needs to be enabled during the model loading process, not after the model is loaded. Usually, when the above code can run smoothly, Disk Offload can be directly enabled: ```python from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule @@ -212,7 +212,7 @@ with torch.no_grad(): output = model(**inputs) ``` -Disk Offload is an extremely special VRAM management scheme. It only supports `.safetensors` format files, not binary files such as `.bin`, `.pth`, `.ckpt`, and does not support [state dict converter](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape. +Disk Offload is an extremely special VRAM management scheme. It only supports `.safetensors` format files, not binary files such as `.bin`, `.pth`, `.ckpt`, and does not support [state dict converter](../Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape. If there are situations where Disk Offload cannot run normally but non-Disk Offload can run normally, please submit an issue to us on GitHub. @@ -227,7 +227,7 @@ To make it easier for users to use the VRAM management function, we write the fi } ```# Fine-Grained VRAM Management Scheme -This document introduces how to write reasonable fine-grained VRAM management schemes for models, and how to use the VRAM management functions in `DiffSynth-Studio` for other external code libraries. Before reading this document, please read the document [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). +This document introduces how to write reasonable fine-grained VRAM management schemes for models, and how to use the VRAM management functions in `DiffSynth-Studio` for other external code libraries. Before reading this document, please read the document [VRAM Management](../Pipeline_Usage/VRAM_management.md). ## How Much VRAM Does a 20B Model Need? @@ -351,7 +351,7 @@ module_map={ } ``` -In addition, `vram_config` and `vram_limit` are also required, which have been introduced in [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md#more-usage-methods). +In addition, `vram_config` and `vram_limit` are also required, which have been introduced in [VRAM Management](../Pipeline_Usage/VRAM_management.md#more-usage-methods). Call `enable_vram_management` to enable VRAM management. Note that the `device` when loading the model is `cpu`, consistent with `offload_device`: @@ -398,7 +398,7 @@ The above code only requires 2G VRAM to run the `forward` of a 20B model. ## Disk Offload -[Disk Offload](/docs/en/Pipeline_Usage/VRAM_management.md#disk-offload) is a special VRAM management scheme that needs to be enabled during the model loading process, not after the model is loaded. Usually, when the above code can run smoothly, Disk Offload can be directly enabled: +[Disk Offload](../Pipeline_Usage/VRAM_management.md#disk-offload) is a special VRAM management scheme that needs to be enabled during the model loading process, not after the model is loaded. Usually, when the above code can run smoothly, Disk Offload can be directly enabled: ```python from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule @@ -439,7 +439,7 @@ with torch.no_grad(): output = model(**inputs) ``` -Disk Offload is an extremely special VRAM management scheme. It only supports `.safetensors` format files, not binary files such as `.bin`, `.pth`, `.ckpt`, and does not support [state dict converter](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape. +Disk Offload is an extremely special VRAM management scheme. It only supports `.safetensors` format files, not binary files such as `.bin`, `.pth`, `.ckpt`, and does not support [state dict converter](../Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape. If there are situations where Disk Offload cannot run normally but non-Disk Offload can run normally, please submit an issue to us on GitHub. diff --git a/docs/en/Developer_Guide/Integrating_Your_Model.md b/docs/en/Developer_Guide/Integrating_Your_Model.md index ae5e6f2..817c875 100644 --- a/docs/en/Developer_Guide/Integrating_Your_Model.md +++ b/docs/en/Developer_Guide/Integrating_Your_Model.md @@ -183,4 +183,4 @@ Loaded model: { ## Step 5: Writing Model VRAM Management Scheme -`DiffSynth-Studio` supports complex VRAM management. See [Enabling VRAM Management](/docs/en/Developer_Guide/Enabling_VRAM_management.md) for details. \ No newline at end of file +`DiffSynth-Studio` supports complex VRAM management. See [Enabling VRAM Management](../Developer_Guide/Enabling_VRAM_management.md) for details. \ No newline at end of file diff --git a/docs/en/Developer_Guide/Training_Diffusion_Models.md b/docs/en/Developer_Guide/Training_Diffusion_Models.md index 3fc92fc..6f2aefe 100644 --- a/docs/en/Developer_Guide/Training_Diffusion_Models.md +++ b/docs/en/Developer_Guide/Training_Diffusion_Models.md @@ -1,6 +1,6 @@ # Integrating Model Training -After [integrating models](/docs/en/Developer_Guide/Integrating_Your_Model.md) and [implementing Pipeline](/docs/en/Developer_Guide/Building_a_Pipeline.md), the next step is to integrate model training functionality. +After [integrating models](../Developer_Guide/Integrating_Your_Model.md) and [implementing Pipeline](../Developer_Guide/Building_a_Pipeline.md), the next step is to integrate model training functionality. ## Training-Inference Consistent Pipeline Modification diff --git a/docs/en/Makefile b/docs/en/Makefile new file mode 100644 index 0000000..41c270b --- /dev/null +++ b/docs/en/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) \ No newline at end of file diff --git a/docs/en/Model_Details/FLUX.md b/docs/en/Model_Details/FLUX.md index 1120a34..283f895 100644 --- a/docs/en/Model_Details/FLUX.md +++ b/docs/en/Model_Details/FLUX.md @@ -14,7 +14,7 @@ cd DiffSynth-Studio pip install -e . ``` -For more information about installation, please refer to [Install Dependencies](/docs/en/Pipeline_Usage/Setup.md). +For more information about installation, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md). ## Quick Start @@ -81,31 +81,31 @@ graph LR; | Model ID | Extra Parameters | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training | | - | - | - | - | - | - | - | - | -| [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) | | [code](/examples/flux/model_inference/FLUX.1-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py) | -| [black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) | | [code](/examples/flux/model_inference/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py) | -| [black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) | `kontext_images` | [code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py) | -| [alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | -| [InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py) | -| [jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py) | -| [InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter) | `ipadapter_images`, `ipadapter_scale` | [code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py) | -| [ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou) | `infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py) | -| [DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) | `eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint` | [code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py) | - | - | [code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py) | -| [DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev) | `lora_encoder_inputs`, `lora_encoder_scale` | [code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py) | - | - | -| [DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev) | | [code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py) | - | - | - | - | - | -| [stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit) | `step1x_reference_image` | [code](/examples/flux/model_inference/Step1X-Edit.py) | [code](/examples/flux/model_inference_low_vram/Step1X-Edit.py) | [code](/examples/flux/model_training/full/Step1X-Edit.sh) | [code](/examples/flux/model_training/validate_full/Step1X-Edit.py) | [code](/examples/flux/model_training/lora/Step1X-Edit.sh) | [code](/examples/flux/model_training/validate_lora/Step1X-Edit.py) | -| [ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview) | `flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop` | [code](/examples/flux/model_inference/FLEX.2-preview.py) | [code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py) | [code](/examples/flux/model_training/full/FLEX.2-preview.sh) | [code](/examples/flux/model_training/validate_full/FLEX.2-preview.py) | [code](/examples/flux/model_training/lora/FLEX.2-preview.sh) | [code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py) | -| [DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2) | `nexus_gen_reference_image` | [code](/examples/flux/model_inference/Nexus-Gen-Editing.py) | [code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py) | [code](/examples/flux/model_training/full/Nexus-Gen.sh) | [code](/examples/flux/model_training/validate_full/Nexus-Gen.py) | [code](/examples/flux/model_training/lora/Nexus-Gen.sh) | [code](/examples/flux/model_training/validate_lora/Nexus-Gen.py) | +| [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev.py) | +| [black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Krea-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py) | +| [black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) | `kontext_images` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py) | +| [alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | +| [InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py) | +| [jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py) | +| [InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter) | `ipadapter_images`, `ipadapter_scale` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py) | +| [ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou) | `infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py) | +| [DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) | `eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-EliGen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py) | +| [DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev) | `lora_encoder_inputs`, `lora_encoder_scale` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py) | - | - | +| [DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py) | - | - | - | - | - | +| [stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit) | `step1x_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Step1X-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Step1X-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Step1X-Edit.py) | +| [ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview) | `flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLEX.2-preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLEX.2-preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLEX.2-preview.py) | +| [DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2) | `nexus_gen_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Nexus-Gen-Editing.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Nexus-Gen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Nexus-Gen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Nexus-Gen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Nexus-Gen.py) | Special Training Scripts: -* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md), [code](/examples/flux/model_training/special/differential_training/) -* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/flux/model_training/special/fp8_training/) -* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/flux/model_training/special/split_training/) -* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh) +* Differential LoRA Training: [doc](../Training/Differential_LoRA.md) +* FP8 Precision Training: [doc](../Training/FP8_Precision.md) +* Two-stage Split Training: [doc](../Training/Split_Training.md) +* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md) ## Model Inference -Models are loaded via `FluxImagePipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models). +Models are loaded via `FluxImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models). Input parameters for `FluxImagePipeline` inference include: @@ -143,11 +143,11 @@ Input parameters for `FluxImagePipeline` inference include: * `flex_control_stop`: Flex model control stop timestep. * `nexus_gen_reference_image`: Nexus-Gen model reference image. -If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above. +If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above. ## Model Training -FLUX series models are uniformly trained through [`examples/flux/model_training/train.py`](/examples/flux/model_training/train.py), and the script parameters include: +FLUX series models are uniformly trained through [`examples/flux/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/train.py), and the script parameters include: * General Training Parameters * Dataset Basic Configuration @@ -198,4 +198,4 @@ We have built a sample image dataset for your testing. You can download this dat modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset ``` -We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/). \ No newline at end of file +We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/). diff --git a/docs/en/Model_Details/FLUX2.md b/docs/en/Model_Details/FLUX2.md index 89e3c92..f3bb020 100644 --- a/docs/en/Model_Details/FLUX2.md +++ b/docs/en/Model_Details/FLUX2.md @@ -21,7 +21,7 @@ cd DiffSynth-Studio pip install -e . ``` -For more information about installation, please refer to [Install Dependencies](/docs/en/Pipeline_Usage/Setup.md). +For more information about installation, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md). ## Quick Start @@ -61,22 +61,22 @@ image.save("image.jpg") | Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training | | - | - | - | - | - | - | - | -|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| -|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)| -|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)| -|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)| -|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)| +|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| +|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)| +|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)| +|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)| +|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)| Special Training Scripts: -* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md) -* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md) -* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md) -* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md) +* Differential LoRA Training: [doc](../Training/Differential_LoRA.md) +* FP8 Precision Training: [doc](../Training/FP8_Precision.md) +* Two-stage Split Training: [doc](../Training/Split_Training.md) +* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md) ## Model Inference -Models are loaded via `Flux2ImagePipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models). +Models are loaded via `Flux2ImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models). Input parameters for `Flux2ImagePipeline` inference include: @@ -95,11 +95,11 @@ Input parameters for `Flux2ImagePipeline` inference include: * `tile_stride`: Tile stride during VAE encoding/decoding stages, default is 64, only effective when `tiled=True`, must be less than or equal to `tile_size`. * `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be disabled by setting to `lambda x:x`. -If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above. +If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above. ## Model Training -FLUX.2 series models are uniformly trained through [`examples/flux2/model_training/train.py`](/examples/flux2/model_training/train.py), and the script parameters include: +FLUX.2 series models are uniformly trained through [`examples/flux2/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/train.py), and the script parameters include: * General Training Parameters * Dataset Basic Configuration @@ -148,4 +148,4 @@ We have built a sample image dataset for your testing. You can download this dat modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset ``` -We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/). +We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/). diff --git a/docs/en/Model_Details/LTX-2.md b/docs/en/Model_Details/LTX-2.md new file mode 100644 index 0000000..c285a7f --- /dev/null +++ b/docs/en/Model_Details/LTX-2.md @@ -0,0 +1,116 @@ +# LTX-2 + +LTX-2 is a series of audio-video generation models developed by Lightricks. + +## Installation + +Before using this project for model inference and training, please install DiffSynth-Studio first. + +```shell +git clone https://github.com/modelscope/DiffSynth-Studio.git +cd DiffSynth-Studio +pip install -e . +``` + +For more information about installation, please refer to [Installation Dependencies](../Pipeline_Usage/Setup.md). + +## Quick Start + +Run the following code to quickly load the [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) model and perform inference. VRAM management has been enabled, and the framework will automatically control model parameter loading based on remaining VRAM. It can run with a minimum of 8GB VRAM. + +```python +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.float8_e5m2, + "offload_device": "cpu", + "onload_dtype": torch.float8_e5m2, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e5m2, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\"" +negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +height, width, num_frames = 512, 768, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_onestage.mp4', + fps=24, + audio_sample_rate=24000, +) +``` + +## Model Overview +|Model ID|Additional Parameters|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training| +|-|-|-|-|-|-|-|-| +|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-| +|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-| +|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-| +|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-| +|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-| +|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Static](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Static)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py)|-|-|-|-| + +## Model Inference + +Models are loaded through `LTX2AudioVideoPipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details. + +Input parameters for `LTX2AudioVideoPipeline` inference include: + +* `prompt`: Prompt describing the content appearing in the video. +* `negative_prompt`: Negative prompt describing content that should not appear in the video, default value is `""`. +* `cfg_scale`: Classifier-free guidance parameter, default value is 3.0. +* `input_images`: List of input images for image-to-video generation. +* `input_images_indexes`: Frame index list of input images in the video. +* `input_images_strength`: Strength of input images, default value is 1.0. +* `denoising_strength`: Denoising strength, range is 0~1, default value is 1.0. +* `seed`: Random seed. Default is `None`, which means completely random. +* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different results will be generated on different GPUs. +* `height`: Video height, must be a multiple of 32 (single-stage) or 64 (two-stage). +* `width`: Video width, must be a multiple of 32 (single-stage) or 64 (two-stage). +* `num_frames`: Number of video frames, default value is 121, must be a multiple of 8 + 1. +* `num_inference_steps`: Number of inference steps, default value is 40. +* `tiled`: Whether to enable VAE tiling inference, default is `True`. When set to `True`, it can significantly reduce VRAM usage during VAE encoding/decoding stages, with slight errors and minor inference time extension. +* `tile_size_in_pixels`: Pixel tiling size during VAE encoding/decoding stages, default is 512. +* `tile_overlap_in_pixels`: Pixel tiling overlap size during VAE encoding/decoding stages, default is 128. +* `tile_size_in_frames`: Frame tiling size during VAE encoding/decoding stages, default is 128. +* `tile_overlap_in_frames`: Frame tiling overlap size during VAE encoding/decoding stages, default is 24. +* `use_two_stage_pipeline`: Whether to use two-stage pipeline, default is `False`. +* `use_distilled_pipeline`: Whether to use distilled pipeline, default is `False`. +* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be set to `lambda x:x` to hide the progress bar. + +If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the previous "Supported Inference Scripts" section. + +## Model Training + +The LTX-2 series models currently do not support training functionality. We will add related support as soon as possible. diff --git a/docs/en/Model_Details/Overview.md b/docs/en/Model_Details/Overview.md index 5df8593..286141e 100644 --- a/docs/en/Model_Details/Overview.md +++ b/docs/en/Model_Details/Overview.md @@ -2,7 +2,7 @@ ## Qwen-Image -Documentation: [./Qwen-Image.md](/docs/en/Model_Details/Qwen-Image.md) +Documentation: [./Qwen-Image.md](../Model_Details/Qwen-Image.md)
@@ -69,23 +69,23 @@ graph LR; | Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training | | - | - | - | - | - | - | - | -| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](/examples/qwen_image/model_inference/Qwen-Image.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) | -| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) | -| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) | -| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) | -| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) | -| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) | -| [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) | [code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py) | -| [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) | [code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py) | -| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py) | -| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py) | -| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | -| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) | -| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - | +| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) | +| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) | +| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) | +| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) | +| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) | +| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) | +| [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py) | +| [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py) | +| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py) | +| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py) | +| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | +| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) | +| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - | ## FLUX Series -Documentation: [./FLUX.md](/docs/en/Model_Details/FLUX.md) +Documentation: [./FLUX.md](../Model_Details/FLUX.md)
@@ -149,24 +149,24 @@ graph LR; | Model ID | Extra Parameters | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training | | - | - | - | - | - | - | - | - | -| [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) | | [code](/examples/flux/model_inference/FLUX.1-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py) | -| [black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) | | [code](/examples/flux/model_inference/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py) | -| [black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) | `kontext_images` | [code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py) | -| [alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | -| [InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py) | -| [jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py) | -| [InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter) | `ipadapter_images`, `ipadapter_scale` | [code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py) | -| [ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou) | `infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py) | -| [DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) | `eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint` | [code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py) | - | - | [code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py) | -| [DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev) | `lora_encoder_inputs`, `lora_encoder_scale` | [code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py) | - | - | -| [DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev) | | [code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py) | - | - | - | - | - | -| [stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit) | `step1x_reference_image` | [code](/examples/flux/model_inference/Step1X-Edit.py) | [code](/examples/flux/model_inference_low_vram/Step1X-Edit.py) | [code](/examples/flux/model_training/full/Step1X-Edit.sh) | [code](/examples/flux/model_training/validate_full/Step1X-Edit.py) | [code](/examples/flux/model_training/lora/Step1X-Edit.sh) | [code](/examples/flux/model_training/validate_lora/Step1X-Edit.py) | -| [ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview) | `flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop` | [code](/examples/flux/model_inference/FLEX.2-preview.py) | [code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py) | [code](/examples/flux/model_training/full/FLEX.2-preview.sh) | [code](/examples/flux/model_training/validate_full/FLEX.2-preview.py) | [code](/examples/flux/model_training/lora/FLEX.2-preview.sh) | [code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py) | -| [DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2) | `nexus_gen_reference_image` | [code](/examples/flux/model_inference/Nexus-Gen-Editing.py) | [code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py) | [code](/examples/flux/model_training/full/Nexus-Gen.sh) | [code](/examples/flux/model_training/validate_full/Nexus-Gen.py) | [code](/examples/flux/model_training/lora/Nexus-Gen.sh) | [code](/examples/flux/model_training/validate_lora/Nexus-Gen.py) | +| [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev.py) | +| [black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Krea-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py) | +| [black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) | `kontext_images` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py) | +| [alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | +| [InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py) | +| [jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py) | +| [InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter) | `ipadapter_images`, `ipadapter_scale` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py) | +| [ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou) | `infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py) | +| [DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) | `eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-EliGen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py) | +| [DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev) | `lora_encoder_inputs`, `lora_encoder_scale` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py) | - | - | +| [DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py) | - | - | - | - | - | +| [stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit) | `step1x_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Step1X-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Step1X-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Step1X-Edit.py) | +| [ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview) | `flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLEX.2-preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLEX.2-preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLEX.2-preview.py) | +| [DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2) | `nexus_gen_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Nexus-Gen-Editing.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Nexus-Gen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Nexus-Gen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Nexus-Gen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Nexus-Gen.py) | ## Wan Series -Documentation: [./Wan.md](/docs/en/Model_Details/Wan.md) +Documentation: [./Wan.md](../Model_Details/Wan.md)
@@ -254,38 +254,38 @@ graph LR; | Model ID | Extra Parameters | Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training | | - | - | - | - | - | - | - | -| [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | | [code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py) | -| [Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | | [code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py) | -| [Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py) | -| [Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py) | -| [Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py) | -| [iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py) | -| [Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py) | -| [Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py) | -| [PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py) | -| [PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control) | `control_video` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py) | -| [PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py) | -| [PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control) | `control_video` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py) | -| [PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py) | -| [PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py) | -| [PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py) | -| [PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py) | -| [PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | -| [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | -| [DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1) | `motion_bucket_id` | [code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py) | -| [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) | | [code](/examples/wanvideo/model_inference/krea-realtime-video.py) | [code](/examples/wanvideo/model_training/full/krea-realtime-video.sh) | [code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py) | [code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh) | [code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py) | -| [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) | `longcat_video` | [code](/examples/wanvideo/model_inference/LongCat-Video.py) | [code](/examples/wanvideo/model_training/full/LongCat-Video.sh) | [code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py) | [code](/examples/wanvideo/model_training/lora/LongCat-Video.sh) | [code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py) | -| [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) | `vap_video`, `vap_prompt` | [code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py) | [code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py) | [code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py) | -| [Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B) | | [code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py) | -| [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py) | -| [Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py) | -| [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B) | `input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video` | [code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py) | -| [Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B) | `input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video` | [code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py) | -| [PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py) | -| [PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py) | -| [PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py) | -| [PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py) | +| [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py) | +| [Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py) | +| [Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py) | +| [Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py) | +| [Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py) | +| [iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py) | +| [Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py) | +| [Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py) | +| [PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py) | +| [PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control) | `control_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py) | +| [PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py) | +| [PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control) | `control_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py) | +| [PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control) | `control_video`, `reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py) | +| [PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control) | `control_video`, `reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py) | +| [PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py) | +| [PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py) | +| [PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | +| [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | +| [DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1) | `motion_bucket_id` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py) | +| [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/krea-realtime-video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/krea-realtime-video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/krea-realtime-video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/krea-realtime-video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py) | +| [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) | `longcat_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/LongCat-Video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/LongCat-Video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/LongCat-Video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/LongCat-Video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/LongCat-Video.py) | +| [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) | `vap_video`, `vap_prompt` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py) | +| [Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py) | +| [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py) | +| [Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py) | +| [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B) | `input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py) | +| [Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B) | `input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py) | +| [PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py) | +| [PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py) | +| [PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control) | `control_video`, `reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py) | +| [PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py) | -* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/wanvideo/model_training/special/fp8_training/) -* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/wanvideo/model_training/special/split_training/) -* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/wanvideo/model_training/special/direct_distill/) \ No newline at end of file +* FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/fp8_training/) +* Two-stage Split Training: [doc](../Training/Split_Training.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/split_training/) +* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/direct_distill/) diff --git a/docs/en/Model_Details/Qwen-Image.md b/docs/en/Model_Details/Qwen-Image.md index 08b8a35..043ca63 100644 --- a/docs/en/Model_Details/Qwen-Image.md +++ b/docs/en/Model_Details/Qwen-Image.md @@ -14,7 +14,7 @@ cd DiffSynth-Studio pip install -e . ``` -For more information about installation, please refer to [Install Dependencies](/docs/en/Pipeline_Usage/Setup.md). +For more information about installation, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md). ## Quick Start @@ -80,35 +80,41 @@ graph LR; | Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training | | - | - | - | - | - | - | - | -| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](/examples/qwen_image/model_inference/Qwen-Image.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) | -|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)| -| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) | -| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) | -|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)| -|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)| -|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)| -| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) | -| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) | -| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) | -| [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) | [code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py) | -| [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) | [code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py) | -| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py) | -| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py) | -| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | -| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) | -| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - | -|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-| +| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) | +|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)| +| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) | +| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) | +|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)| +|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-| +|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)| +|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)| +| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) | +| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) | +| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) | +| [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py) | +| [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py) | +| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py) | +| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py) | +| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | +| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) | +| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - | +|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-| Special Training Scripts: -* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md), [code](/examples/qwen_image/model_training/special/differential_training/) -* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/qwen_image/model_training/special/fp8_training/) -* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/qwen_image/model_training/special/split_training/) -* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) +* Differential LoRA Training: [doc](../Training/Differential_LoRA.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/qwen_image/model_training/special/differential_training/) +* FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/qwen_image/model_training/special/fp8_training/) +* Two-stage Split Training: [doc](../Training/Split_Training.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/qwen_image/model_training/special/split_training/) +* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md), [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) + +DeepSpeed ZeRO Stage 3 Training: The Qwen-Image series models support DeepSpeed ZeRO Stage 3 training, which partitions the model across multiple GPUs. Taking full parameter training of the Qwen-Image model as an example, the following modifications are required: + +* `--config_file examples/qwen_image/model_training/full/accelerate_config_zero3.yaml` +* `--initialize_model_on_cpu` ## Model Inference -Models are loaded via `QwenImagePipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models). +Models are loaded via `QwenImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models). Input parameters for `QwenImagePipeline` inference include: @@ -139,11 +145,11 @@ Input parameters for `QwenImagePipeline` inference include: * `tile_stride`: Tile stride during VAE encoding/decoding stages, default is 64, only effective when `tiled=True`, must be less than or equal to `tile_size`. * `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be disabled by setting to `lambda x:x`. -If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above. +If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above. ## Model Training -Qwen-Image series models are uniformly trained through [`examples/qwen_image/model_training/train.py`](/examples/qwen_image/model_training/train.py), and the script parameters include: +Qwen-Image series models are uniformly trained through [`examples/qwen_image/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/train.py), and the script parameters include: * General Training Parameters * Dataset Basic Configuration @@ -193,4 +199,4 @@ We have built a sample image dataset for your testing. You can download this dat modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset ``` -We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/). \ No newline at end of file +We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/). diff --git a/docs/en/Model_Details/Wan.md b/docs/en/Model_Details/Wan.md index 83141bf..20e1282 100644 --- a/docs/en/Model_Details/Wan.md +++ b/docs/en/Model_Details/Wan.md @@ -14,7 +14,7 @@ cd DiffSynth-Studio pip install -e . ``` -For more information about installation, please refer to [Install Dependencies](/docs/en/Pipeline_Usage/Setup.md). +For more information about installation, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md). ## Quick Start @@ -106,45 +106,50 @@ graph LR; | Model ID | Extra Parameters | Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training | | - | - | - | - | - | - | - | -| [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | | [code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py) | -| [Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | | [code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py) | -| [Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py) | -| [Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py) | -| [Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py) | -| [iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py) | -| [Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py) | -| [Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py) | -| [PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py) | -| [PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control) | `control_video` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py) | -| [PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py) | -| [PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control) | `control_video` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py) | -| [PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py) | -| [PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py) | -| [PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py) | -| [PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py) | -| [PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | -| [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | -| [DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1) | `motion_bucket_id` | [code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py) | -| [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) | | [code](/examples/wanvideo/model_inference/krea-realtime-video.py) | [code](/examples/wanvideo/model_training/full/krea-realtime-video.sh) | [code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py) | [code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh) | [code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py) | -| [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) | `longcat_video` | [code](/examples/wanvideo/model_inference/LongCat-Video.py) | [code](/examples/wanvideo/model_training/full/LongCat-Video.sh) | [code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py) | [code](/examples/wanvideo/model_training/lora/LongCat-Video.sh) | [code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py) | -| [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) | `vap_video`, `vap_prompt` | [code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py) | [code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py) | [code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py) | -| [Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B) | | [code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py) | -| [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py) | -| [Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py) | -| [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B) | `input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video` | [code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py) | -| [Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B) | `input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video` | [code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py) | -| [PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py) | -| [PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py) | -| [PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py) | -| [PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py) | +| [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py) | +| [Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py) | +| [Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py) | +| [Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py) | +| [Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py) | +| [iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py) | +| [Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py) | +| [Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py) | +| [PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py) | +| [PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control) | `control_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py) | +| [PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py) | +| [PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control) | `control_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py) | +| [PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control) | `control_video`, `reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py) | +| [PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control) | `control_video`, `reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py) | +| [PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py) | +| [PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py) | +| [PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | +| [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | +| [DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1) | `motion_bucket_id` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py) | +| [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/krea-realtime-video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/krea-realtime-video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/krea-realtime-video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/krea-realtime-video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py) | +| [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) | `longcat_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/LongCat-Video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/LongCat-Video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/LongCat-Video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/LongCat-Video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/LongCat-Video.py) | +| [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) | `vap_video`, `vap_prompt` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py) | +| [Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py) | +| [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py) | +| [Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py) | +| [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B) | `input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py) | +| [Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B) | `input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py) | +| [PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py) | +| [PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py) | +| [PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control) | `control_video`, `reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py) | +| [PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py) | -* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/wanvideo/model_training/special/fp8_training/) -* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/wanvideo/model_training/special/split_training/) -* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/wanvideo/model_training/special/direct_distill/) +* FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/fp8_training/) +* Two-stage Split Training: [doc](../Training/Split_Training.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/split_training/) +* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/direct_distill/) + +DeepSpeed ZeRO Stage 3 Training: The Wan series models support DeepSpeed ZeRO Stage 3 training, which partitions the model across multiple GPUs. Taking full parameter training of the Wan2.1-T2V-14B model as an example, the following modifications are required: + +* `--config_file examples/wanvideo/model_training/full/accelerate_config_zero3.yaml` +* `--initialize_model_on_cpu` ## Model Inference -Models are loaded via `WanVideoPipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models). +Models are loaded via `WanVideoPipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models). Input parameters for `WanVideoPipeline` inference include: @@ -194,11 +199,11 @@ Input parameters for `WanVideoPipeline` inference include: * `tea_cache_model_id`: Model ID used by TeaCache. * `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be disabled by setting to `lambda x:x`. -If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above. +If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above. ## Model Training -Wan series models are uniformly trained through [`examples/wanvideo/model_training/train.py`](/examples/wanvideo/model_training/train.py), and the script parameters include: +Wan series models are uniformly trained through [`examples/wanvideo/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/train.py), and the script parameters include: * General Training Parameters * Dataset Basic Configuration @@ -249,4 +254,4 @@ We have built a sample video dataset for your testing. You can download this dat modelscope download --dataset DiffSynth-Studio/example_video_dataset --local_dir ./data/example_video_dataset ``` -We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/). \ No newline at end of file +We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/). diff --git a/docs/en/Model_Details/Z-Image.md b/docs/en/Model_Details/Z-Image.md index 3673a52..38075cc 100644 --- a/docs/en/Model_Details/Z-Image.md +++ b/docs/en/Model_Details/Z-Image.md @@ -12,7 +12,7 @@ cd DiffSynth-Studio pip install -e . ``` -For more information about installation, please refer to [Install Dependencies](/docs/en/Pipeline_Usage/Setup.md). +For more information about installation, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md). ## Quick Start @@ -50,18 +50,23 @@ image.save("image.jpg") ## Model Overview -| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training | -| - | - | - | - | - | - | - | -| [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) | [code](/examples/z_image/model_inference/Z-Image-Turbo.py) | [code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py) | [code](/examples/z_image/model_training/full/Z-Image-Turbo.sh) | [code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py) | [code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh) | [code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py) | +|Model ID|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training| +|-|-|-|-|-|-|-| +|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image.py)| +|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-i2L.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-| +|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)| +|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)| +|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)| +|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)| Special Training Scripts: -* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md), [code](/examples/z_image/model_training/special/differential_training/) -* Trajectory Imitation Distillation Training (Experimental Feature): [code](/examples/z_image/model_training/special/trajectory_imitation/) +* Differential LoRA Training: [doc](../Training/Differential_LoRA.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/z_image/model_training/special/differential_training/) +* Trajectory Imitation Distillation Training (Experimental Feature): [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/z_image/model_training/special/trajectory_imitation/) ## Model Inference -Models are loaded via `ZImagePipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models). +Models are loaded via `ZImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models). Input parameters for `ZImagePipeline` inference include: @@ -75,12 +80,15 @@ Input parameters for `ZImagePipeline` inference include: * `seed`: Random seed. Default is `None`, meaning completely random. * `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different GPUs will produce different generation results. * `num_inference_steps`: Number of inference steps, default value is 8. +* `controlnet_inputs`: Inputs for ControlNet models. +* `edit_image`: Edit images for image editing models, supporting multiple images. +* `positive_only_lora`: LoRA weights used only in positive prompts. -If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above. +If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above. ## Model Training -Z-Image series models are uniformly trained through [`examples/z_image/model_training/train.py`](/examples/z_image/model_training/train.py), and the script parameters include: +Z-Image series models are uniformly trained through [`examples/z_image/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/train.py), and the script parameters include: * General Training Parameters * Dataset Basic Configuration @@ -129,13 +137,13 @@ We have built a sample image dataset for your testing. You can download this dat modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset ``` -We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/). +We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/). Training Tips: * [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) is a distilled acceleration model. Therefore, direct training will quickly cause the model to lose its acceleration capability. The effect of inference with "acceleration configuration" (`num_inference_steps=8`, `cfg_scale=1`) becomes worse, while the effect of inference with "no acceleration configuration" (`num_inference_steps=30`, `cfg_scale=2`) becomes better. The following training and inference schemes can be adopted: - * Standard SFT Training ([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + No Acceleration Configuration Inference - * Differential LoRA Training ([code](/examples/z_image/model_training/special/differential_training/)) + Acceleration Configuration Inference + * Standard SFT Training ([code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + No Acceleration Configuration Inference + * Differential LoRA Training ([code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/z_image/model_training/special/differential_training/)) + Acceleration Configuration Inference * An additional LoRA needs to be loaded in differential LoRA training, e.g., [ostris/zimage_turbo_training_adapter](https://www.modelscope.cn/models/ostris/zimage_turbo_training_adapter) - * Standard SFT Training ([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Trajectory Imitation Distillation Training ([code](/examples/z_image/model_training/special/trajectory_imitation/)) + Acceleration Configuration Inference - * Standard SFT Training ([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Load Distillation Acceleration LoRA During Inference ([model](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-Turbo-DistillPatch)) + Acceleration Configuration Inference + * Standard SFT Training ([code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Trajectory Imitation Distillation Training ([code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/z_image/model_training/special/trajectory_imitation/)) + Acceleration Configuration Inference + * Standard SFT Training ([code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Load Distillation Acceleration LoRA During Inference ([model](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-Turbo-DistillPatch)) + Acceleration Configuration Inference diff --git a/docs/en/Pipeline_Usage/Environment_Variables.md b/docs/en/Pipeline_Usage/Environment_Variables.md index 91d016f..281018b 100644 --- a/docs/en/Pipeline_Usage/Environment_Variables.md +++ b/docs/en/Pipeline_Usage/Environment_Variables.md @@ -28,7 +28,7 @@ Model download root directory. Can be set to any local path. If `local_model_pat ## `DIFFSYNTH_ATTENTION_IMPLEMENTATION` -Attention mechanism implementation method. Can be set to `flash_attention_3`, `flash_attention_2`, `sage_attention`, `xformers`, or `torch`. See [`./core/attention.md`](/docs/en/API_Reference/core/attention.md) for details. +Attention mechanism implementation method. Can be set to `flash_attention_3`, `flash_attention_2`, `sage_attention`, `xformers`, or `torch`. See [`./core/attention.md`](../API_Reference/core/attention.md) for details. ## `DIFFSYNTH_DISK_MAP_BUFFER_SIZE` diff --git a/docs/en/Pipeline_Usage/GPU_support.md b/docs/en/Pipeline_Usage/GPU_support.md index d1e77ef..e27d23a 100644 --- a/docs/en/Pipeline_Usage/GPU_support.md +++ b/docs/en/Pipeline_Usage/GPU_support.md @@ -2,7 +2,7 @@ `DiffSynth-Studio` supports various GPUs and NPUs. This document explains how to run model inference and training on these devices. -Before you begin, please follow the [Installation Guide](/docs/en/Pipeline_Usage/Setup.md) to install the required GPU/NPU dependencies. +Before you begin, please follow the [Installation Guide](../Pipeline_Usage/Setup.md) to install the required GPU/NPU dependencies. ## NVIDIA GPU @@ -58,6 +58,14 @@ video = pipe( save_video(video, "video.mp4", fps=15, quality=5) ``` +#### USP(Unified Sequence Parallel) +If you want to use this feature on NPU, please install additional third-party libraries as follows: +```shell +pip install git+https://github.com/feifeibear/long-context-attention.git +pip install git+https://github.com/xdit-project/xDiT.git +``` + + ### Training NPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_training`, for example `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`. @@ -82,4 +90,5 @@ Set 0 or not set: indicates not enabling the binding function | Model | Parameter | Note | |----------------|---------------------------|-------------------| | Wan 14B series | --initialize_model_on_cpu | The 14B model needs to be initialized on the CPU | +| Qwen-Image series | --initialize_model_on_cpu | The model needs to be initialized on the CPU | | Z-Image series | --enable_npu_patch | Using NPU fusion operator to replace the corresponding operator in Z-image model to improve the performance of the model on NPU | \ No newline at end of file diff --git a/docs/en/Pipeline_Usage/Model_Inference.md b/docs/en/Pipeline_Usage/Model_Inference.md index e5a85a0..8cd3edf 100644 --- a/docs/en/Pipeline_Usage/Model_Inference.md +++ b/docs/en/Pipeline_Usage/Model_Inference.md @@ -22,7 +22,7 @@ pipe = QwenImagePipeline.from_pretrained( ) ``` -Where `torch_dtype` and `device` are computation precision and computation device (not model precision and device). `model_configs` can be configured in multiple ways for model paths. For how models are loaded internally in this project, please refer to [`diffsynth.core.loader`](/docs/en/API_Reference/core/loader.md). +Where `torch_dtype` and `device` are computation precision and computation device (not model precision and device). `model_configs` can be configured in multiple ways for model paths. For how models are loaded internally in this project, please refer to [`diffsynth.core.loader`](../API_Reference/core/loader.md).
@@ -34,7 +34,7 @@ Where `torch_dtype` and `device` are computation precision and computation devic > ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), > ``` > -> Model files are downloaded to the `./models` path by default, which can be modified through [environment variable DIFFSYNTH_MODEL_BASE_PATH](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path). +> Model files are downloaded to the `./models` path by default, which can be modified through [environment variable DIFFSYNTH_MODEL_BASE_PATH](../Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path).
@@ -61,7 +61,7 @@ Where `torch_dtype` and `device` are computation precision and computation devic
-By default, even after models have been downloaded, the program will still query remotely for missing files. To completely disable remote requests, set [environment variable DIFFSYNTH_SKIP_DOWNLOAD](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) to `True`. +By default, even after models have been downloaded, the program will still query remotely for missing files. To completely disable remote requests, set [environment variable DIFFSYNTH_SKIP_DOWNLOAD](../Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) to `True`. ```shell import os @@ -69,7 +69,7 @@ os.environ["DIFFSYNTH_SKIP_DOWNLOAD"] = "True" import diffsynth ``` -To download models from [HuggingFace](https://huggingface.co/), set [environment variable DIFFSYNTH_DOWNLOAD_SOURCE](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_download_source) to `huggingface`. +To download models from [HuggingFace](https://huggingface.co/), set [environment variable DIFFSYNTH_DOWNLOAD_SOURCE](../Pipeline_Usage/Environment_Variables.md#diffsynth_download_source) to `huggingface`. ```shell import os @@ -102,4 +102,65 @@ image.save("image.jpg") Each model `Pipeline` has different input parameters. Please refer to the documentation for each model. -If the model parameters are too large, causing insufficient VRAM, please enable [VRAM management](/docs/en/Pipeline_Usage/VRAM_management.md). \ No newline at end of file +If the model parameters are too large, causing insufficient VRAM, please enable [VRAM management](../Pipeline_Usage/VRAM_management.md). + +## Loading LoRA + +LoRA is a lightweight model training method that produces a small number of parameters to extend model capabilities. DiffSynth-Studio supports two ways to load LoRA: cold loading and hot loading. + +* Cold loading: When the base model does not have [VRAM management](../Pipeline_Usage/VRAM_management.md) enabled, LoRA will be fused into the base model weights. In this case, inference speed remains unchanged, but LoRA cannot be unloaded after loading. + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +lora = ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1", origin_file_pattern="model.safetensors") +pipe.load_lora(pipe.dit, lora, alpha=1) +prompt = "Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal." +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +* Hot loading: When the base model has [VRAM management](../Pipeline_Usage/VRAM_management.md) enabled, LoRA will not be fused into the base model weights. In this case, inference speed will be slower, but LoRA can be unloaded through `pipe.clear_lora()` after loading. + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cuda", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +lora = ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1", origin_file_pattern="model.safetensors") +pipe.load_lora(pipe.dit, lora, alpha=1) +prompt = "Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal." +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +pipe.clear_lora() +``` diff --git a/docs/en/Pipeline_Usage/Model_Training.md b/docs/en/Pipeline_Usage/Model_Training.md index 3c5bffd..bf797bc 100644 --- a/docs/en/Pipeline_Usage/Model_Training.md +++ b/docs/en/Pipeline_Usage/Model_Training.md @@ -65,7 +65,7 @@ image_1.jpg,"a dog" image_2.jpg,"a cat" ``` -We have built sample datasets for your testing. To understand how the universal dataset architecture is implemented, please refer to [`diffsynth.core.data`](/docs/en/API_Reference/core/data.md). +We have built sample datasets for your testing. To understand how the universal dataset architecture is implemented, please refer to [`diffsynth.core.data`](../API_Reference/core/data.md).
@@ -93,7 +93,7 @@ We have built sample datasets for your testing. To understand how the universal ## Loading Models -Similar to [model loading during inference](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models), we support multiple ways to configure model paths, and the two methods can be mixed. +Similar to [model loading during inference](../Pipeline_Usage/Model_Inference.md#loading-models), we support multiple ways to configure model paths, and the two methods can be mixed.
@@ -115,9 +115,9 @@ Similar to [model loading during inference](/docs/en/Pipeline_Usage/Model_Infere > --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" > ``` > -> Model files are downloaded to the `./models` path by default, which can be modified through [environment variable DIFFSYNTH_MODEL_BASE_PATH](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path). +> Model files are downloaded to the `./models` path by default, which can be modified through [environment variable DIFFSYNTH_MODEL_BASE_PATH](../Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path). > -> By default, even after models have been downloaded, the program will still query remotely for missing files. To completely disable remote requests, set [environment variable DIFFSYNTH_SKIP_DOWNLOAD](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) to `True`. +> By default, even after models have been downloaded, the program will still query remotely for missing files. To completely disable remote requests, set [environment variable DIFFSYNTH_SKIP_DOWNLOAD](../Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) to `True`.
@@ -237,11 +237,11 @@ accelerate launch --config_file examples/qwen_image/model_training/full/accelera ## Training Considerations -* In addition to the `csv` format, dataset metadata also supports `json` and `jsonl` formats. For how to choose the best metadata format, please refer to [/docs/en/API_Reference/core/data.md#metadata](/docs/en/API_Reference/core/data.md#metadata) +* In addition to the `csv` format, dataset metadata also supports `json` and `jsonl` formats. For how to choose the best metadata format, please refer to [../API_Reference/core/data.md#metadata](../API_Reference/core/data.md#metadata) * Training effectiveness is usually strongly correlated with training steps and weakly correlated with epoch count. Therefore, we recommend using the `--save_steps` parameter to save model files at training step intervals. * When data volume * `dataset_repeat` exceeds $10^9$, we observed that the dataset speed becomes significantly slower, which seems to be a `PyTorch` bug. We are not sure if newer versions of `PyTorch` have fixed this issue. * For learning rate `--learning_rate`, it is recommended to set to `1e-4` in LoRA training and `1e-5` in full training. -* The training framework does not support batch size > 1. The reasons are complex. See [Q&A: Why doesn't the training framework support batch size > 1?](/docs/en/QA.md#why-doesnt-the-training-framework-support-batch-size--1) +* The training framework does not support batch size > 1. The reasons are complex. See [Q&A: Why doesn't the training framework support batch size > 1?](../QA.md#why-doesnt-the-training-framework-support-batch-size--1) * Some models contain redundant parameters. For example, the text encoding part of the last layer of Qwen-Image's DiT part. When training these models, `--find_unused_parameters` needs to be set to avoid errors in multi-GPU training. For compatibility with community models, we do not intend to remove these redundant parameters. * The loss function value of Diffusion models has little relationship with actual effects. Therefore, we do not record loss function values during training. We recommend setting `--num_epochs` to a sufficiently large value, testing while training, and manually closing the training program after the effect converges. -* `--use_gradient_checkpointing` is usually enabled unless GPU VRAM is sufficient; `--use_gradient_checkpointing_offload` is enabled as needed. See [`diffsynth.core.gradient`](/docs/en/API_Reference/core/gradient.md) for details. \ No newline at end of file +* `--use_gradient_checkpointing` is usually enabled unless GPU VRAM is sufficient; `--use_gradient_checkpointing_offload` is enabled as needed. See [`diffsynth.core.gradient`](../API_Reference/core/gradient.md) for details. \ No newline at end of file diff --git a/docs/en/Pipeline_Usage/Setup.md b/docs/en/Pipeline_Usage/Setup.md index dc06364..2e45668 100644 --- a/docs/en/Pipeline_Usage/Setup.md +++ b/docs/en/Pipeline_Usage/Setup.md @@ -41,7 +41,7 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm6 # x86 pip install -e .[npu] --extra-index-url "https://download.pytorch.org/whl/cpu" -When using Ascend NPU, please replace `"cuda"` with `"npu"` in your Python code. For details, see [NPU Support](/docs/en/Pipeline_Usage/GPU_support.md#ascend-npu). +When using Ascend NPU, please replace `"cuda"` with `"npu"` in your Python code. For details, see [NPU Support](../Pipeline_Usage/GPU_support.md#ascend-npu). ## Other Installation Issues diff --git a/docs/en/Pipeline_Usage/VRAM_management.md b/docs/en/Pipeline_Usage/VRAM_management.md index ecf5379..9a30000 100644 --- a/docs/en/Pipeline_Usage/VRAM_management.md +++ b/docs/en/Pipeline_Usage/VRAM_management.md @@ -140,7 +140,7 @@ image.save("image.jpg") In more extreme cases, when memory is also insufficient to store the entire model, the Disk Offload feature allows lazy loading of model parameters, meaning each Layer of the model only reads the corresponding parameters from disk when the forward function is called. When enabling this feature, we recommend using high-speed SSD drives. -Disk Offload is a very special VRAM management solution that only supports `.safetensors` format files, not `.bin`, `.pth`, `.ckpt`, or other binary files, and does not support [state dict converter](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape. +Disk Offload is a very special VRAM management solution that only supports `.safetensors` format files, not `.bin`, `.pth`, `.ckpt`, or other binary files, and does not support [state dict converter](../Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape. ```python from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig @@ -196,7 +196,7 @@ Specifically, the VRAM management module divides model Layers into the following * Preparing: Intermediate state between Onload and Computation. A temporary storage state when VRAM allows. This state is controlled by the VRAM management mechanism and enters this state if and only if [vram_limit is set to unlimited] or [vram_limit is set and there is spare VRAM] * Computation: The model is being computed. This state is controlled by the VRAM management mechanism and is temporarily entered only during `forward` -If you are a model developer and want to control the VRAM management granularity of a specific model, please refer to [../Developer_Guide/Enabling_VRAM_management.md](/docs/en/Developer_Guide/Enabling_VRAM_management.md). +If you are a model developer and want to control the VRAM management granularity of a specific model, please refer to [../Developer_Guide/Enabling_VRAM_management.md](../Developer_Guide/Enabling_VRAM_management.md). ## Best Practices diff --git a/docs/en/QA.md b/docs/en/QA.md index fe75460..ae97d1b 100644 --- a/docs/en/QA.md +++ b/docs/en/QA.md @@ -25,4 +25,11 @@ Even with suitable hardware conditions, we currently have no plans to support na * The main challenge of native FP8 precision training is precision overflow caused by gradient explosion. To ensure training stability, the model structure needs to be redesigned accordingly. However, no model developers are willing to do so at present. * Additionally, models trained with native FP8 precision can only be computed with BF16 precision during inference without Hopper architecture GPUs, theoretically resulting in generation quality inferior to FP8. -Therefore, native FP8 precision training technology is extremely immature. We will observe the technological developments in the open-source community. \ No newline at end of file +Therefore, native FP8 precision training technology is extremely immature. We will observe the technological developments in the open-source community. + +## How to dynamically load LoRA models during inference? + +We support two loading methods for LoRA models. See [LoRA Loading](./Pipeline_Usage/Model_Inference.md#loading-lora) for details: + +* Cold Loading: When [VRAM Management](./Pipeline_Usage/VRAM_management.md) is not enabled for the base model, LoRA will be fused into the base model weights. In this case, inference speed remains unchanged, and LoRA cannot be unloaded after loading. +* Hot Loading: When [VRAM Management](./Pipeline_Usage/VRAM_management.md) is enabled for the base model, LoRA will not be fused into the base model weights. In this case, inference speed will slow down, and LoRA can be unloaded after loading via `pipe.clear_lora()`. diff --git a/docs/en/README.md b/docs/en/README.md index 39ae439..aac6000 100644 --- a/docs/en/README.md +++ b/docs/en/README.md @@ -26,58 +26,58 @@ graph LR; This section introduces the basic usage of `DiffSynth-Studio`, including how to enable VRAM management for inference on GPUs with extremely low VRAM, and how to train various base models, LoRAs, ControlNets, and other models. -* [Installation Dependencies](/docs/en/Pipeline_Usage/Setup.md) -* [Model Inference](/docs/en/Pipeline_Usage/Model_Inference.md) -* [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md) -* [Model Training](/docs/en/Pipeline_Usage/Model_Training.md) -* [Environment Variables](/docs/en/Pipeline_Usage/Environment_Variables.md) -* [GPU/NPU Support](/docs/en/Pipeline_Usage/GPU_support.md) +* [Installation Dependencies](./Pipeline_Usage/Setup.md) +* [Model Inference](./Pipeline_Usage/Model_Inference.md) +* [VRAM Management](./Pipeline_Usage/VRAM_management.md) +* [Model Training](./Pipeline_Usage/Model_Training.md) +* [Environment Variables](./Pipeline_Usage/Environment_Variables.md) +* [GPU/NPU Support](./Pipeline_Usage/GPU_support.md) ## Section 2: Model Details This section introduces the Diffusion models supported by `DiffSynth-Studio`. Some model pipelines feature special functionalities such as controllable generation and parallel acceleration. -* [FLUX.1](/docs/en/Model_Details/FLUX.md) -* [Wan](/docs/en/Model_Details/Wan.md) -* [Qwen-Image](/docs/en/Model_Details/Qwen-Image.md) -* [FLUX.2](/docs/en/Model_Details/FLUX2.md) -* [Z-Image](/docs/en/Model_Details/Z-Image.md) +* [FLUX.1](./Model_Details/FLUX.md) +* [Wan](./Model_Details/Wan.md) +* [Qwen-Image](./Model_Details/Qwen-Image.md) +* [FLUX.2](./Model_Details/FLUX2.md) +* [Z-Image](./Model_Details/Z-Image.md) ## Section 3: Training Framework This section introduces the design philosophy of the training framework in `DiffSynth-Studio`, helping developers understand the principles of Diffusion model training algorithms. -* [Basic Principles of Diffusion Models](/docs/en/Training/Understanding_Diffusion_models.md) -* [Standard Supervised Training](/docs/en/Training/Supervised_Fine_Tuning.md) -* [Enabling FP8 Precision in Training](/docs/en/Training/FP8_Precision.md) -* [End-to-End Distillation Accelerated Training](/docs/en/Training/Direct_Distill.md) -* [Two-Stage Split Training](/docs/en/Training/Split_Training.md) -* [Differential LoRA Training](/docs/en/Training/Differential_LoRA.md) +* [Basic Principles of Diffusion Models](./Training/Understanding_Diffusion_models.md) +* [Standard Supervised Training](./Training/Supervised_Fine_Tuning.md) +* [Enabling FP8 Precision in Training](./Training/FP8_Precision.md) +* [End-to-End Distillation Accelerated Training](./Training/Direct_Distill.md) +* [Two-Stage Split Training](./Training/Split_Training.md) +* [Differential LoRA Training](./Training/Differential_LoRA.md) ## Section 4: Model Integration This section introduces how to integrate models into `DiffSynth-Studio` to utilize the framework's basic functions, helping developers provide support for new models in this project or perform inference and training of private models. -* [Integrating Model Architecture](/docs/en/Developer_Guide/Integrating_Your_Model.md) -* [Building a Pipeline](/docs/en/Developer_Guide/Building_a_Pipeline.md) -* [Enabling Fine-Grained VRAM Management](/docs/en/Developer_Guide/Enabling_VRAM_management.md) -* [Model Training Integration](/docs/en/Developer_Guide/Training_Diffusion_Models.md) +* [Integrating Model Architecture](./Developer_Guide/Integrating_Your_Model.md) +* [Building a Pipeline](./Developer_Guide/Building_a_Pipeline.md) +* [Enabling Fine-Grained VRAM Management](./Developer_Guide/Enabling_VRAM_management.md) +* [Model Training Integration](./Developer_Guide/Training_Diffusion_Models.md) ## Section 5: API Reference This section introduces the independent core module `diffsynth.core` in `DiffSynth-Studio`, explaining how internal functions are designed and operate. Developers can use these functional modules in other codebase developments if needed. -* [`diffsynth.core.attention`](/docs/en/API_Reference/core/attention.md): Attention mechanism implementation -* [`diffsynth.core.data`](/docs/en/API_Reference/core/data.md): Data processing operators and general datasets -* [`diffsynth.core.gradient`](/docs/en/API_Reference/core/gradient.md): Gradient checkpointing -* [`diffsynth.core.loader`](/docs/en/API_Reference/core/loader.md): Model download and loading -* [`diffsynth.core.vram`](/docs/en/API_Reference/core/vram.md): VRAM management +* [`diffsynth.core.attention`](./API_Reference/core/attention.md): Attention mechanism implementation +* [`diffsynth.core.data`](./API_Reference/core/data.md): Data processing operators and general datasets +* [`diffsynth.core.gradient`](./API_Reference/core/gradient.md): Gradient checkpointing +* [`diffsynth.core.loader`](./API_Reference/core/loader.md): Model download and loading +* [`diffsynth.core.vram`](./API_Reference/core/vram.md): VRAM management ## Section 6: Academic Guide This section introduces how to use `DiffSynth-Studio` to train new models, helping researchers explore new model technologies. -* Training models from scratch 【coming soon】 +* [Training models from scratch](./Research_Tutorial/train_from_scratch.md) * Inference improvement techniques 【coming soon】 * Designing controllable generation models 【coming soon】 * Creating new training paradigms 【coming soon】 @@ -86,4 +86,4 @@ This section introduces how to use `DiffSynth-Studio` to train new models, helpi This section summarizes common developer questions. If you encounter issues during usage or development, please refer to this section. If you still cannot resolve the problem, please submit an issue on GitHub. -* [Frequently Asked Questions](/docs/en/QA.md) \ No newline at end of file +* [Frequently Asked Questions](./QA.md) \ No newline at end of file diff --git a/docs/en/Research_Tutorial/train_from_scratch.md b/docs/en/Research_Tutorial/train_from_scratch.md new file mode 100644 index 0000000..527664c --- /dev/null +++ b/docs/en/Research_Tutorial/train_from_scratch.md @@ -0,0 +1,476 @@ +# Training Models from Scratch + +DiffSynth-Studio's training engine supports training foundation models from scratch. This article introduces how to train a small text-to-image model with only 0.1B parameters from scratch. + +## 1. Building Model Architecture + +### 1.1 Diffusion Model + +From UNet [[1]](https://arxiv.org/abs/1505.04597) [[2]](https://arxiv.org/abs/2112.10752) to DiT [[3]](https://arxiv.org/abs/2212.09748) [[4]](https://arxiv.org/abs/2403.03206), the mainstream model architectures of Diffusion have undergone multiple evolutions. Typically, a Diffusion model's inputs include: + +* Image tensor (`latents`): The encoding of images, generated by the VAE model, containing partial noise +* Text tensor (`prompt_embeds`): The encoding of text, generated by the text encoder +* Timestep (`timestep`): A scalar used to mark which stage of the Diffusion process we are currently at + +The model's output is a tensor with the same shape as the image tensor, representing the denoising direction predicted by the model. For details about Diffusion model theory, please refer to [Basic Principles of Diffusion Models](../Training/Understanding_Diffusion_models.md). In this article, we build a DiT model with only 0.1B parameters: `AAADiT`. + +
+Model Architecture Code + +```python +import torch, accelerate +from PIL import Image +from typing import Union +from tqdm import tqdm +from einops import rearrange, repeat + +from transformers import AutoProcessor, AutoTokenizer +from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model +from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task +from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit +from diffsynth.models.general_modules import TimestepEmbeddings +from diffsynth.models.z_image_text_encoder import ZImageTextEncoder +from diffsynth.models.flux2_vae import Flux2VAE + + +class AAAPositionalEmbedding(torch.nn.Module): + def __init__(self, height=16, width=16, dim=1024): + super().__init__() + self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width))) + self.text_emb = torch.nn.Parameter(torch.randn((dim,))) + + def forward(self, image, text): + height, width = image.shape[-2:] + image_emb = self.image_emb.to(device=image.device, dtype=image.dtype) + image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear") + image_emb = rearrange(image_emb, "B C H W -> B (H W) C") + text_emb = self.text_emb.to(device=text.device, dtype=text.dtype) + text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1]) + emb = torch.concat([image_emb, text_emb], dim=1) + return emb + + +class AAABlock(torch.nn.Module): + def __init__(self, dim=1024, num_heads=32): + super().__init__() + self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False) + self.to_q = torch.nn.Linear(dim, dim) + self.to_k = torch.nn.Linear(dim, dim) + self.to_v = torch.nn.Linear(dim, dim) + self.to_out = torch.nn.Linear(dim, dim) + self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False) + self.ff = torch.nn.Sequential( + torch.nn.Linear(dim, dim*3), + torch.nn.SiLU(), + torch.nn.Linear(dim*3, dim), + ) + self.to_gate = torch.nn.Linear(dim, dim * 2) + self.num_heads = num_heads + + def attention(self, emb, pos_emb): + emb = self.norm_attn(emb + pos_emb) + q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb) + emb = attention_forward( + q, k, v, + q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)", + dims={"n": self.num_heads}, + ) + emb = self.to_out(emb) + return emb + + def feed_forward(self, emb, pos_emb): + emb = self.norm_mlp(emb + pos_emb) + emb = self.ff(emb) + return emb + + def forward(self, emb, pos_emb, t_emb): + gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1) + emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn) + emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp) + return emb + + +class AAADiT(torch.nn.Module): + def __init__(self, dim=1024): + super().__init__() + self.pos_embedder = AAAPositionalEmbedding(dim=dim) + self.timestep_embedder = TimestepEmbeddings(256, dim) + self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim)) + self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim)) + self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)]) + self.proj_out = torch.nn.Linear(dim, 128) + + def forward( + self, + latents, + prompt_embeds, + timestep, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + pos_emb = self.pos_embedder(latents, prompt_embeds) + t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1) + image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C")) + text = self.text_embedder(prompt_embeds) + emb = torch.concat([image, text], dim=1) + for block_id, block in enumerate(self.blocks): + emb = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + emb=emb, + pos_emb=pos_emb, + t_emb=t_emb, + ) + emb = emb[:, :latents.shape[-1] * latents.shape[-2]] + emb = self.proj_out(emb) + emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1]) + return emb +``` + +
+ +### 1.2 Encoder-Decoder Models + +Besides the Diffusion model used for denoising, we also need two other models: + +* Text Encoder: Used to encode text into tensors. We adopt the [Qwen/Qwen3-0.6B](https://modelscope.cn/models/Qwen/Qwen3-0.6B) model. +* VAE Encoder-Decoder: The encoder part is used to encode images into tensors, and the decoder part is used to decode image tensors into images. We adopt the VAE model from [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B). + +The architectures of these two models are already integrated in DiffSynth-Studio, located at [/diffsynth/models/z_image_text_encoder.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/models/z_image_text_encoder.py) and [/diffsynth/models/flux2_vae.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/models/flux2_vae.py), so we don't need to modify any code. + +## 2. Building Pipeline + +We introduced how to build a model Pipeline in the document [Integrating Pipeline](../Developer_Guide/Building_a_Pipeline.md). For the model in this article, we also need to build a Pipeline to connect the text encoder, Diffusion model, and VAE encoder-decoder. + +
+Pipeline Code + +```python +class AAAImagePipeline(BasePipeline): + def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + self.scheduler = FlowMatchScheduler("FLUX.2") + self.text_encoder: ZImageTextEncoder = None + self.dit: AAADiT = None + self.vae: Flux2VAE = None + self.tokenizer: AutoProcessor = None + self.in_iteration_models = ("dit",) + self.units = [ + AAAUnit_PromptEmbedder(), + AAAUnit_NoiseInitializer(), + AAAUnit_InputImageEmbedder(), + ] + self.model_fn = model_fn_aaa + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = "cuda", + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = None, + vram_limit: float = None, + ): + # Initialize pipeline + pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder") + pipe.dit = model_pool.fetch_model("aaa_dit") + pipe.vae = model_pool.fetch_model("flux2_vae") + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 1.0, + # Image + input_image: Image.Image = None, + denoising_strength: float = 1.0, + # Shape + height: int = 1024, + width: int = 1024, + # Randomness + seed: int = None, + rand_device: str = "cpu", + # Steps + num_inference_steps: int = 30, + # Progress bar + progress_bar_cmd = tqdm, + ): + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16) + + # Parameters + inputs_posi = {"prompt": prompt} + inputs_nega = {"negative_prompt": negative_prompt} + inputs_shared = { + "cfg_scale": cfg_scale, + "input_image": input_image, "denoising_strength": denoising_strength, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "num_inference_steps": num_inference_steps, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae']) + image = self.vae.decode(inputs_shared["latents"]) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image + + +class AAAUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("prompt_embeds",), + onload_model_names=("text_encoder",) + ) + self.hidden_states_layers = (-1,) + + def process(self, pipe: AAAImagePipeline, prompt): + pipe.load_models_to_device(self.onload_model_names) + text = pipe.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device) + output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False) + prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1) + return {"prompt_embeds": prompt_embeds} + + +class AAAUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "seed", "rand_device"), + output_params=("noise",), + ) + + def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device): + noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + return {"noise": noise} + + +class AAAUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: AAAImagePipeline, input_image, noise): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae']) + image = pipe.preprocess_image(input_image) + input_latents = pipe.vae.encode(image) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": input_latents} + + +def model_fn_aaa( + dit: AAADiT, + latents=None, + prompt_embeds=None, + timestep=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + model_output = dit( + latents, + prompt_embeds, + timestep, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + return model_output +``` + +
+ +## 3. Preparing Dataset + +To quickly verify training effectiveness, we use the dataset [Pokemon-First Generation](https://modelscope.cn/datasets/DiffSynth-Studio/pokemon-gen1), which is reproduced from the open-source project [pokemon-dataset-zh](https://github.com/42arch/pokemon-dataset-zh), containing 151 first-generation Pokemon from Bulbasaur to Mew. If you want to use other datasets, please refer to the document [Preparing Datasets](../Pipeline_Usage/Model_Training.md#preparing-datasets) and [`diffsynth.core.data`](../API_Reference/core/data.md). + +```shell +modelscope download --dataset DiffSynth-Studio/pokemon-gen1 --local_dir ./data +``` + +### 4. Start Training + +The training process can be quickly implemented using Pipeline. We have placed the complete code at [../Research_Tutorial/train_from_scratch.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/docs/en/Research_Tutorial/train_from_scratch.py), which can be directly started with `python docs/en/Research_Tutorial/train_from_scratch.py` for single GPU training. + +To enable multi-GPU parallel training, please run `accelerate config` to set relevant parameters, then use the command `accelerate launch docs/en/Research_Tutorial/train_from_scratch.py` to start training. + +This training script has no stopping condition, please manually close it when needed. The model converges after training approximately 60,000 steps, requiring 10-20 hours for single GPU training. + +
+Training Code + +```python +class AAATrainingModule(DiffusionTrainingModule): + def __init__(self, device): + super().__init__() + self.pipe = AAAImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device=device, + model_configs=[ + ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"), + ) + self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device) + self.pipe.freeze_except(["dit"]) + self.pipe.scheduler.set_timesteps(1000, training=True) + + def forward(self, data): + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {"negative_prompt": ""} + inputs_shared = { + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], + "cfg_scale": 1, + "use_gradient_checkpointing": False, + "use_gradient_checkpointing_offload": False, + } + for unit in self.pipe.units: + inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) + loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi) + return loss + + +if __name__ == "__main__": + accelerator = accelerate.Accelerator(gradient_accumulation_steps=1) + dataset = UnifiedDataset( + base_path="data/images", + metadata_path="data/metadata_merged.csv", + max_data_items=10000000, + data_file_keys=("image",), + main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256) + ) + model = AAATrainingModule(device=accelerator.device) + model_logger = ModelLogger( + "models/AAA/v1", + remove_prefix_in_ckpt="pipe.dit.", + ) + launch_training_task( + accelerator, dataset, model, model_logger, + learning_rate=2e-4, + num_workers=4, + save_steps=50000, + num_epochs=999999, + ) +``` + +
+ +## 5. Verifying Training Results + +If you don't want to wait for the model training to complete, you can directly download [our pre-trained model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel). + +```shell +modelscope download --model DiffSynth-Studio/AAAMyModel step-600000.safetensors --local_dir models/DiffSynth-Studio/AAAMyModel +``` + +Loading the model + +```python +from diffsynth import load_model + +pipe = AAAImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"), +) +pipe.dit = load_model(AAADiT, "models/DiffSynth-Studio/AAAMyModel/step-600000.safetensors", torch_dtype=torch.bfloat16, device="cuda") +``` + +Model inference, generating the first-generation Pokemon "starter trio". At this point, the images generated by the model basically match the training data. + +```python +for seed, prompt in enumerate([ + "green, lizard, plant, Grass, Poison, seed on back, red eyes, smiling expression, short stout limbs, sharp claws", + "orange, cream, lizard, Fire, flame on tail tip, large eyes, smiling expression, cream-colored belly patch, sharp claws", + "blue, beige, brown, turtle, water type, shell, big eyes, short limbs, curled tail", +]): + image = pipe( + prompt=prompt, + negative_prompt=" ", + num_inference_steps=30, + cfg_scale=10, + seed=seed, + height=256, width=256, + ) + image.save(f"image_{seed}.jpg") +``` + +|![Image](https://github.com/user-attachments/assets/3c620fbf-5d28-4a1a-b887-519d85ac7d1c)|![Image](https://github.com/user-attachments/assets/909efd4c-9e61-4b33-9321-39da0e499b00)|![Image](https://github.com/user-attachments/assets/f3474bcd-b474-4a90-a1ea-579f67e161e3)| +|-|-|-| + +Model inference, generating Pokemon with "sharp claws". At this point, different random seeds can produce different image results. + +```python +for seed, prompt in enumerate([ + "sharp claws", + "sharp claws", + "sharp claws", +]): + image = pipe( + prompt=prompt, + negative_prompt=" ", + num_inference_steps=30, + cfg_scale=10, + seed=seed+4, + height=256, width=256, + ) + image.save(f"image_sharp_claws_{seed}.jpg") +``` + +|![Image](https://github.com/user-attachments/assets/94862edd-96ae-4276-a38f-795249f11a13)|![Image](https://github.com/user-attachments/assets/b2291f23-20ba-42de-8bfd-76cb4afc6eea)|![Image](https://github.com/user-attachments/assets/f2aab9a4-85ec-498e-8039-648b1289796e)| +|-|-|-| + +Now, we have obtained a 0.1B small text-to-image model. This model can already generate 151 Pokemon, but cannot generate other image content. If you increase the amount of data, model parameters, and number of GPUs based on this, you can train a more powerful text-to-image model! \ No newline at end of file diff --git a/docs/en/Research_Tutorial/train_from_scratch.py b/docs/en/Research_Tutorial/train_from_scratch.py new file mode 100644 index 0000000..328c24d --- /dev/null +++ b/docs/en/Research_Tutorial/train_from_scratch.py @@ -0,0 +1,341 @@ +import torch, accelerate +from PIL import Image +from typing import Union +from tqdm import tqdm +from einops import rearrange, repeat + +from transformers import AutoProcessor, AutoTokenizer +from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model +from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task +from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit +from diffsynth.models.general_modules import TimestepEmbeddings +from diffsynth.models.z_image_text_encoder import ZImageTextEncoder +from diffsynth.models.flux2_vae import Flux2VAE + + +class AAAPositionalEmbedding(torch.nn.Module): + def __init__(self, height=16, width=16, dim=1024): + super().__init__() + self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width))) + self.text_emb = torch.nn.Parameter(torch.randn((dim,))) + + def forward(self, image, text): + height, width = image.shape[-2:] + image_emb = self.image_emb.to(device=image.device, dtype=image.dtype) + image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear") + image_emb = rearrange(image_emb, "B C H W -> B (H W) C") + text_emb = self.text_emb.to(device=text.device, dtype=text.dtype) + text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1]) + emb = torch.concat([image_emb, text_emb], dim=1) + return emb + + +class AAABlock(torch.nn.Module): + def __init__(self, dim=1024, num_heads=32): + super().__init__() + self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False) + self.to_q = torch.nn.Linear(dim, dim) + self.to_k = torch.nn.Linear(dim, dim) + self.to_v = torch.nn.Linear(dim, dim) + self.to_out = torch.nn.Linear(dim, dim) + self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False) + self.ff = torch.nn.Sequential( + torch.nn.Linear(dim, dim*3), + torch.nn.SiLU(), + torch.nn.Linear(dim*3, dim), + ) + self.to_gate = torch.nn.Linear(dim, dim * 2) + self.num_heads = num_heads + + def attention(self, emb, pos_emb): + emb = self.norm_attn(emb + pos_emb) + q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb) + emb = attention_forward( + q, k, v, + q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)", + dims={"n": self.num_heads}, + ) + emb = self.to_out(emb) + return emb + + def feed_forward(self, emb, pos_emb): + emb = self.norm_mlp(emb + pos_emb) + emb = self.ff(emb) + return emb + + def forward(self, emb, pos_emb, t_emb): + gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1) + emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn) + emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp) + return emb + + +class AAADiT(torch.nn.Module): + def __init__(self, dim=1024): + super().__init__() + self.pos_embedder = AAAPositionalEmbedding(dim=dim) + self.timestep_embedder = TimestepEmbeddings(256, dim) + self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim)) + self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim)) + self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)]) + self.proj_out = torch.nn.Linear(dim, 128) + + def forward( + self, + latents, + prompt_embeds, + timestep, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + pos_emb = self.pos_embedder(latents, prompt_embeds) + t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1) + image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C")) + text = self.text_embedder(prompt_embeds) + emb = torch.concat([image, text], dim=1) + for block_id, block in enumerate(self.blocks): + emb = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + emb=emb, + pos_emb=pos_emb, + t_emb=t_emb, + ) + emb = emb[:, :latents.shape[-1] * latents.shape[-2]] + emb = self.proj_out(emb) + emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1]) + return emb + + +class AAAImagePipeline(BasePipeline): + def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + self.scheduler = FlowMatchScheduler("FLUX.2") + self.text_encoder: ZImageTextEncoder = None + self.dit: AAADiT = None + self.vae: Flux2VAE = None + self.tokenizer: AutoProcessor = None + self.in_iteration_models = ("dit",) + self.units = [ + AAAUnit_PromptEmbedder(), + AAAUnit_NoiseInitializer(), + AAAUnit_InputImageEmbedder(), + ] + self.model_fn = model_fn_aaa + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = "cuda", + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = None, + vram_limit: float = None, + ): + # Initialize pipeline + pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder") + pipe.dit = model_pool.fetch_model("aaa_dit") + pipe.vae = model_pool.fetch_model("flux2_vae") + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 1.0, + # Image + input_image: Image.Image = None, + denoising_strength: float = 1.0, + # Shape + height: int = 1024, + width: int = 1024, + # Randomness + seed: int = None, + rand_device: str = "cpu", + # Steps + num_inference_steps: int = 30, + # Progress bar + progress_bar_cmd = tqdm, + ): + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16) + + # Parameters + inputs_posi = {"prompt": prompt} + inputs_nega = {"negative_prompt": negative_prompt} + inputs_shared = { + "cfg_scale": cfg_scale, + "input_image": input_image, "denoising_strength": denoising_strength, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "num_inference_steps": num_inference_steps, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae']) + image = self.vae.decode(inputs_shared["latents"]) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image + + +class AAAUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("prompt_embeds",), + onload_model_names=("text_encoder",) + ) + self.hidden_states_layers = (-1,) + + def process(self, pipe: AAAImagePipeline, prompt): + pipe.load_models_to_device(self.onload_model_names) + text = pipe.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device) + output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False) + prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1) + return {"prompt_embeds": prompt_embeds} + + +class AAAUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "seed", "rand_device"), + output_params=("noise",), + ) + + def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device): + noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + return {"noise": noise} + + +class AAAUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: AAAImagePipeline, input_image, noise): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae']) + image = pipe.preprocess_image(input_image) + input_latents = pipe.vae.encode(image) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": input_latents} + + +def model_fn_aaa( + dit: AAADiT, + latents=None, + prompt_embeds=None, + timestep=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + model_output = dit( + latents, + prompt_embeds, + timestep, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + return model_output + + +class AAATrainingModule(DiffusionTrainingModule): + def __init__(self, device): + super().__init__() + self.pipe = AAAImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device=device, + model_configs=[ + ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"), + ) + self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device) + self.pipe.freeze_except(["dit"]) + self.pipe.scheduler.set_timesteps(1000, training=True) + + def forward(self, data): + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {"negative_prompt": ""} + inputs_shared = { + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], + "cfg_scale": 1, + "use_gradient_checkpointing": False, + "use_gradient_checkpointing_offload": False, + } + for unit in self.pipe.units: + inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) + loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi) + return loss + + +if __name__ == "__main__": + accelerator = accelerate.Accelerator(gradient_accumulation_steps=1) + dataset = UnifiedDataset( + base_path="data/images", + metadata_path="data/metadata_merged.csv", + max_data_items=10000000, + data_file_keys=("image",), + main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256) + ) + model = AAATrainingModule(device=accelerator.device) + model_logger = ModelLogger( + "models/AAA/v1", + remove_prefix_in_ckpt="pipe.dit.", + ) + launch_training_task( + accelerator, dataset, model, model_logger, + learning_rate=2e-4, + num_workers=4, + save_steps=50000, + num_epochs=999999, + ) \ No newline at end of file diff --git a/docs/en/Training/Differential_LoRA.md b/docs/en/Training/Differential_LoRA.md index febe507..67bea80 100644 --- a/docs/en/Training/Differential_LoRA.md +++ b/docs/en/Training/Differential_LoRA.md @@ -8,8 +8,8 @@ We were unable to identify the original proposer of differential LoRA training, Assume we have two similar-content images: Image 1 and Image 2. For example, both images contain a car, but Image 1 has fewer details while Image 2 has more details. In differential LoRA training, we perform two-step training: -* Train LoRA 1 using Image 1 as training data with [standard supervised training](/docs/en/Training/Supervised_Fine_Tuning.md) -* Train LoRA 2 using Image 2 as training data, after integrating LoRA 1 into the base model, with [standard supervised training](/docs/en/Training/Supervised_Fine_Tuning.md) +* Train LoRA 1 using Image 1 as training data with [standard supervised training](../Training/Supervised_Fine_Tuning.md) +* Train LoRA 2 using Image 2 as training data, after integrating LoRA 1 into the base model, with [standard supervised training](../Training/Supervised_Fine_Tuning.md) In the first training step, since there is only one training image, the LoRA model easily overfits. Therefore, after training, LoRA 1 will cause the model to generate Image 1 without hesitation, regardless of the random seed. In the second training step, the LoRA model overfits again. Thus, after training, with the combined effect of LoRA 1 and LoRA 2, the model will generate Image 2 without hesitation. In short: diff --git a/docs/en/Training/Direct_Distill.md b/docs/en/Training/Direct_Distill.md index 4cbeb59..34cfabb 100644 --- a/docs/en/Training/Direct_Distill.md +++ b/docs/en/Training/Direct_Distill.md @@ -44,7 +44,7 @@ Click on the model links to go to the model pages and view the model effects. ## Using Distillation Accelerated Training in the Training Framework -First, you need to generate training data. Please refer to the [Model Inference](/docs/en/Pipeline_Usage/Model_Inference.md) section to write inference code and generate training data with a sufficient number of inference steps. +First, you need to generate training data. Please refer to the [Model Inference](../Pipeline_Usage/Model_Inference.md) section to write inference code and generate training data with a sufficient number of inference steps. Taking Qwen-Image as an example, the following code can generate an image: @@ -67,7 +67,7 @@ image = pipe(prompt, seed=0, num_inference_steps=40) image.save("image.jpg") ``` -Then, we compile the necessary information into [metadata files](/docs/en/API_Reference/core/data.md#metadata): +Then, we compile the necessary information into [metadata files](../API_Reference/core/data.md#metadata): ```csv image,prompt,seed,rand_device,num_inference_steps,cfg_scale @@ -86,11 +86,11 @@ Then start LoRA distillation accelerated training: bash examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh ``` -Please note that in the [training script parameters](/docs/en/Pipeline_Usage/Model_Training.md#script-parameters), the image resolution setting for the dataset should avoid triggering scaling processing. When setting `--height` and `--width` to enable fixed resolution, all training data must be generated with exactly the same width and height. When setting `--max_pixels` to enable dynamic resolution, the value of `--max_pixels` must be greater than or equal to the pixel area of any training image. +Please note that in the [training script parameters](../Pipeline_Usage/Model_Training.md#script-parameters), the image resolution setting for the dataset should avoid triggering scaling processing. When setting `--height` and `--width` to enable fixed resolution, all training data must be generated with exactly the same width and height. When setting `--max_pixels` to enable dynamic resolution, the value of `--max_pixels` must be greater than or equal to the pixel area of any training image. ## Framework Design Concept -Compared to [Standard Supervised Training](/docs/en/Training/Supervised_Fine_Tuning.md), Direct Distillation only differs in the training loss function. The loss function for Direct Distillation is `DirectDistillLoss` in `diffsynth.diffusion.loss`. +Compared to [Standard Supervised Training](../Training/Supervised_Fine_Tuning.md), Direct Distillation only differs in the training loss function. The loss function for Direct Distillation is `DirectDistillLoss` in `diffsynth.diffusion.loss`. ## Future Work diff --git a/docs/en/Training/FP8_Precision.md b/docs/en/Training/FP8_Precision.md index 5f23abb..b7913b7 100644 --- a/docs/en/Training/FP8_Precision.md +++ b/docs/en/Training/FP8_Precision.md @@ -1,12 +1,12 @@ # Enabling FP8 Precision in Training -Although `DiffSynth-Studio` supports [VRAM management](/docs/en/Pipeline_Usage/VRAM_management.md) in model inference, most of the techniques for reducing VRAM usage are not suitable for training. Offloading would cause extremely slow training processes. +Although `DiffSynth-Studio` supports [VRAM management](../Pipeline_Usage/VRAM_management.md) in model inference, most of the techniques for reducing VRAM usage are not suitable for training. Offloading would cause extremely slow training processes. -FP8 precision is the only VRAM management strategy that can be enabled during training. However, this framework currently does not support native FP8 precision training. For reasons, see [Q&A: Why doesn't the training framework support native FP8 precision training?](/docs/en/QA.md#why-doesnt-the-training-framework-support-native-fp8-precision-training). It only supports storing models whose parameters are not updated by gradients (models that do not require gradient backpropagation, or whose gradients only update their LoRA) in FP8 precision. +FP8 precision is the only VRAM management strategy that can be enabled during training. However, this framework currently does not support native FP8 precision training. For reasons, see [Q&A: Why doesn't the training framework support native FP8 precision training?](../QA.md#why-doesnt-the-training-framework-support-native-fp8-precision-training). It only supports storing models whose parameters are not updated by gradients (models that do not require gradient backpropagation, or whose gradients only update their LoRA) in FP8 precision. ## Enabling FP8 -In our provided training scripts, you can quickly set models to be stored in FP8 precision through the `--fp8_models` parameter. Taking Qwen-Image LoRA training as an example, we provide a script for enabling FP8 training located at [`/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh`](/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh). After training is completed, you can verify the training results with the script [`/examples/qwen_image/model_training/special/fp8_training/validate.py`](/examples/qwen_image/model_training/special/fp8_training/validate.py). +In our provided training scripts, you can quickly set models to be stored in FP8 precision through the `--fp8_models` parameter. Taking Qwen-Image LoRA training as an example, we provide a script for enabling FP8 training located at [`/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh). After training is completed, you can verify the training results with the script [`/examples/qwen_image/model_training/special/fp8_training/validate.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/special/fp8_training/validate.py). Please note that this FP8 VRAM management strategy does not support gradient updates. When a model is set to be trainable, FP8 precision cannot be enabled for that model. Models that support FP8 include two types: diff --git a/docs/en/Training/Split_Training.md b/docs/en/Training/Split_Training.md index 07068d2..16a6b4f 100644 --- a/docs/en/Training/Split_Training.md +++ b/docs/en/Training/Split_Training.md @@ -8,7 +8,7 @@ This document introduces split training, which can automatically divide the trai In the training process of most models, a large amount of computation occurs in "preprocessing," i.e., "computations unrelated to the denoising model," including VAE encoding, text encoding, etc. When the corresponding model parameters are fixed, the results of these computations are repetitive. For each data sample, the computational results are identical across multiple epochs. Therefore, we provide a "split training" feature that can automatically analyze and split the training process. -For standard supervised training of ordinary text-to-image models, the splitting process is straightforward. It only requires splitting the computation of all [`Pipeline Units`](/docs/en/Developer_Guide/Building_a_Pipeline.md#units) into the first stage, storing the computational results to disk, and then reading these results from disk in the second stage for subsequent computations. However, if gradient backpropagation is required during preprocessing, the situation becomes extremely complex. To address this, we introduced a computational graph splitting algorithm to analyze how to split the computation. +For standard supervised training of ordinary text-to-image models, the splitting process is straightforward. It only requires splitting the computation of all [`Pipeline Units`](../Developer_Guide/Building_a_Pipeline.md#units) into the first stage, storing the computational results to disk, and then reading these results from disk in the second stage for subsequent computations. However, if gradient backpropagation is required during preprocessing, the situation becomes extremely complex. To address this, we introduced a computational graph splitting algorithm to analyze how to split the computation. ## Computational Graph Splitting Algorithm @@ -16,7 +16,7 @@ For standard supervised training of ordinary text-to-image models, the splitting ## Using Split Training -Split training already supports [Standard Supervised Training](/docs/en/Training/Supervised_Fine_Tuning.md) and [Direct Distillation Training](/docs/en/Training/Direct_Distill.md). The `--task` parameter in the training command controls this. Taking LoRA training of the Qwen-Image model as an example, the pre-split training command is: +Split training already supports [Standard Supervised Training](../Training/Supervised_Fine_Tuning.md) and [Direct Distillation Training](../Training/Direct_Distill.md). The `--task` parameter in the training command controls this. Taking LoRA training of the Qwen-Image model as an example, the pre-split training command is: ```shell accelerate launch examples/qwen_image/model_training/train.py \ diff --git a/docs/en/Training/Supervised_Fine_Tuning.md b/docs/en/Training/Supervised_Fine_Tuning.md index fd29c10..e534c07 100644 --- a/docs/en/Training/Supervised_Fine_Tuning.md +++ b/docs/en/Training/Supervised_Fine_Tuning.md @@ -1,10 +1,10 @@ # Standard Supervised Training -After understanding the [Basic Principles of Diffusion Models](/docs/en/Training/Understanding_Diffusion_models.md), this document introduces how the framework implements Diffusion model training. This document explains the framework's principles to help developers write new training code. If you want to use our provided default training functions, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md). +After understanding the [Basic Principles of Diffusion Models](../Training/Understanding_Diffusion_models.md), this document introduces how the framework implements Diffusion model training. This document explains the framework's principles to help developers write new training code. If you want to use our provided default training functions, please refer to [Model Training](../Pipeline_Usage/Model_Training.md). Recalling the model training pseudocode from earlier, when we actually write code, the situation becomes extremely complex. Some models require additional guidance conditions and preprocessing, such as ControlNet; some models require cross-computation with the denoising model, such as VACE; some models require Gradient Checkpointing due to excessive VRAM demands, such as Qwen-Image's DiT. -To achieve strict consistency between inference and training, we abstractly encapsulate components like `Pipeline`, reusing inference code extensively during training. Please refer to [Integrating Pipeline](/docs/en/Developer_Guide/Building_a_Pipeline.md) to understand the design of `Pipeline` components. Next, we'll introduce how the training framework utilizes `Pipeline` components to build training algorithms. +To achieve strict consistency between inference and training, we abstractly encapsulate components like `Pipeline`, reusing inference code extensively during training. Please refer to [Integrating Pipeline](../Developer_Guide/Building_a_Pipeline.md) to understand the design of `Pipeline` components. Next, we'll introduce how the training framework utilizes `Pipeline` components to build training algorithms. ## Framework Design Concept @@ -48,13 +48,13 @@ In `__init__`, model initialization is required. First load the model, then swit ) ``` -The logic for loading models is basically consistent with inference, supporting loading models from remote and local paths. See [Model Inference](/docs/en/Pipeline_Usage/Model_Inference.md) for details, but please note not to enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). +The logic for loading models is basically consistent with inference, supporting loading models from remote and local paths. See [Model Inference](../Pipeline_Usage/Model_Inference.md) for details, but please note not to enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). `switch_pipe_to_training_mode` can switch the model to training mode. See `switch_pipe_to_training_mode` for details. ### `forward` -In `forward`, the loss function value needs to be calculated. First perform preprocessing, then compute the loss function through the `Pipeline`'s [`model_fn`](/docs/en/Developer_Guide/Building_a_Pipeline.md#model_fn). +In `forward`, the loss function value needs to be calculated. First perform preprocessing, then compute the loss function through the `Pipeline`'s [`model_fn`](../Developer_Guide/Building_a_Pipeline.md#model_fn). ```python def forward(self, data): @@ -90,7 +90,7 @@ The loss function calculation reuses `FlowMatchSFTLoss` from `diffsynth.diffusio The training framework requires other modules, including: * accelerator: Training launcher provided by `accelerate`, see [`accelerate`](https://huggingface.co/docs/accelerate/index) for details -* dataset: Generic dataset, see [`diffsynth.core.data`](/docs/en/API_Reference/core/data.md) for details +* dataset: Generic dataset, see [`diffsynth.core.data`](../API_Reference/core/data.md) for details * model_logger: Model logger, see `diffsynth.diffusion.logger` for details ```python diff --git a/docs/en/Training/Understanding_Diffusion_models.md b/docs/en/Training/Understanding_Diffusion_models.md index 5c81b6a..7704dd2 100644 --- a/docs/en/Training/Understanding_Diffusion_models.md +++ b/docs/en/Training/Understanding_Diffusion_models.md @@ -6,7 +6,7 @@ This document introduces the basic principles of Diffusion models to help you un Diffusion models generate clear images or video content through iterative denoising. We start by explaining the generation process of a data sample $x_0$. Intuitively, in a complete round of denoising, we start from random Gaussian noise $x_T$ and iteratively obtain $x_{T-1}$, $x_{T-2}$, $x_{T-3}$, $\cdots$, gradually reducing the noise content at each step until we finally obtain the noise-free data sample $x_0$. -(Figure) +![Image](https://github.com/user-attachments/assets/6471ae4c-a635-4924-8b36-b0bd4d42043d) This process is intuitive, but to understand the details, we need to answer several questions: @@ -28,7 +28,7 @@ As for the intermediate values $\sigma_{T-1}$, $\sigma_{T-2}$, $\cdots$, $\sigma At an intermediate step, we can directly synthesize noisy data samples $x_t=(1-\sigma_t)x_0+\sigma_t x_T$. -(Figure) +![Image](https://github.com/user-attachments/assets/e25a2f71-123c-4e18-8b34-3a066af15667) ## How is the iterative denoising computation performed? @@ -40,8 +40,6 @@ Before understanding the iterative denoising computation, we need to clarify wha Among these, the guidance condition $c$ is a newly introduced parameter that is input by the user. It can be text describing the image content or a sketch outlining the image structure. -(Figure) - The model's output $\hat \epsilon(x_t,c,t)$ approximately equals $x_T-x_0$, which is the direction of the entire diffusion process (the reverse process of denoising). Next, we analyze the computation occurring in one iteration. At time step $t$, after the model computes an approximation of $x_T-x_0$, we calculate the next $x_{t-1}$: @@ -91,8 +89,6 @@ After understanding the iterative denoising process, we next consider how to tra The training process differs from the generation process. If we retain multi-step iterations during training, the gradient would need to backpropagate through multiple steps, bringing catastrophic time and space complexity. To improve computational efficiency, we randomly select a time step $t$ for training. -(Figure) - The following is pseudocode for the training process: > Obtain data sample $x_0$ and guidance condition $c$ from the dataset @@ -113,7 +109,7 @@ The following is pseudocode for the training process: From theory to practice, more details need to be filled in. Modern Diffusion model architectures have matured, with mainstream architectures following the "three-stage" architecture proposed by Latent Diffusion, including data encoder-decoder, guidance condition encoder, and denoising model. -(Figure) +![Image](https://github.com/user-attachments/assets/43855430-6427-4aca-83a0-f684e01438b1) ### Data Encoder-Decoder @@ -142,4 +138,4 @@ The denoising model is the true essence of Diffusion models, with diverse model ## How does this project encapsulate and implement model training? -Please read the next document: [Standard Supervised Training](/docs/en/Training/Supervised_Fine_Tuning.md) \ No newline at end of file +Please read the next document: [Standard Supervised Training](../Training/Supervised_Fine_Tuning.md) \ No newline at end of file diff --git a/docs/en/conf.py b/docs/en/conf.py new file mode 100644 index 0000000..2005124 --- /dev/null +++ b/docs/en/conf.py @@ -0,0 +1,124 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys + +# import sphinx_book_theme + +sys.path.insert(0, os.path.abspath('../../')) +# -- Project information ----------------------------------------------------- + +project = 'diffsynth' +copyright = '2022-2025, Alibaba ModelScope' +author = 'ModelScope Authors' +version_file = '../../diffsynth/version.py' +html_theme = 'sphinx_rtd_theme' +language = 'en' + + +def get_version(): + with open(version_file, 'r', encoding='utf-8') as f: + exec(compile(f.read(), version_file, 'exec')) + return locals()['__version__'] + + +# The full version, including alpha/beta/rc tags +version = get_version() +release = version + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.napoleon', + 'sphinx.ext.autosummary', + 'sphinx.ext.autodoc', + 'sphinx.ext.viewcode', + 'sphinx_markdown_tables', + 'sphinx_copybutton', + "sphinx_rtd_theme", + 'sphinx.ext.mathjax', + 'myst_parser', +] +# build the templated autosummary files +autosummary_generate = True +numpydoc_show_class_members = False + +# Enable overriding of function signatures in the first line of the docstring. +autodoc_docstring_signature = True + +# Disable docstring inheritance +autodoc_inherit_docstrings = False + +# Show type hints in the description +autodoc_typehints = 'description' + +# Add parameter types if the parameter is documented in the docstring +autodoc_typehints_description_target = 'documented_params' + +autodoc_default_options = { + 'member-order': 'bysource', +} + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +source_suffix = ['.rst', '.md'] + +# The master toctree document. +root_doc = 'index' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ['build'] +# A list of glob-style patterns [1] that are used to find source files. +# They are matched against the source file names relative to the source directory, +# using slashes as directory separators on all platforms. +# The default is **, meaning that all files are recursively included from the source directory. +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +# html_theme = 'sphinx_book_theme' +# html_theme_path = [sphinx_book_theme.get_html_theme_path()] +# html_theme_options = {} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] +# html_css_files = ['css/readthedocs.css'] + +# -- Options for HTMLHelp output --------------------------------------------- +# Output file base name for HTML help builder. + +# -- Extension configuration ------------------------------------------------- +# Ignore >>> when copying code +copybutton_prompt_text = r'>>> |\.\.\. ' +copybutton_prompt_is_regexp = True + +# Example configuration for intersphinx: refer to the Python standard library. +intersphinx_mapping = {'https://docs.python.org/': None} + +myst_enable_extensions = [ + 'amsmath', + 'dollarmath', + 'colon_fence', +] diff --git a/docs/en/index.rst b/docs/en/index.rst new file mode 100644 index 0000000..ab195ef --- /dev/null +++ b/docs/en/index.rst @@ -0,0 +1,77 @@ +Welcome to DiffSynth-Studio's Documentation +========================================== + +.. toctree:: + :maxdepth: 2 + :caption: Documentation Introduction + + README + +.. toctree:: + :maxdepth: 2 + :caption: Getting Started + + Pipeline_Usage/Setup + Pipeline_Usage/Model_Inference + Pipeline_Usage/VRAM_management + Pipeline_Usage/Model_Training + Pipeline_Usage/Environment_Variables + Pipeline_Usage/GPU_support + +.. toctree:: + :maxdepth: 2 + :caption: Model Details + + Model_Details/FLUX + Model_Details/Wan + Model_Details/Qwen-Image + Model_Details/FLUX2 + Model_Details/Z-Image + +.. toctree:: + :maxdepth: 2 + :caption: Training Framework + + Training/Understanding_Diffusion_models + Training/Supervised_Fine_Tuning + Training/FP8_Precision + Training/Direct_Distill + Training/Split_Training + Training/Differential_LoRA + +.. toctree:: + :maxdepth: 2 + :caption: Model Integration + + Developer_Guide/Integrating_Your_Model + Developer_Guide/Building_a_Pipeline + Developer_Guide/Enabling_VRAM_management + Developer_Guide/Training_Diffusion_Models + +.. toctree:: + :maxdepth: 2 + :caption: API Reference + + API_Reference/core/attention + API_Reference/core/data + API_Reference/core/gradient + API_Reference/core/loader + API_Reference/core/vram + +.. toctree:: + :maxdepth: 2 + :caption: Research Guide + + Research_Tutorial/train_from_scratch + +.. toctree:: + :maxdepth: 2 + :caption: FAQ + + QA + +Indices and tables +================== +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000..e002209 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,11 @@ +docutils>=0.16.0 +myst_parser +recommonmark +sphinx>=5.3.0 +sphinx-book-theme +sphinx-copybutton +sphinx-autobuild +sphinx-rtd-theme +sphinx_markdown_tables +sphinxcontrib-mermaid +pymdown-extensions \ No newline at end of file diff --git a/docs/zh/.readthedocs.yaml b/docs/zh/.readthedocs.yaml new file mode 100644 index 0000000..0b1ab71 --- /dev/null +++ b/docs/zh/.readthedocs.yaml @@ -0,0 +1,28 @@ +# .readthedocs.yaml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the OS, Python version and other tools you might need +build: + os: ubuntu-22.04 + tools: + python: "3.10" + +# Build documentation in the "docs/" directory with Sphinx +sphinx: + configuration: docs/zh/conf.py + +# Optionally build your docs in additional formats such as PDF and ePub +# formats: +# - pdf +# - epub + +# Optional but recommended, declare the Python requirements required +# to build your documentation +# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html +python: + install: + - requirements: docs/requirements.txt diff --git a/docs/zh/API_Reference/core/attention.md b/docs/zh/API_Reference/core/attention.md index c30e180..1b6ce83 100644 --- a/docs/zh/API_Reference/core/attention.md +++ b/docs/zh/API_Reference/core/attention.md @@ -1,6 +1,6 @@ # `diffsynth.core.attention`: 注意力机制实现 -`diffsynth.core.attention` 提供了注意力机制实现的路由机制,根据 `Python` 环境中的可用包和[环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation)自动选择高效的注意力机制实现。 +`diffsynth.core.attention` 提供了注意力机制实现的路由机制,根据 `Python` 环境中的可用包和[环境变量](../../Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation)自动选择高效的注意力机制实现。 ## 注意力机制 @@ -46,7 +46,7 @@ output_1 = attention(query, key, value) * xFormers:[GitHub](https://github.com/facebookresearch/xformers)、[文档](https://facebookresearch.github.io/xformers/components/ops.html#module-xformers.ops) * PyTorch:[GitHub](https://github.com/pytorch/pytorch)、[文档](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) -如需调用除 `PyTorch` 外的其他注意力实现,请按照其 GitHub 页面的指引安装对应的包。`DiffSynth-Studio` 会自动根据 Python 环境中的可用包路由到对应的实现上,也可通过[环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation)控制。 +如需调用除 `PyTorch` 外的其他注意力实现,请按照其 GitHub 页面的指引安装对应的包。`DiffSynth-Studio` 会自动根据 Python 环境中的可用包路由到对应的实现上,也可通过[环境变量](../../Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation)控制。 ```python from diffsynth.core.attention import attention_forward diff --git a/docs/zh/API_Reference/core/loader.md b/docs/zh/API_Reference/core/loader.md index ad2d245..e30ef9c 100644 --- a/docs/zh/API_Reference/core/loader.md +++ b/docs/zh/API_Reference/core/loader.md @@ -8,9 +8,9 @@ ### 从远程下载并加载模型 -以模型[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) 为例,在 `ModelConfig` 中填写 `model_id` 和 `origin_file_pattern` 后即可自动下载模型。默认下载到 `./models` 路径,该路径可通过[环境变量 DIFFSYNTH_MODEL_BASE_PATH](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path) 修改。 +以模型[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) 为例,在 `ModelConfig` 中填写 `model_id` 和 `origin_file_pattern` 后即可自动下载模型。默认下载到 `./models` 路径,该路径可通过[环境变量 DIFFSYNTH_MODEL_BASE_PATH](../../Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path) 修改。 -默认情况下,即使模型已经下载完毕,程序仍会向远程查询是否有遗漏文件,如果要完全关闭远程请求,请将[环境变量 DIFFSYNTH_SKIP_DOWNLOAD](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) 设置为 `True`。 +默认情况下,即使模型已经下载完毕,程序仍会向远程查询是否有遗漏文件,如果要完全关闭远程请求,请将[环境变量 DIFFSYNTH_SKIP_DOWNLOAD](../../Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) 设置为 `True`。 ```python from diffsynth.core import ModelConfig @@ -51,7 +51,7 @@ config = ModelConfig(path=[ ### 显存管理配置 -`ModelConfig` 也包含了显存管理配置信息,详见[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md#更多使用方式)。 +`ModelConfig` 也包含了显存管理配置信息,详见[显存管理](../../Pipeline_Usage/VRAM_management.md#更多使用方式)。 ## 模型文件加载 @@ -103,11 +103,11 @@ print(hash_model_file([ 模型哈希值只与模型文件中 state dict 的 keys 和 tensor shape 有关,与模型参数的数值、文件保存时间等信息无关。在计算 `.safetensors` 格式文件的模型哈希值时,`hash_model_file` 是几乎瞬间完成的,无需读取模型的参数;但在计算 `.bin`、`.pth`、`.ckpt` 等二进制文件的模型哈希值时,则需要读取全部模型参数,因此**我们不建议开发者继续使用这些格式的文件。** -通过[编写模型 Config](/docs/zh/Developer_Guide/Integrating_Your_Model.md#step-3-编写模型-config)并将模型哈希值等信息填入 `diffsynth/configs/model_configs.py`,开发者可以让 `DiffSynth-Studio` 自动识别模型类型并加载。 +通过[编写模型 Config](../../Developer_Guide/Integrating_Your_Model.md#step-3-编写模型-config)并将模型哈希值等信息填入 `diffsynth/configs/model_configs.py`,开发者可以让 `DiffSynth-Studio` 自动识别模型类型并加载。 ## 模型加载 -`load_model` 是 `diffsynth.core.loader` 中加载模型的外部入口,它会调用 [skip_model_initialization](/docs/zh/API_Reference/core/vram.md#跳过模型参数初始化) 跳过模型参数初始化。如果启用了 [Disk Offload](/docs/zh/Pipeline_Usage/VRAM_management.md#disk-offload),则调用 [DiskMap](/docs/zh/API_Reference/core/vram.md#state-dict-硬盘映射) 进行惰性加载;如果没有启用 Disk Offload,则调用 [load_state_dict](#模型文件加载) 加载模型参数。如果需要的话,还会调用 [state dict converter](/docs/zh/Developer_Guide/Integrating_Your_Model.md#step-2-模型文件格式转换) 进行模型格式转换。最后调用 `model.eval()` 将其切换到推理模式。 +`load_model` 是 `diffsynth.core.loader` 中加载模型的外部入口,它会调用 [skip_model_initialization](../../API_Reference/core/vram.md#跳过模型参数初始化) 跳过模型参数初始化。如果启用了 [Disk Offload](../../Pipeline_Usage/VRAM_management.md#disk-offload),则调用 [DiskMap](../../API_Reference/core/vram.md#state-dict-硬盘映射) 进行惰性加载;如果没有启用 Disk Offload,则调用 [load_state_dict](#模型文件加载) 加载模型参数。如果需要的话,还会调用 [state dict converter](../../Developer_Guide/Integrating_Your_Model.md#step-2-模型文件格式转换) 进行模型格式转换。最后调用 `model.eval()` 将其切换到推理模式。 以下是一个启用了 Disk Offload 的使用案例: diff --git a/docs/zh/API_Reference/core/vram.md b/docs/zh/API_Reference/core/vram.md index f79b9da..d97a516 100644 --- a/docs/zh/API_Reference/core/vram.md +++ b/docs/zh/API_Reference/core/vram.md @@ -31,7 +31,7 @@ state_dict = load_state_dict(path, device="cpu") model.load_state_dict(state_dict, assign=True) ``` -在 `DiffSynth-Studio` 中,所有预训练模型都遵循这一加载逻辑。开发者在[接入模型](/docs/zh/Developer_Guide/Integrating_Your_Model.md)完毕后即可直接以这种方式快速加载模型。 +在 `DiffSynth-Studio` 中,所有预训练模型都遵循这一加载逻辑。开发者在[接入模型](../../Developer_Guide/Integrating_Your_Model.md)完毕后即可直接以这种方式快速加载模型。 ## State Dict 硬盘映射 @@ -57,10 +57,10 @@ state_dict = DiskMap(path, device="cpu") # Fast print(state_dict["img_in.weight"]) ``` -`DiskMap` 是 `DiffSynth-Studio` 中 Disk Offload 的基本组件,开发者在[配置细粒度显存管理方案](/docs/zh/Developer_Guide/Enabling_VRAM_management.md)后即可直接启用 Disk Offload。 +`DiskMap` 是 `DiffSynth-Studio` 中 Disk Offload 的基本组件,开发者在[配置细粒度显存管理方案](../../Developer_Guide/Enabling_VRAM_management.md)后即可直接启用 Disk Offload。 `DiskMap` 是利用 `.safetensors` 文件的特性实现的功能,因此在使用 `.bin`、`.pth`、`.ckpt` 等二进制文件时,模型的参数是全量加载的,这也导致 Disk Offload 不支持这些格式的文件。**我们不建议开发者继续使用这些格式的文件。** ## 显存管理可替换模块 -在启用 `DiffSynth-Studio` 的显存管理后,模型内部的模块会被替换为 `diffsynth.core.vram.layers` 中的可替换模块,其使用方式详见[细粒度显存管理方案](/docs/zh/Developer_Guide/Enabling_VRAM_management.md#编写细粒度显存管理方案)。 +在启用 `DiffSynth-Studio` 的显存管理后,模型内部的模块会被替换为 `diffsynth.core.vram.layers` 中的可替换模块,其使用方式详见[细粒度显存管理方案](../../Developer_Guide/Enabling_VRAM_management.md#编写细粒度显存管理方案)。 diff --git a/docs/zh/Developer_Guide/Building_a_Pipeline.md b/docs/zh/Developer_Guide/Building_a_Pipeline.md index cac5b62..4b16b74 100644 --- a/docs/zh/Developer_Guide/Building_a_Pipeline.md +++ b/docs/zh/Developer_Guide/Building_a_Pipeline.md @@ -1,6 +1,6 @@ # 接入 Pipeline -在[将 Pipeline 所需的模型接入](/docs/zh/Developer_Guide/Integrating_Your_Model.md)之后,还需构建 `Pipeline` 用于模型推理,本文档提供 `Pipeline` 构建的标准化流程,开发者也可参考现有的 `Pipeline` 进行构建。 +在[将 Pipeline 所需的模型接入](../Developer_Guide/Integrating_Your_Model.md)之后,还需构建 `Pipeline` 用于模型推理,本文档提供 `Pipeline` 构建的标准化流程,开发者也可参考现有的 `Pipeline` 进行构建。 `Pipeline` 的实现位于 `diffsynth/pipelines`,每个 `Pipeline` 包含以下必要的关键组件: @@ -79,7 +79,7 @@ class NewDiffSynthPipeline(BasePipeline): return pipe ``` -开发者需要实现其中获取模型的逻辑,对应的模型名称即为[模型接入时填写的模型 Config](/docs/zh/Developer_Guide/Integrating_Your_Model.md#step-3-编写模型-config) 中的 `"model_name"`。 +开发者需要实现其中获取模型的逻辑,对应的模型名称即为[模型接入时填写的模型 Config](../Developer_Guide/Integrating_Your_Model.md#step-3-编写模型-config) 中的 `"model_name"`。 部分模型还需要加载 `tokenizer`,可根据需要在 `from_pretrained` 上添加额外的 `tokenizer_config` 参数并在获取模型后实现这部分。 diff --git a/docs/zh/Developer_Guide/Enabling_VRAM_management.md b/docs/zh/Developer_Guide/Enabling_VRAM_management.md index a067f8d..22ef752 100644 --- a/docs/zh/Developer_Guide/Enabling_VRAM_management.md +++ b/docs/zh/Developer_Guide/Enabling_VRAM_management.md @@ -1,6 +1,6 @@ # 细粒度显存管理方案 -本文档介绍如何为模型编写合理的细粒度显存管理方案,以及如何将 `DiffSynth-Studio` 中的显存管理功能用于外部的其他代码库,在阅读本文档前,请先阅读文档[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)。 +本文档介绍如何为模型编写合理的细粒度显存管理方案,以及如何将 `DiffSynth-Studio` 中的显存管理功能用于外部的其他代码库,在阅读本文档前,请先阅读文档[显存管理](../Pipeline_Usage/VRAM_management.md)。 ## 20B 模型需要多少显存? @@ -124,7 +124,7 @@ module_map={ } ``` -此外,还需要提供 `vram_config` 与 `vram_limit`,这两个参数在[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md#更多使用方式)中已有介绍。 +此外,还需要提供 `vram_config` 与 `vram_limit`,这两个参数在[显存管理](../Pipeline_Usage/VRAM_management.md#更多使用方式)中已有介绍。 调用 `enable_vram_management` 即可启用显存管理,注意此时模型加载时的 `device` 为 `cpu`,与 `offload_device` 一致: @@ -171,7 +171,7 @@ with torch.no_grad(): ## Disk Offload -[Disk Offload](/docs/zh/Pipeline_Usage/VRAM_management.md#disk-offload) 是特殊的显存管理方案,需在模型加载过程中启用,而非模型加载完毕后。通常,在以上代码能够顺利运行的前提下,Disk Offload 可以直接启用: +[Disk Offload](../Pipeline_Usage/VRAM_management.md#disk-offload) 是特殊的显存管理方案,需在模型加载过程中启用,而非模型加载完毕后。通常,在以上代码能够顺利运行的前提下,Disk Offload 可以直接启用: ```python from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule @@ -212,7 +212,7 @@ with torch.no_grad(): output = model(**inputs) ``` -Disk Offload 是极为特殊的显存管理方案,只支持 `.safetensors` 格式文件,不支持 `.bin`、`.pth`、`.ckpt` 等二进制文件,不支持带 Tensor reshape 的 [state dict converter](/docs/zh/Developer_Guide/Integrating_Your_Model.md#step-2-模型文件格式转换)。 +Disk Offload 是极为特殊的显存管理方案,只支持 `.safetensors` 格式文件,不支持 `.bin`、`.pth`、`.ckpt` 等二进制文件,不支持带 Tensor reshape 的 [state dict converter](../Developer_Guide/Integrating_Your_Model.md#step-2-模型文件格式转换)。 如果出现非 Disk Offload 能正常运行但 Disk Offload 不能正常运行的情况,请在 GitHub 上给我们提 issue。 diff --git a/docs/zh/Developer_Guide/Integrating_Your_Model.md b/docs/zh/Developer_Guide/Integrating_Your_Model.md index cd58cfc..81c0975 100644 --- a/docs/zh/Developer_Guide/Integrating_Your_Model.md +++ b/docs/zh/Developer_Guide/Integrating_Your_Model.md @@ -183,4 +183,4 @@ Loaded model: { ## Step 5: 编写模型显存管理方案 -`DiffSynth-Studio` 支持复杂的显存管理,详见[启用显存管理](/docs/zh/Developer_Guide/Enabling_VRAM_management.md)。 +`DiffSynth-Studio` 支持复杂的显存管理,详见[启用显存管理](../Developer_Guide/Enabling_VRAM_management.md)。 diff --git a/docs/zh/Developer_Guide/Training_Diffusion_Models.md b/docs/zh/Developer_Guide/Training_Diffusion_Models.md index 4313fa1..38399a9 100644 --- a/docs/zh/Developer_Guide/Training_Diffusion_Models.md +++ b/docs/zh/Developer_Guide/Training_Diffusion_Models.md @@ -1,6 +1,6 @@ # 接入模型训练 -在[接入模型](/docs/zh/Developer_Guide/Integrating_Your_Model.md)并[实现 Pipeline](/docs/zh/Developer_Guide/Building_a_Pipeline.md)后,接下来接入模型训练功能。 +在[接入模型](../Developer_Guide/Integrating_Your_Model.md)并[实现 Pipeline](../Developer_Guide/Building_a_Pipeline.md)后,接下来接入模型训练功能。 ## 训推一致的 Pipeline 改造 diff --git a/docs/zh/Makefile b/docs/zh/Makefile new file mode 100644 index 0000000..41c270b --- /dev/null +++ b/docs/zh/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) \ No newline at end of file diff --git a/docs/zh/Model_Details/FLUX.md b/docs/zh/Model_Details/FLUX.md index 71576dc..77828ef 100644 --- a/docs/zh/Model_Details/FLUX.md +++ b/docs/zh/Model_Details/FLUX.md @@ -14,7 +14,7 @@ cd DiffSynth-Studio pip install -e . ``` -更多关于安装的信息,请参考[安装依赖](/docs/zh/Pipeline_Usage/Setup.md)。 +更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。 ## 快速开始 @@ -81,31 +81,31 @@ graph LR; |模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-|-| -|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py)| -|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)| -|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)| -|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)| -|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)| -|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)| -|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)| -|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)| -|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)| -|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-| -|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-| -|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](/examples/flux/model_inference/Step1X-Edit.py)|[code](/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](/examples/flux/model_training/full/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_lora/Step1X-Edit.py)| -|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](/examples/flux/model_inference/FLEX.2-preview.py)|[code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py)| -|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](/examples/flux/model_training/full/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_lora/Nexus-Gen.py)| +|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev.py)| +|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)| +|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)| +|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)| +|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)| +|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)| +|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)| +|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)| +|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)| +|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-| +|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-| +|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Step1X-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Step1X-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Step1X-Edit.py)| +|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLEX.2-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLEX.2-preview.py)| +|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Nexus-Gen.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Nexus-Gen.py)| 特殊训练脚本: -* 差分 LoRA 训练:[doc](/docs/zh/Training/Differential_LoRA.md)、[code](/examples/flux/model_training/special/differential_training/) -* FP8 精度训练:[doc](/docs/zh/Training/FP8_Precision.md)、[code](/examples/flux/model_training/special/fp8_training/) -* 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)、[code](/examples/flux/model_training/special/split_training/) -* 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)、[code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh) +* 差分 LoRA 训练:[doc](../Training/Differential_LoRA.md) +* FP8 精度训练:[doc](../Training/FP8_Precision.md) +* 两阶段拆分训练:[doc](../Training/Split_Training.md) +* 端到端直接蒸馏:[doc](../Training/Direct_Distill.md) ## 模型推理 -模型通过 `FluxImagePipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。 +模型通过 `FluxImagePipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。 `FluxImagePipeline` 推理的输入参数包括: @@ -143,11 +143,11 @@ graph LR; * `flex_control_stop`: Flex 模型的控制停止时间步。 * `nexus_gen_reference_image`: Nexus-Gen 模型的参考图像。 -如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。 +如果显存不足,请开启[显存管理](../Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。 ## 模型训练 -FLUX 系列模型统一通过 [`examples/flux/model_training/train.py`](/examples/flux/model_training/train.py) 进行训练,脚本的参数包括: +FLUX 系列模型统一通过 [`examples/flux/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/train.py) 进行训练,脚本的参数包括: * 通用训练参数 * 数据集基础配置 @@ -198,4 +198,4 @@ FLUX 系列模型统一通过 [`examples/flux/model_training/train.py`](/example modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset ``` -我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。 \ No newline at end of file +我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。 diff --git a/docs/zh/Model_Details/FLUX2.md b/docs/zh/Model_Details/FLUX2.md index 896ad9f..66725e6 100644 --- a/docs/zh/Model_Details/FLUX2.md +++ b/docs/zh/Model_Details/FLUX2.md @@ -21,7 +21,7 @@ cd DiffSynth-Studio pip install -e . ``` -更多关于安装的信息,请参考[安装依赖](/docs/zh/Pipeline_Usage/Setup.md)。 +更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。 ## 快速开始 @@ -61,22 +61,22 @@ image.save("image.jpg") |模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-| -|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| -|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)| -|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)| -|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)| -|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)| +|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| +|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)| +|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)| +|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)| +|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)| 特殊训练脚本: -* 差分 LoRA 训练:[doc](/docs/zh/Training/Differential_LoRA.md) -* FP8 精度训练:[doc](/docs/zh/Training/FP8_Precision.md) -* 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md) -* 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md) +* 差分 LoRA 训练:[doc](../Training/Differential_LoRA.md) +* FP8 精度训练:[doc](../Training/FP8_Precision.md) +* 两阶段拆分训练:[doc](../Training/Split_Training.md) +* 端到端直接蒸馏:[doc](../Training/Direct_Distill.md) ## 模型推理 -模型通过 `Flux2ImagePipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。 +模型通过 `Flux2ImagePipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。 `Flux2ImagePipeline` 推理的输入参数包括: @@ -95,11 +95,11 @@ image.save("image.jpg") * `tile_stride`: VAE 编解码阶段的分块步长,默认为 64,仅在 `tiled=True` 时生效,需保证其数值小于或等于 `tile_size`。 * `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。 -如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。 +如果显存不足,请开启[显存管理](../Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。 ## 模型训练 -FLUX.2 系列模型统一通过 [`examples/flux2/model_training/train.py`](/examples/flux2/model_training/train.py) 进行训练,脚本的参数包括: +FLUX.2 系列模型统一通过 [`examples/flux2/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/train.py) 进行训练,脚本的参数包括: * 通用训练参数 * 数据集基础配置 @@ -148,4 +148,4 @@ FLUX.2 系列模型统一通过 [`examples/flux2/model_training/train.py`](/exam modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset ``` -我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。 +我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。 diff --git a/docs/zh/Model_Details/LTX-2.md b/docs/zh/Model_Details/LTX-2.md new file mode 100644 index 0000000..6961931 --- /dev/null +++ b/docs/zh/Model_Details/LTX-2.md @@ -0,0 +1,116 @@ +# LTX-2 + +LTX-2 是由 Lightricks 开发的音视频生成模型系列。 + +## 安装 + +在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。 + +```shell +git clone https://github.com/modelscope/DiffSynth-Studio.git +cd DiffSynth-Studio +pip install -e . +``` + +更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。 + +## 快速开始 + +运行以下代码可以快速加载 [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8GB 显存即可运行。 + +```python +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.float8_e5m2, + "offload_device": "cpu", + "onload_dtype": torch.float8_e5m2, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e5m2, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”" +negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +height, width, num_frames = 512, 768, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_onestage.mp4', + fps=24, + audio_sample_rate=24000, +) +``` + +## 模型总览 +|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-|-| +|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-| +|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-| +|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-| +|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-| +|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-| +|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py)|-|-|-|-| +|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Static](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Static)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py)|-|-|-|-| + +## 模型推理 + +模型通过 `LTX2AudioVideoPipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。 + +`LTX2AudioVideoPipeline` 推理的输入参数包括: + +* `prompt`: 提示词,描述视频中出现的内容。 +* `negative_prompt`: 负向提示词,描述视频中不应该出现的内容,默认值为 `""`。 +* `cfg_scale`: Classifier-free guidance 的参数,默认值为 3.0。 +* `input_images`: 输入图像列表,用于图生视频。 +* `input_images_indexes`: 输入图像在视频中的帧索引列表。 +* `input_images_strength`: 输入图像的强度,默认值为 1.0。 +* `denoising_strength`: 去噪强度,范围是 0~1,默认值为 1.0。 +* `seed`: 随机种子。默认为 `None`,即完全随机。 +* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。 +* `height`: 视频高度,需保证高度为 32 的倍数(单阶段)或 64 的倍数(两阶段)。 +* `width`: 视频宽度,需保证宽度为 32 的倍数(单阶段)或 64 的倍数(两阶段)。 +* `num_frames`: 视频帧数,默认值为 121,需保证为 8 的倍数 + 1。 +* `num_inference_steps`: 推理次数,默认值为 40。 +* `tiled`: 是否启用 VAE 分块推理,默认为 `True`。设置为 `True` 时可显著减少 VAE 编解码阶段的显存占用,会产生少许误差,以及少量推理时间延长。 +* `tile_size_in_pixels`: VAE 编解码阶段的像素分块大小,默认为 512。 +* `tile_overlap_in_pixels`: VAE 编解码阶段的像素分块重叠大小,默认为 128。 +* `tile_size_in_frames`: VAE 编解码阶段的帧分块大小,默认为 128。 +* `tile_overlap_in_frames`: VAE 编解码阶段的帧分块重叠大小,默认为 24。 +* `use_two_stage_pipeline`: 是否使用两阶段管道,默认为 `False`。 +* `use_distilled_pipeline`: 是否使用蒸馏管道,默认为 `False`。 +* `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。 + +如果显存不足,请开启[显存管理](../Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"支持的推理脚本"中的表格。 + +## 模型训练 + +LTX-2 系列模型目前暂不支持训练功能。我们将尽快添加相关支持。 diff --git a/docs/zh/Model_Details/Overview.md b/docs/zh/Model_Details/Overview.md index 9c0e679..cdfdce9 100644 --- a/docs/zh/Model_Details/Overview.md +++ b/docs/zh/Model_Details/Overview.md @@ -2,7 +2,7 @@ ## Qwen-Image -文档:[./Qwen-Image.md](/docs/zh/Model_Details/Qwen-Image.md) +文档:[./Qwen-Image.md](../Model_Details/Qwen-Image.md)
@@ -69,23 +69,23 @@ graph LR; |模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-| -|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)| -|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)| -|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)| -|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| -|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| -|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)| -|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)| -|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)| -|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)| -|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)| -|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)| -|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)| -|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-| +|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)| +|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)| +|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)| +|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| +|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| +|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)| +|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)| +|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)| +|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)| +|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)| +|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)| +|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)| +|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-| ## FLUX 系列 -文档:[./FLUX.md](/docs/zh/Model_Details/FLUX.md) +文档:[./FLUX.md](../Model_Details/FLUX.md)
@@ -149,24 +149,24 @@ graph LR; |模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-|-| -|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py)| -|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)| -|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)| -|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)| -|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)| -|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)| -|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)| -|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)| -|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)| -|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-| -|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-| -|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](/examples/flux/model_inference/Step1X-Edit.py)|[code](/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](/examples/flux/model_training/full/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_lora/Step1X-Edit.py)| -|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](/examples/flux/model_inference/FLEX.2-preview.py)|[code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py)| -|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](/examples/flux/model_training/full/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_lora/Nexus-Gen.py)| +|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev.py)| +|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)| +|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)| +|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)| +|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)| +|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)| +|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)| +|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)| +|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)| +|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-| +|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-| +|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Step1X-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Step1X-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Step1X-Edit.py)| +|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLEX.2-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLEX.2-preview.py)| +|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Nexus-Gen.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Nexus-Gen.py)| ## Wan 系列 -文档:[./Wan.md](/docs/zh/Model_Details/Wan.md) +文档:[./Wan.md](../Model_Details/Wan.md)
@@ -255,34 +255,34 @@ graph LR; |模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-| -|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)| -|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)| -|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)| -|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)| -|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)| -|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)| -|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)| -|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)| -|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)| -|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)| -|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)| -|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)| -|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)| -|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/examples/wanmodel_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)| -|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)| -|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)| -|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)| -|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)| -|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| -|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)| -|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](/examples/wanvideo/model_inference/LongCat-Video.py)|[code](/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)| -|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)| -|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)| -|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)| -|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)| -|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)| -|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)| -|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)| -|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)| -|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)| -|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)| +|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)| +|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)| +|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)| +|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)| +|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)| +|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)| +|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)| +|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)| +|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)| +|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)| +|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)| +|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)| +|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| +|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)| +|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)| +|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)| +|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)| +|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)| +|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)| +|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)| +|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)| +|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)| +|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)| +|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)| +|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)| diff --git a/docs/zh/Model_Details/Qwen-Image.md b/docs/zh/Model_Details/Qwen-Image.md index 697438f..3415ff2 100644 --- a/docs/zh/Model_Details/Qwen-Image.md +++ b/docs/zh/Model_Details/Qwen-Image.md @@ -14,7 +14,7 @@ cd DiffSynth-Studio pip install -e . ``` -更多关于安装的信息,请参考[安装依赖](/docs/zh/Pipeline_Usage/Setup.md)。 +更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。 ## 快速开始 @@ -80,35 +80,41 @@ graph LR; |模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-| -|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)| -|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)| -|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)| -|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)| -|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)| -|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)| -|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)| -|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| -|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| -|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)| -|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)| -|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)| -|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)| -|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)| -|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)| -|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)| -|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-| -|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-| +|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)| +|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)| +|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)| +|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)| +|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)| +|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-| +|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)| +|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)| +|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| +|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| +|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)| +|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)| +|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)| +|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)| +|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)| +|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)| +|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)| +|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-| +|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-| 特殊训练脚本: -* 差分 LoRA 训练:[doc](/docs/zh/Training/Differential_LoRA.md)、[code](/examples/qwen_image/model_training/special/differential_training/) -* FP8 精度训练:[doc](/docs/zh/Training/FP8_Precision.md)、[code](/examples/qwen_image/model_training/special/fp8_training/) -* 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)、[code](/examples/qwen_image/model_training/special/split_training/) -* 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)、[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) +* 差分 LoRA 训练:[doc](../Training/Differential_LoRA.md)、[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/qwen_image/model_training/special/differential_training/) +* FP8 精度训练:[doc](../Training/FP8_Precision.md)、[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/qwen_image/model_training/special/fp8_training/) +* 两阶段拆分训练:[doc](../Training/Split_Training.md)、[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/qwen_image/model_training/special/split_training/) +* 端到端直接蒸馏:[doc](../Training/Direct_Distill.md)、[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) + +DeepSpeed ZeRO 3 训练:Qwen-Image 系列模型支持 DeepSpeed ZeRO 3 训练,将模型拆分到多个 GPU 上,以 Qwen-Image 模型的全量训练为例,需修改: + +* `--config_file examples/qwen_image/model_training/full/accelerate_config_zero3.yaml` +* `--initialize_model_on_cpu` ## 模型推理 -模型通过 `QwenImagePipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。 +模型通过 `QwenImagePipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。 `QwenImagePipeline` 推理的输入参数包括: @@ -139,11 +145,11 @@ graph LR; * `tile_stride`: VAE 编解码阶段的分块步长,默认为 64,仅在 `tiled=True` 时生效,需保证其数值小于或等于 `tile_size`。 * `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。 -如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文“模型总览”中的表格。 +如果显存不足,请开启[显存管理](../Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文“模型总览”中的表格。 ## 模型训练 -Qwen-Image 系列模型统一通过 [`examples/qwen_image/model_training/train.py`](/examples/qwen_image/model_training/train.py) 进行训练,脚本的参数包括: +Qwen-Image 系列模型统一通过 [`examples/qwen_image/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/train.py) 进行训练,脚本的参数包括: * 通用训练参数 * 数据集基础配置 @@ -193,4 +199,4 @@ Qwen-Image 系列模型统一通过 [`examples/qwen_image/model_training/train.p modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset ``` -我们为每个模型编写了推荐的训练脚本,请参考前文“模型总览”中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。 +我们为每个模型编写了推荐的训练脚本,请参考前文“模型总览”中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。 diff --git a/docs/zh/Model_Details/Wan.md b/docs/zh/Model_Details/Wan.md index b8c3032..0144bd2 100644 --- a/docs/zh/Model_Details/Wan.md +++ b/docs/zh/Model_Details/Wan.md @@ -14,7 +14,7 @@ cd DiffSynth-Studio pip install -e . ``` -更多关于安装的信息,请参考[安装依赖](/docs/zh/Pipeline_Usage/Setup.md)。 +更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。 ## 快速开始 @@ -107,45 +107,50 @@ graph LR; |模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-| -|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)| -|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)| -|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)| -|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)| -|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)| -|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)| -|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)| -|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)| -|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)| -|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)| -|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)| -|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)| -|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)| -|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)| -|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)| -|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)| -|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)| -|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)| -|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| -|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)| -|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](/examples/wanvideo/model_inference/LongCat-Video.py)|[code](/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)| -|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)| -|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)| -|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)| -|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)| -|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)| -|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)| -|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)| -|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)| -|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)| -|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)| +|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)| +|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)| +|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)| +|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)| +|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)| +|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)| +|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)| +|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)| +|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)| +|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)| +|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)| +|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)| +|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| +|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)| +|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)| +|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)| +|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)| +|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)| +|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)| +|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)| +|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)| +|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)| +|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)| +|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)| +|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)| -* FP8 精度训练:[doc](/docs/zh/Training/FP8_Precision.md)、[code](/examples/wanvideo/model_training/special/fp8_training/) -* 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)、[code](/examples/wanvideo/model_training/special/split_training/) -* 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)、[code](/examples/wanvideo/model_training/special/direct_distill/) +* FP8 精度训练:[doc](../Training/FP8_Precision.md)、[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/fp8_training/) +* 两阶段拆分训练:[doc](../Training/Split_Training.md)、[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/split_training/) +* 端到端直接蒸馏:[doc](../Training/Direct_Distill.md)、[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/direct_distill/) + +DeepSpeed ZeRO 3 训练:Wan 系列模型支持 DeepSpeed ZeRO 3 训练,将模型拆分到多个 GPU 上,以 Wan2.1-T2V-14B 模型的全量训练为例,需修改: + +* `--config_file examples/wanvideo/model_training/full/accelerate_config_zero3.yaml` +* `--initialize_model_on_cpu` ## 模型推理 -模型通过 `WanVideoPipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。 +模型通过 `WanVideoPipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。 `WanVideoPipeline` 推理的输入参数包括: @@ -195,11 +200,11 @@ graph LR; * `tea_cache_model_id`: TeaCache 使用的模型 ID。 * `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。 -如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。 +如果显存不足,请开启[显存管理](../Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。 ## 模型训练 -Wan 系列模型统一通过 [`examples/wanvideo/model_training/train.py`](/examples/wanvideo/model_training/train.py) 进行训练,脚本的参数包括: +Wan 系列模型统一通过 [`examples/wanvideo/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/train.py) 进行训练,脚本的参数包括: * 通用训练参数 * 数据集基础配置 @@ -250,4 +255,4 @@ Wan 系列模型统一通过 [`examples/wanvideo/model_training/train.py`](/exam modelscope download --dataset DiffSynth-Studio/example_video_dataset --local_dir ./data/example_video_dataset ``` -我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。 \ No newline at end of file +我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。 diff --git a/docs/zh/Model_Details/Z-Image.md b/docs/zh/Model_Details/Z-Image.md index c51083a..4e77360 100644 --- a/docs/zh/Model_Details/Z-Image.md +++ b/docs/zh/Model_Details/Z-Image.md @@ -12,7 +12,7 @@ cd DiffSynth-Studio pip install -e . ``` -更多关于安装的信息,请参考[安装依赖](/docs/zh/Pipeline_Usage/Setup.md)。 +更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。 ## 快速开始 @@ -52,16 +52,21 @@ image.save("image.jpg") |模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-| -|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)| +|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image.py)| +|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-i2L.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-| +|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)| +|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)| +|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)| +|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)| 特殊训练脚本: -* 差分 LoRA 训练:[doc](/docs/zh/Training/Differential_LoRA.md)、[code](/examples/z_image/model_training/special/differential_training/) -* 轨迹模仿蒸馏训练(实验性功能):[code](/examples/z_image/model_training/special/trajectory_imitation/) +* 差分 LoRA 训练:[doc](../Training/Differential_LoRA.md)、[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/z_image/model_training/special/differential_training/) +* 轨迹模仿蒸馏训练(实验性功能):[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/z_image/model_training/special/trajectory_imitation/) ## 模型推理 -模型通过 `ZImagePipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。 +模型通过 `ZImagePipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。 `ZImagePipeline` 推理的输入参数包括: @@ -75,12 +80,15 @@ image.save("image.jpg") * `seed`: 随机种子。默认为 `None`,即完全随机。 * `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。 * `num_inference_steps`: 推理次数,默认值为 8。 +* `controlnet_inputs`: ControlNet 模型的输入。 +* `edit_image`: 编辑模型的待编辑图像,支持多张图像。 +* `positive_only_lora`: 仅在正向提示词中使用的 LoRA 权重。 -如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。 +如果显存不足,请开启[显存管理](../Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。 ## 模型训练 -Z-Image 系列模型统一通过 [`examples/z_image/model_training/train.py`](/examples/z_image/model_training/train.py) 进行训练,脚本的参数包括: +Z-Image 系列模型统一通过 [`examples/z_image/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/train.py) 进行训练,脚本的参数包括: * 通用训练参数 * 数据集基础配置 @@ -129,13 +137,13 @@ Z-Image 系列模型统一通过 [`examples/z_image/model_training/train.py`](/e modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset ``` -我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。 +我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。 训练技巧: * [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) 是一个蒸馏加速的模型,因此直接训练将会迅速让模型失去加速能力,以“加速配置”(`num_inference_steps=8`,`cfg_scale=1`)推理的效果变差,以“无加速配置”(`num_inference_steps=30`,`cfg_scale=2`)推理的效果变好。可采用以下方案训练和推理: - * 标准 SFT 训练([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + 无加速配置推理 - * 差分 LoRA 训练([code](/examples/z_image/model_training/special/differential_training/)) + 加速配置推理 + * 标准 SFT 训练([code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + 无加速配置推理 + * 差分 LoRA 训练([code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/z_image/model_training/special/differential_training/)) + 加速配置推理 * 差分 LoRA 训练中需加载一个额外的 LoRA,例如 [ostris/zimage_turbo_training_adapter](https://www.modelscope.cn/models/ostris/zimage_turbo_training_adapter) - * 标准 SFT 训练([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh))+ 轨迹模仿蒸馏训练([code](/examples/z_image/model_training/special/trajectory_imitation/))+ 加速配置推理 - * 标准 SFT 训练([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh))+ 推理时加载蒸馏加速 LoRA([model](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-Turbo-DistillPatch)) + 加速配置推理 + * 标准 SFT 训练([code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo.sh))+ 轨迹模仿蒸馏训练([code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/z_image/model_training/special/trajectory_imitation/))+ 加速配置推理 + * 标准 SFT 训练([code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo.sh))+ 推理时加载蒸馏加速 LoRA([model](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-Turbo-DistillPatch)) + 加速配置推理 diff --git a/docs/zh/Pipeline_Usage/Environment_Variables.md b/docs/zh/Pipeline_Usage/Environment_Variables.md index 9c96fcc..a4ad45b 100644 --- a/docs/zh/Pipeline_Usage/Environment_Variables.md +++ b/docs/zh/Pipeline_Usage/Environment_Variables.md @@ -28,7 +28,7 @@ DIFFSYNTH_MODEL_BASE_PATH="./path_to_my_models" python xxx.py ## `DIFFSYNTH_ATTENTION_IMPLEMENTATION` -注意力机制实现的方式,可以设置为 `flash_attention_3`、`flash_attention_2`、`sage_attention`、`xformers`、`torch`。详见 [`./core/attention.md`](/docs/zh/API_Reference/core/attention.md). +注意力机制实现的方式,可以设置为 `flash_attention_3`、`flash_attention_2`、`sage_attention`、`xformers`、`torch`。详见 [`./core/attention.md`](../API_Reference/core/attention.md). ## `DIFFSYNTH_DISK_MAP_BUFFER_SIZE` diff --git a/docs/zh/Pipeline_Usage/GPU_support.md b/docs/zh/Pipeline_Usage/GPU_support.md index 7a66923..9760aea 100644 --- a/docs/zh/Pipeline_Usage/GPU_support.md +++ b/docs/zh/Pipeline_Usage/GPU_support.md @@ -2,7 +2,7 @@ `DiffSynth-Studio` 支持多种 GPU/NPU,本文介绍如何在这些设备上运行模型推理和训练。 -在开始前,请参考[安装依赖](/docs/zh/Pipeline_Usage/Setup.md)安装好 GPU/NPU 相关的依赖包。 +在开始前,请参考[安装依赖](../Pipeline_Usage/Setup.md)安装好 GPU/NPU 相关的依赖包。 ## NVIDIA GPU @@ -58,6 +58,13 @@ video = pipe( save_video(video, "video.mp4", fps=15, quality=5) ``` +#### USP(Unified Sequence Parallel) +如果想要在NPU上使用该特性,请通过如下方式安装额外的第三方库: +```shell +pip install git+https://github.com/feifeibear/long-context-attention.git +pip install git+https://github.com/xdit-project/xDiT.git +``` + ### 训练 当前已为每类模型添加NPU的启动脚本样例,脚本存放在`examples/xxx/special/npu_training`目录下,例如 `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`。 @@ -82,4 +89,5 @@ export CPU_AFFINITY_CONF=1 | 模型 | 参数 | 备注 | |-----------|------|-------------------| | Wan 14B系列 | --initialize_model_on_cpu | 14B模型需要在cpu上进行初始化 | +| Qwen-Image系列 | --initialize_model_on_cpu | 模型需要在cpu上进行初始化 | | Z-Image 系列 | --enable_npu_patch | 使用NPU融合算子来替换Z-image模型中的对应算子以提升模型在NPU上的性能 | \ No newline at end of file diff --git a/docs/zh/Pipeline_Usage/Model_Inference.md b/docs/zh/Pipeline_Usage/Model_Inference.md index 75a1ed8..f66f629 100644 --- a/docs/zh/Pipeline_Usage/Model_Inference.md +++ b/docs/zh/Pipeline_Usage/Model_Inference.md @@ -22,7 +22,7 @@ pipe = QwenImagePipeline.from_pretrained( ) ``` -其中 `torch_dtype` 和 `device` 是计算精度和计算设备(不是模型的精度和设备)。`model_configs` 可通过多种方式配置模型路径,关于本项目内部是如何加载模型的,请参考 [`diffsynth.core.loader`](/docs/zh/API_Reference/core/loader.md)。 +其中 `torch_dtype` 和 `device` 是计算精度和计算设备(不是模型的精度和设备)。`model_configs` 可通过多种方式配置模型路径,关于本项目内部是如何加载模型的,请参考 [`diffsynth.core.loader`](../API_Reference/core/loader.md)。
@@ -34,7 +34,7 @@ pipe = QwenImagePipeline.from_pretrained( > ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), > ``` > -> 模型文件默认下载到 `./models` 路径,该路径可通过[环境变量 DIFFSYNTH_MODEL_BASE_PATH](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path) 修改。 +> 模型文件默认下载到 `./models` 路径,该路径可通过[环境变量 DIFFSYNTH_MODEL_BASE_PATH](../Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path) 修改。
@@ -61,7 +61,7 @@ pipe = QwenImagePipeline.from_pretrained(
-默认情况下,即使模型已经下载完毕,程序仍会向远程查询是否有遗漏文件,如果要完全关闭远程请求,请将[环境变量 DIFFSYNTH_SKIP_DOWNLOAD](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) 设置为 `True`。 +默认情况下,即使模型已经下载完毕,程序仍会向远程查询是否有遗漏文件,如果要完全关闭远程请求,请将[环境变量 DIFFSYNTH_SKIP_DOWNLOAD](../Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) 设置为 `True`。 ```shell import os @@ -69,7 +69,7 @@ os.environ["DIFFSYNTH_SKIP_DOWNLOAD"] = "True" import diffsynth ``` -如需从 [HuggingFace](https://huggingface.co/) 下载模型,请将[环境变量 DIFFSYNTH_DOWNLOAD_SOURCE](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_download_source) 设置为 `huggingface`。 +如需从 [HuggingFace](https://huggingface.co/) 下载模型,请将[环境变量 DIFFSYNTH_DOWNLOAD_SOURCE](../Pipeline_Usage/Environment_Variables.md#diffsynth_download_source) 设置为 `huggingface`。 ```shell import os @@ -102,4 +102,65 @@ image.save("image.jpg") 每个模型 `Pipeline` 的输入参数不同,请参考各模型的文档。 -如果模型参数量太大,导致显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)。 +如果模型参数量太大,导致显存不足,请开启[显存管理](../Pipeline_Usage/VRAM_management.md)。 + +## 加载 LoRA + +LoRA 是一种轻量化的模型训练方式,产生少量参数,扩展模型的能力。DiffSynth-Studio 的 LoRA 加载有两种方式:冷加载和热加载。 + +* 冷加载:当基础模型未开启[显存管理](../Pipeline_Usage/VRAM_management.md)时,LoRA 会融合进基础模型权重,此时推理速度没有变化,LoRA 加载后无法卸载。 + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +lora = ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1", origin_file_pattern="model.safetensors") +pipe.load_lora(pipe.dit, lora, alpha=1) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +* 热加载:当基础模型开启[显存管理](../Pipeline_Usage/VRAM_management.md)时,LoRA 不会融合进基础模型权重,此时推理速度会变慢,LoRA 加载后可通过 `pipe.clear_lora()` 卸载。 + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cuda", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +lora = ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1", origin_file_pattern="model.safetensors") +pipe.load_lora(pipe.dit, lora, alpha=1) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +pipe.clear_lora() +``` diff --git a/docs/zh/Pipeline_Usage/Model_Training.md b/docs/zh/Pipeline_Usage/Model_Training.md index 208246e..c863f28 100644 --- a/docs/zh/Pipeline_Usage/Model_Training.md +++ b/docs/zh/Pipeline_Usage/Model_Training.md @@ -65,7 +65,7 @@ image_1.jpg,"a dog" image_2.jpg,"a cat" ``` -我们构建了样例数据集,以方便您进行测试。了解通用数据集架构是如何实现的,请参考 [`diffsynth.core.data`](/docs/zh/API_Reference/core/data.md)。 +我们构建了样例数据集,以方便您进行测试。了解通用数据集架构是如何实现的,请参考 [`diffsynth.core.data`](../API_Reference/core/data.md)。
@@ -93,7 +93,7 @@ image_2.jpg,"a cat" ## 加载模型 -类似于[推理时的模型加载](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型),我们支持多种方式配置模型路径,两种方式是可以混用的。 +类似于[推理时的模型加载](../Pipeline_Usage/Model_Inference.md#加载模型),我们支持多种方式配置模型路径,两种方式是可以混用的。
@@ -115,9 +115,9 @@ image_2.jpg,"a cat" > --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" > ``` > -> 模型文件默认下载到 `./models` 路径,该路径可通过[环境变量 DIFFSYNTH_MODEL_BASE_PATH](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path) 修改。 +> 模型文件默认下载到 `./models` 路径,该路径可通过[环境变量 DIFFSYNTH_MODEL_BASE_PATH](../Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path) 修改。 > -> 默认情况下,即使模型已经下载完毕,程序仍会向远程查询是否有遗漏文件,如果要完全关闭远程请求,请将[环境变量 DIFFSYNTH_SKIP_DOWNLOAD](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) 设置为 `True`。 +> 默认情况下,即使模型已经下载完毕,程序仍会向远程查询是否有遗漏文件,如果要完全关闭远程请求,请将[环境变量 DIFFSYNTH_SKIP_DOWNLOAD](../Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) 设置为 `True`。
@@ -235,11 +235,11 @@ accelerate launch --config_file examples/qwen_image/model_training/full/accelera ## 训练注意事项 -* 数据集的元数据除 `csv` 格式外,还支持 `json`、`jsonl` 格式,关于如何选择最佳的元数据格式,请参考[](/docs/zh/API_Reference/core/data.md#元数据) +* 数据集的元数据除 `csv` 格式外,还支持 `json`、`jsonl` 格式,关于如何选择最佳的元数据格式,请参考[](../API_Reference/core/data.md#元数据) * 通常训练效果与训练步数强相关,与 epoch 数量弱相关,因此我们更推荐使用参数 `--save_steps` 按训练步数间隔来保存模型文件。 * 当数据量 * `dataset_repeat` 超过 $10^9$ 时,我们观测到数据集的速度明显变慢,这似乎是 `PyTorch` 的 bug,我们尚不确定新版本的 `PyTorch` 是否已经修复了这一问题。 * 学习率 `--learning_rate` 在 LoRA 训练中建议设置为 `1e-4`,在全量训练中建议设置为 `1e-5`。 -* 训练框架不支持 batch size > 1,原因是复杂的,详见 [Q&A: 为什么训练框架不支持 batch size > 1?](/docs/zh/QA.md#为什么训练框架不支持-batch-size--1) +* 训练框架不支持 batch size > 1,原因是复杂的,详见 [Q&A: 为什么训练框架不支持 batch size > 1?](../QA.md#为什么训练框架不支持-batch-size--1) * 少数模型包含冗余参数,例如 Qwen-Image 的 DiT 部分最后一层的文本编码部分,在训练这些模型时,需设置 `--find_unused_parameters` 避免在多 GPU 训练中报错。出于对开源社区模型兼容性的考虑,我们不打算删除这些冗余参数。 * Diffusion 模型的损失函数值与实际效果的关系不大,因此我们在训练过程中不会记录损失函数值。我们建议把 `--num_epochs` 设置为足够大的数值,边训边测,直至效果收敛后手动关闭训练程序。 -* `--use_gradient_checkpointing` 通常是开启的,除非 GPU 显存足够;`--use_gradient_checkpointing_offload` 则按需开启,详见 [`diffsynth.core.gradient`](/docs/zh/API_Reference/core/gradient.md)。 +* `--use_gradient_checkpointing` 通常是开启的,除非 GPU 显存足够;`--use_gradient_checkpointing_offload` 则按需开启,详见 [`diffsynth.core.gradient`](../API_Reference/core/gradient.md)。 diff --git a/docs/zh/Pipeline_Usage/Setup.md b/docs/zh/Pipeline_Usage/Setup.md index 9823593..13a2cae 100644 --- a/docs/zh/Pipeline_Usage/Setup.md +++ b/docs/zh/Pipeline_Usage/Setup.md @@ -41,7 +41,7 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm6 # x86 pip install -e .[npu] --extra-index-url "https://download.pytorch.org/whl/cpu" -使用 Ascend NPU 时,请将 Python 代码中的 `"cuda"` 改为 `"npu"`,详见[NPU 支持](/docs/zh/Pipeline_Usage/GPU_support.md#ascend-npu)。 +使用 Ascend NPU 时,请将 Python 代码中的 `"cuda"` 改为 `"npu"`,详见[NPU 支持](../Pipeline_Usage/GPU_support.md#ascend-npu)。 ## 其他安装问题 diff --git a/docs/zh/Pipeline_Usage/VRAM_management.md b/docs/zh/Pipeline_Usage/VRAM_management.md index 2235c12..cc7ddf9 100644 --- a/docs/zh/Pipeline_Usage/VRAM_management.md +++ b/docs/zh/Pipeline_Usage/VRAM_management.md @@ -140,7 +140,7 @@ image.save("image.jpg") 在更为极端的情况下,当内存也不足以存储整个模型时,Disk Offload 功能可以让模型参数惰性加载,即,模型中的每个 Layer 仅在调用 forward 时才会从硬盘中读取相应的参数。启用这一功能时,我们建议使用高速的 SSD 硬盘。 -Disk Offload 是极为特殊的显存管理方案,只支持 `.safetensors` 格式文件,不支持 `.bin`、`.pth`、`.ckpt` 等二进制文件,不支持带 Tensor reshape 的 [state dict converter](/docs/zh/Developer_Guide/Integrating_Your_Model.md#step-2-模型文件格式转换)。 +Disk Offload 是极为特殊的显存管理方案,只支持 `.safetensors` 格式文件,不支持 `.bin`、`.pth`、`.ckpt` 等二进制文件,不支持带 Tensor reshape 的 [state dict converter](../Developer_Guide/Integrating_Your_Model.md#step-2-模型文件格式转换)。 ```python from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig @@ -196,7 +196,7 @@ vram_config = { * Preparing:Onload 和 Computation 的中间状态,在显存允许的前提下的暂存状态,这个状态由显存管理机制控制切换,当且仅当【vram_limit 设置为无限制】或【vram_limit 已设置且有空余显存】时会进入这一状态 * Computation:模型正在计算过程中,这个状态由显存管理机制控制切换,仅在 `forward` 中临时进入 -如果你是模型开发者,希望自行控制某个模型的显存管理粒度,请参考[../Developer_Guide/Enabling_VRAM_management.md](/docs/zh/Developer_Guide/Enabling_VRAM_management.md)。 +如果你是模型开发者,希望自行控制某个模型的显存管理粒度,请参考[../Developer_Guide/Enabling_VRAM_management.md](../Developer_Guide/Enabling_VRAM_management.md)。 ## 最佳实践 diff --git a/docs/zh/QA.md b/docs/zh/QA.md index b1d55df..a072b6a 100644 --- a/docs/zh/QA.md +++ b/docs/zh/QA.md @@ -26,3 +26,10 @@ * 此外,使用原生 FP8 精度训练的模型,在推理时若没有 Hopper 架构 GPU,则只能以 BF16 精度进行计算,理论上其生成效果反而不如 FP8。 因此,原生 FP8 精度训练技术是极不成熟的,我们静观开源社区的技术发展。 + +## 如何在推理时动态加载 LoRA 模型? + +我们支持 LoRA 模型的两种加载方式,详见[LoRA 加载](./Pipeline_Usage/Model_Inference.md#加载-lora): + +* 冷加载:当基础模型未开启[显存管理](./Pipeline_Usage/VRAM_management.md)时,LoRA 会融合进基础模型权重,此时推理速度没有变化,LoRA 加载后无法卸载。 +* 热加载:当基础模型开启[显存管理](./Pipeline_Usage/VRAM_management.md)时,LoRA 不会融合进基础模型权重,此时推理速度会变慢,LoRA 加载后可通过 `pipe.clear_lora()` 卸载。 diff --git a/docs/zh/README.md b/docs/zh/README.md index edcef50..825415e 100644 --- a/docs/zh/README.md +++ b/docs/zh/README.md @@ -26,58 +26,58 @@ graph LR; 本节介绍 `DiffSynth-Studio` 的基本使用方式,包括如何启用显存管理从而在极低显存的 GPU 上进行推理,以及如何训练任意基础模型、LoRA、ControlNet 等模型。 -* [安装依赖](/docs/zh/Pipeline_Usage/Setup.md) -* [模型推理](/docs/zh/Pipeline_Usage/Model_Inference.md) -* [显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md) -* [模型训练](/docs/zh/Pipeline_Usage/Model_Training.md) -* [环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md) -* [GPU/NPU 支持](/docs/zh/Pipeline_Usage/GPU_support.md) +* [安装依赖](./Pipeline_Usage/Setup.md) +* [模型推理](./Pipeline_Usage/Model_Inference.md) +* [显存管理](./Pipeline_Usage/VRAM_management.md) +* [模型训练](./Pipeline_Usage/Model_Training.md) +* [环境变量](./Pipeline_Usage/Environment_Variables.md) +* [GPU/NPU 支持](./Pipeline_Usage/GPU_support.md) ## Section 2: 模型详解 本节介绍 `DiffSynth-Studio` 所支持的 Diffusion 模型,部分模型 Pipeline 具备可控生成、并行加速等特色功能。 -* [FLUX.1](/docs/zh/Model_Details/FLUX.md) -* [Wan](/docs/zh/Model_Details/Wan.md) -* [Qwen-Image](/docs/zh/Model_Details/Qwen-Image.md) -* [FLUX.2](/docs/zh/Model_Details/FLUX2.md) -* [Z-Image](/docs/zh/Model_Details/Z-Image.md) +* [FLUX.1](./Model_Details/FLUX.md) +* [Wan](./Model_Details/Wan.md) +* [Qwen-Image](./Model_Details/Qwen-Image.md) +* [FLUX.2](./Model_Details/FLUX2.md) +* [Z-Image](./Model_Details/Z-Image.md) ## Section 3: 训练框架 本节介绍 `DiffSynth-Studio` 中训练框架的设计思路,帮助开发者理解 Diffusion 模型训练算法的原理。 -* [Diffusion 模型基本原理](/docs/zh/Training/Understanding_Diffusion_models.md) -* [标准监督训练](/docs/zh/Training/Supervised_Fine_Tuning.md) -* [在训练中启用 FP8 精度](/docs/zh/Training/FP8_Precision.md) -* [端到端的蒸馏加速训练](/docs/zh/Training/Direct_Distill.md) -* [两阶段拆分训练](/docs/zh/Training/Split_Training.md) -* [差分 LoRA 训练](/docs/zh/Training/Differential_LoRA.md) +* [Diffusion 模型基本原理](./Training/Understanding_Diffusion_models.md) +* [标准监督训练](./Training/Supervised_Fine_Tuning.md) +* [在训练中启用 FP8 精度](./Training/FP8_Precision.md) +* [端到端的蒸馏加速训练](./Training/Direct_Distill.md) +* [两阶段拆分训练](./Training/Split_Training.md) +* [差分 LoRA 训练](./Training/Differential_LoRA.md) ## Section 4: 模型接入 本节介绍如何将模型接入 `DiffSynth-Studio` 从而使用框架基础功能,帮助开发者为本项目提供新模型的支持,或进行私有化模型的推理和训练。 -* [接入模型结构](/docs/zh/Developer_Guide/Integrating_Your_Model.md) -* [接入 Pipeline](/docs/zh/Developer_Guide/Building_a_Pipeline.md) -* [接入细粒度显存管理](/docs/zh/Developer_Guide/Enabling_VRAM_management.md) -* [接入模型训练](/docs/zh/Developer_Guide/Training_Diffusion_Models.md) +* [接入模型结构](./Developer_Guide/Integrating_Your_Model.md) +* [接入 Pipeline](./Developer_Guide/Building_a_Pipeline.md) +* [接入细粒度显存管理](./Developer_Guide/Enabling_VRAM_management.md) +* [接入模型训练](./Developer_Guide/Training_Diffusion_Models.md) ## Section 5: API 参考 本节介绍 `DiffSynth-Studio` 中的独立核心模块 `diffsynth.core`,介绍内部的功能是如何设计和运作的,开发者如有需要,可将其中的功能模块用于其他代码库的开发中。 -* [`diffsynth.core.attention`](/docs/zh/API_Reference/core/attention.md): 注意力机制实现 -* [`diffsynth.core.data`](/docs/zh/API_Reference/core/data.md): 数据处理算子与通用数据集 -* [`diffsynth.core.gradient`](/docs/zh/API_Reference/core/gradient.md): 梯度检查点 -* [`diffsynth.core.loader`](/docs/zh/API_Reference/core/loader.md): 模型下载与加载 -* [`diffsynth.core.vram`](/docs/zh/API_Reference/core/vram.md): 显存管理 +* [`diffsynth.core.attention`](./API_Reference/core/attention.md): 注意力机制实现 +* [`diffsynth.core.data`](./API_Reference/core/data.md): 数据处理算子与通用数据集 +* [`diffsynth.core.gradient`](./API_Reference/core/gradient.md): 梯度检查点 +* [`diffsynth.core.loader`](./API_Reference/core/loader.md): 模型下载与加载 +* [`diffsynth.core.vram`](./API_Reference/core/vram.md): 显存管理 ## Section 6: 学术导引 本节介绍如何利用 `DiffSynth-Studio` 训练新的模型,帮助科研工作者探索新的模型技术。 -* 从零开始训练模型【coming soon】 +* [从零开始训练模型](./Research_Tutorial/train_from_scratch.md) * 推理改进优化技术【coming soon】 * 设计可控生成模型【coming soon】 * 创建新的训练范式【coming soon】 @@ -86,4 +86,4 @@ graph LR; 本节总结了开发者常见的问题,如果你在使用和开发中遇到了问题,请参考本节内容,如果仍无法解决,请到 GitHub 上给我们提 issue。 -* [常见问题](/docs/zh/QA.md) +* [常见问题](./QA.md) diff --git a/docs/zh/Research_Tutorial/train_from_scratch.md b/docs/zh/Research_Tutorial/train_from_scratch.md new file mode 100644 index 0000000..c89dce4 --- /dev/null +++ b/docs/zh/Research_Tutorial/train_from_scratch.md @@ -0,0 +1,477 @@ +# 从零开始训练模型 + +DiffSynth-Studio 的训练引擎支持从零开始训练基础模型,本文介绍如何从零开始训练一个参数量仅为 0.1B 的小型文生图模型。 + +## 1. 构建模型结构 + +### 1.1 Diffusion 模型 + +从 UNet [[1]](https://arxiv.org/abs/1505.04597) [[2]](https://arxiv.org/abs/2112.10752) 到 DiT [[3]](https://arxiv.org/abs/2212.09748) [[4]](https://arxiv.org/abs/2403.03206),Diffusion 的主流模型结构经历了多次演变。通常,一个 Diffusion 模型的输入包括: + +* 图像张量(`latents`):图像的编码,由 VAE 模型产生,含有部分噪声 +* 文本张量(`prompt_embeds`):文本的编码,由文本编码器产生 +* 时间步(`timestep`):标量,用于标记当前处于 Diffusion 过程的哪个阶段 + +模型的输出是与图像张量形状相同的张量,表示模型预测的去噪方向,关于 Diffusion 模型理论的细节,请参考 [Diffusion 模型基本原理](../Training/Understanding_Diffusion_models.md)。在本文中,我们构建一个仅含 0.1B 参数的 DiT 模型:`AAADiT`。 + +
+模型结构代码 + +```python +import torch, accelerate +from PIL import Image +from typing import Union +from tqdm import tqdm +from einops import rearrange, repeat + +from transformers import AutoProcessor, AutoTokenizer +from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model +from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task +from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit +from diffsynth.models.general_modules import TimestepEmbeddings +from diffsynth.models.z_image_text_encoder import ZImageTextEncoder +from diffsynth.models.flux2_vae import Flux2VAE + + +class AAAPositionalEmbedding(torch.nn.Module): + def __init__(self, height=16, width=16, dim=1024): + super().__init__() + self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width))) + self.text_emb = torch.nn.Parameter(torch.randn((dim,))) + + def forward(self, image, text): + height, width = image.shape[-2:] + image_emb = self.image_emb.to(device=image.device, dtype=image.dtype) + image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear") + image_emb = rearrange(image_emb, "B C H W -> B (H W) C") + text_emb = self.text_emb.to(device=text.device, dtype=text.dtype) + text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1]) + emb = torch.concat([image_emb, text_emb], dim=1) + return emb + + +class AAABlock(torch.nn.Module): + def __init__(self, dim=1024, num_heads=32): + super().__init__() + self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False) + self.to_q = torch.nn.Linear(dim, dim) + self.to_k = torch.nn.Linear(dim, dim) + self.to_v = torch.nn.Linear(dim, dim) + self.to_out = torch.nn.Linear(dim, dim) + self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False) + self.ff = torch.nn.Sequential( + torch.nn.Linear(dim, dim*3), + torch.nn.SiLU(), + torch.nn.Linear(dim*3, dim), + ) + self.to_gate = torch.nn.Linear(dim, dim * 2) + self.num_heads = num_heads + + def attention(self, emb, pos_emb): + emb = self.norm_attn(emb + pos_emb) + q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb) + emb = attention_forward( + q, k, v, + q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)", + dims={"n": self.num_heads}, + ) + emb = self.to_out(emb) + return emb + + def feed_forward(self, emb, pos_emb): + emb = self.norm_mlp(emb + pos_emb) + emb = self.ff(emb) + return emb + + def forward(self, emb, pos_emb, t_emb): + gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1) + emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn) + emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp) + return emb + + +class AAADiT(torch.nn.Module): + def __init__(self, dim=1024): + super().__init__() + self.pos_embedder = AAAPositionalEmbedding(dim=dim) + self.timestep_embedder = TimestepEmbeddings(256, dim) + self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim)) + self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim)) + self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)]) + self.proj_out = torch.nn.Linear(dim, 128) + + def forward( + self, + latents, + prompt_embeds, + timestep, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + pos_emb = self.pos_embedder(latents, prompt_embeds) + t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1) + image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C")) + text = self.text_embedder(prompt_embeds) + emb = torch.concat([image, text], dim=1) + for block_id, block in enumerate(self.blocks): + emb = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + emb=emb, + pos_emb=pos_emb, + t_emb=t_emb, + ) + emb = emb[:, :latents.shape[-1] * latents.shape[-2]] + emb = self.proj_out(emb) + emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1]) + return emb +``` + +
+ +### 1.2 编解码器模型 + +除了用于去噪的 Diffusion 模型以外,我们还需要另外两个模型: + +* 文本编码器:用于将文本编码为张量。我们采用 [Qwen/Qwen3-0.6B](https://modelscope.cn/models/Qwen/Qwen3-0.6B) 模型。 +* VAE 编解码器:编码器部分用于将图像编码为张量,解码器部分用于将图像张量解码为图像。我们采用 [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 中的 VAE 模型。 + +这两个模型的结构都已集成在 DiffSynth-Studio 中,分别位于 [/diffsynth/models/z_image_text_encoder.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/models/z_image_text_encoder.py) 和 [/diffsynth/models/flux2_vae.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/models/flux2_vae.py),因此我们不需要修改任何代码。 + +## 2. 构建 Pipeline + +我们在文档 [接入 Pipeline](../Developer_Guide/Building_a_Pipeline.md) 中介绍了如何构建一个模型 Pipeline,对于本文中的模型,我们也需要构建一个 Pipeline,连接文本编码器、Diffusion 模型、VAE 编解码器。 + +
+Pipeline 代码 + +```python +class AAAImagePipeline(BasePipeline): + def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + self.scheduler = FlowMatchScheduler("FLUX.2") + self.text_encoder: ZImageTextEncoder = None + self.dit: AAADiT = None + self.vae: Flux2VAE = None + self.tokenizer: AutoProcessor = None + self.in_iteration_models = ("dit",) + self.units = [ + AAAUnit_PromptEmbedder(), + AAAUnit_NoiseInitializer(), + AAAUnit_InputImageEmbedder(), + ] + self.model_fn = model_fn_aaa + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = "cuda", + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = None, + vram_limit: float = None, + ): + # Initialize pipeline + pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder") + pipe.dit = model_pool.fetch_model("aaa_dit") + pipe.vae = model_pool.fetch_model("flux2_vae") + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 1.0, + # Image + input_image: Image.Image = None, + denoising_strength: float = 1.0, + # Shape + height: int = 1024, + width: int = 1024, + # Randomness + seed: int = None, + rand_device: str = "cpu", + # Steps + num_inference_steps: int = 30, + # Progress bar + progress_bar_cmd = tqdm, + ): + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16) + + # Parameters + inputs_posi = {"prompt": prompt} + inputs_nega = {"negative_prompt": negative_prompt} + inputs_shared = { + "cfg_scale": cfg_scale, + "input_image": input_image, "denoising_strength": denoising_strength, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "num_inference_steps": num_inference_steps, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae']) + image = self.vae.decode(inputs_shared["latents"]) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image + + +class AAAUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("prompt_embeds",), + onload_model_names=("text_encoder",) + ) + self.hidden_states_layers = (-1,) + + def process(self, pipe: AAAImagePipeline, prompt): + pipe.load_models_to_device(self.onload_model_names) + text = pipe.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device) + output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False) + prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1) + return {"prompt_embeds": prompt_embeds} + + +class AAAUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "seed", "rand_device"), + output_params=("noise",), + ) + + def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device): + noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + return {"noise": noise} + + +class AAAUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: AAAImagePipeline, input_image, noise): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae']) + image = pipe.preprocess_image(input_image) + input_latents = pipe.vae.encode(image) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": input_latents} + + +def model_fn_aaa( + dit: AAADiT, + latents=None, + prompt_embeds=None, + timestep=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + model_output = dit( + latents, + prompt_embeds, + timestep, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + return model_output +``` + +
+ +## 3. 准备数据集 + +为了快速验证训练效果,我们使用数据集 [宝可梦-第一世代](https://modelscope.cn/datasets/DiffSynth-Studio/pokemon-gen1),这个数据集转载自开源项目 [pokemon-dataset-zh](https://github.com/42arch/pokemon-dataset-zh),包含从妙蛙种子到梦幻的 151 个第一世代宝可梦。如果你想使用其他数据集,请参考文档 [准备数据集](../Pipeline_Usage/Model_Training.md#准备数据集) 和 [`diffsynth.core.data`](../API_Reference/core/data.md)。 + +```shell +modelscope download --dataset DiffSynth-Studio/pokemon-gen1 --local_dir ./data +``` + +### 4. 开始训练 + +训练过程可使用 Pipeline 快速实现,我们已将完整的代码放在 [../Research_Tutorial/train_from_scratch.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/docs/zh/Research_Tutorial/train_from_scratch.py),可直接通过 `python docs/zh/Research_Tutorial/train_from_scratch.py` 开始单 GPU 训练。 + +如需开启多 GPU 并行训练,请运行 `accelerate config` 设置相关参数,然后使用命令 `accelerate launch docs/zh/Research_Tutorial/train_from_scratch.py` 开始训练。 + +这个训练脚本没有设置停止条件,请在需要时手动关闭。模型在训练大约 6 万步后收敛,单 GPU 训练需要 10~20 小时。 + + +
+训练代码 + +```python +class AAATrainingModule(DiffusionTrainingModule): + def __init__(self, device): + super().__init__() + self.pipe = AAAImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device=device, + model_configs=[ + ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"), + ) + self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device) + self.pipe.freeze_except(["dit"]) + self.pipe.scheduler.set_timesteps(1000, training=True) + + def forward(self, data): + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {"negative_prompt": ""} + inputs_shared = { + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], + "cfg_scale": 1, + "use_gradient_checkpointing": False, + "use_gradient_checkpointing_offload": False, + } + for unit in self.pipe.units: + inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) + loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi) + return loss + + +if __name__ == "__main__": + accelerator = accelerate.Accelerator(gradient_accumulation_steps=1) + dataset = UnifiedDataset( + base_path="data/images", + metadata_path="data/metadata_merged.csv", + max_data_items=10000000, + data_file_keys=("image",), + main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256) + ) + model = AAATrainingModule(device=accelerator.device) + model_logger = ModelLogger( + "models/AAA/v1", + remove_prefix_in_ckpt="pipe.dit.", + ) + launch_training_task( + accelerator, dataset, model, model_logger, + learning_rate=2e-4, + num_workers=4, + save_steps=50000, + num_epochs=999999, + ) +``` + +
+ +## 5. 验证训练效果 + +如果你不想等待模型训练完成,可以直接下载[我们预先训练好的模型](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel)。 + +```shell +modelscope download --model DiffSynth-Studio/AAAMyModel step-600000.safetensors --local_dir models/DiffSynth-Studio/AAAMyModel +``` + +加载模型 + +```python +from diffsynth import load_model + +pipe = AAAImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"), +) +pipe.dit = load_model(AAADiT, "models/DiffSynth-Studio/AAAMyModel/step-600000.safetensors", torch_dtype=torch.bfloat16, device="cuda") +``` + +模型推理,生成第一世代宝可梦“御三家”,此时模型生成的图像内容与训练数据基本一致。 + +```python +for seed, prompt in enumerate([ + "green, lizard, plant, Grass, Poison, seed on back, red eyes, smiling expression, short stout limbs, sharp claws", + "orange, cream, lizard, Fire, flame on tail tip, large eyes, smiling expression, cream-colored belly patch, sharp claws", + "蓝色,米色,棕色,乌龟,水系,龟壳,大眼睛,短四肢,卷曲尾巴", +]): + image = pipe( + prompt=prompt, + negative_prompt=" ", + num_inference_steps=30, + cfg_scale=10, + seed=seed, + height=256, width=256, + ) + image.save(f"image_{seed}.jpg") +``` + +|![Image](https://github.com/user-attachments/assets/3c620fbf-5d28-4a1a-b887-519d85ac7d1c)|![Image](https://github.com/user-attachments/assets/909efd4c-9e61-4b33-9321-39da0e499b00)|![Image](https://github.com/user-attachments/assets/f3474bcd-b474-4a90-a1ea-579f67e161e3)| +|-|-|-| + +模型推理,生成具有“锐利爪子”的宝可梦,此时不同的随机种子能够产生不同的图像结果。 + +```python +for seed, prompt in enumerate([ + "sharp claws", + "sharp claws", + "sharp claws", +]): + image = pipe( + prompt=prompt, + negative_prompt=" ", + num_inference_steps=30, + cfg_scale=10, + seed=seed+4, + height=256, width=256, + ) + image.save(f"image_sharp_claws_{seed}.jpg") +``` + +|![Image](https://github.com/user-attachments/assets/94862edd-96ae-4276-a38f-795249f11a13)|![Image](https://github.com/user-attachments/assets/b2291f23-20ba-42de-8bfd-76cb4afc6eea)|![Image](https://github.com/user-attachments/assets/f2aab9a4-85ec-498e-8039-648b1289796e)| +|-|-|-| + +现在,我们获得了一个 0.1B 的小型文生图模型,这个模型已经能够生成 151 个宝可梦,但无法生成其他图像内容。如果在此基础上增加数据量、模型参数量、GPU 数量,你就可以训练出一个更强大的文生图模型! diff --git a/docs/zh/Research_Tutorial/train_from_scratch.py b/docs/zh/Research_Tutorial/train_from_scratch.py new file mode 100644 index 0000000..328c24d --- /dev/null +++ b/docs/zh/Research_Tutorial/train_from_scratch.py @@ -0,0 +1,341 @@ +import torch, accelerate +from PIL import Image +from typing import Union +from tqdm import tqdm +from einops import rearrange, repeat + +from transformers import AutoProcessor, AutoTokenizer +from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model +from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task +from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit +from diffsynth.models.general_modules import TimestepEmbeddings +from diffsynth.models.z_image_text_encoder import ZImageTextEncoder +from diffsynth.models.flux2_vae import Flux2VAE + + +class AAAPositionalEmbedding(torch.nn.Module): + def __init__(self, height=16, width=16, dim=1024): + super().__init__() + self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width))) + self.text_emb = torch.nn.Parameter(torch.randn((dim,))) + + def forward(self, image, text): + height, width = image.shape[-2:] + image_emb = self.image_emb.to(device=image.device, dtype=image.dtype) + image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear") + image_emb = rearrange(image_emb, "B C H W -> B (H W) C") + text_emb = self.text_emb.to(device=text.device, dtype=text.dtype) + text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1]) + emb = torch.concat([image_emb, text_emb], dim=1) + return emb + + +class AAABlock(torch.nn.Module): + def __init__(self, dim=1024, num_heads=32): + super().__init__() + self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False) + self.to_q = torch.nn.Linear(dim, dim) + self.to_k = torch.nn.Linear(dim, dim) + self.to_v = torch.nn.Linear(dim, dim) + self.to_out = torch.nn.Linear(dim, dim) + self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False) + self.ff = torch.nn.Sequential( + torch.nn.Linear(dim, dim*3), + torch.nn.SiLU(), + torch.nn.Linear(dim*3, dim), + ) + self.to_gate = torch.nn.Linear(dim, dim * 2) + self.num_heads = num_heads + + def attention(self, emb, pos_emb): + emb = self.norm_attn(emb + pos_emb) + q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb) + emb = attention_forward( + q, k, v, + q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)", + dims={"n": self.num_heads}, + ) + emb = self.to_out(emb) + return emb + + def feed_forward(self, emb, pos_emb): + emb = self.norm_mlp(emb + pos_emb) + emb = self.ff(emb) + return emb + + def forward(self, emb, pos_emb, t_emb): + gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1) + emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn) + emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp) + return emb + + +class AAADiT(torch.nn.Module): + def __init__(self, dim=1024): + super().__init__() + self.pos_embedder = AAAPositionalEmbedding(dim=dim) + self.timestep_embedder = TimestepEmbeddings(256, dim) + self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim)) + self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim)) + self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)]) + self.proj_out = torch.nn.Linear(dim, 128) + + def forward( + self, + latents, + prompt_embeds, + timestep, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + pos_emb = self.pos_embedder(latents, prompt_embeds) + t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1) + image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C")) + text = self.text_embedder(prompt_embeds) + emb = torch.concat([image, text], dim=1) + for block_id, block in enumerate(self.blocks): + emb = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + emb=emb, + pos_emb=pos_emb, + t_emb=t_emb, + ) + emb = emb[:, :latents.shape[-1] * latents.shape[-2]] + emb = self.proj_out(emb) + emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1]) + return emb + + +class AAAImagePipeline(BasePipeline): + def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + self.scheduler = FlowMatchScheduler("FLUX.2") + self.text_encoder: ZImageTextEncoder = None + self.dit: AAADiT = None + self.vae: Flux2VAE = None + self.tokenizer: AutoProcessor = None + self.in_iteration_models = ("dit",) + self.units = [ + AAAUnit_PromptEmbedder(), + AAAUnit_NoiseInitializer(), + AAAUnit_InputImageEmbedder(), + ] + self.model_fn = model_fn_aaa + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = "cuda", + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = None, + vram_limit: float = None, + ): + # Initialize pipeline + pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder") + pipe.dit = model_pool.fetch_model("aaa_dit") + pipe.vae = model_pool.fetch_model("flux2_vae") + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 1.0, + # Image + input_image: Image.Image = None, + denoising_strength: float = 1.0, + # Shape + height: int = 1024, + width: int = 1024, + # Randomness + seed: int = None, + rand_device: str = "cpu", + # Steps + num_inference_steps: int = 30, + # Progress bar + progress_bar_cmd = tqdm, + ): + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16) + + # Parameters + inputs_posi = {"prompt": prompt} + inputs_nega = {"negative_prompt": negative_prompt} + inputs_shared = { + "cfg_scale": cfg_scale, + "input_image": input_image, "denoising_strength": denoising_strength, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "num_inference_steps": num_inference_steps, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae']) + image = self.vae.decode(inputs_shared["latents"]) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image + + +class AAAUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("prompt_embeds",), + onload_model_names=("text_encoder",) + ) + self.hidden_states_layers = (-1,) + + def process(self, pipe: AAAImagePipeline, prompt): + pipe.load_models_to_device(self.onload_model_names) + text = pipe.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device) + output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False) + prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1) + return {"prompt_embeds": prompt_embeds} + + +class AAAUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "seed", "rand_device"), + output_params=("noise",), + ) + + def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device): + noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + return {"noise": noise} + + +class AAAUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: AAAImagePipeline, input_image, noise): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae']) + image = pipe.preprocess_image(input_image) + input_latents = pipe.vae.encode(image) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": input_latents} + + +def model_fn_aaa( + dit: AAADiT, + latents=None, + prompt_embeds=None, + timestep=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + model_output = dit( + latents, + prompt_embeds, + timestep, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + return model_output + + +class AAATrainingModule(DiffusionTrainingModule): + def __init__(self, device): + super().__init__() + self.pipe = AAAImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device=device, + model_configs=[ + ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"), + ) + self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device) + self.pipe.freeze_except(["dit"]) + self.pipe.scheduler.set_timesteps(1000, training=True) + + def forward(self, data): + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {"negative_prompt": ""} + inputs_shared = { + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], + "cfg_scale": 1, + "use_gradient_checkpointing": False, + "use_gradient_checkpointing_offload": False, + } + for unit in self.pipe.units: + inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) + loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi) + return loss + + +if __name__ == "__main__": + accelerator = accelerate.Accelerator(gradient_accumulation_steps=1) + dataset = UnifiedDataset( + base_path="data/images", + metadata_path="data/metadata_merged.csv", + max_data_items=10000000, + data_file_keys=("image",), + main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256) + ) + model = AAATrainingModule(device=accelerator.device) + model_logger = ModelLogger( + "models/AAA/v1", + remove_prefix_in_ckpt="pipe.dit.", + ) + launch_training_task( + accelerator, dataset, model, model_logger, + learning_rate=2e-4, + num_workers=4, + save_steps=50000, + num_epochs=999999, + ) \ No newline at end of file diff --git a/docs/zh/Training/Differential_LoRA.md b/docs/zh/Training/Differential_LoRA.md index 2489ea0..3993c6e 100644 --- a/docs/zh/Training/Differential_LoRA.md +++ b/docs/zh/Training/Differential_LoRA.md @@ -8,8 +8,8 @@ 假设我们有两张内容相似的图像:图 1 和图 2。例如两张图中分别有一辆车,但图 1 中画面细节更少,图 2 中画面细节更多。在差分 LoRA 训练中,我们进行两步训练: -* 以图 1 为训练数据,以[标准监督训练](/docs/zh/Training/Supervised_Fine_Tuning.md)的方式,训练 LoRA 1 -* 以图 2 为训练数据,将 LoRA 1 融入基础模型后,以[标准监督训练](/docs/zh/Training/Supervised_Fine_Tuning.md)的方式,训练 LoRA 2 +* 以图 1 为训练数据,以[标准监督训练](../Training/Supervised_Fine_Tuning.md)的方式,训练 LoRA 1 +* 以图 2 为训练数据,将 LoRA 1 融入基础模型后,以[标准监督训练](../Training/Supervised_Fine_Tuning.md)的方式,训练 LoRA 2 在第一步训练中,由于训练数据仅有一张图,LoRA 模型很容易过拟合,因此训练完成后,LoRA 1 会让模型毫不犹豫地生成图 1,无论随机种子是什么。在第二步训练中,LoRA 模型再次过拟合,因此训练完成后,在 LoRA 1 和 LoRA 2 的共同作用下,模型会毫不犹豫地生成图 2。简言之: diff --git a/docs/zh/Training/Direct_Distill.md b/docs/zh/Training/Direct_Distill.md index 946a767..4a9ae79 100644 --- a/docs/zh/Training/Direct_Distill.md +++ b/docs/zh/Training/Direct_Distill.md @@ -44,7 +44,7 @@ loss = torch.nn.functional.mse_loss(image_1, image_2) ## 在训练框架中使用蒸馏加速训练 -首先,需要生成训练数据,请参考[模型推理](/docs/zh/Pipeline_Usage/Model_Inference.md)部分编写推理代码,以足够多的推理步数生成训练数据。 +首先,需要生成训练数据,请参考[模型推理](../Pipeline_Usage/Model_Inference.md)部分编写推理代码,以足够多的推理步数生成训练数据。 以 Qwen-Image 为例,以下代码可以生成一张图片: @@ -67,7 +67,7 @@ image = pipe(prompt, seed=0, num_inference_steps=40) image.save("image.jpg") ``` -然后,我们把必要的信息编写成[元数据文件](/docs/zh/API_Reference/core/data.md#元数据): +然后,我们把必要的信息编写成[元数据文件](../API_Reference/core/data.md#元数据): ```csv image,prompt,seed,rand_device,num_inference_steps,cfg_scale @@ -86,11 +86,11 @@ modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir bash examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh ``` -请注意,在[训练脚本参数](/docs/zh/Pipeline_Usage/Model_Training.md#脚本参数)中,数据集的图像分辨率设置要避免触发缩放处理。当设定 `--height` 和 `--width` 以启用固定分辨率时,所有训练数据必须是以完全一致的宽高生成的;当设定 `--max_pixels` 以启用动态分辨率时,`--max_pixels` 的数值必须大于或等于任一训练图像的像素面积。 +请注意,在[训练脚本参数](../Pipeline_Usage/Model_Training.md#脚本参数)中,数据集的图像分辨率设置要避免触发缩放处理。当设定 `--height` 和 `--width` 以启用固定分辨率时,所有训练数据必须是以完全一致的宽高生成的;当设定 `--max_pixels` 以启用动态分辨率时,`--max_pixels` 的数值必须大于或等于任一训练图像的像素面积。 ## 训练框架设计思路 -直接蒸馏与[标准监督训练](/docs/zh/Training/Supervised_Fine_Tuning.md)相比,仅训练的损失函数不同,直接蒸馏的损失函数是 `diffsynth.diffusion.loss` 中的 `DirectDistillLoss`。 +直接蒸馏与[标准监督训练](../Training/Supervised_Fine_Tuning.md)相比,仅训练的损失函数不同,直接蒸馏的损失函数是 `diffsynth.diffusion.loss` 中的 `DirectDistillLoss`。 ## 未来工作 diff --git a/docs/zh/Training/FP8_Precision.md b/docs/zh/Training/FP8_Precision.md index a1f428a..09162a1 100644 --- a/docs/zh/Training/FP8_Precision.md +++ b/docs/zh/Training/FP8_Precision.md @@ -1,12 +1,12 @@ # 在训练中启用 FP8 精度 -尽管 `DiffSynth-Studio` 在模型推理中支持[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),但其中的大部分减少显存占用的技术不适合用于训练中,Offload 会导致极为缓慢的训练过程。 +尽管 `DiffSynth-Studio` 在模型推理中支持[显存管理](../Pipeline_Usage/VRAM_management.md),但其中的大部分减少显存占用的技术不适合用于训练中,Offload 会导致极为缓慢的训练过程。 -FP8 精度是唯一可在训练过程中启用的显存管理策略,但本框架目前不支持原生 FP8 精度训练,原因详见 [Q&A: 为什么训练框架不支持原生 FP8 精度训练?](/docs/zh/QA.md#为什么训练框架不支持原生-fp8-精度训练),仅支持将参数不被梯度更新的模型(不需要梯度回传,或梯度仅更新其 LoRA)以 FP8 精度进行存储。 +FP8 精度是唯一可在训练过程中启用的显存管理策略,但本框架目前不支持原生 FP8 精度训练,原因详见 [Q&A: 为什么训练框架不支持原生 FP8 精度训练?](../QA.md#为什么训练框架不支持原生-fp8-精度训练),仅支持将参数不被梯度更新的模型(不需要梯度回传,或梯度仅更新其 LoRA)以 FP8 精度进行存储。 ## 启用 FP8 -在我们提供的训练脚本中,通过参数 `--fp8_models` 即可快速设置以 FP8 精度存储的模型。以 Qwen-Image 的 LoRA 训练为例,我们提供了启用 FP8 训练的脚本,位于 [`/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh`](/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh)。训练完成后,可通过脚本 [`/examples/qwen_image/model_training/special/fp8_training/validate.py`](/examples/qwen_image/model_training/special/fp8_training/validate.py) 验证训练效果。 +在我们提供的训练脚本中,通过参数 `--fp8_models` 即可快速设置以 FP8 精度存储的模型。以 Qwen-Image 的 LoRA 训练为例,我们提供了启用 FP8 训练的脚本,位于 [`/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh)。训练完成后,可通过脚本 [`/examples/qwen_image/model_training/special/fp8_training/validate.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/special/fp8_training/validate.py) 验证训练效果。 请注意,这种 FP8 显存管理策略不支持梯度更新,当某个模型被设置为可训练时,不能为这个模型开启 FP8 精度,支持开启 FP8 的模型包括两类: diff --git a/docs/zh/Training/Split_Training.md b/docs/zh/Training/Split_Training.md index f98d56e..6e87a84 100644 --- a/docs/zh/Training/Split_Training.md +++ b/docs/zh/Training/Split_Training.md @@ -8,7 +8,7 @@ 在大部分模型的训练过程中,大量计算发生在“前处理”中,即“与去噪模型无关的计算”,包括 VAE 编码、文本编码等。当对应的模型参数固定时,这部分计算的结果是重复的,在多个 epoch 中每个数据样本的计算结果完全相同,因此我们提供了“拆分训练”功能,该功能可以自动分析并拆分训练过程。 -对于普通文生图模型的标准监督训练,拆分过程是非常简单的,只需要把所有 [`Pipeline Units`](/docs/zh/Developer_Guide/Building_a_Pipeline.md#units) 的计算拆分到第一阶段,将计算结果存储到硬盘中,然后在第二阶段从硬盘中读取这些结果并进行后续计算即可。但如果前处理过程中需要梯度回传,情况就变得极其复杂,为此,我们引入了一个计算图拆分算法用于分析如何拆分计算。 +对于普通文生图模型的标准监督训练,拆分过程是非常简单的,只需要把所有 [`Pipeline Units`](../Developer_Guide/Building_a_Pipeline.md#units) 的计算拆分到第一阶段,将计算结果存储到硬盘中,然后在第二阶段从硬盘中读取这些结果并进行后续计算即可。但如果前处理过程中需要梯度回传,情况就变得极其复杂,为此,我们引入了一个计算图拆分算法用于分析如何拆分计算。 ## 计算图拆分算法 @@ -16,7 +16,7 @@ ## 使用拆分训练 -拆分训练已支持[标准监督训练](/docs/zh/Training/Supervised_Fine_Tuning.md)和[直接蒸馏训练](/docs/zh/Training/Direct_Distill.md),在训练命令中通过 `--task` 参数控制,以 Qwen-Image 模型的 LoRA 训练为例,拆分前的训练命令为: +拆分训练已支持[标准监督训练](../Training/Supervised_Fine_Tuning.md)和[直接蒸馏训练](../Training/Direct_Distill.md),在训练命令中通过 `--task` 参数控制,以 Qwen-Image 模型的 LoRA 训练为例,拆分前的训练命令为: ```shell accelerate launch examples/qwen_image/model_training/train.py \ diff --git a/docs/zh/Training/Supervised_Fine_Tuning.md b/docs/zh/Training/Supervised_Fine_Tuning.md index f2f8aa3..eb14557 100644 --- a/docs/zh/Training/Supervised_Fine_Tuning.md +++ b/docs/zh/Training/Supervised_Fine_Tuning.md @@ -1,10 +1,10 @@ # 标准监督训练 -在理解 [Diffusion 模型基本原理](/docs/zh/Training/Understanding_Diffusion_models.md)之后,本文档介绍框架如何实现 Diffusion 模型的训练。本文档介绍框架的原理,帮助开发者编写新的训练代码,如需使用我们提供的默认训练功能,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md)。 +在理解 [Diffusion 模型基本原理](../Training/Understanding_Diffusion_models.md)之后,本文档介绍框架如何实现 Diffusion 模型的训练。本文档介绍框架的原理,帮助开发者编写新的训练代码,如需使用我们提供的默认训练功能,请参考[模型训练](../Pipeline_Usage/Model_Training.md)。 回顾前文中的模型训练伪代码,当我们实际编写代码时,情况会变得极为复杂。部分模型需要输入额外的引导条件并进行预处理,例如 ControlNet;部分模型需要与去噪模型进行交叉式的计算,例如 VACE;部分模型因显存需求过大,需要开启 Gradient Checkpointing,例如 Qwen-Image 的 DiT。 -为了实现严格的推理和训练一致性,我们对 `Pipeline` 等组件进行了抽象封装,在训练过程中大量复用推理代码。请参考[接入 Pipeline](/docs/zh/Developer_Guide/Building_a_Pipeline.md) 了解 `Pipeline` 组件的设计。接下来我们介绍训练框架如何利用 `Pipeline` 组件构建训练算法。 +为了实现严格的推理和训练一致性,我们对 `Pipeline` 等组件进行了抽象封装,在训练过程中大量复用推理代码。请参考[接入 Pipeline](../Developer_Guide/Building_a_Pipeline.md) 了解 `Pipeline` 组件的设计。接下来我们介绍训练框架如何利用 `Pipeline` 组件构建训练算法。 ## 框架设计思路 @@ -48,13 +48,13 @@ class QwenImageTrainingModule(DiffusionTrainingModule): ) ``` -加载模型的逻辑与推理时基本一致,支持从远程和本地路径加载模型,详见[模型推理](/docs/zh/Pipeline_Usage/Model_Inference.md),但请注意不要启用[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)。 +加载模型的逻辑与推理时基本一致,支持从远程和本地路径加载模型,详见[模型推理](../Pipeline_Usage/Model_Inference.md),但请注意不要启用[显存管理](../Pipeline_Usage/VRAM_management.md)。 `switch_pipe_to_training_mode` 可以将模型切换到训练模式,详见 `switch_pipe_to_training_mode`。 ### `forward` -在 `forward` 中需计算损失函数值,先进行前处理,然后经过 `Pipeline` 的 [`model_fn`](/docs/zh/Developer_Guide/Building_a_Pipeline.md#model_fn) 计算损失函数。 +在 `forward` 中需计算损失函数值,先进行前处理,然后经过 `Pipeline` 的 [`model_fn`](../Developer_Guide/Building_a_Pipeline.md#model_fn) 计算损失函数。 ```python def forward(self, data): @@ -90,7 +90,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule): 训练框架还需其他模块,包括: * accelerator: `accelerate` 提供的训练启动器,详见 [`accelerate`](https://huggingface.co/docs/accelerate/index) -* dataset: 通用数据集,详见 [`diffsynth.core.data`](/docs/zh/API_Reference/core/data.md) +* dataset: 通用数据集,详见 [`diffsynth.core.data`](../API_Reference/core/data.md) * model_logger: 模型记录器,详见 `diffsynth.diffusion.logger` ```python diff --git a/docs/zh/Training/Understanding_Diffusion_models.md b/docs/zh/Training/Understanding_Diffusion_models.md index 576edc9..1ac1423 100644 --- a/docs/zh/Training/Understanding_Diffusion_models.md +++ b/docs/zh/Training/Understanding_Diffusion_models.md @@ -6,7 +6,7 @@ Diffusion 模型通过多步迭代式地去噪(denoise)生成清晰的图像或视频内容,我们从一个数据样本 $x_0$ 的生成过程开始讲起。直观地,在完整的一轮 denoise 过程中,我们从随机高斯噪声 $x_T$ 开始,通过迭代依次得到 $x_{T-1}$、$x_{T-2}$、$x_{T-3}$、$\cdots$,在每一步中逐渐减少噪声含量,最终得到不含噪声的数据样本 $x_0$。 -(图) +![Image](https://github.com/user-attachments/assets/6471ae4c-a635-4924-8b36-b0bd4d42043d) 这个过程是很直观的,但如果要理解其中的细节,我们就需要回答这几个问题: @@ -28,7 +28,7 @@ Diffusion 模型通过多步迭代式地去噪(denoise)生成清晰的图像 那么在中间的某一步,我们可以直接合成含噪声的数据样本 $x_t=(1-\sigma_t)x_0+\sigma_t x_T$。 -(图) +![Image](https://github.com/user-attachments/assets/e25a2f71-123c-4e18-8b34-3a066af15667) ## 迭代去噪的计算是如何进行的? @@ -40,11 +40,10 @@ Diffusion 模型通过多步迭代式地去噪(denoise)生成清晰的图像 其中,引导条件 $c$ 是新引入的参数,它是由用户输入的,可以是用于描述图像内容的文本,也可以是用于勾勒图像结构的线稿图。 -(图) - 而模型的输出 $\hat \epsilon(x_t,c,t)$,则近似地等于 $x_T-x_0$,也就是整个扩散过程(去噪过程的反向过程)的方向。 接下来我们分析一步迭代中发生的计算,在时间步 $t$,模型通过计算得到近似的 $x_T-x_0$ 后,我们计算下一步的 $x_{t-1}$: + $$ \begin{aligned} x_{t-1}&=x_t + (\sigma_{t-1} - \sigma_t) \cdot \hat \epsilon(x_t,c,t)\\ @@ -53,6 +52,7 @@ x_{t-1}&=x_t + (\sigma_{t-1} - \sigma_t) \cdot \hat \epsilon(x_t,c,t)\\ &=(1-\sigma_{t-1})x_0+\sigma_{t-1}x_T \end{aligned} $$ + 完美!与时间步 $t-1$ 时的噪声含量定义完美契合。 > (这部分可能有点难懂,请不必担心,首次阅读本文时建议跳过这部分,不影响后文的阅读。) @@ -89,8 +89,6 @@ $$ 训练过程不同于生成过程,如果我们在训练过程中保留多步迭代,那么梯度需经过多步回传,带来的时间和空间复杂度是灾难性的。为了提高计算效率,我们在训练中随机选择某一时间步 $t$ 进行训练。 -(图) - 以下是训练过程的伪代码 > 从数据集获取数据样本 $x_0$ 和引导条件 $c$ @@ -111,7 +109,7 @@ $$ 从理论到实践,还需要填充更多细节。现代 Diffusion 模型架构已经发展成熟,主流的架构沿用了 Latent Diffusion 所提出的“三段式”架构,包括数据编解码器、引导条件编码器、去噪模型三部分。 -(图) +![Image](https://github.com/user-attachments/assets/43855430-6427-4aca-83a0-f684e01438b1) ### 数据编解码器 @@ -140,4 +138,4 @@ $$ ## 本项目如何封装和实现模型训练? -请阅读下一文档:[标准监督训练](/docs/zh/Training/Supervised_Fine_Tuning.md) +请阅读下一文档:[标准监督训练](../Training/Supervised_Fine_Tuning.md) diff --git a/docs/zh/conf.py b/docs/zh/conf.py new file mode 100644 index 0000000..6c5ec30 --- /dev/null +++ b/docs/zh/conf.py @@ -0,0 +1,124 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys + +# import sphinx_book_theme + +sys.path.insert(0, os.path.abspath('../../')) +# -- Project information ----------------------------------------------------- + +project = 'diffsynth' +copyright = '2022-2025, Alibaba ModelScope' +author = 'ModelScope Authors' +version_file = '../../diffsynth/version.py' +html_theme = 'sphinx_rtd_theme' +language = 'zh_CN' + + +def get_version(): + with open(version_file, 'r', encoding='utf-8') as f: + exec(compile(f.read(), version_file, 'exec')) + return locals()['__version__'] + + +# The full version, including alpha/beta/rc tags +version = get_version() +release = version + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.napoleon', + 'sphinx.ext.autosummary', + 'sphinx.ext.autodoc', + 'sphinx.ext.viewcode', + 'sphinx_markdown_tables', + 'sphinx_copybutton', + "sphinx_rtd_theme", + 'sphinx.ext.mathjax', + 'myst_parser', +] +# build the templated autosummary files +autosummary_generate = True +numpydoc_show_class_members = False + +# Enable overriding of function signatures in the first line of the docstring. +autodoc_docstring_signature = True + +# Disable docstring inheritance +autodoc_inherit_docstrings = False + +# Show type hints in the description +autodoc_typehints = 'description' + +# Add parameter types if the parameter is documented in the docstring +autodoc_typehints_description_target = 'documented_params' + +autodoc_default_options = { + 'member-order': 'bysource', +} + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +source_suffix = ['.rst', '.md'] + +# The master toctree document. +root_doc = 'index' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ['build'] +# A list of glob-style patterns [1] that are used to find source files. +# They are matched against the source file names relative to the source directory, +# using slashes as directory separators on all platforms. +# The default is **, meaning that all files are recursively included from the source directory. +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +# html_theme = 'sphinx_book_theme' +# html_theme_path = [sphinx_book_theme.get_html_theme_path()] +# html_theme_options = {} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] +# html_css_files = ['css/readthedocs.css'] + +# -- Options for HTMLHelp output --------------------------------------------- +# Output file base name for HTML help builder. + +# -- Extension configuration ------------------------------------------------- +# Ignore >>> when copying code +copybutton_prompt_text = r'>>> |\.\.\. ' +copybutton_prompt_is_regexp = True + +# Example configuration for intersphinx: refer to the Python standard library. +intersphinx_mapping = {'https://docs.python.org/': None} + +myst_enable_extensions = [ + 'amsmath', + 'dollarmath', + 'colon_fence', +] diff --git a/docs/zh/index.rst b/docs/zh/index.rst new file mode 100644 index 0000000..4e82d3e --- /dev/null +++ b/docs/zh/index.rst @@ -0,0 +1,77 @@ +欢迎来到 DiffSynth-Studio 的文档 +===================== + +.. toctree:: + :maxdepth: 2 + :caption: 文档介绍 + + README + +.. toctree:: + :maxdepth: 2 + :caption: 上手使用 + + Pipeline_Usage/Setup + Pipeline_Usage/Model_Inference + Pipeline_Usage/VRAM_management + Pipeline_Usage/Model_Training + Pipeline_Usage/Environment_Variables + Pipeline_Usage/GPU_support + +.. toctree:: + :maxdepth: 2 + :caption: 模型详解 + + Model_Details/FLUX + Model_Details/Wan + Model_Details/Qwen-Image + Model_Details/FLUX2 + Model_Details/Z-Image + +.. toctree:: + :maxdepth: 2 + :caption: 训练框架 + + Training/Understanding_Diffusion_models + Training/Supervised_Fine_Tuning + Training/FP8_Precision + Training/Direct_Distill + Training/Split_Training + Training/Differential_LoRA + +.. toctree:: + :maxdepth: 2 + :caption: 模型接入 + + Developer_Guide/Integrating_Your_Model + Developer_Guide/Building_a_Pipeline + Developer_Guide/Enabling_VRAM_management + Developer_Guide/Training_Diffusion_Models + +.. toctree:: + :maxdepth: 2 + :caption: API 参考 + + API_Reference/core/attention + API_Reference/core/data + API_Reference/core/gradient + API_Reference/core/loader + API_Reference/core/vram + +.. toctree:: + :maxdepth: 2 + :caption: 学术导引 + + Research_Tutorial/train_from_scratch + +.. toctree:: + :maxdepth: 2 + :caption: 常见问题 + + QA + +Indices and tables +================== +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/examples/flux/model_training/full/accelerate_config_zero3.yaml b/examples/flux/model_training/full/accelerate_config_zero3.yaml new file mode 100644 index 0000000..e6a8d27 --- /dev/null +++ b/examples/flux/model_training/full/accelerate_config_zero3.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py b/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py index dbdc8e4..b1f6f40 100644 --- a/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py +++ b/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py @@ -21,6 +21,7 @@ pipe = Flux2ImagePipeline.from_pretrained( ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), ], tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, ) prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4) diff --git a/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py b/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py index dc7b9a7..f79d8b3 100644 --- a/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py +++ b/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py @@ -21,6 +21,7 @@ pipe = Flux2ImagePipeline.from_pretrained( ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), ], tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, ) prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4) diff --git a/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py b/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py index 5a1517f..4538fdb 100644 --- a/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py +++ b/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py @@ -21,6 +21,7 @@ pipe = Flux2ImagePipeline.from_pretrained( ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), ], tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, ) prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4) diff --git a/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py b/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py index e0df8a6..65a59f6 100644 --- a/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py +++ b/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py @@ -21,6 +21,7 @@ pipe = Flux2ImagePipeline.from_pretrained( ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), ], tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, ) prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4) diff --git a/examples/flux2/model_training/full/accelerate_config_zero3.yaml b/examples/flux2/model_training/full/accelerate_config_zero3.yaml new file mode 100644 index 0000000..e6a8d27 --- /dev/null +++ b/examples/flux2/model_training/full/accelerate_config_zero3.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/flux2/model_training/special/npu_training/FLUX.2-dev-Lora-NPU.sh b/examples/flux2/model_training/special/npu_training/FLUX.2-dev-Lora-NPU.sh new file mode 100644 index 0000000..ed678f2 --- /dev/null +++ b/examples/flux2/model_training/special/npu_training/FLUX.2-dev-Lora-NPU.sh @@ -0,0 +1,36 @@ +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export CPU_AFFINITY_CONF=1 + +accelerate launch examples/flux2/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 1 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-dev:text_encoder/*.safetensors,black-forest-labs/FLUX.2-dev:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-dev-LoRA-splited-cache" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_qkv_mlp_proj,to_out.0,to_add_out,linear_in,linear_out,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out,single_transformer_blocks.24.attn.to_out,single_transformer_blocks.25.attn.to_out,single_transformer_blocks.26.attn.to_out,single_transformer_blocks.27.attn.to_out,single_transformer_blocks.28.attn.to_out,single_transformer_blocks.29.attn.to_out,single_transformer_blocks.30.attn.to_out,single_transformer_blocks.31.attn.to_out,single_transformer_blocks.32.attn.to_out,single_transformer_blocks.33.attn.to_out,single_transformer_blocks.34.attn.to_out,single_transformer_blocks.35.attn.to_out,single_transformer_blocks.36.attn.to_out,single_transformer_blocks.37.attn.to_out,single_transformer_blocks.38.attn.to_out,single_transformer_blocks.39.attn.to_out,single_transformer_blocks.40.attn.to_out,single_transformer_blocks.41.attn.to_out,single_transformer_blocks.42.attn.to_out,single_transformer_blocks.43.attn.to_out,single_transformer_blocks.44.attn.to_out,single_transformer_blocks.45.attn.to_out,single_transformer_blocks.46.attn.to_out,single_transformer_blocks.47.attn.to_out" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --task "sft:data_process" + +accelerate launch --config_file examples/flux2/model_training/full/accelerate_config_zero3.yaml examples/flux2/model_training/train.py \ + --dataset_base_path "./models/train/FLUX.2-dev-LoRA-splited-cache" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-dev:transformer/*.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-dev-LoRA-splited" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_qkv_mlp_proj,to_out.0,to_add_out,linear_in,linear_out,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out,single_transformer_blocks.24.attn.to_out,single_transformer_blocks.25.attn.to_out,single_transformer_blocks.26.attn.to_out,single_transformer_blocks.27.attn.to_out,single_transformer_blocks.28.attn.to_out,single_transformer_blocks.29.attn.to_out,single_transformer_blocks.30.attn.to_out,single_transformer_blocks.31.attn.to_out,single_transformer_blocks.32.attn.to_out,single_transformer_blocks.33.attn.to_out,single_transformer_blocks.34.attn.to_out,single_transformer_blocks.35.attn.to_out,single_transformer_blocks.36.attn.to_out,single_transformer_blocks.37.attn.to_out,single_transformer_blocks.38.attn.to_out,single_transformer_blocks.39.attn.to_out,single_transformer_blocks.40.attn.to_out,single_transformer_blocks.41.attn.to_out,single_transformer_blocks.42.attn.to_out,single_transformer_blocks.43.attn.to_out,single_transformer_blocks.44.attn.to_out,single_transformer_blocks.45.attn.to_out,single_transformer_blocks.46.attn.to_out,single_transformer_blocks.47.attn.to_out" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --initialize_model_on_cpu \ + --task "sft:train" diff --git a/examples/flux2/model_training/special/npu_training/FLUX.2-klein-9B-NPU.sh b/examples/flux2/model_training/special/npu_training/FLUX.2-klein-9B-NPU.sh new file mode 100644 index 0000000..57755ac --- /dev/null +++ b/examples/flux2/model_training/special/npu_training/FLUX.2-klein-9B-NPU.sh @@ -0,0 +1,34 @@ +# This script is tested on 8*910B(NPU) +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export CPU_AFFINITY_CONF=1 + +accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \ + --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-klein-9B_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing + +# Edit +# accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \ +# --dataset_base_path data/example_image_dataset \ +# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \ +# --data_file_keys "image,edit_image" \ +# --extra_inputs "edit_image" \ +# --max_pixels 1048576 \ +# --dataset_repeat 50 \ +# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \ +# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \ +# --learning_rate 1e-5 \ +# --num_epochs 2 \ +# --remove_prefix_in_ckpt "pipe.dit." \ +# --output_path "./models/train/FLUX.2-klein-9B_full" \ +# --trainable_models "dit" \ +# --use_gradient_checkpointing diff --git a/examples/flux2/model_training/train.py b/examples/flux2/model_training/train.py index ea727b8..6101687 100644 --- a/examples/flux2/model_training/train.py +++ b/examples/flux2/model_training/train.py @@ -85,6 +85,7 @@ def flux2_parser(): parser = add_general_config(parser) parser = add_image_size_config(parser) parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") + parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.") return parser @@ -126,7 +127,7 @@ if __name__ == "__main__": fp8_models=args.fp8_models, offload_models=args.offload_models, task=args.task, - device=accelerator.device, + device="cpu" if args.initialize_model_on_cpu else accelerator.device, ) model_logger = ModelLogger( args.output_path, diff --git a/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py b/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py new file mode 100644 index 0000000..b8e0811 --- /dev/null +++ b/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py @@ -0,0 +1,69 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 +from PIL import Image +from modelscope import dataset_snapshot_download + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), +) + +prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”" +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"] +) +image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height)) +# first frame +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_distilled_pipeline=True, + input_images=[image], + input_images_indexes=[0], + input_images_strength=1.0, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_distilled_i2av_first.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py b/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py new file mode 100644 index 0000000..1614c1a --- /dev/null +++ b/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py @@ -0,0 +1,55 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 +from PIL import Image +from modelscope import dataset_snapshot_download + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), +) +prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”" +negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +height, width, num_frames = 512 * 2, 768 * 2, 121 +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"] +) +image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height)) +# first frame +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=False, + input_images=[image], + input_images_indexes=[0], + input_images_strength=1.0, + num_inference_steps=40, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_onestage_i2av_first.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py b/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py new file mode 100644 index 0000000..e73ef3d --- /dev/null +++ b/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py @@ -0,0 +1,72 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 +from PIL import Image +from modelscope import dataset_snapshot_download + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +) + +prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”" +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +height, width, num_frames = 512 * 2, 768 * 2, 121 +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"] +) +image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height)) +# first frame +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=42, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_two_stage_pipeline=True, + num_inference_steps=40, + input_images=[image], + input_images_indexes=[0], + input_images_strength=1.0, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_twostage_i2av_first.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py new file mode 100644 index 0000000..c1dc94b --- /dev/null +++ b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py @@ -0,0 +1,62 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +) +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In", origin_file_pattern="ltx-2-19b-lora-camera-control-dolly-in.safetensors"), +) + +prompt = "Dolly-in shot: A cheerful girl smiles brightly and says, 'I enjoy working with Diffsynth-Studio, it's a perfect framework.' The camera smoothly moves closer to her face, highlighting her enthusiasm and sincerity." +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_two_stage_pipeline=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_dolly_in_lora.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py new file mode 100644 index 0000000..f6b3f0a --- /dev/null +++ b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py @@ -0,0 +1,62 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +) +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left", origin_file_pattern="ltx-2-19b-lora-camera-control-dolly-left.safetensors"), +) + +prompt = "Dolly-left shot: A joyful young woman sits at a minimalist desk with a laptop running Diffsynth-Studio, code and generative visuals glowing on screen. She turns slightly toward the camera and says with a smile, 'I enjoy working with Diffsynth-Studio, it's a perfect framework.' As she speaks, the camera smoothly dollies left, revealing a wall of framed open-source project posters, a whiteboard covered in neural network sketches, and a shelf stacked with AI/graphics books beside her." +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_two_stage_pipeline=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_dolly_left_lora.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py new file mode 100644 index 0000000..6f8fd72 --- /dev/null +++ b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py @@ -0,0 +1,63 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +) +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out", origin_file_pattern="ltx-2-19b-lora-camera-control-dolly-out.safetensors"), +) + +prompt = "Dolly-out shot: A joyful young woman smiles warmly and says: 'I enjoy working with Diffsynth-Studio, it's a perfect framework.' As she speaks, the camera slowly dollies out, revealing a bright, modern creative studio filled with plants, whiteboards full of diagrams, and soft natural light from large windows." + +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_two_stage_pipeline=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path=f'ltx2_dolly_out.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py new file mode 100644 index 0000000..2de3233 --- /dev/null +++ b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py @@ -0,0 +1,62 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +) +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right", origin_file_pattern="ltx-2-19b-lora-camera-control-dolly-right.safetensors"), +) + +prompt = "Dolly-right shot: A happy girl looks up and says happily, 'I enjoy working with Diffsynth-Studio, it's a perfect framework.' She sits before a sunlit café table, her open laptop displaying the Github interface. The camera glides right to show a barista crafting coffee in the background, shelves of artisan beans, and a chalkboard menu softly blurred in the bokeh." +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_two_stage_pipeline=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_dolly_right_lora.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py new file mode 100644 index 0000000..571fd6b --- /dev/null +++ b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py @@ -0,0 +1,72 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +) +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down", origin_file_pattern="ltx-2-19b-lora-camera-control-jib-down.safetensors"), +) +prompt = ( + "A girl is very happy, standing on a clean studio floor with soft ambient lighting. " + "She is speaking directly to the camera: “I enjoy working with Diffsynth-Studio, it's a perfect framework.” " + "The shot begins with a medium close-up framing her from the waist up. As the camera performs a smooth jib-down movement—" + "descending vertically downward—it gradually reveals more of the lower portion of the scene. " + "During the descent, the following elements become visible near the bottom of the frame: " + "- The polished concrete floor with subtle reflections of the girl’s shoes, " + "- A small branded mat labeled “Diffsynth-Studio” placed just beneath her feet, " + "- The lower part of a sleek workstation desk with a glowing logo on its front panel, partially hidden at the start but fully revealed as the camera lowers. " + "This downward motion provides a dynamic reveal of contextual details that reinforce the professional and creative environment, " + "while maintaining focus on the girl’s enthusiastic expression throughout." +) +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_two_stage_pipeline=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_camera_jib_down.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py new file mode 100644 index 0000000..18905fe --- /dev/null +++ b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py @@ -0,0 +1,66 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +) +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up", origin_file_pattern="ltx-2-19b-lora-camera-control-jib-up.safetensors"), +) +prompt = ( + "A girl stands happily at a sleek desk with a glowing 'Diffsynth-Studio' logo, saying: “I enjoy working with Diffsynth-Studio, it's a perfect framework.” " + "The shot starts low—framing her waist, shoes, and a branded floor mat—and smoothly jib-ups upward. " + "As the camera rises, it reveals her smiling face, upper body, and behind her: a bright creative studio with wall art and a large window showing daylight sky. " + "The final frame fully shows the inspiring workspace above the initial view, ensuring spatial continuity." +) +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_two_stage_pipeline=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_camera_jib_up.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py new file mode 100644 index 0000000..ffa9b38 --- /dev/null +++ b/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py @@ -0,0 +1,61 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +) +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="Lightricks/LTX-2-19b-LoRA-Camera-Control-Static", origin_file_pattern="ltx-2-19b-lora-camera-control-static.safetensors"), +) +prompt = "A beautiful sunset over the ocean." +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_two_stage_pipeline=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_camera_static.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py b/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py new file mode 100644 index 0000000..2b87dd3 --- /dev/null +++ b/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py @@ -0,0 +1,57 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), +) + +prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”" +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_distilled_pipeline=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_distilled.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py b/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py new file mode 100644 index 0000000..ade78d0 --- /dev/null +++ b/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py @@ -0,0 +1,42 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), +) +prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”" +negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +height, width, num_frames = 512, 768, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_onestage.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py b/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py new file mode 100644 index 0000000..84bbc0c --- /dev/null +++ b/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py @@ -0,0 +1,58 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +) + +prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”" +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_two_stage_pipeline=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_twostage.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py b/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py new file mode 100644 index 0000000..7020b40 --- /dev/null +++ b/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py @@ -0,0 +1,70 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 +from PIL import Image +from modelscope import dataset_snapshot_download + +vram_config = { + "offload_dtype": torch.float8_e5m2, + "offload_device": "cpu", + "onload_dtype": torch.float8_e5m2, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e5m2, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”" +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"] +) +image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height)) +# first frame +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_distilled_pipeline=True, + input_images=[image], + input_images_indexes=[0], + input_images_strength=1.0, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_distilled_i2av_first.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py b/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py new file mode 100644 index 0000000..48ca23b --- /dev/null +++ b/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py @@ -0,0 +1,56 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 +from PIL import Image +from modelscope import dataset_snapshot_download + +vram_config = { + "offload_dtype": torch.float8_e5m2, + "offload_device": "cpu", + "onload_dtype": torch.float8_e5m2, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e5m2, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”" +negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +height, width, num_frames = 512 * 2, 768 * 2, 121 +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"] +) +image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height)) +# first frame +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=False, + input_images=[image], + input_images_indexes=[0], + input_images_strength=1.0, + num_inference_steps=40, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_onestage_i2av_first.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py b/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py new file mode 100644 index 0000000..5411b8c --- /dev/null +++ b/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py @@ -0,0 +1,72 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 +from PIL import Image +from modelscope import dataset_snapshot_download + +vram_config = { + "offload_dtype": torch.float8_e5m2, + "offload_device": "cpu", + "onload_dtype": torch.float8_e5m2, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e5m2, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”" +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"] +) +image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height)) +# first frame +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=42, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_two_stage_pipeline=True, + num_inference_steps=40, + input_images=[image], + input_images_indexes=[0], + input_images_strength=1.0, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_twostage_i2av_first.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py new file mode 100644 index 0000000..b15e4cf --- /dev/null +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py @@ -0,0 +1,62 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.float8_e5m2, + "offload_device": "cpu", + "onload_dtype": torch.float8_e5m2, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e5m2, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +) +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In", origin_file_pattern="ltx-2-19b-lora-camera-control-dolly-in.safetensors"), +) + +prompt = "Dolly-in shot: A cheerful girl smiles brightly and says, 'I enjoy working with Diffsynth-Studio, it's a perfect framework.' The camera smoothly moves closer to her face, highlighting her enthusiasm and sincerity." +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_two_stage_pipeline=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_dolly_in_lora.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py new file mode 100644 index 0000000..4a7a5aa --- /dev/null +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py @@ -0,0 +1,62 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.float8_e5m2, + "offload_device": "cpu", + "onload_dtype": torch.float8_e5m2, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e5m2, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +) +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left", origin_file_pattern="ltx-2-19b-lora-camera-control-dolly-left.safetensors"), +) + +prompt = "Dolly-left shot: A joyful young woman sits at a minimalist desk with a laptop running Diffsynth-Studio, code and generative visuals glowing on screen. She turns slightly toward the camera and says with a smile, 'I enjoy working with Diffsynth-Studio, it's a perfect framework.' As she speaks, the camera smoothly dollies left, revealing a wall of framed open-source project posters, a whiteboard covered in neural network sketches, and a shelf stacked with AI/graphics books beside her." +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_two_stage_pipeline=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_dolly_left_lora.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py new file mode 100644 index 0000000..9ae6884 --- /dev/null +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py @@ -0,0 +1,63 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.float8_e5m2, + "offload_device": "cpu", + "onload_dtype": torch.float8_e5m2, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e5m2, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +) +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out", origin_file_pattern="ltx-2-19b-lora-camera-control-dolly-out.safetensors"), +) + +prompt = "Dolly-out shot: A joyful young woman smiles warmly and says: 'I enjoy working with Diffsynth-Studio, it's a perfect framework.' As she speaks, the camera slowly dollies out, revealing a bright, modern creative studio filled with plants, whiteboards full of diagrams, and soft natural light from large windows." + +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_two_stage_pipeline=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path=f'ltx2_dolly_out.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py new file mode 100644 index 0000000..ab9f9ae --- /dev/null +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py @@ -0,0 +1,62 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.float8_e5m2, + "offload_device": "cpu", + "onload_dtype": torch.float8_e5m2, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e5m2, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +) +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right", origin_file_pattern="ltx-2-19b-lora-camera-control-dolly-right.safetensors"), +) + +prompt = "Dolly-right shot: A happy girl looks up and says happily, 'I enjoy working with Diffsynth-Studio, it's a perfect framework.' She sits before a sunlit café table, her open laptop displaying the Github interface. The camera glides right to show a barista crafting coffee in the background, shelves of artisan beans, and a chalkboard menu softly blurred in the bokeh." +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_two_stage_pipeline=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_dolly_right_lora.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py new file mode 100644 index 0000000..9fc6e41 --- /dev/null +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py @@ -0,0 +1,72 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.float8_e5m2, + "offload_device": "cpu", + "onload_dtype": torch.float8_e5m2, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e5m2, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +) +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down", origin_file_pattern="ltx-2-19b-lora-camera-control-jib-down.safetensors"), +) +prompt = ( + "A girl is very happy, standing on a clean studio floor with soft ambient lighting. " + "She is speaking directly to the camera: “I enjoy working with Diffsynth-Studio, it's a perfect framework.” " + "The shot begins with a medium close-up framing her from the waist up. As the camera performs a smooth jib-down movement—" + "descending vertically downward—it gradually reveals more of the lower portion of the scene. " + "During the descent, the following elements become visible near the bottom of the frame: " + "- The polished concrete floor with subtle reflections of the girl’s shoes, " + "- A small branded mat labeled “Diffsynth-Studio” placed just beneath her feet, " + "- The lower part of a sleek workstation desk with a glowing logo on its front panel, partially hidden at the start but fully revealed as the camera lowers. " + "This downward motion provides a dynamic reveal of contextual details that reinforce the professional and creative environment, " + "while maintaining focus on the girl’s enthusiastic expression throughout." +) +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_two_stage_pipeline=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_camera_jib_down.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py new file mode 100644 index 0000000..628e7c3 --- /dev/null +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py @@ -0,0 +1,66 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.float8_e5m2, + "offload_device": "cpu", + "onload_dtype": torch.float8_e5m2, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e5m2, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +) +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up", origin_file_pattern="ltx-2-19b-lora-camera-control-jib-up.safetensors"), +) +prompt = ( + "A girl stands happily at a sleek desk with a glowing 'Diffsynth-Studio' logo, saying: “I enjoy working with Diffsynth-Studio, it's a perfect framework.” " + "The shot starts low—framing her waist, shoes, and a branded floor mat—and smoothly jib-ups upward. " + "As the camera rises, it reveals her smiling face, upper body, and behind her: a bright creative studio with wall art and a large window showing daylight sky. " + "The final frame fully shows the inspiring workspace above the initial view, ensuring spatial continuity." +) +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_two_stage_pipeline=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_camera_jib_up.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py new file mode 100644 index 0000000..b6394bc --- /dev/null +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py @@ -0,0 +1,61 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.float8_e5m2, + "offload_device": "cpu", + "onload_dtype": torch.float8_e5m2, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e5m2, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +) +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="Lightricks/LTX-2-19b-LoRA-Camera-Control-Static", origin_file_pattern="ltx-2-19b-lora-camera-control-static.safetensors"), +) +prompt = "A beautiful sunset over the ocean." +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_two_stage_pipeline=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_camera_static.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py new file mode 100644 index 0000000..d8b6a5d --- /dev/null +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py @@ -0,0 +1,58 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.float8_e5m2, + "offload_device": "cpu", + "onload_dtype": torch.float8_e5m2, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e5m2, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”" +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_distilled_pipeline=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_distilled.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py new file mode 100644 index 0000000..894c417 --- /dev/null +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py @@ -0,0 +1,43 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.float8_e5m2, + "offload_device": "cpu", + "onload_dtype": torch.float8_e5m2, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e5m2, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”" +negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +height, width, num_frames = 512, 768, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_onestage.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py new file mode 100644 index 0000000..65650d0 --- /dev/null +++ b/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py @@ -0,0 +1,59 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.float8_e5m2, + "offload_device": "cpu", + "onload_dtype": torch.float8_e5m2, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e5m2, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”" +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_two_stage_pipeline=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_twostage.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-InpaintCanny.py b/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-InpaintCanny.py new file mode 100644 index 0000000..f8af9e8 --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-InpaintCanny.py @@ -0,0 +1,49 @@ +import torch +from PIL import Image +from modelscope import dataset_snapshot_download +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny", origin_file_pattern="model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="canny/*.jpg" +) +prompt = "一只小狗,毛发光洁柔顺,眼神灵动,背景是樱花纷飞的春日庭院,唯美温馨。" + +controlnet_canny_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1328, 1328)) + +controlnet_inpaint_image = Image.open("./data/example_image_dataset/canny/image_2.jpg").convert("RGB").resize((1328, 1328)) +# generate a centered square mask +inpaint_mask = Image.new("L", controlnet_inpaint_image.size, 0) +mask_size = 512 +left = (controlnet_inpaint_image.width - mask_size) // 2 +top = (controlnet_inpaint_image.height - mask_size) // 2 +right = left + mask_size +bottom = top + mask_size +inpaint_mask.paste(255, (left, top, right, bottom)) +inpaint_mask = inpaint_mask.resize((1328, 1328)).convert("RGB") + +image = pipe( + prompt, seed=0, + input_image=controlnet_inpaint_image, inpaint_mask=inpaint_mask, + blockwise_controlnet_inputs=[ + ControlNetInput(image=controlnet_inpaint_image, inpaint_mask=inpaint_mask, controlnet_id=0), + ControlNetInput(image=controlnet_canny_image, controlnet_id=1), + ], + num_inference_steps=40, +) +image.save("image.jpg") diff --git a/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py b/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py new file mode 100644 index 0000000..67eca5a --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py @@ -0,0 +1,47 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from modelscope import snapshot_download +from PIL import Image +import torch + +# Load models +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image-Edit-2511", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), +) +lora = ModelConfig( + model_id="DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA", + origin_file_pattern="model.safetensors" +) +pipe.load_lora(pipe.dit, lora) + +# Load images +snapshot_download( + "DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA", + local_dir="./data", + allow_file_pattern="assets/*" +) +edit_image = [ + Image.open("data/assets/image1_original.png"), + Image.open("data/assets/image1_edit_1.png"), + Image.open("data/assets/image2_original.png") +] +prompt = "Edit image 3 based on the transformation from image 1 to image 2." +negative_prompt = "泛黄,AI感,不真实,丑陋,油腻的皮肤,异常的肢体,不协调的肢体" + +# Generate +image_4 = pipe( + prompt=prompt, negative_prompt=negative_prompt, + edit_image=edit_image, + seed=1, + num_inference_steps=50, + height=1280, + width=720, + zero_cond_t=True, +) +image_4.save("image.png") \ No newline at end of file diff --git a/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py b/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py new file mode 100644 index 0000000..098a77c --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py @@ -0,0 +1,53 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, FlowMatchScheduler +from modelscope import dataset_snapshot_download +from PIL import Image +import torch + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image-Edit-2511", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), +) + +lora = ModelConfig( + model_id="lightx2v/Qwen-Image-Edit-2511-Lightning", + origin_file_pattern="Qwen-Image-Edit-2511-Lightning-4steps-V1.0-bf16.safetensors" +) +pipe.load_lora(pipe.dit, lora, alpha=8/64) +pipe.scheduler = FlowMatchScheduler("Qwen-Image-Lightning") + + +dataset_snapshot_download( + "DiffSynth-Studio/example_image_dataset", + allow_file_pattern="qwen_image_edit/*", + local_dir="data/example_image_dataset", +) + +prompt = "生成这两个人的合影" +edit_image = [ + Image.open("data/example_image_dataset/qwen_image_edit/image1.jpg"), + Image.open("data/example_image_dataset/qwen_image_edit/image2.jpg"), +] +image = pipe( + prompt, + edit_image=edit_image, + seed=1, + num_inference_steps=4, + height=1152, + width=896, + edit_image_auto_resize=True, + zero_cond_t=True, # This is a special parameter introduced by Qwen-Image-Edit-2511 + cfg_scale=1.0, +) +image.save("image.jpg") + +# Qwen-Image-Edit-2511 is a multi-image editing model. +# Please use a list to input `edit_image`, even if the input contains only one image. +# edit_image = [Image.open("image.jpg")] +# Please do not input the image directly. +# edit_image = Image.open("image.jpg") diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-InpaintCanny.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-InpaintCanny.py new file mode 100644 index 0000000..38fdcc1 --- /dev/null +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-InpaintCanny.py @@ -0,0 +1,59 @@ +import torch +from PIL import Image +from modelscope import dataset_snapshot_download +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny", origin_file_pattern="model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="canny/*.jpg" +) +prompt = "一只小狗,毛发光洁柔顺,眼神灵动,背景是樱花纷飞的春日庭院,唯美温馨。" + +controlnet_canny_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1328, 1328)) + +controlnet_inpaint_image = Image.open("./data/example_image_dataset/canny/image_2.jpg").convert("RGB").resize((1328, 1328)) +# generate a centered square mask +inpaint_mask = Image.new("L", controlnet_inpaint_image.size, 0) +mask_size = 512 +left = (controlnet_inpaint_image.width - mask_size) // 2 +top = (controlnet_inpaint_image.height - mask_size) // 2 +right = left + mask_size +bottom = top + mask_size +inpaint_mask.paste(255, (left, top, right, bottom)) +inpaint_mask = inpaint_mask.resize((1328, 1328)).convert("RGB") + +image = pipe( + prompt, seed=0, + input_image=controlnet_inpaint_image, inpaint_mask=inpaint_mask, + blockwise_controlnet_inputs=[ + ControlNetInput(image=controlnet_inpaint_image, inpaint_mask=inpaint_mask, controlnet_id=0), + ControlNetInput(image=controlnet_canny_image, controlnet_id=1), + ], + num_inference_steps=40, +) +image.save("image.jpg") diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-ICEdit.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-ICEdit.py new file mode 100644 index 0000000..8a90f6e --- /dev/null +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-ICEdit.py @@ -0,0 +1,58 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from modelscope import snapshot_download +from PIL import Image +import torch + +# Load models +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image-Edit-2511", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), +) + +lora = ModelConfig( + model_id="DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA", + origin_file_pattern="model.safetensors" +) +pipe.load_lora(pipe.dit, lora) + +# Load images +snapshot_download( + "DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA", + local_dir="./data", + allow_file_pattern="assets/*" +) +edit_image = [ + Image.open("data/assets/image1_original.png"), + Image.open("data/assets/image1_edit_1.png"), + Image.open("data/assets/image2_original.png") +] +prompt = "Edit image 3 based on the transformation from image 1 to image 2." +negative_prompt = "泛黄,AI感,不真实,丑陋,油腻的皮肤,异常的肢体,不协调的肢体" + +# Generate +image_4 = pipe( + prompt=prompt, negative_prompt=negative_prompt, + edit_image=edit_image, + seed=1, + num_inference_steps=50, + height=1280, + width=720, + zero_cond_t=True, +) +image_4.save("image.png") \ No newline at end of file diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py new file mode 100644 index 0000000..cbe43a2 --- /dev/null +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py @@ -0,0 +1,63 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, FlowMatchScheduler +from modelscope import dataset_snapshot_download +from PIL import Image +import torch + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image-Edit-2511", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), +) + +lora = ModelConfig( + model_id="lightx2v/Qwen-Image-Edit-2511-Lightning", + origin_file_pattern="Qwen-Image-Edit-2511-Lightning-4steps-V1.0-bf16.safetensors" +) +pipe.load_lora(pipe.dit, lora, alpha=8/64) +pipe.scheduler = FlowMatchScheduler("Qwen-Image-Lightning") + + +dataset_snapshot_download( + "DiffSynth-Studio/example_image_dataset", + allow_file_pattern="qwen_image_edit/*", + local_dir="data/example_image_dataset", +) + +prompt = "生成这两个人的合影" +edit_image = [ + Image.open("data/example_image_dataset/qwen_image_edit/image1.jpg"), + Image.open("data/example_image_dataset/qwen_image_edit/image2.jpg"), +] +image = pipe( + prompt, + edit_image=edit_image, + seed=1, + num_inference_steps=4, + height=1152, + width=896, + edit_image_auto_resize=True, + zero_cond_t=True, # This is a special parameter introduced by Qwen-Image-Edit-2511 + cfg_scale=1.0, +) +image.save("image.jpg") + +# Qwen-Image-Edit-2511 is a multi-image editing model. +# Please use a list to input `edit_image`, even if the input contains only one image. +# edit_image = [Image.open("image.jpg")] +# Please do not input the image directly. +# edit_image = Image.open("image.jpg") diff --git a/examples/qwen_image/model_training/full/accelerate_config_zero3.yaml b/examples/qwen_image/model_training/full/accelerate_config_zero3.yaml new file mode 100644 index 0000000..e6a8d27 --- /dev/null +++ b/examples/qwen_image/model_training/full/accelerate_config_zero3.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/qwen_image/model_training/special/npu_training/Qwen-Image-Edit-2509-NPU.sh b/examples/qwen_image/model_training/special/npu_training/Qwen-Image-Edit-2509-NPU.sh new file mode 100644 index 0000000..02de9e9 --- /dev/null +++ b/examples/qwen_image/model_training/special/npu_training/Qwen-Image-Edit-2509-NPU.sh @@ -0,0 +1,20 @@ +# This script was tested using zero3 and on 8*910B(NPU) +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export CPU_AFFINITY_CONF=1 + +accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero3.yaml examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \ + --data_file_keys "image,edit_image" \ + --extra_inputs "edit_image" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2509:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-Edit-2509_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing \ + --find_unused_parameters \ + --initialize_model_on_cpu diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 8f38d04..ecb4239 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -101,6 +101,7 @@ def qwen_image_parser(): parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.") parser.add_argument("--zero_cond_t", default=False, action="store_true", help="A special parameter introduced by Qwen-Image-Edit-2511. Please enable it for this model.") + parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.") return parser @@ -151,7 +152,7 @@ if __name__ == "__main__": fp8_models=args.fp8_models, offload_models=args.offload_models, task=args.task, - device=accelerator.device, + device="cpu" if args.initialize_model_on_cpu else accelerator.device, zero_cond_t=args.zero_cond_t, ) model_logger = ModelLogger( diff --git a/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh b/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh index 10c4a5a..9eec871 100644 --- a/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh +++ b/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh @@ -7,6 +7,7 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate --num_frames 81 \ --dataset_repeat 100 \ --model_id_with_origin_paths "Wan-AI/Wan2.2-S2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/model.safetensors,Wan-AI/Wan2.2-S2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-S2V-14B:Wan2.1_VAE.pth" \ + --audio_processor_path "Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/" \ --learning_rate 1e-5 \ --num_epochs 1 \ --trainable_models "dit" \ diff --git a/examples/wanvideo/model_training/full/accelerate_config_zero3.yaml b/examples/wanvideo/model_training/full/accelerate_config_zero3.yaml new file mode 100644 index 0000000..e6a8d27 --- /dev/null +++ b/examples/wanvideo/model_training/full/accelerate_config_zero3.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh b/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh index 510796b..ec5bb87 100644 --- a/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh +++ b/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh @@ -7,6 +7,7 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate --num_frames 81 \ --dataset_repeat 100 \ --model_id_with_origin_paths "Wan-AI/Wan2.2-S2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/model.safetensors,Wan-AI/Wan2.2-S2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-S2V-14B:Wan2.1_VAE.pth" \ + --audio_processor_path "Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/" \ --learning_rate 1e-4 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.dit." \ diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index 4973438..d4074a6 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -33,7 +33,7 @@ class WanTrainingModule(DiffusionTrainingModule): # Load models model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) tokenizer_config = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/") if tokenizer_path is None else ModelConfig(tokenizer_path) - audio_processor_config = ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/") if audio_processor_path is None else ModelConfig(audio_processor_path) + audio_processor_config = self.parse_path_or_model_id(audio_processor_path) self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, audio_processor_config=audio_processor_config) self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) diff --git a/examples/z_image/model_inference/Z-Image-i2L.py b/examples/z_image/model_inference/Z-Image-i2L.py new file mode 100644 index 0000000..82b7ace --- /dev/null +++ b/examples/z_image/model_inference/Z-Image-i2L.py @@ -0,0 +1,61 @@ +from diffsynth.pipelines.z_image import ( + ZImagePipeline, ModelConfig, + ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode +) +from modelscope import snapshot_download +from safetensors.torch import save_file +import torch +from PIL import Image + +# Use `vram_config` to enable LoRA hot-loading +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cuda", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} + +# Load models +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Z-Image-i2L", origin_file_pattern="model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) + +# Load images +snapshot_download( + model_id="DiffSynth-Studio/Z-Image-i2L", + allow_file_pattern="assets/style/*", + local_dir="data/Z-Image-i2L_style_input" +) +images = [Image.open(f"data/Z-Image-i2L_style_input/assets/style/1/{i}.jpg") for i in range(4)] + +# Image to LoRA +with torch.no_grad(): + embs = ZImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images) + lora = ZImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"] +save_file(lora, "lora.safetensors") + +# Generate images +prompt = "a cat" +negative_prompt = "泛黄,发绿,模糊,低分辨率,低质量图像,扭曲的肢体,诡异的外观,丑陋,AI感,噪点,网格感,JPEG压缩条纹,异常的肢体,水印,乱码,意义不明的字符" +image = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=0, cfg_scale=4, num_inference_steps=50, + positive_only_lora=lora, + sigma_shift=8 +) +image.save("image.jpg") diff --git a/examples/z_image/model_inference/Z-Image.py b/examples/z_image/model_inference/Z-Image.py new file mode 100644 index 0000000..6dca342 --- /dev/null +++ b/examples/z_image/model_inference/Z-Image.py @@ -0,0 +1,17 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) +prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights." +image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=50, cfg_scale=4) +image.save("image_Z-Image.jpg") diff --git a/examples/z_image/model_inference_low_vram/Z-Image-Omni-Base-i2L.py b/examples/z_image/model_inference_low_vram/Z-Image-Omni-Base-i2L.py index 7378ada..6da6960 100644 --- a/examples/z_image/model_inference_low_vram/Z-Image-Omni-Base-i2L.py +++ b/examples/z_image/model_inference_low_vram/Z-Image-Omni-Base-i2L.py @@ -33,6 +33,7 @@ pipe = ZImagePipeline.from_pretrained( ModelConfig(model_id="DiffSynth-Studio/Z-Image-Omni-Base-i2L", origin_file_pattern="model.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, ) # Load images diff --git a/examples/z_image/model_inference_low_vram/Z-Image-Omni-Base.py b/examples/z_image/model_inference_low_vram/Z-Image-Omni-Base.py index 0af1e53..b9fa293 100644 --- a/examples/z_image/model_inference_low_vram/Z-Image-Omni-Base.py +++ b/examples/z_image/model_inference_low_vram/Z-Image-Omni-Base.py @@ -22,6 +22,7 @@ pipe = ZImagePipeline.from_pretrained( ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, ) prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights." image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4) diff --git a/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py b/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py index cd4276f..61ea96f 100644 --- a/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py +++ b/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py @@ -24,6 +24,7 @@ pipe = ZImagePipeline.from_pretrained( ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, ) dataset_snapshot_download( diff --git a/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py b/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py index f325508..54811c0 100644 --- a/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py +++ b/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py @@ -24,6 +24,7 @@ pipe = ZImagePipeline.from_pretrained( ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, ) # Control diff --git a/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py b/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py index 6fe170f..0c81bd6 100644 --- a/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py +++ b/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py @@ -24,6 +24,7 @@ pipe = ZImagePipeline.from_pretrained( ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), ], tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, ) # Control diff --git a/examples/z_image/model_inference_low_vram/Z-Image-i2L.py b/examples/z_image/model_inference_low_vram/Z-Image-i2L.py new file mode 100644 index 0000000..a98537c --- /dev/null +++ b/examples/z_image/model_inference_low_vram/Z-Image-i2L.py @@ -0,0 +1,62 @@ +from diffsynth.pipelines.z_image import ( + ZImagePipeline, ModelConfig, + ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode +) +from modelscope import snapshot_download +from safetensors.torch import save_file +import torch +from PIL import Image + +# Use `vram_config` to enable LoRA hot-loading +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} + +# Load models +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Z-Image-i2L", origin_file_pattern="model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), + vram_limit=0, +) + +# Load images +snapshot_download( + model_id="DiffSynth-Studio/Z-Image-i2L", + allow_file_pattern="assets/style/*", + local_dir="data/Z-Image-i2L_style_input" +) +images = [Image.open(f"data/Z-Image-i2L_style_input/assets/style/1/{i}.jpg") for i in range(4)] + +# Image to LoRA +with torch.no_grad(): + embs = ZImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images) + lora = ZImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"] +save_file(lora, "lora.safetensors") + +# Generate images +prompt = "a cat" +negative_prompt = "泛黄,发绿,模糊,低分辨率,低质量图像,扭曲的肢体,诡异的外观,丑陋,AI感,噪点,网格感,JPEG压缩条纹,异常的肢体,水印,乱码,意义不明的字符" +image = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=0, cfg_scale=4, num_inference_steps=50, + positive_only_lora=lora, + sigma_shift=8 +) +image.save("image.jpg") diff --git a/examples/z_image/model_inference_low_vram/Z-Image.py b/examples/z_image/model_inference_low_vram/Z-Image.py new file mode 100644 index 0000000..5eee761 --- /dev/null +++ b/examples/z_image/model_inference_low_vram/Z-Image.py @@ -0,0 +1,27 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights." +image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=50, cfg_scale=4) +image.save("image_Z-Image.jpg") diff --git a/examples/z_image/model_training/full/Z-Image.sh b/examples/z_image/model_training/full/Z-Image.sh new file mode 100644 index 0000000..2136324 --- /dev/null +++ b/examples/z_image/model_training/full/Z-Image.sh @@ -0,0 +1,14 @@ +# This example is tested on 8*A100 +accelerate launch --config_file examples/z_image/model_training/full/accelerate_config.yaml examples/z_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "Tongyi-MAI/Z-Image:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Z-Image_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 diff --git a/examples/z_image/model_training/full/accelerate_config_zero3.yaml b/examples/z_image/model_training/full/accelerate_config_zero3.yaml new file mode 100644 index 0000000..e6a8d27 --- /dev/null +++ b/examples/z_image/model_training/full/accelerate_config_zero3.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/z_image/model_training/lora/Z-Image.sh b/examples/z_image/model_training/lora/Z-Image.sh new file mode 100644 index 0000000..b660eef --- /dev/null +++ b/examples/z_image/model_training/lora/Z-Image.sh @@ -0,0 +1,15 @@ +accelerate launch examples/z_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Tongyi-MAI/Z-Image:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Z-Image_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 diff --git a/examples/z_image/model_training/validate_full/Z-Image.py b/examples/z_image/model_training/validate_full/Z-Image.py new file mode 100644 index 0000000..b2a1d8e --- /dev/null +++ b/examples/z_image/model_training/validate_full/Z-Image.py @@ -0,0 +1,20 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +from diffsynth.core import load_state_dict +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) +state_dict = load_state_dict("./models/train/Z-Image_full/epoch-1.safetensors", torch_dtype=torch.bfloat16) +pipe.dit.load_state_dict(state_dict) +prompt = "a dog" +image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=50, cfg_scale=4) +image.save("image.jpg") diff --git a/examples/z_image/model_training/validate_lora/Z-Image.py b/examples/z_image/model_training/validate_lora/Z-Image.py new file mode 100644 index 0000000..d12356f --- /dev/null +++ b/examples/z_image/model_training/validate_lora/Z-Image.py @@ -0,0 +1,18 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "./models/train/Z-Image_lora/epoch-4.safetensors") +prompt = "a dog" +image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=50, cfg_scale=4) +image.save("image.jpg") diff --git a/pyproject.toml b/pyproject.toml index de82279..9a5075b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "diffsynth" -version = "2.0.3" +version = "2.0.4" description = "Enjoy the magic of Diffusion models!" authors = [{name = "ModelScope Team"}] license = {text = "Apache-2.0"}