support wan i2v training

This commit is contained in:
Artiprocher
2025-03-13 15:14:10 +08:00
parent 490d420d82
commit a25bd74d8b
3 changed files with 155 additions and 564 deletions

View File

@@ -10,6 +10,13 @@ cd DiffSynth-Studio
pip install -e .
```
Wan-Video supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority.
* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
* [Sage Attention](https://github.com/thu-ml/SageAttention)
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.)
## Inference
### Wan-Video-1.3B-T2V
@@ -44,13 +51,17 @@ https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
Wan-Video-14B-I2V adds the functionality of image-to-video based on Wan-Video-14B-T2V. The model size remains the same, therefore the speed and VRAM requirements are also consistent. See [`./wan_14b_image_to_video.py`](./wan_14b_image_to_video.py).
**In the sample code, we use the same settings as the T2V 14B model, with FP8 quantization enabled by default. However, we found that this model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.**
![Image](https://github.com/user-attachments/assets/adf8047f-7943-4aaa-a555-2b32dc415f39)
https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75
## Train
We support Wan-Video LoRA training and full training. Here is a tutorial.
We support Wan-Video LoRA training and full training. Here is a tutorial. This is an experimental feature. Below is a video sample generated from the character Keqing LoRA:
https://github.com/user-attachments/assets/9bd8e30b-97e8-44f9-bb6f-da004ba376a9
Step 1: Install additional packages
@@ -67,7 +78,7 @@ data/example_dataset/
├── metadata.csv
└── train
├── video_00001.mp4
└── video_00002.mp4
└── image_00002.jpg
```
`metadata.csv`:
@@ -75,9 +86,11 @@ data/example_dataset/
```
file_name,text
video_00001.mp4,"video description"
video_00001.mp4,"video description"
image_00002.jpg,"video description"
```
We support both images and videos. An image is treated as a single frame of video.
Step 3: Data process
```shell
@@ -119,8 +132,8 @@ CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \
--steps_per_epoch 500 \
--max_epochs 10 \
--learning_rate 1e-4 \
--lora_rank 4 \
--lora_alpha 4 \
--lora_rank 16 \
--lora_alpha 16 \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--accumulate_grad_batches 1 \
--use_gradient_checkpointing
@@ -142,48 +155,12 @@ CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \
--use_gradient_checkpointing
```
Step 4-1: I2V LoRA-training
```shell
# cache latents
CUDA_VISIBLE_DEVICES="0" python train_wan_i2v.py \
--task data_process \
--dataset_path data/fps24_V6 \
--output_path ./output \
--text_encoder_path "./models/Wan-AI/Wan2.1-I2V-14B-720P/models_t5_umt5-xxl-enc-bf16.pth" \
--vae_path "./models/Wan-AI/Wan2.1-I2V-14B-720P/Wan2.1_VAE.pth" \
--image_encoder_path "./models/Wan-AI/Wan2.1-I2V-14B-720P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
--tiled \
--num_frames 121 \
--height 309 \
--width 186
```
If you wish to train the 14B model, please separate the safetensor files with a comma. For example: `models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors`.
If you wish to train the image-to-video model, please add an extra parameter `--image_encoder_path "models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"`.
For LoRA training, the Wan-1.3B-T2V model requires 16G of VRAM for processing 81 frames at 480P, while the Wan-14B-T2V model requires 60G of VRAM for the same configuration. To further reduce VRAM requirements by 20%-30%, you can include the parameter `--use_gradient_checkpointing_offload`.
```shell
# run I2V training
CUDA_VISIBLE_DEVICES="0" python train_wan_i2v.py \
--task train \
--train_architecture lora \
--dataset_path data/kling_hips_fps24_V6 \
--output_path ./output \
--dit_path "[
\"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00001-of-00007.safetensors\",
\"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00002-of-00007.safetensors\",
\"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00003-of-00007.safetensors\",
\"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00004-of-00007.safetensors\",
\"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00005-of-00007.safetensors\",
\"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00006-of-00007.safetensors\",
\"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00007-of-00007.safetensors\"
]" \
--steps_per_epoch 500 \
--max_epochs 10 \
--learning_rate 1e-4 \
--lora_rank 4 \
--lora_alpha 4 \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--accumulate_grad_batches 1 \
--use_gradient_checkpointing
```
Step 5: Test
Test LoRA: