mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
wan-refactor
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
# Wan-Video
|
||||
# Wan 2.1
|
||||
|
||||
Wan-Video is a collection of video synthesis models open-sourced by Alibaba.
|
||||
[切换到中文](./README_zh.md)
|
||||
|
||||
Wan 2.1 is a collection of video synthesis models open-sourced by Alibaba.
|
||||
|
||||
Before using this model, please install DiffSynth-Studio from **source code**.
|
||||
|
||||
@@ -10,267 +12,378 @@ cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## Model Zoo
|
||||
| Model ID | Extra Parameters | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./model_inference/Wan2.1-T2V-1.3B.py)|[code](./model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](./model_inference/Wan2.1-T2V-14B.py)|[code](./model_training/full/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./model_training/lora/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](./model_inference/Wan2.1-I2V-14B-480P.py)|[code](./model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](./model_inference/Wan2.1-I2V-14B-720P.py)|[code](./model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](./model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](./model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](./model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|
||||
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](./model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](./model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](./model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](./model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|
||||
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](./model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](./model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](./model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-14B-InP.py)|[code](./model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](./model_inference/Wan2.1-Fun-14B-Control.py)|[code](./model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](./model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](./model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](./model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|
||||
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](./model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|
||||
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
||||
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
||||
|
||||
|Developer|Name|Link|Scripts|
|
||||
|-|-|-|-|
|
||||
|Wan Team|1.3B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)|[wan_1.3b_text_to_video.py](./wan_1.3b_text_to_video.py)|
|
||||
|Wan Team|14B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)|[wan_14b_text_to_video.py](./wan_14b_text_to_video.py)|
|
||||
|Wan Team|14B image-to-video 480P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)|
|
||||
|Wan Team|14B image-to-video 720P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)|
|
||||
|Wan Team|14B first-last-frame-to-video 720P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|[wan_14B_flf2v.py](./wan_14B_flf2v.py)|
|
||||
|DiffSynth-Studio Team|1.3B aesthetics LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1).|
|
||||
|DiffSynth-Studio Team|1.3B Highres-fix LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1).|
|
||||
|DiffSynth-Studio Team|1.3B ExVideo LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1).|
|
||||
|DiffSynth-Studio Team|1.3B Speed Control adapter|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|[wan_1.3b_motion_controller.py](./wan_1.3b_motion_controller.py)|
|
||||
|PAI Team|1.3B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)|
|
||||
|PAI Team|14B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)|
|
||||
|PAI Team|1.3B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|[wan_fun_control.py](./wan_fun_control.py)|
|
||||
|PAI Team|14B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|[wan_fun_control.py](./wan_fun_control.py)|
|
||||
|IIC Team|1.3B VACE|[Link](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|[wan_1.3b_vace.py](./wan_1.3b_vace.py)|
|
||||
|
||||
Base model features
|
||||
## Model Inference
|
||||
|
||||
||Text-to-video|Image-to-video|End frame|Control|Reference image|
|
||||
|-|-|-|-|-|-|
|
||||
|1.3B text-to-video|✅|||||
|
||||
|14B text-to-video|✅|||||
|
||||
|14B image-to-video 480P||✅||||
|
||||
|14B image-to-video 720P||✅||||
|
||||
|14B first-last-frame-to-video 720P||✅|✅|||
|
||||
|1.3B InP||✅|✅|||
|
||||
|14B InP||✅|✅|||
|
||||
|1.3B Control||||✅||
|
||||
|14B Control||||✅||
|
||||
|1.3B VACE||||✅|✅|
|
||||
The following sections will help you understand our functionalities and write inference code.
|
||||
|
||||
Adapter model compatibility
|
||||
<details>
|
||||
|
||||
||1.3B text-to-video|1.3B InP|1.3B VACE|
|
||||
|-|-|-|-|
|
||||
|1.3B aesthetics LoRA|✅||✅|
|
||||
|1.3B Highres-fix LoRA|✅||✅|
|
||||
|1.3B ExVideo LoRA|✅||✅|
|
||||
|1.3B Speed Control adapter|✅|✅|✅|
|
||||
<summary>Loading the Model</summary>
|
||||
|
||||
## VRAM Usage
|
||||
The model is loaded using `from_pretrained`:
|
||||
|
||||
* Fine-grained offload: We recommend that users adjust the `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py).
|
||||
```python
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
* FP8 Quantization: You only need to adjust the `torch_dtype` in the `ModelManager` (not the pipeline!).
|
||||
Here, `torch_dtype` and `device` specify the computation precision and device respectively. The `model_configs` can be used to configure model paths in various ways:
|
||||
|
||||
We present a detailed table here. The model (14B text-to-video) is tested on a single A100.
|
||||
* Downloading the model from [ModelScope](https://modelscope.cn/) and loading it. In this case, both `model_id` and `origin_file_pattern` need to be specified, for example:
|
||||
|
||||
|`torch_dtype`|`num_persistent_param_in_dit`|Speed|Required VRAM|Default Setting|
|
||||
|-|-|-|-|-|
|
||||
|torch.bfloat16|None (unlimited)|18.5s/it|48G||
|
||||
|torch.bfloat16|7*10**9 (7B)|20.8s/it|24G||
|
||||
|torch.bfloat16|0|23.4s/it|10G||
|
||||
|torch.float8_e4m3fn|None (unlimited)|18.3s/it|24G|yes|
|
||||
|torch.float8_e4m3fn|0|24.0s/it|10G||
|
||||
```python
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
```
|
||||
|
||||
**We found that 14B image-to-video model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.**
|
||||
* Loading the model from a local file path. In this case, the `path` parameter needs to be specified, for example:
|
||||
|
||||
## Efficient Attention Implementation
|
||||
```python
|
||||
ModelConfig(path="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors")
|
||||
```
|
||||
|
||||
DiffSynth-Studio supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority. However, we recommend to use the default torch SDPA.
|
||||
For models that are loaded from multiple files, simply use a list, for example:
|
||||
|
||||
* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
|
||||
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
|
||||
* [Sage Attention](https://github.com/thu-ml/SageAttention)
|
||||
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.)
|
||||
```python
|
||||
ModelConfig(path=[
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
|
||||
])
|
||||
```
|
||||
|
||||
## Acceleration
|
||||
The `from_pretrained` function also provides additional parameters to control the behavior during model loading:
|
||||
|
||||
We support multiple acceleration solutions:
|
||||
* [TeaCache](https://github.com/ali-vilab/TeaCache): See [wan_1.3b_text_to_video_accelerate.py](./wan_1.3b_text_to_video_accelerate.py).
|
||||
* `tokenizer_config`: Path to the tokenizer of the Wan model. Default value is `ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*")`.
|
||||
* `local_model_path`: Path where downloaded models are saved. Default value is `"./models"`.
|
||||
* `skip_download`: Whether to skip downloading models. Default value is `False`. When your network cannot access [ModelScope](https://modelscope.cn/), manually download the necessary files and set this to `True`.
|
||||
* `redirect_common_files`: Whether to redirect duplicate model files. Default value is `True`. Since the Wan series models include multiple base models, some modules like text encoder are shared across these models. To avoid redundant downloads, we redirect the model paths.
|
||||
* `use_usp`: Whether to enable Unified Sequence Parallel. Default value is `False`. Used for multi-GPU parallel inference.
|
||||
|
||||
* [Unified Sequence Parallel](https://github.com/xdit-project/xDiT): See [wan_14b_text_to_video_usp.py](./wan_14b_text_to_video_usp.py)
|
||||
</details>
|
||||
|
||||
```bash
|
||||
<details>
|
||||
|
||||
<summary>VRAM Management</summary>
|
||||
|
||||
DiffSynth-Studio provides fine-grained VRAM management for the Wan model, allowing it to run on devices with limited VRAM. You can enable offloading functionality via the following code, which moves parts of the model to system memory on devices with limited VRAM:
|
||||
|
||||
```python
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
```
|
||||
|
||||
FP8 quantization is also supported:
|
||||
|
||||
```python
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
```
|
||||
|
||||
Both FP8 quantization and offloading can be enabled simultaneously:
|
||||
|
||||
```python
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
```
|
||||
|
||||
FP8 quantization significantly reduces VRAM usage but does not accelerate computations. Some models may experience issues such as blurry, torn, or distorted outputs due to insufficient precision when using FP8 quantization. Use FP8 quantization with caution.
|
||||
|
||||
The `enable_vram_management` function provides the following parameters to control VRAM usage:
|
||||
|
||||
* `vram_limit`: VRAM usage limit (in GB). By default, it uses all available VRAM on the device. Note that this is not an absolute limit; if the specified VRAM is insufficient but more VRAM is actually available, inference will proceed using the minimum required VRAM.
|
||||
* `vram_buffer`: Size of the VRAM buffer (in GB). Default is 0.5GB. Since certain large neural network layers may consume more VRAM unpredictably during their execution phase, a VRAM buffer is necessary. Ideally, this should match the maximum VRAM consumed by any single layer in the model.
|
||||
* `num_persistent_param_in_dit`: Number of persistent parameters in DiT models. By default, there is no limit. We plan to remove this parameter in the future, so please avoid relying on it.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Inference Acceleration</summary>
|
||||
|
||||
Wan supports multiple acceleration techniques, including:
|
||||
|
||||
* **Efficient attention implementations**: If any of these attention implementations are installed in your Python environment, they will be automatically enabled in the following priority:
|
||||
* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
|
||||
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
|
||||
* [Sage Attention](https://github.com/thu-ml/SageAttention)
|
||||
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default setting; we recommend installing `torch>=2.5.0`)
|
||||
* **Unified Sequence Parallel**: Sequence parallelism based on [xDiT](https://github.com/xdit-project/xDiT). Please refer to [this example](./acceleration/unified_sequence_parallel.py), and run it using the command:
|
||||
|
||||
```shell
|
||||
pip install xfuser>=0.4.3
|
||||
torchrun --standalone --nproc_per_node=8 examples/wanvideo/wan_14b_text_to_video_usp.py
|
||||
torchrun --standalone --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py
|
||||
```
|
||||
|
||||
* Tensor Parallel: See [wan_14b_text_to_video_tensor_parallel.py](./wan_14b_text_to_video_tensor_parallel.py).
|
||||
* **TeaCache**: Acceleration technique [TeaCache](https://github.com/ali-vilab/TeaCache). Please refer to [this example](./acceleration/teacache.py).
|
||||
|
||||
## Gallery
|
||||
</details>
|
||||
|
||||
1.3B text-to-video.
|
||||
|
||||
https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
|
||||
<details>
|
||||
|
||||
Put sunglasses on the dog.
|
||||
<summary>Input Parameters</summary>
|
||||
|
||||
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
|
||||
The pipeline accepts the following input parameters during inference:
|
||||
|
||||
14B text-to-video.
|
||||
* `prompt`: Prompt describing the content to appear in the video.
|
||||
* `negative_prompt`: Negative prompt describing content that should not appear in the video. Default is `""`.
|
||||
* `input_image`: Input image, applicable for image-to-video models such as [`Wan-AI/Wan2.1-I2V-14B-480P`](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) and [`PAI/Wan2.1-Fun-1.3B-InP`](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP), as well as first-and-last-frame models like [`Wan-AI/Wan2.1-FLF2V-14B-720P`](Wan-AI/Wan2.1-FLF2V-14B-720P).
|
||||
* `end_image`: End frame, applicable for first-and-last-frame models such as [`Wan-AI/Wan2.1-FLF2V-14B-720P`](Wan-AI/Wan2.1-FLF2V-14B-720P).
|
||||
* `input_video`: Input video used for video-to-video generation. Applicable to any Wan series model and must be used together with `denoising_strength`.
|
||||
* `denoising_strength`: Denoising strength in range [0, 1]. A smaller value results in a video closer to `input_video`.
|
||||
* `control_video`: Control video, applicable to Wan models with control capabilities such as [`PAI/Wan2.1-Fun-1.3B-Control`](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control).
|
||||
* `reference_image`: Reference image, applicable to Wan models supporting reference images such as [`PAI/Wan2.1-Fun-V1.1-1.3B-Control`](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control).
|
||||
* `camera_control_direction`: Camera control direction, optional values are "Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown". Applicable to Camera-Control models, such as [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://www.modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera).
|
||||
* `camera_control_speed`: Camera control speed. Applicable to Camera-Control models, such as [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://www.modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera).
|
||||
* `camera_control_origin`: Origin coordinate of the camera control sequence. Please refer to the [original paper](https://arxiv.org/pdf/2404.02101) for proper configuration. Applicable to Camera-Control models, such as [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://www.modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera).
|
||||
* `vace_video`: Input video for VACE models, applicable to the VACE series such as [`iic/VACE-Wan2.1-1.3B-Preview`](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview).
|
||||
* `vace_video_mask`: Mask video for VACE models, applicable to the VACE series such as [`iic/VACE-Wan2.1-1.3B-Preview`](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview).
|
||||
* `vace_reference_image`: Reference image for VACE models, applicable to the VACE series such as [`iic/VACE-Wan2.1-1.3B-Preview`](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview).
|
||||
* `vace_scale`: Influence of the VACE model on the base model, default is 1. Higher values increase control strength but may lead to visual artifacts or breakdowns.
|
||||
* `seed`: Random seed. Default is `None`, meaning fully random.
|
||||
* `rand_device`: Device used to generate random Gaussian noise matrix. Default is `"cpu"`. When set to `"cuda"`, different GPUs may produce different generation results.
|
||||
* `height`: Frame height, default is 480. Must be a multiple of 16; if not, it will be rounded up.
|
||||
* `width`: Frame width, default is 832. Must be a multiple of 16; if not, it will be rounded up.
|
||||
* `num_frames`: Number of frames, default is 81. Must be a multiple of 4 plus 1; if not, it will be rounded up, minimum is 1.
|
||||
* `cfg_scale`: Classifier-free guidance scale, default is 5. Higher values increase adherence to the prompt but may cause visual artifacts.
|
||||
* `cfg_merge`: Whether to merge both sides of classifier-free guidance for unified inference. Default is `False`. This parameter currently only works for basic text-to-video and image-to-video models.
|
||||
* `num_inference_steps`: Number of inference steps, default is 50.
|
||||
* `sigma_shift`: Parameter from Rectified Flow theory, default is 5. Higher values make the model stay longer at the initial denoising stage. Increasing this may improve video quality but may also cause inconsistency between generated videos and training data due to deviation from training behavior.
|
||||
* `motion_bucket_id`: Motion intensity, range [0, 100], applicable to motion control modules such as [`DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1`](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1). Larger values indicate more intense motion.
|
||||
* `tiled`: Whether to enable tiled VAE inference, default is `False`. Setting to `True` significantly reduces VRAM usage during VAE encoding/decoding but introduces small errors and slightly increases inference time.
|
||||
* `tile_size`: Tile size during VAE encoding/decoding, default is (30, 52), only effective when `tiled=True`.
|
||||
* `tile_stride`: Stride of tiles during VAE encoding/decoding, default is (15, 26), only effective when `tiled=True`. Must be less than or equal to `tile_size`.
|
||||
* `sliding_window_size`: Sliding window size for DiT part. Experimental feature, effects are unstable.
|
||||
* `sliding_window_stride`: Sliding window stride for DiT part. Experimental feature, effects are unstable.
|
||||
* `tea_cache_l1_thresh`: Threshold for TeaCache. Larger values result in faster speed but lower quality. Note that after enabling TeaCache, the inference speed is not uniform, so the remaining time shown on the progress bar becomes inaccurate.
|
||||
* `tea_cache_model_id`: TeaCache parameter template, options include `"Wan2.1-T2V-1.3B"`, `"Wan2.1-T2V-14B"`, `"Wan2.1-I2V-14B-480P"`, `"Wan2.1-I2V-14B-720P"`.
|
||||
* `progress_bar_cmd`: Progress bar implementation, default is `tqdm.tqdm`. You can set it to `lambda x:x` to disable the progress bar.
|
||||
|
||||
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
|
||||
</details>
|
||||
|
||||
14B image-to-video.
|
||||
## Model Training
|
||||
|
||||
https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75
|
||||
Wan series models are trained using a unified script located at [`./model_training/train.py`](./model_training/train.py).
|
||||
|
||||
14B first-last-frame-to-video
|
||||
<details>
|
||||
|
||||
|First frame|Last frame|Video|
|
||||
|-|-|-|
|
||||
|||https://github.com/user-attachments/assets/2a6a2681-622c-4512-b852-5f22e73830b1|
|
||||
<summary>Script Parameters</summary>
|
||||
|
||||
## Train
|
||||
The script includes the following parameters:
|
||||
|
||||
We support Wan-Video LoRA training and full training. Here is a tutorial. This is an experimental feature. Below is a video sample generated from the character Keqing LoRA:
|
||||
* Dataset
|
||||
* `--dataset_base_path`: Base path of the dataset.
|
||||
* `--dataset_metadata_path`: Path to the metadata file of the dataset.
|
||||
* `--height`: Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.
|
||||
* `--width`: Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.
|
||||
* `--num_frames`: Number of frames per video. Frames are sampled from the video prefix.
|
||||
* `--data_file_keys`: Data file keys in the metadata. Comma-separated.
|
||||
* `--dataset_repeat`: Number of times to repeat the dataset per epoch.
|
||||
* Models
|
||||
* `--model_paths`: Paths to load models. In JSON format.
|
||||
* `--model_id_with_origin_paths`: Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.
|
||||
* Training
|
||||
* `--learning_rate`: Learning rate.
|
||||
* `--num_epochs`: Number of epochs.
|
||||
* `--output_path`: Output save path.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefix in ckpt.
|
||||
* Trainable Modules
|
||||
* `--trainable_models`: Models to train, e.g., dit, vae, text_encoder.
|
||||
* `--lora_base_model`: Which model LoRA is added to.
|
||||
* `--lora_target_modules`: Which layers LoRA is added to.
|
||||
* `--lora_rank`: Rank of LoRA.
|
||||
* Extra Inputs
|
||||
* `--extra_inputs`: Additional model inputs, comma-separated.
|
||||
* VRAM Management
|
||||
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
|
||||
|
||||
https://github.com/user-attachments/assets/9bd8e30b-97e8-44f9-bb6f-da004ba376a9
|
||||
Additionally, the training framework is built upon [`accelerate`](https://huggingface.co/docs/accelerate/index). Before starting training, run `accelerate config` to configure GPU-related parameters. For certain training scripts (e.g., full fine-tuning of 14B models), we provide recommended `accelerate` configuration files, which can be found in the corresponding training scripts.
|
||||
|
||||
Step 1: Install additional packages
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 1: Prepare the Dataset</summary>
|
||||
|
||||
The dataset consists of a series of files. We recommend organizing your dataset as follows:
|
||||
|
||||
```
|
||||
pip install peft lightning pandas
|
||||
```
|
||||
|
||||
Step 2: Prepare your dataset
|
||||
|
||||
You need to manage the training videos as follows:
|
||||
|
||||
```
|
||||
data/example_dataset/
|
||||
data/example_video_dataset/
|
||||
├── metadata.csv
|
||||
└── train
|
||||
├── video_00001.mp4
|
||||
└── image_00002.jpg
|
||||
├── video1.mp4
|
||||
└── video2.mp4
|
||||
```
|
||||
|
||||
`metadata.csv`:
|
||||
Here, `video1.mp4` and `video2.mp4` are training video files, and `metadata.csv` is the metadata list, for example:
|
||||
|
||||
```
|
||||
file_name,text
|
||||
video_00001.mp4,"video description"
|
||||
image_00002.jpg,"video description"
|
||||
video,prompt
|
||||
video1.mp4,"from sunset to night, a small town, light, house, river"
|
||||
video2.mp4,"a dog is running"
|
||||
```
|
||||
|
||||
We support both images and videos. An image is treated as a single frame of video.
|
||||
|
||||
Step 3: Data process
|
||||
We have prepared a sample video dataset to help you test. You can download it using the following command:
|
||||
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \
|
||||
--task data_process \
|
||||
--dataset_path data/example_dataset \
|
||||
--output_path ./models \
|
||||
--text_encoder_path "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth" \
|
||||
--vae_path "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth" \
|
||||
--tiled \
|
||||
--num_frames 81 \
|
||||
--height 480 \
|
||||
--width 832
|
||||
modelscope download --dataset DiffSynth-Studio/example_video_dataset README.md --local_dir ./data/example_video_dataset
|
||||
```
|
||||
|
||||
After that, some cached files will be stored in the dataset folder.
|
||||
The dataset supports mixed training of videos and images. Supported video formats include `"mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"`, and supported image formats include `"jpg", "jpeg", "png", "webp"`.
|
||||
|
||||
The resolution of videos can be controlled via script parameters `--height`, `--width`, and `--num_frames`. For each video, the first `num_frames` frames will be used for training; therefore, an error will occur if the video length is less than `num_frames`. Image files will be treated as single-frame videos. When both `--height` and `--width` are left empty, dynamic resolution will be enabled, meaning training will use the actual resolution of each video or image in the dataset.
|
||||
|
||||
**We strongly recommend using fixed-resolution training and avoiding mixing images and videos in the same dataset due to load balancing issues in multi-GPU training.**
|
||||
|
||||
When the model requires additional inputs, such as the `control_video` needed by control-capable models like [`PAI/Wan2.1-Fun-1.3B-Control`](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control), please add corresponding columns in the metadata file, for example:
|
||||
|
||||
```
|
||||
data/example_dataset/
|
||||
├── metadata.csv
|
||||
└── train
|
||||
├── video_00001.mp4
|
||||
├── video_00001.mp4.tensors.pth
|
||||
├── video_00002.mp4
|
||||
└── video_00002.mp4.tensors.pth
|
||||
video,prompt,control_video
|
||||
video1.mp4,"from sunset to night, a small town, light, house, river",video1_softedge.mp4
|
||||
```
|
||||
|
||||
Step 4: Train
|
||||
If additional inputs contain video or image files, their column names need to be specified in the `--data_file_keys` parameter. The default value of this parameter is `"image,video"`, meaning it parses columns named `image` and `video`. You can extend this list based on the additional input requirements, for example: `--data_file_keys "image,video,control_video"`, and also enable `--input_contains_control_video`.
|
||||
|
||||
LoRA training:
|
||||
</details>
|
||||
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \
|
||||
--task train \
|
||||
--train_architecture lora \
|
||||
--dataset_path data/example_dataset \
|
||||
--output_path ./models \
|
||||
--dit_path "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors" \
|
||||
--steps_per_epoch 500 \
|
||||
--max_epochs 10 \
|
||||
--learning_rate 1e-4 \
|
||||
--lora_rank 16 \
|
||||
--lora_alpha 16 \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--accumulate_grad_batches 1 \
|
||||
--use_gradient_checkpointing
|
||||
```
|
||||
|
||||
Full training:
|
||||
<details>
|
||||
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \
|
||||
--task train \
|
||||
--train_architecture full \
|
||||
--dataset_path data/example_dataset \
|
||||
--output_path ./models \
|
||||
--dit_path "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors" \
|
||||
--steps_per_epoch 500 \
|
||||
--max_epochs 10 \
|
||||
--learning_rate 1e-4 \
|
||||
--accumulate_grad_batches 1 \
|
||||
--use_gradient_checkpointing
|
||||
```
|
||||
<summary>Step 2: Load the Model</summary>
|
||||
|
||||
If you wish to train the 14B model, please separate the safetensor files with a comma. For example: `models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors`.
|
||||
|
||||
If you wish to train the image-to-video model, please add an extra parameter `--image_encoder_path "models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"`.
|
||||
|
||||
For LoRA training, the Wan-1.3B-T2V model requires 16G of VRAM for processing 81 frames at 480P, while the Wan-14B-T2V model requires 60G of VRAM for the same configuration. To further reduce VRAM requirements by 20%-30%, you can include the parameter `--use_gradient_checkpointing_offload`.
|
||||
|
||||
Step 5: Test
|
||||
|
||||
Test LoRA:
|
||||
Similar to the model loading logic during inference, you can configure the model to be loaded directly via its model ID. For instance, during inference we load the model using:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||
|
||||
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||
model_manager.load_models([
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
||||
])
|
||||
model_manager.load_lora("models/lightning_logs/version_1/checkpoints/epoch=0-step=500.ckpt", lora_alpha=1.0)
|
||||
pipe = WanVideoPipeline.from_model_manager(model_manager, device="cuda")
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||
|
||||
video = pipe(
|
||||
prompt="...",
|
||||
negative_prompt="...",
|
||||
num_inference_steps=50,
|
||||
seed=0, tiled=True
|
||||
)
|
||||
save_video(video, "video.mp4", fps=30, quality=5)
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
]
|
||||
```
|
||||
|
||||
Test fine-tuned base model:
|
||||
During training, simply use the following parameter to load the corresponding model:
|
||||
|
||||
```shell
|
||||
--model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth"
|
||||
```
|
||||
|
||||
If you want to load the model from local files, for example during inference:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||
|
||||
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||
model_manager.load_models([
|
||||
"models/lightning_logs/version_1/checkpoints/epoch=0-step=500.ckpt",
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
||||
])
|
||||
pipe = WanVideoPipeline.from_model_manager(model_manager, device="cuda")
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||
|
||||
video = pipe(
|
||||
prompt="...",
|
||||
negative_prompt="...",
|
||||
num_inference_steps=50,
|
||||
seed=0, tiled=True
|
||||
)
|
||||
save_video(video, "video.mp4", fps=30, quality=5)
|
||||
model_configs=[
|
||||
ModelConfig(path=[
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
|
||||
]),
|
||||
ModelConfig(path="models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(path="models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth"),
|
||||
]
|
||||
```
|
||||
|
||||
Then during training, set the parameter as:
|
||||
|
||||
```shell
|
||||
--model_paths '[
|
||||
[
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors"
|
||||
],
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth"
|
||||
]' \
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 3: Configure Trainable Modules</summary>
|
||||
|
||||
The training framework supports full fine-tuning of base models or LoRA-based training. Here are some examples:
|
||||
|
||||
* Full fine-tuning of the DiT module: `--trainable_models dit`
|
||||
* Training a LoRA model for the DiT module: `--lora_base_model dit --lora_target_modules "q,k,v,o,ffn.0,ffn.2" --lora_rank 32`
|
||||
* Training both a LoRA model for DiT and the Motion Controller (yes, you can train such advanced structures): `--trainable_models motion_controller --lora_base_model dit --lora_target_modules "q,k,v,o,ffn.0,ffn.2" --lora_rank 32`
|
||||
|
||||
Additionally, since multiple modules (text encoder, dit, vae) are loaded in the training script, you need to remove prefixes when saving model files. For example, when fully fine-tuning the DiT module or training a LoRA version of DiT, please set `--remove_prefix_in_ckpt pipe.dit.`
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 4: Launch the Training Process</summary>
|
||||
|
||||
We have prepared training commands for each model. Please refer to the table at the beginning of this document.
|
||||
|
||||
Note that full fine-tuning of the 14B model requires 8 GPUs, each with at least 80GB VRAM. During full fine-tuning of these 14B models, you must install `deepspeed` (`pip install deepspeed`). We have provided recommended [configuration files](./model_training/full/accelerate_config_14B.yaml), which will be loaded automatically in the corresponding training scripts. These scripts have been tested on 8*A100.
|
||||
|
||||
The default video resolution in the training script is `480*832*81`. Increasing the resolution may cause out-of-memory errors. To reduce VRAM usage, add the parameter `--use_gradient_checkpointing_offload`.
|
||||
|
||||
</details>
|
||||
|
||||
Reference in New Issue
Block a user