mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 00:38:11 +00:00
Compare commits
3 Commits
dpo-refine
...
dpo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1e8c201d3b | ||
|
|
d96709fb6a | ||
|
|
bf7b339efb |
@@ -95,9 +95,7 @@ image.save("image.jpg")
|
|||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
||||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||||
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
|
||||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
@@ -207,12 +205,10 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
|||||||
|
|
||||||
| Model ID | Extra Parameters | Inference | Full Training | Validate After Full Training | LoRA Training | Validate After LoRA Training |
|
| Model ID | Extra Parameters | Inference | Full Training | Validate After Full Training | LoRA Training | Validate After LoRA Training |
|
||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
|
||||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
||||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||||
|[Wan-AI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|
|
||||||
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
||||||
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
||||||
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
||||||
@@ -385,7 +381,7 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
|
|||||||
|
|
||||||
## Update History
|
## Update History
|
||||||
|
|
||||||
- **September 23, 2025** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) is released! This model is jointly developed and open-sourced by us and the Taobao Design Team. The model is built upon Qwen-Image, specifically designed for e-commerce poster scenarios, and supports precise partition layout control. Please refer to [our example code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py).
|
- **September 22, 2025**: We have supported Direct Preference Optimization (DPO) training for Qwen-Image. Please refer to the [example code](examples/qwen_image/model_training/lora/Qwen-Image-DPO.sh) for the training script.
|
||||||
|
|
||||||
- **September 9, 2025**: Our training framework now supports multiple training modes and has been adapted for Qwen-Image. In addition to the standard SFT training mode, Direct Distill is now also supported; please refer to [our example code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh). This feature is experimental, and we will continue to improve it to support comprehensive model training capabilities.
|
- **September 9, 2025**: Our training framework now supports multiple training modes and has been adapted for Qwen-Image. In addition to the standard SFT training mode, Direct Distill is now also supported; please refer to [our example code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh). This feature is experimental, and we will continue to improve it to support comprehensive model training capabilities.
|
||||||
|
|
||||||
|
|||||||
@@ -97,9 +97,7 @@ image.save("image.jpg")
|
|||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
||||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||||
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
|
||||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
@@ -207,12 +205,10 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
|||||||
|
|
||||||
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
|
||||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
||||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||||
|[Wan-AI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|
|
||||||
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
||||||
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
||||||
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
||||||
@@ -401,7 +397,7 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
|
|||||||
|
|
||||||
## 更新历史
|
## 更新历史
|
||||||
|
|
||||||
- **2025年9月23日** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) 发布!本模型由我们与淘天体验设计团队联合研发并开源。模型基于 Qwen-Image 构建,专为电商海报场景设计,支持精确的分区布局控制。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)。
|
- **2025年9月22日** 我们支持了 Qwen-Image 的直接偏好对齐 (DPO) 训练,训练脚本请参考[示例代码](examples/qwen_image/model_training/lora/Qwen-Image-DPO.sh)。
|
||||||
|
|
||||||
- **2025年9月9日** 我们的训练框架支持了多种训练模式,目前已适配 Qwen-Image,除标准 SFT 训练模式外,已支持 Direct Distill,请参考[我们的示例代码](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)。这项功能是实验性的,我们将会继续完善已支持更全面的模型训练功能。
|
- **2025年9月9日** 我们的训练框架支持了多种训练模式,目前已适配 Qwen-Image,除标准 SFT 训练模式外,已支持 Direct Distill,请参考[我们的示例代码](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)。这项功能是实验性的,我们将会继续完善已支持更全面的模型训练功能。
|
||||||
|
|
||||||
|
|||||||
@@ -63,7 +63,6 @@ from ..models.wan_video_vae import WanVideoVAE, WanVideoVAE38
|
|||||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||||
from ..models.wan_video_vace import VaceWanModel
|
from ..models.wan_video_vace import VaceWanModel
|
||||||
from ..models.wav2vec import WanS2VAudioEncoder
|
from ..models.wav2vec import WanS2VAudioEncoder
|
||||||
from ..models.wan_video_animate_adapter import WanAnimateAdapter
|
|
||||||
|
|
||||||
from ..models.step1x_connector import Qwen2Connector
|
from ..models.step1x_connector import Qwen2Connector
|
||||||
|
|
||||||
@@ -143,6 +142,7 @@ model_loader_configs = [
|
|||||||
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
(None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
|
(None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
|
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
|
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
|
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
@@ -176,7 +176,6 @@ model_loader_configs = [
|
|||||||
(None, "073bce9cf969e317e5662cd570c3e79c", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
|
(None, "073bce9cf969e317e5662cd570c3e79c", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
|
||||||
(None, "a9e54e480a628f0b956a688a81c33bab", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
|
(None, "a9e54e480a628f0b956a688a81c33bab", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
|
||||||
(None, "06be60f3a4526586d8431cd038a71486", ["wans2v_audio_encoder"], [WanS2VAudioEncoder], "civitai"),
|
(None, "06be60f3a4526586d8431cd038a71486", ["wans2v_audio_encoder"], [WanS2VAudioEncoder], "civitai"),
|
||||||
(None, "31fa352acb8a1b1d33cd8764273d80a2", ["wan_video_dit", "wan_video_animate_adapter"], [WanModel, WanAnimateAdapter], "civitai"),
|
|
||||||
]
|
]
|
||||||
huggingface_model_loader_configs = [
|
huggingface_model_loader_configs = [
|
||||||
# These configs are provided for detecting model type automatically.
|
# These configs are provided for detecting model type automatically.
|
||||||
|
|||||||
@@ -1,670 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
import math
|
|
||||||
from typing import Tuple, Optional, List
|
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
MEMORY_LAYOUT = {
|
|
||||||
"flash": (
|
|
||||||
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
|
|
||||||
lambda x: x,
|
|
||||||
),
|
|
||||||
"torch": (
|
|
||||||
lambda x: x.transpose(1, 2),
|
|
||||||
lambda x: x.transpose(1, 2),
|
|
||||||
),
|
|
||||||
"vanilla": (
|
|
||||||
lambda x: x.transpose(1, 2),
|
|
||||||
lambda x: x.transpose(1, 2),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
mode="torch",
|
|
||||||
drop_rate=0,
|
|
||||||
attn_mask=None,
|
|
||||||
causal=False,
|
|
||||||
max_seqlen_q=None,
|
|
||||||
batch_size=1,
|
|
||||||
):
|
|
||||||
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
|
||||||
|
|
||||||
if mode == "torch":
|
|
||||||
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
|
||||||
attn_mask = attn_mask.to(q.dtype)
|
|
||||||
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
|
|
||||||
|
|
||||||
x = post_attn_layout(x)
|
|
||||||
b, s, a, d = x.shape
|
|
||||||
out = x.reshape(b, s, -1)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class CausalConv1d(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.pad_mode = pad_mode
|
|
||||||
padding = (kernel_size - 1, 0) # T
|
|
||||||
self.time_causal_padding = padding
|
|
||||||
|
|
||||||
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
|
||||||
return self.conv(x)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FaceEncoder(nn.Module):
|
|
||||||
def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None):
|
|
||||||
factory_kwargs = {"dtype": dtype, "device": device}
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)
|
|
||||||
self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
|
||||||
self.act = nn.SiLU()
|
|
||||||
self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)
|
|
||||||
self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)
|
|
||||||
|
|
||||||
self.out_proj = nn.Linear(1024, hidden_dim)
|
|
||||||
self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
|
||||||
|
|
||||||
self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
|
||||||
|
|
||||||
self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
|
||||||
|
|
||||||
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
|
|
||||||
x = rearrange(x, "b t c -> b c t")
|
|
||||||
b, c, t = x.shape
|
|
||||||
|
|
||||||
x = self.conv1_local(x)
|
|
||||||
x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
|
|
||||||
|
|
||||||
x = self.norm1(x)
|
|
||||||
x = self.act(x)
|
|
||||||
x = rearrange(x, "b t c -> b c t")
|
|
||||||
x = self.conv2(x)
|
|
||||||
x = rearrange(x, "b c t -> b t c")
|
|
||||||
x = self.norm2(x)
|
|
||||||
x = self.act(x)
|
|
||||||
x = rearrange(x, "b t c -> b c t")
|
|
||||||
x = self.conv3(x)
|
|
||||||
x = rearrange(x, "b c t -> b t c")
|
|
||||||
x = self.norm3(x)
|
|
||||||
x = self.act(x)
|
|
||||||
x = self.out_proj(x)
|
|
||||||
x = rearrange(x, "(b n) t c -> b t n c", b=b)
|
|
||||||
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
|
|
||||||
x = torch.cat([x, padding], dim=-2)
|
|
||||||
x_local = x.clone()
|
|
||||||
|
|
||||||
return x_local
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
elementwise_affine=True,
|
|
||||||
eps: float = 1e-6,
|
|
||||||
device=None,
|
|
||||||
dtype=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the RMSNorm normalization layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dim (int): The dimension of the input tensor.
|
|
||||||
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
eps (float): A small value added to the denominator for numerical stability.
|
|
||||||
weight (nn.Parameter): Learnable scaling parameter.
|
|
||||||
|
|
||||||
"""
|
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
|
||||||
super().__init__()
|
|
||||||
self.eps = eps
|
|
||||||
if elementwise_affine:
|
|
||||||
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
|
||||||
|
|
||||||
def _norm(self, x):
|
|
||||||
"""
|
|
||||||
Apply the RMSNorm normalization to the input tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): The input tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The normalized tensor.
|
|
||||||
|
|
||||||
"""
|
|
||||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""
|
|
||||||
Forward pass through the RMSNorm layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): The input tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The output tensor after applying RMSNorm.
|
|
||||||
|
|
||||||
"""
|
|
||||||
output = self._norm(x.float()).type_as(x)
|
|
||||||
if hasattr(self, "weight"):
|
|
||||||
output = output * self.weight
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def get_norm_layer(norm_layer):
|
|
||||||
"""
|
|
||||||
Get the normalization layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
norm_layer (str): The type of normalization layer.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
norm_layer (nn.Module): The normalization layer.
|
|
||||||
"""
|
|
||||||
if norm_layer == "layer":
|
|
||||||
return nn.LayerNorm
|
|
||||||
elif norm_layer == "rms":
|
|
||||||
return RMSNorm
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
|
||||||
|
|
||||||
|
|
||||||
class FaceAdapter(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_dim: int,
|
|
||||||
heads_num: int,
|
|
||||||
qk_norm: bool = True,
|
|
||||||
qk_norm_type: str = "rms",
|
|
||||||
num_adapter_layers: int = 1,
|
|
||||||
dtype=None,
|
|
||||||
device=None,
|
|
||||||
):
|
|
||||||
|
|
||||||
factory_kwargs = {"dtype": dtype, "device": device}
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_size = hidden_dim
|
|
||||||
self.heads_num = heads_num
|
|
||||||
self.fuser_blocks = nn.ModuleList(
|
|
||||||
[
|
|
||||||
FaceBlock(
|
|
||||||
self.hidden_size,
|
|
||||||
self.heads_num,
|
|
||||||
qk_norm=qk_norm,
|
|
||||||
qk_norm_type=qk_norm_type,
|
|
||||||
**factory_kwargs,
|
|
||||||
)
|
|
||||||
for _ in range(num_adapter_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
motion_embed: torch.Tensor,
|
|
||||||
idx: int,
|
|
||||||
freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
|
|
||||||
freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FaceBlock(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
heads_num: int,
|
|
||||||
qk_norm: bool = True,
|
|
||||||
qk_norm_type: str = "rms",
|
|
||||||
qk_scale: float = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
):
|
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.deterministic = False
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.heads_num = heads_num
|
|
||||||
head_dim = hidden_size // heads_num
|
|
||||||
self.scale = qk_scale or head_dim**-0.5
|
|
||||||
|
|
||||||
self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
|
|
||||||
self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
|
|
||||||
|
|
||||||
self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
|
|
||||||
|
|
||||||
qk_norm_layer = get_norm_layer(qk_norm_type)
|
|
||||||
self.q_norm = (
|
|
||||||
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
|
||||||
)
|
|
||||||
self.k_norm = (
|
|
||||||
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
|
||||||
|
|
||||||
self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
motion_vec: torch.Tensor,
|
|
||||||
motion_mask: Optional[torch.Tensor] = None,
|
|
||||||
use_context_parallel=False,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
B, T, N, C = motion_vec.shape
|
|
||||||
T_comp = T
|
|
||||||
|
|
||||||
x_motion = self.pre_norm_motion(motion_vec)
|
|
||||||
x_feat = self.pre_norm_feat(x)
|
|
||||||
|
|
||||||
kv = self.linear1_kv(x_motion)
|
|
||||||
q = self.linear1_q(x_feat)
|
|
||||||
|
|
||||||
k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
|
|
||||||
q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
|
|
||||||
|
|
||||||
# Apply QK-Norm if needed.
|
|
||||||
q = self.q_norm(q).to(v)
|
|
||||||
k = self.k_norm(k).to(v)
|
|
||||||
|
|
||||||
k = rearrange(k, "B L N H D -> (B L) H N D")
|
|
||||||
v = rearrange(v, "B L N H D -> (B L) H N D")
|
|
||||||
|
|
||||||
q = rearrange(q, "B (L S) H D -> (B L) H S D", L=T_comp)
|
|
||||||
# Compute attention.
|
|
||||||
attn = F.scaled_dot_product_attention(q, k, v)
|
|
||||||
|
|
||||||
attn = rearrange(attn, "(B L) H S D -> B (L S) (H D)", L=T_comp)
|
|
||||||
|
|
||||||
output = self.linear2(attn)
|
|
||||||
|
|
||||||
if motion_mask is not None:
|
|
||||||
output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def custom_qr(input_tensor):
|
|
||||||
original_dtype = input_tensor.dtype
|
|
||||||
if original_dtype == torch.bfloat16:
|
|
||||||
q, r = torch.linalg.qr(input_tensor.to(torch.float32))
|
|
||||||
return q.to(original_dtype), r.to(original_dtype)
|
|
||||||
return torch.linalg.qr(input_tensor)
|
|
||||||
|
|
||||||
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
|
||||||
return F.leaky_relu(input + bias, negative_slope) * scale
|
|
||||||
|
|
||||||
|
|
||||||
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
|
|
||||||
_, minor, in_h, in_w = input.shape
|
|
||||||
kernel_h, kernel_w = kernel.shape
|
|
||||||
|
|
||||||
out = input.view(-1, minor, in_h, 1, in_w, 1)
|
|
||||||
out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
|
|
||||||
out = out.view(-1, minor, in_h * up_y, in_w * up_x)
|
|
||||||
|
|
||||||
out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
|
||||||
out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
|
|
||||||
max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
|
|
||||||
|
|
||||||
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
|
||||||
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
|
||||||
out = F.conv2d(out, w)
|
|
||||||
out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
|
||||||
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
|
|
||||||
return out[:, :, ::down_y, ::down_x]
|
|
||||||
|
|
||||||
|
|
||||||
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
|
||||||
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
|
|
||||||
|
|
||||||
|
|
||||||
def make_kernel(k):
|
|
||||||
k = torch.tensor(k, dtype=torch.float32)
|
|
||||||
if k.ndim == 1:
|
|
||||||
k = k[None, :] * k[:, None]
|
|
||||||
k /= k.sum()
|
|
||||||
return k
|
|
||||||
|
|
||||||
|
|
||||||
class FusedLeakyReLU(nn.Module):
|
|
||||||
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
|
|
||||||
super().__init__()
|
|
||||||
self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
|
|
||||||
self.negative_slope = negative_slope
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class Blur(nn.Module):
|
|
||||||
def __init__(self, kernel, pad, upsample_factor=1):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
kernel = make_kernel(kernel)
|
|
||||||
|
|
||||||
if upsample_factor > 1:
|
|
||||||
kernel = kernel * (upsample_factor ** 2)
|
|
||||||
|
|
||||||
self.register_buffer('kernel', kernel)
|
|
||||||
|
|
||||||
self.pad = pad
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
return upfirdn2d(input, self.kernel, pad=self.pad)
|
|
||||||
|
|
||||||
|
|
||||||
class ScaledLeakyReLU(nn.Module):
|
|
||||||
def __init__(self, negative_slope=0.2):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.negative_slope = negative_slope
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
return F.leaky_relu(input, negative_slope=self.negative_slope)
|
|
||||||
|
|
||||||
|
|
||||||
class EqualConv2d(nn.Module):
|
|
||||||
def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
|
|
||||||
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
|
||||||
|
|
||||||
self.stride = stride
|
|
||||||
self.padding = padding
|
|
||||||
|
|
||||||
if bias:
|
|
||||||
self.bias = nn.Parameter(torch.zeros(out_channel))
|
|
||||||
else:
|
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
|
|
||||||
return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return (
|
|
||||||
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
|
|
||||||
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class EqualLinear(nn.Module):
|
|
||||||
def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
|
||||||
|
|
||||||
if bias:
|
|
||||||
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
|
||||||
else:
|
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
self.activation = activation
|
|
||||||
|
|
||||||
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
|
||||||
self.lr_mul = lr_mul
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
|
|
||||||
if self.activation:
|
|
||||||
out = F.linear(input, self.weight * self.scale)
|
|
||||||
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
|
||||||
else:
|
|
||||||
out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
|
|
||||||
|
|
||||||
|
|
||||||
class ConvLayer(nn.Sequential):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channel,
|
|
||||||
out_channel,
|
|
||||||
kernel_size,
|
|
||||||
downsample=False,
|
|
||||||
blur_kernel=[1, 3, 3, 1],
|
|
||||||
bias=True,
|
|
||||||
activate=True,
|
|
||||||
):
|
|
||||||
layers = []
|
|
||||||
|
|
||||||
if downsample:
|
|
||||||
factor = 2
|
|
||||||
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
|
||||||
pad0 = (p + 1) // 2
|
|
||||||
pad1 = p // 2
|
|
||||||
|
|
||||||
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
|
||||||
|
|
||||||
stride = 2
|
|
||||||
self.padding = 0
|
|
||||||
|
|
||||||
else:
|
|
||||||
stride = 1
|
|
||||||
self.padding = kernel_size // 2
|
|
||||||
|
|
||||||
layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
|
|
||||||
bias=bias and not activate))
|
|
||||||
|
|
||||||
if activate:
|
|
||||||
if bias:
|
|
||||||
layers.append(FusedLeakyReLU(out_channel))
|
|
||||||
else:
|
|
||||||
layers.append(ScaledLeakyReLU(0.2))
|
|
||||||
|
|
||||||
super().__init__(*layers)
|
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(nn.Module):
|
|
||||||
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
|
||||||
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
|
||||||
|
|
||||||
self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
out = self.conv1(input)
|
|
||||||
out = self.conv2(out)
|
|
||||||
|
|
||||||
skip = self.skip(input)
|
|
||||||
out = (out + skip) / math.sqrt(2)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class EncoderApp(nn.Module):
|
|
||||||
def __init__(self, size, w_dim=512):
|
|
||||||
super(EncoderApp, self).__init__()
|
|
||||||
|
|
||||||
channels = {
|
|
||||||
4: 512,
|
|
||||||
8: 512,
|
|
||||||
16: 512,
|
|
||||||
32: 512,
|
|
||||||
64: 256,
|
|
||||||
128: 128,
|
|
||||||
256: 64,
|
|
||||||
512: 32,
|
|
||||||
1024: 16
|
|
||||||
}
|
|
||||||
|
|
||||||
self.w_dim = w_dim
|
|
||||||
log_size = int(math.log(size, 2))
|
|
||||||
|
|
||||||
self.convs = nn.ModuleList()
|
|
||||||
self.convs.append(ConvLayer(3, channels[size], 1))
|
|
||||||
|
|
||||||
in_channel = channels[size]
|
|
||||||
for i in range(log_size, 2, -1):
|
|
||||||
out_channel = channels[2 ** (i - 1)]
|
|
||||||
self.convs.append(ResBlock(in_channel, out_channel))
|
|
||||||
in_channel = out_channel
|
|
||||||
|
|
||||||
self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
|
|
||||||
res = []
|
|
||||||
h = x
|
|
||||||
for conv in self.convs:
|
|
||||||
h = conv(h)
|
|
||||||
res.append(h)
|
|
||||||
|
|
||||||
return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:]
|
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
|
||||||
def __init__(self, size, dim=512, dim_motion=20):
|
|
||||||
super(Encoder, self).__init__()
|
|
||||||
|
|
||||||
# appearance netmork
|
|
||||||
self.net_app = EncoderApp(size, dim)
|
|
||||||
|
|
||||||
# motion network
|
|
||||||
fc = [EqualLinear(dim, dim)]
|
|
||||||
for i in range(3):
|
|
||||||
fc.append(EqualLinear(dim, dim))
|
|
||||||
|
|
||||||
fc.append(EqualLinear(dim, dim_motion))
|
|
||||||
self.fc = nn.Sequential(*fc)
|
|
||||||
|
|
||||||
def enc_app(self, x):
|
|
||||||
h_source = self.net_app(x)
|
|
||||||
return h_source
|
|
||||||
|
|
||||||
def enc_motion(self, x):
|
|
||||||
h, _ = self.net_app(x)
|
|
||||||
h_motion = self.fc(h)
|
|
||||||
return h_motion
|
|
||||||
|
|
||||||
|
|
||||||
class Direction(nn.Module):
|
|
||||||
def __init__(self, motion_dim):
|
|
||||||
super(Direction, self).__init__()
|
|
||||||
self.weight = nn.Parameter(torch.randn(512, motion_dim))
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
|
|
||||||
weight = self.weight + 1e-8
|
|
||||||
Q, R = custom_qr(weight)
|
|
||||||
if input is None:
|
|
||||||
return Q
|
|
||||||
else:
|
|
||||||
input_diag = torch.diag_embed(input) # alpha, diagonal matrix
|
|
||||||
out = torch.matmul(input_diag, Q.T)
|
|
||||||
out = torch.sum(out, dim=1)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class Synthesis(nn.Module):
|
|
||||||
def __init__(self, motion_dim):
|
|
||||||
super(Synthesis, self).__init__()
|
|
||||||
self.direction = Direction(motion_dim)
|
|
||||||
|
|
||||||
|
|
||||||
class Generator(nn.Module):
|
|
||||||
def __init__(self, size, style_dim=512, motion_dim=20):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.enc = Encoder(size, style_dim, motion_dim)
|
|
||||||
self.dec = Synthesis(motion_dim)
|
|
||||||
|
|
||||||
def get_motion(self, img):
|
|
||||||
#motion_feat = self.enc.enc_motion(img)
|
|
||||||
motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
|
|
||||||
motion = self.dec.direction(motion_feat)
|
|
||||||
return motion
|
|
||||||
|
|
||||||
|
|
||||||
class WanAnimateAdapter(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.pose_patch_embedding = torch.nn.Conv3d(16, 5120, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
|
||||||
self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20)
|
|
||||||
self.face_adapter = FaceAdapter(heads_num=40, hidden_dim=5120, num_adapter_layers=40 // 5)
|
|
||||||
self.face_encoder = FaceEncoder(in_dim=512, hidden_dim=5120, num_heads=4)
|
|
||||||
|
|
||||||
def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):
|
|
||||||
pose_latents = self.pose_patch_embedding(pose_latents)
|
|
||||||
x[:, :, 1:] += pose_latents
|
|
||||||
|
|
||||||
b,c,T,h,w = face_pixel_values.shape
|
|
||||||
face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
|
|
||||||
|
|
||||||
encode_bs = 8
|
|
||||||
face_pixel_values_tmp = []
|
|
||||||
for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):
|
|
||||||
face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))
|
|
||||||
|
|
||||||
motion_vec = torch.cat(face_pixel_values_tmp)
|
|
||||||
|
|
||||||
motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
|
|
||||||
motion_vec = self.face_encoder(motion_vec)
|
|
||||||
|
|
||||||
B, L, H, C = motion_vec.shape
|
|
||||||
pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
|
|
||||||
motion_vec = torch.cat([pad_face, motion_vec], dim=1)
|
|
||||||
return x, motion_vec
|
|
||||||
|
|
||||||
def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None):
|
|
||||||
if block_idx % 5 == 0:
|
|
||||||
adapter_args = [x, motion_vec, motion_masks, False]
|
|
||||||
residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args)
|
|
||||||
x = residual_out + x
|
|
||||||
return x
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return WanAnimateAdapterStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
class WanAnimateAdapterStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
if name.startswith("pose_patch_embedding.") or name.startswith("face_adapter") or name.startswith("face_encoder") or name.startswith("motion_encoder"):
|
|
||||||
state_dict_[name] = param
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
@@ -342,7 +342,9 @@ class WanModel(torch.nn.Module):
|
|||||||
y_camera = self.control_adapter(control_camera_latents_input)
|
y_camera = self.control_adapter(control_camera_latents_input)
|
||||||
x = [u + v for u, v in zip(x, y_camera)]
|
x = [u + v for u, v in zip(x, y_camera)]
|
||||||
x = x[0].unsqueeze(0)
|
x = x[0].unsqueeze(0)
|
||||||
return x
|
grid_size = x.shape[2:]
|
||||||
|
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
||||||
|
return x, grid_size # x, grid_size: (f, h, w)
|
||||||
|
|
||||||
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
|
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
|
||||||
return rearrange(
|
return rearrange(
|
||||||
@@ -494,7 +496,6 @@ class WanModelStateDictConverter:
|
|||||||
|
|
||||||
def from_civitai(self, state_dict):
|
def from_civitai(self, state_dict):
|
||||||
state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
|
state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
|
||||||
state_dict = {name: param for name, param in state_dict.items() if name.split(".")[0] not in ["pose_patch_embedding", "face_adapter", "face_encoder", "motion_encoder"]}
|
|
||||||
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
|
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
|
||||||
config = {
|
config = {
|
||||||
"has_image_input": False,
|
"has_image_input": False,
|
||||||
@@ -551,6 +552,20 @@ class WanModelStateDictConverter:
|
|||||||
"num_layers": 30,
|
"num_layers": 30,
|
||||||
"eps": 1e-6
|
"eps": 1e-6
|
||||||
}
|
}
|
||||||
|
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
|
||||||
|
config = {
|
||||||
|
"has_image_input": True,
|
||||||
|
"patch_size": [1, 2, 2],
|
||||||
|
"in_dim": 36,
|
||||||
|
"dim": 5120,
|
||||||
|
"ffn_dim": 13824,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"text_dim": 4096,
|
||||||
|
"out_dim": 16,
|
||||||
|
"num_heads": 40,
|
||||||
|
"num_layers": 40,
|
||||||
|
"eps": 1e-6
|
||||||
|
}
|
||||||
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
|
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
|
||||||
# 1.3B PAI control
|
# 1.3B PAI control
|
||||||
config = {
|
config = {
|
||||||
|
|||||||
@@ -194,12 +194,11 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
enable_vram_management(self.vae, module_map=module_map, module_config=model_config)
|
enable_vram_management(self.vae, module_map=module_map, module_config=model_config)
|
||||||
|
|
||||||
|
|
||||||
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, auto_offload=True, enable_dit_fp8_computation=False):
|
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False):
|
||||||
self.vram_management_enabled = True
|
self.vram_management_enabled = True
|
||||||
if vram_limit is None and auto_offload:
|
if vram_limit is None:
|
||||||
vram_limit = self.get_vram()
|
vram_limit = self.get_vram()
|
||||||
if vram_limit is not None:
|
vram_limit = vram_limit - vram_buffer
|
||||||
vram_limit = vram_limit - vram_buffer
|
|
||||||
|
|
||||||
if self.text_encoder is not None:
|
if self.text_encoder is not None:
|
||||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionRotaryEmbedding
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionRotaryEmbedding
|
||||||
@@ -524,63 +523,37 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit):
|
|||||||
selected = hidden_states[bool_mask]
|
selected = hidden_states[bool_mask]
|
||||||
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
||||||
return split_result
|
return split_result
|
||||||
|
|
||||||
def calculate_dimensions(self, target_area, ratio):
|
|
||||||
import math
|
|
||||||
width = math.sqrt(target_area * ratio)
|
|
||||||
height = width / ratio
|
|
||||||
width = round(width / 32) * 32
|
|
||||||
height = round(height / 32) * 32
|
|
||||||
return width, height
|
|
||||||
|
|
||||||
def resize_image(self, image, target_area=384*384):
|
|
||||||
width, height = self.calculate_dimensions(target_area, image.size[0] / image.size[1])
|
|
||||||
return image.resize((width, height))
|
|
||||||
|
|
||||||
def encode_prompt(self, pipe: QwenImagePipeline, prompt):
|
|
||||||
template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
|
||||||
drop_idx = 34
|
|
||||||
txt = [template.format(e) for e in prompt]
|
|
||||||
model_inputs = pipe.tokenizer(txt, max_length=4096+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device)
|
|
||||||
if model_inputs.input_ids.shape[1] >= 1024:
|
|
||||||
print(f"Warning!!! QwenImage model was trained on prompts up to 512 tokens. Current prompt requires {model_inputs['input_ids'].shape[1] - drop_idx} tokens, which may lead to unpredictable behavior.")
|
|
||||||
hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, output_hidden_states=True,)[-1]
|
|
||||||
split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)
|
|
||||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
|
||||||
return split_hidden_states
|
|
||||||
|
|
||||||
def encode_prompt_edit(self, pipe: QwenImagePipeline, prompt, edit_image):
|
|
||||||
template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
|
||||||
drop_idx = 64
|
|
||||||
txt = [template.format(e) for e in prompt]
|
|
||||||
model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device)
|
|
||||||
hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1]
|
|
||||||
split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)
|
|
||||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
|
||||||
return split_hidden_states
|
|
||||||
|
|
||||||
def encode_prompt_edit_multi(self, pipe: QwenImagePipeline, prompt, edit_image):
|
|
||||||
template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
|
||||||
drop_idx = 64
|
|
||||||
img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
|
|
||||||
base_img_prompt = "".join([img_prompt_template.format(i + 1) for i in range(len(edit_image))])
|
|
||||||
txt = [template.format(base_img_prompt + e) for e in prompt]
|
|
||||||
edit_image = [self.resize_image(image) for image in edit_image]
|
|
||||||
model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device)
|
|
||||||
hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1]
|
|
||||||
split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)
|
|
||||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
|
||||||
return split_hidden_states
|
|
||||||
|
|
||||||
def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict:
|
def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict:
|
||||||
if pipe.text_encoder is not None:
|
if pipe.text_encoder is not None and prompt is not None:
|
||||||
prompt = [prompt]
|
prompt = [prompt]
|
||||||
|
# If edit_image is None, use the default template for Qwen-Image, otherwise use the template for Qwen-Image-Edit
|
||||||
if edit_image is None:
|
if edit_image is None:
|
||||||
split_hidden_states = self.encode_prompt(pipe, prompt)
|
template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
elif isinstance(edit_image, Image.Image):
|
drop_idx = 34
|
||||||
split_hidden_states = self.encode_prompt_edit(pipe, prompt, edit_image)
|
|
||||||
else:
|
else:
|
||||||
split_hidden_states = self.encode_prompt_edit_multi(pipe, prompt, edit_image)
|
template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
drop_idx = 64
|
||||||
|
txt = [template.format(e) for e in prompt]
|
||||||
|
|
||||||
|
# Qwen-Image-Edit model
|
||||||
|
if pipe.processor is not None:
|
||||||
|
model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device)
|
||||||
|
# Qwen-Image model
|
||||||
|
elif pipe.tokenizer is not None:
|
||||||
|
model_inputs = pipe.tokenizer(txt, max_length=4096+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device)
|
||||||
|
if model_inputs.input_ids.shape[1] >= 1024:
|
||||||
|
print(f"Warning!!! QwenImage model was trained on prompts up to 512 tokens. Current prompt requires {model_inputs['input_ids'].shape[1] - drop_idx} tokens, which may lead to unpredictable behavior.")
|
||||||
|
else:
|
||||||
|
assert False, "QwenImagePipeline requires either tokenizer or processor to be loaded."
|
||||||
|
|
||||||
|
if 'pixel_values' in model_inputs:
|
||||||
|
hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1]
|
||||||
|
else:
|
||||||
|
hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, output_hidden_states=True,)[-1]
|
||||||
|
|
||||||
|
split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)
|
||||||
|
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||||
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
||||||
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
||||||
prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
|
prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
|
||||||
@@ -738,23 +711,12 @@ class QwenImageUnit_EditImageEmbedder(PipelineUnit):
|
|||||||
def process(self, pipe: QwenImagePipeline, edit_image, tiled, tile_size, tile_stride, edit_image_auto_resize=False):
|
def process(self, pipe: QwenImagePipeline, edit_image, tiled, tile_size, tile_stride, edit_image_auto_resize=False):
|
||||||
if edit_image is None:
|
if edit_image is None:
|
||||||
return {}
|
return {}
|
||||||
|
resized_edit_image = self.edit_image_auto_resize(edit_image) if edit_image_auto_resize else edit_image
|
||||||
pipe.load_models_to_device(['vae'])
|
pipe.load_models_to_device(['vae'])
|
||||||
if isinstance(edit_image, Image.Image):
|
edit_image = pipe.preprocess_image(resized_edit_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||||
resized_edit_image = self.edit_image_auto_resize(edit_image) if edit_image_auto_resize else edit_image
|
edit_latents = pipe.vae.encode(edit_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
edit_image = pipe.preprocess_image(resized_edit_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
|
||||||
edit_latents = pipe.vae.encode(edit_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
else:
|
|
||||||
resized_edit_image, edit_latents = [], []
|
|
||||||
for image in edit_image:
|
|
||||||
if edit_image_auto_resize:
|
|
||||||
image = self.edit_image_auto_resize(image)
|
|
||||||
resized_edit_image.append(image)
|
|
||||||
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
|
||||||
latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
edit_latents.append(latents)
|
|
||||||
return {"edit_latents": edit_latents, "edit_image": resized_edit_image}
|
return {"edit_latents": edit_latents, "edit_image": resized_edit_image}
|
||||||
|
|
||||||
|
|
||||||
class QwenImageUnit_ContextImageEmbedder(PipelineUnit):
|
class QwenImageUnit_ContextImageEmbedder(PipelineUnit):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -807,10 +769,9 @@ def model_fn_qwen_image(
|
|||||||
context_image = rearrange(context_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=context_latents.shape[2]//2, W=context_latents.shape[3]//2, P=2, Q=2)
|
context_image = rearrange(context_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=context_latents.shape[2]//2, W=context_latents.shape[3]//2, P=2, Q=2)
|
||||||
image = torch.cat([image, context_image], dim=1)
|
image = torch.cat([image, context_image], dim=1)
|
||||||
if edit_latents is not None:
|
if edit_latents is not None:
|
||||||
edit_latents_list = edit_latents if isinstance(edit_latents, list) else [edit_latents]
|
img_shapes += [(edit_latents.shape[0], edit_latents.shape[2]//2, edit_latents.shape[3]//2)]
|
||||||
img_shapes += [(e.shape[0], e.shape[2]//2, e.shape[3]//2) for e in edit_latents_list]
|
edit_image = rearrange(edit_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=edit_latents.shape[2]//2, W=edit_latents.shape[3]//2, P=2, Q=2)
|
||||||
edit_image = [rearrange(e, "B C (H P) (W Q) -> B (H W) (C P Q)", H=e.shape[2]//2, W=e.shape[3]//2, P=2, Q=2) for e in edit_latents_list]
|
image = torch.cat([image, edit_image], dim=1)
|
||||||
image = torch.cat([image] + edit_image, dim=1)
|
|
||||||
|
|
||||||
image = dit.img_in(image)
|
image = dit.img_in(image)
|
||||||
conditioning = dit.time_text_embed(timestep, image.dtype)
|
conditioning = dit.time_text_embed(timestep, image.dtype)
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
|||||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||||
from ..models.wan_video_vace import VaceWanModel
|
from ..models.wan_video_vace import VaceWanModel
|
||||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||||
from ..models.wan_video_animate_adapter import WanAnimateAdapter
|
|
||||||
from ..schedulers.flow_match import FlowMatchScheduler
|
from ..schedulers.flow_match import FlowMatchScheduler
|
||||||
from ..prompters import WanPrompter
|
from ..prompters import WanPrompter
|
||||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
|
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
|
||||||
@@ -45,10 +44,8 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
self.vae: WanVideoVAE = None
|
self.vae: WanVideoVAE = None
|
||||||
self.motion_controller: WanMotionControllerModel = None
|
self.motion_controller: WanMotionControllerModel = None
|
||||||
self.vace: VaceWanModel = None
|
self.vace: VaceWanModel = None
|
||||||
self.vace2: VaceWanModel = None
|
self.in_iteration_models = ("dit", "motion_controller", "vace")
|
||||||
self.animate_adapter: WanAnimateAdapter = None
|
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace")
|
||||||
self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter")
|
|
||||||
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2", "animate_adapter")
|
|
||||||
self.unit_runner = PipelineUnitRunner()
|
self.unit_runner = PipelineUnitRunner()
|
||||||
self.units = [
|
self.units = [
|
||||||
WanVideoUnit_ShapeChecker(),
|
WanVideoUnit_ShapeChecker(),
|
||||||
@@ -64,10 +61,6 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
WanVideoUnit_FunCameraControl(),
|
WanVideoUnit_FunCameraControl(),
|
||||||
WanVideoUnit_SpeedControl(),
|
WanVideoUnit_SpeedControl(),
|
||||||
WanVideoUnit_VACE(),
|
WanVideoUnit_VACE(),
|
||||||
WanVideoPostUnit_AnimateVideoSplit(),
|
|
||||||
WanVideoPostUnit_AnimatePoseLatents(),
|
|
||||||
WanVideoPostUnit_AnimateFacePixelValues(),
|
|
||||||
WanVideoPostUnit_AnimateInpaint(),
|
|
||||||
WanVideoUnit_UnifiedSequenceParallel(),
|
WanVideoUnit_UnifiedSequenceParallel(),
|
||||||
WanVideoUnit_TeaCache(),
|
WanVideoUnit_TeaCache(),
|
||||||
WanVideoUnit_CfgMerger(),
|
WanVideoUnit_CfgMerger(),
|
||||||
@@ -76,34 +69,13 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
WanVideoPostUnit_S2V(),
|
WanVideoPostUnit_S2V(),
|
||||||
]
|
]
|
||||||
self.model_fn = model_fn_wan_video
|
self.model_fn = model_fn_wan_video
|
||||||
|
|
||||||
|
|
||||||
def load_lora(
|
def load_lora(self, module, path, alpha=1):
|
||||||
self,
|
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
||||||
module: torch.nn.Module,
|
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
|
||||||
lora_config: Union[ModelConfig, str] = None,
|
loader.load(module, lora, alpha=alpha)
|
||||||
alpha=1,
|
|
||||||
hotload=False,
|
|
||||||
state_dict=None,
|
|
||||||
):
|
|
||||||
if state_dict is None:
|
|
||||||
if isinstance(lora_config, str):
|
|
||||||
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
|
|
||||||
else:
|
|
||||||
lora_config.download_if_necessary()
|
|
||||||
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
|
||||||
else:
|
|
||||||
lora = state_dict
|
|
||||||
if hotload:
|
|
||||||
for name, module in module.named_modules():
|
|
||||||
if isinstance(module, AutoWrappedLinear):
|
|
||||||
lora_a_name = f'{name}.lora_A.default.weight'
|
|
||||||
lora_b_name = f'{name}.lora_B.default.weight'
|
|
||||||
if lora_a_name in lora and lora_b_name in lora:
|
|
||||||
module.lora_A_weights.append(lora[lora_a_name] * alpha)
|
|
||||||
module.lora_B_weights.append(lora[lora_b_name])
|
|
||||||
else:
|
|
||||||
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
|
||||||
loader.load(module, lora, alpha=alpha)
|
|
||||||
|
|
||||||
def training_loss(self, **inputs):
|
def training_loss(self, **inputs):
|
||||||
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps)
|
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps)
|
||||||
@@ -386,13 +358,8 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
pipe.vae = model_manager.fetch_model("wan_video_vae")
|
pipe.vae = model_manager.fetch_model("wan_video_vae")
|
||||||
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||||
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
||||||
vace = model_manager.fetch_model("wan_video_vace", index=2)
|
pipe.vace = model_manager.fetch_model("wan_video_vace")
|
||||||
if isinstance(vace, list):
|
|
||||||
pipe.vace, pipe.vace2 = vace
|
|
||||||
else:
|
|
||||||
pipe.vace = vace
|
|
||||||
pipe.audio_encoder = model_manager.fetch_model("wans2v_audio_encoder")
|
pipe.audio_encoder = model_manager.fetch_model("wans2v_audio_encoder")
|
||||||
pipe.animate_adapter = model_manager.fetch_model("wan_video_animate_adapter")
|
|
||||||
|
|
||||||
# Size division factor
|
# Size division factor
|
||||||
if pipe.vae is not None:
|
if pipe.vae is not None:
|
||||||
@@ -445,11 +412,6 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
vace_video_mask: Optional[Image.Image] = None,
|
vace_video_mask: Optional[Image.Image] = None,
|
||||||
vace_reference_image: Optional[Image.Image] = None,
|
vace_reference_image: Optional[Image.Image] = None,
|
||||||
vace_scale: Optional[float] = 1.0,
|
vace_scale: Optional[float] = 1.0,
|
||||||
# Animate
|
|
||||||
animate_pose_video: Optional[list[Image.Image]] = None,
|
|
||||||
animate_face_video: Optional[list[Image.Image]] = None,
|
|
||||||
animate_inpaint_video: Optional[list[Image.Image]] = None,
|
|
||||||
animate_mask_video: Optional[list[Image.Image]] = None,
|
|
||||||
# Randomness
|
# Randomness
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
rand_device: Optional[str] = "cpu",
|
rand_device: Optional[str] = "cpu",
|
||||||
@@ -507,7 +469,6 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||||
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
|
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
|
||||||
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video,
|
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video,
|
||||||
"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video,
|
|
||||||
}
|
}
|
||||||
for unit in self.units:
|
for unit in self.units:
|
||||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||||
@@ -520,7 +481,6 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2:
|
if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2:
|
||||||
self.load_models_to_device(self.in_iteration_models_2)
|
self.load_models_to_device(self.in_iteration_models_2)
|
||||||
models["dit"] = self.dit2
|
models["dit"] = self.dit2
|
||||||
models["vace"] = self.vace2
|
|
||||||
|
|
||||||
# Timestep
|
# Timestep
|
||||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
@@ -542,12 +502,8 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"]
|
inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"]
|
||||||
|
|
||||||
# VACE (TODO: remove it)
|
# VACE (TODO: remove it)
|
||||||
if vace_reference_image is not None or (animate_pose_video is not None and animate_face_video is not None):
|
if vace_reference_image is not None:
|
||||||
if vace_reference_image is not None and isinstance(vace_reference_image, list):
|
inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:]
|
||||||
f = len(vace_reference_image)
|
|
||||||
else:
|
|
||||||
f = 1
|
|
||||||
inputs_shared["latents"] = inputs_shared["latents"][:, :, f:]
|
|
||||||
# post-denoising, pre-decoding processing logic
|
# post-denoising, pre-decoding processing logic
|
||||||
for unit in self.post_units:
|
for unit in self.post_units:
|
||||||
inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||||
@@ -578,12 +534,11 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit):
|
|||||||
def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image):
|
def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image):
|
||||||
length = (num_frames - 1) // 4 + 1
|
length = (num_frames - 1) // 4 + 1
|
||||||
if vace_reference_image is not None:
|
if vace_reference_image is not None:
|
||||||
f = len(vace_reference_image) if isinstance(vace_reference_image, list) else 1
|
length += 1
|
||||||
length += f
|
|
||||||
shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor)
|
shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor)
|
||||||
noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device)
|
noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device)
|
||||||
if vace_reference_image is not None:
|
if vace_reference_image is not None:
|
||||||
noise = torch.concat((noise[:, :, -f:], noise[:, :, :-f]), dim=2)
|
noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2)
|
||||||
return {"noise": noise}
|
return {"noise": noise}
|
||||||
|
|
||||||
|
|
||||||
@@ -602,9 +557,7 @@ class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
|
|||||||
input_video = pipe.preprocess_video(input_video)
|
input_video = pipe.preprocess_video(input_video)
|
||||||
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
if vace_reference_image is not None:
|
if vace_reference_image is not None:
|
||||||
if not isinstance(vace_reference_image, list):
|
vace_reference_image = pipe.preprocess_video([vace_reference_image])
|
||||||
vace_reference_image = [vace_reference_image]
|
|
||||||
vace_reference_image = pipe.preprocess_video(vace_reference_image)
|
|
||||||
vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device)
|
vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
input_latents = torch.concat([vace_reference_latents, input_latents], dim=2)
|
input_latents = torch.concat([vace_reference_latents, input_latents], dim=2)
|
||||||
if pipe.scheduler.training:
|
if pipe.scheduler.training:
|
||||||
@@ -896,23 +849,11 @@ class WanVideoUnit_VACE(PipelineUnit):
|
|||||||
if vace_reference_image is None:
|
if vace_reference_image is None:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if not isinstance(vace_reference_image,list):
|
vace_reference_image = pipe.preprocess_video([vace_reference_image])
|
||||||
vace_reference_image = [vace_reference_image]
|
|
||||||
|
|
||||||
vace_reference_image = pipe.preprocess_video(vace_reference_image)
|
|
||||||
|
|
||||||
bs, c, f, h, w = vace_reference_image.shape
|
|
||||||
new_vace_ref_images = []
|
|
||||||
for j in range(f):
|
|
||||||
new_vace_ref_images.append(vace_reference_image[0, :, j:j+1])
|
|
||||||
vace_reference_image = new_vace_ref_images
|
|
||||||
|
|
||||||
vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
|
vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
|
||||||
vace_reference_latents = [u.unsqueeze(0) for u in vace_reference_latents]
|
vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2)
|
||||||
|
vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2)
|
||||||
vace_video_latents = torch.concat((*vace_reference_latents, vace_video_latents), dim=2)
|
|
||||||
vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :f]), vace_mask_latents), dim=2)
|
|
||||||
|
|
||||||
vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)
|
vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)
|
||||||
return {"vace_context": vace_context, "vace_scale": vace_scale}
|
return {"vace_context": vace_context, "vace_scale": vace_scale}
|
||||||
@@ -1062,95 +1003,6 @@ class WanVideoPostUnit_S2V(PipelineUnit):
|
|||||||
return {"latents": latents}
|
return {"latents": latents}
|
||||||
|
|
||||||
|
|
||||||
class WanVideoPostUnit_AnimateVideoSplit(PipelineUnit):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(input_params=("input_video", "animate_pose_video", "animate_face_video", "animate_inpaint_video", "animate_mask_video"))
|
|
||||||
|
|
||||||
def process(self, pipe: WanVideoPipeline, input_video, animate_pose_video, animate_face_video, animate_inpaint_video, animate_mask_video):
|
|
||||||
if input_video is None:
|
|
||||||
return {}
|
|
||||||
if animate_pose_video is not None:
|
|
||||||
animate_pose_video = animate_pose_video[:len(input_video) - 4]
|
|
||||||
if animate_face_video is not None:
|
|
||||||
animate_face_video = animate_face_video[:len(input_video) - 4]
|
|
||||||
if animate_inpaint_video is not None:
|
|
||||||
animate_inpaint_video = animate_inpaint_video[:len(input_video) - 4]
|
|
||||||
if animate_mask_video is not None:
|
|
||||||
animate_mask_video = animate_mask_video[:len(input_video) - 4]
|
|
||||||
return {"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video}
|
|
||||||
|
|
||||||
|
|
||||||
class WanVideoPostUnit_AnimatePoseLatents(PipelineUnit):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
input_params=("animate_pose_video", "tiled", "tile_size", "tile_stride"),
|
|
||||||
onload_model_names=("vae",)
|
|
||||||
)
|
|
||||||
|
|
||||||
def process(self, pipe: WanVideoPipeline, animate_pose_video, tiled, tile_size, tile_stride):
|
|
||||||
if animate_pose_video is None:
|
|
||||||
return {}
|
|
||||||
pipe.load_models_to_device(self.onload_model_names)
|
|
||||||
animate_pose_video = pipe.preprocess_video(animate_pose_video)
|
|
||||||
pose_latents = pipe.vae.encode(animate_pose_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
||||||
return {"pose_latents": pose_latents}
|
|
||||||
|
|
||||||
|
|
||||||
class WanVideoPostUnit_AnimateFacePixelValues(PipelineUnit):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(take_over=True)
|
|
||||||
|
|
||||||
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
|
||||||
if inputs_shared.get("animate_face_video", None) is None:
|
|
||||||
return inputs_shared, inputs_posi, inputs_nega
|
|
||||||
inputs_posi["face_pixel_values"] = pipe.preprocess_video(inputs_shared["animate_face_video"])
|
|
||||||
inputs_nega["face_pixel_values"] = torch.zeros_like(inputs_posi["face_pixel_values"]) - 1
|
|
||||||
return inputs_shared, inputs_posi, inputs_nega
|
|
||||||
|
|
||||||
|
|
||||||
class WanVideoPostUnit_AnimateInpaint(PipelineUnit):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
input_params=("animate_inpaint_video", "animate_mask_video", "input_image", "tiled", "tile_size", "tile_stride"),
|
|
||||||
onload_model_names=("vae",)
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"):
|
|
||||||
if mask_pixel_values is None:
|
|
||||||
msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
|
|
||||||
else:
|
|
||||||
msk = mask_pixel_values.clone()
|
|
||||||
msk[:, :mask_len] = 1
|
|
||||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
|
||||||
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
|
||||||
msk = msk.transpose(1, 2)[0]
|
|
||||||
return msk
|
|
||||||
|
|
||||||
def process(self, pipe: WanVideoPipeline, animate_inpaint_video, animate_mask_video, input_image, tiled, tile_size, tile_stride):
|
|
||||||
if animate_inpaint_video is None or animate_mask_video is None:
|
|
||||||
return {}
|
|
||||||
pipe.load_models_to_device(self.onload_model_names)
|
|
||||||
|
|
||||||
bg_pixel_values = pipe.preprocess_video(animate_inpaint_video)
|
|
||||||
y_reft = pipe.vae.encode(bg_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0].to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
||||||
_, lat_t, lat_h, lat_w = y_reft.shape
|
|
||||||
|
|
||||||
ref_pixel_values = pipe.preprocess_video([input_image])
|
|
||||||
ref_latents = pipe.vae.encode(ref_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
||||||
mask_ref = self.get_i2v_mask(1, lat_h, lat_w, 1, device=pipe.device)
|
|
||||||
y_ref = torch.concat([mask_ref, ref_latents[0]]).to(dtype=torch.bfloat16, device=pipe.device)
|
|
||||||
|
|
||||||
mask_pixel_values = 1 - pipe.preprocess_video(animate_mask_video, max_value=1, min_value=0)
|
|
||||||
mask_pixel_values = rearrange(mask_pixel_values, "b c t h w -> (b t) c h w")
|
|
||||||
mask_pixel_values = torch.nn.functional.interpolate(mask_pixel_values, size=(lat_h, lat_w), mode='nearest')
|
|
||||||
mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b t c h w", b=1)[:,:,0]
|
|
||||||
msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, 0, mask_pixel_values=mask_pixel_values, device=pipe.device)
|
|
||||||
|
|
||||||
y_reft = torch.concat([msk_reft, y_reft]).to(dtype=torch.bfloat16, device=pipe.device)
|
|
||||||
y = torch.concat([y_ref, y_reft], dim=1).unsqueeze(0)
|
|
||||||
return {"y": y}
|
|
||||||
|
|
||||||
|
|
||||||
class TeaCache:
|
class TeaCache:
|
||||||
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
||||||
self.num_inference_steps = num_inference_steps
|
self.num_inference_steps = num_inference_steps
|
||||||
@@ -1261,7 +1113,6 @@ def model_fn_wan_video(
|
|||||||
dit: WanModel,
|
dit: WanModel,
|
||||||
motion_controller: WanMotionControllerModel = None,
|
motion_controller: WanMotionControllerModel = None,
|
||||||
vace: VaceWanModel = None,
|
vace: VaceWanModel = None,
|
||||||
animate_adapter: WanAnimateAdapter = None,
|
|
||||||
latents: torch.Tensor = None,
|
latents: torch.Tensor = None,
|
||||||
timestep: torch.Tensor = None,
|
timestep: torch.Tensor = None,
|
||||||
context: torch.Tensor = None,
|
context: torch.Tensor = None,
|
||||||
@@ -1277,8 +1128,6 @@ def model_fn_wan_video(
|
|||||||
tea_cache: TeaCache = None,
|
tea_cache: TeaCache = None,
|
||||||
use_unified_sequence_parallel: bool = False,
|
use_unified_sequence_parallel: bool = False,
|
||||||
motion_bucket_id: Optional[torch.Tensor] = None,
|
motion_bucket_id: Optional[torch.Tensor] = None,
|
||||||
pose_latents=None,
|
|
||||||
face_pixel_values=None,
|
|
||||||
sliding_window_size: Optional[int] = None,
|
sliding_window_size: Optional[int] = None,
|
||||||
sliding_window_stride: Optional[int] = None,
|
sliding_window_stride: Optional[int] = None,
|
||||||
cfg_merge: bool = False,
|
cfg_merge: bool = False,
|
||||||
@@ -1369,17 +1218,9 @@ def model_fn_wan_video(
|
|||||||
if clip_feature is not None and dit.require_clip_embedding:
|
if clip_feature is not None and dit.require_clip_embedding:
|
||||||
clip_embdding = dit.img_emb(clip_feature)
|
clip_embdding = dit.img_emb(clip_feature)
|
||||||
context = torch.cat([clip_embdding, context], dim=1)
|
context = torch.cat([clip_embdding, context], dim=1)
|
||||||
|
|
||||||
# Camera control
|
# Add camera control
|
||||||
x = dit.patchify(x, control_camera_latents_input)
|
x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
|
||||||
|
|
||||||
# Animate
|
|
||||||
if pose_latents is not None and face_pixel_values is not None:
|
|
||||||
x, motion_vec = animate_adapter.after_patch_embedding(x, pose_latents, face_pixel_values)
|
|
||||||
|
|
||||||
# Patchify
|
|
||||||
f, h, w = x.shape[2:]
|
|
||||||
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
|
||||||
|
|
||||||
# Reference image
|
# Reference image
|
||||||
if reference_latents is not None:
|
if reference_latents is not None:
|
||||||
@@ -1402,11 +1243,7 @@ def model_fn_wan_video(
|
|||||||
tea_cache_update = False
|
tea_cache_update = False
|
||||||
|
|
||||||
if vace_context is not None:
|
if vace_context is not None:
|
||||||
vace_hints = vace(
|
vace_hints = vace(x, vace_context, context, t_mod, freqs)
|
||||||
x, vace_context, context, t_mod, freqs,
|
|
||||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
|
||||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload
|
|
||||||
)
|
|
||||||
|
|
||||||
# blocks
|
# blocks
|
||||||
if use_unified_sequence_parallel:
|
if use_unified_sequence_parallel:
|
||||||
@@ -1424,7 +1261,6 @@ def model_fn_wan_video(
|
|||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
for block_id, block in enumerate(dit.blocks):
|
for block_id, block in enumerate(dit.blocks):
|
||||||
# Block
|
|
||||||
if use_gradient_checkpointing_offload:
|
if use_gradient_checkpointing_offload:
|
||||||
with torch.autograd.graph.save_on_cpu():
|
with torch.autograd.graph.save_on_cpu():
|
||||||
x = torch.utils.checkpoint.checkpoint(
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
@@ -1440,18 +1276,12 @@ def model_fn_wan_video(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
x = block(x, context, t_mod, freqs)
|
x = block(x, context, t_mod, freqs)
|
||||||
|
|
||||||
# VACE
|
|
||||||
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
||||||
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
|
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
|
||||||
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||||
current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0)
|
current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0)
|
||||||
x = x + current_vace_hint * vace_scale
|
x = x + current_vace_hint * vace_scale
|
||||||
|
|
||||||
# Animate
|
|
||||||
if pose_latents is not None and face_pixel_values is not None:
|
|
||||||
x = animate_adapter.after_transformer_block(block_id, x, motion_vec)
|
|
||||||
if tea_cache is not None:
|
if tea_cache is not None:
|
||||||
tea_cache.store(x)
|
tea_cache.store(x)
|
||||||
|
|
||||||
|
|||||||
@@ -269,10 +269,7 @@ class UnifiedDataset(torch.utils.data.Dataset):
|
|||||||
return RouteByType(operator_map=[
|
return RouteByType(operator_map=[
|
||||||
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
|
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
|
||||||
(("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),
|
(("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),
|
||||||
(("gif",), LoadGIF(
|
(("gif",), LoadGIF(num_frames, time_division_factor, time_division_remainder) >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)),
|
||||||
num_frames, time_division_factor, time_division_remainder,
|
|
||||||
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
|
||||||
)),
|
|
||||||
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
|
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
|
||||||
num_frames, time_division_factor, time_division_remainder,
|
num_frames, time_division_factor, time_division_remainder,
|
||||||
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
||||||
@@ -316,7 +313,7 @@ class UnifiedDataset(torch.utils.data.Dataset):
|
|||||||
for key in self.data_file_keys:
|
for key in self.data_file_keys:
|
||||||
if key in data:
|
if key in data:
|
||||||
if key in self.special_operator_map:
|
if key in self.special_operator_map:
|
||||||
data[key] = self.special_operator_map[key](data[key])
|
data[key] = self.special_operator_map[key]
|
||||||
elif key in self.data_file_keys:
|
elif key in self.data_file_keys:
|
||||||
data[key] = self.main_data_operator(data[key])
|
data[key] = self.main_data_operator(data[key])
|
||||||
return data
|
return data
|
||||||
|
|||||||
@@ -396,6 +396,15 @@ class DiffusionTrainingModule(torch.nn.Module):
|
|||||||
param.data = param.to(upcast_dtype)
|
param.data = param.to(upcast_dtype)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def disable_all_lora_layers(self, model):
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if hasattr(module, 'enable_adapters'):
|
||||||
|
module.enable_adapters(False)
|
||||||
|
|
||||||
|
def enable_all_lora_layers(self, model):
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if hasattr(module, 'enable_adapters'):
|
||||||
|
module.enable_adapters(True)
|
||||||
|
|
||||||
def mapping_lora_state_dict(self, state_dict):
|
def mapping_lora_state_dict(self, state_dict):
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
@@ -475,64 +484,6 @@ class DiffusionTrainingModule(torch.nn.Module):
|
|||||||
if len(load_result[1]) > 0:
|
if len(load_result[1]) > 0:
|
||||||
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
||||||
setattr(pipe, lora_base_model, model)
|
setattr(pipe, lora_base_model, model)
|
||||||
|
|
||||||
def disable_all_lora_layers(self, model):
|
|
||||||
for name, module in model.named_modules():
|
|
||||||
if hasattr(module, 'enable_adapters'):
|
|
||||||
module.enable_adapters(False)
|
|
||||||
|
|
||||||
def enable_all_lora_layers(self, model):
|
|
||||||
for name, module in model.named_modules():
|
|
||||||
if hasattr(module, 'enable_adapters'):
|
|
||||||
module.enable_adapters(True)
|
|
||||||
|
|
||||||
|
|
||||||
class DPOLoss:
|
|
||||||
def __init__(self, beta=2500):
|
|
||||||
self.beta = beta
|
|
||||||
|
|
||||||
def sample_timestep(self, pipe):
|
|
||||||
timestep_id = torch.randint(0, pipe.scheduler.num_train_timesteps, (1,))
|
|
||||||
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
||||||
return timestep
|
|
||||||
|
|
||||||
def training_loss_minimum(self, pipe, noise, timestep, **inputs):
|
|
||||||
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
|
||||||
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
|
||||||
noise_pred = pipe.model_fn(**inputs, timestep=timestep)
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
|
||||||
loss = loss * pipe.scheduler.training_weight(timestep)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def loss(self, model, data):
|
|
||||||
# Loss DPO: -logσ(−β(diff_policy − diff_ref))
|
|
||||||
# Prepare inputs
|
|
||||||
win_data = {key: data[key] for key in ["prompt", "image"]}
|
|
||||||
lose_data = {"prompt": data["prompt"], "image": data["lose_image"]}
|
|
||||||
inputs_win = model.forward_preprocess(win_data)
|
|
||||||
inputs_lose = model.forward_preprocess(lose_data)
|
|
||||||
inputs_win.pop('noise')
|
|
||||||
inputs_lose.pop('noise')
|
|
||||||
models = {name: getattr(model.pipe, name) for name in model.pipe.in_iteration_models}
|
|
||||||
# sample timestep and noise
|
|
||||||
timestep = self.sample_timestep(model.pipe)
|
|
||||||
noise = torch.rand_like(inputs_win["latents"])
|
|
||||||
# compute diff_policy = loss_win - loss_lose
|
|
||||||
loss_win = self.training_loss_minimum(model.pipe, noise, timestep, **models, **inputs_win)
|
|
||||||
loss_lose = self.training_loss_minimum(model.pipe, noise, timestep, **models, **inputs_lose)
|
|
||||||
diff_policy = loss_win - loss_lose
|
|
||||||
# compute diff_ref
|
|
||||||
# TODO: may support full model training
|
|
||||||
model.disable_all_lora_layers(model.pipe.dit)
|
|
||||||
# load the original model weights
|
|
||||||
with torch.no_grad():
|
|
||||||
loss_win_ref = self.training_loss_minimum(model.pipe, noise, timestep, **models, **inputs_win)
|
|
||||||
loss_lose_ref = self.training_loss_minimum(model.pipe, noise, timestep, **models, **inputs_lose)
|
|
||||||
diff_ref = loss_win_ref - loss_lose_ref
|
|
||||||
model.enable_all_lora_layers(model.pipe.dit)
|
|
||||||
# compute loss
|
|
||||||
loss = -1. * torch.nn.functional.logsigmoid(self.beta * (diff_ref - diff_policy)).mean()
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
class ModelLogger:
|
class ModelLogger:
|
||||||
@@ -612,9 +563,9 @@ def launch_training_task(
|
|||||||
with accelerator.accumulate(model):
|
with accelerator.accumulate(model):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
if dataset.load_from_cache:
|
if dataset.load_from_cache:
|
||||||
loss = model({}, inputs=data)
|
loss = model({}, inputs=data, accelerator=accelerator)
|
||||||
else:
|
else:
|
||||||
loss = model(data)
|
loss = model(data, accelerator=accelerator)
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
model_logger.on_step_end(accelerator, model, save_steps)
|
model_logger.on_step_end(accelerator, model, save_steps)
|
||||||
@@ -748,4 +699,5 @@ def qwen_image_parser():
|
|||||||
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
|
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
|
||||||
parser.add_argument("--enable_fp8_training", default=False, action="store_true", help="Whether to enable FP8 training. Only available for LoRA training on a single GPU.")
|
parser.add_argument("--enable_fp8_training", default=False, action="store_true", help="Whether to enable FP8 training. Only available for LoRA training on a single GPU.")
|
||||||
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
|
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
|
||||||
|
parser.add_argument("--beta_dpo", type=float, default=1000, help="hyperparameter beta for DPO loss, only used when task is dpo.")
|
||||||
return parser
|
return parser
|
||||||
|
|||||||
@@ -153,6 +153,19 @@ class BasePipeline(torch.nn.Module):
|
|||||||
latents_next = scheduler.step(noise_pred, timestep, latents)
|
latents_next = scheduler.step(noise_pred, timestep, latents)
|
||||||
return latents_next
|
return latents_next
|
||||||
|
|
||||||
|
def sample_timestep(self):
|
||||||
|
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
|
||||||
|
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
return timestep
|
||||||
|
|
||||||
|
def training_loss_minimum(self, noise, timestep, **inputs):
|
||||||
|
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||||
|
training_target = self.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||||
|
noise_pred = self.model_fn(**inputs, timestep=timestep)
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||||
|
loss = loss * self.scheduler.training_weight(timestep)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -116,7 +116,7 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
|||||||
self,
|
self,
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
bias: torch.Tensor = None,
|
bias: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
device = input.device
|
device = input.device
|
||||||
origin_dtype = input.dtype
|
origin_dtype = input.dtype
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class FluxTrainingModule(DiffusionTrainingModule):
|
|||||||
return {**inputs_shared, **inputs_posi}
|
return {**inputs_shared, **inputs_posi}
|
||||||
|
|
||||||
|
|
||||||
def forward(self, data, inputs=None):
|
def forward(self, data, inputs=None, **kwargs):
|
||||||
if inputs is None: inputs = self.forward_preprocess(data)
|
if inputs is None: inputs = self.forward_preprocess(data)
|
||||||
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||||
loss = self.pipe.training_loss(**models, **inputs)
|
loss = self.pipe.training_loss(**models, **inputs)
|
||||||
|
|||||||
@@ -47,12 +47,10 @@ image.save("image.jpg")
|
|||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./model_inference/Qwen-Image.py)|[code](./model_inference_low_vram/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)|
|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./model_inference/Qwen-Image.py)|[code](./model_inference_low_vram/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)|
|
||||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](./model_inference/Qwen-Image-Edit.py)|[code](./model_inference_low_vram/Qwen-Image-Edit.py)|[code](./model_training/full/Qwen-Image-Edit.sh)|[code](./model_training/validate_full/Qwen-Image-Edit.py)|[code](./model_training/lora/Qwen-Image-Edit.sh)|[code](./model_training/validate_lora/Qwen-Image-Edit.py)|
|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](./model_inference/Qwen-Image-Edit.py)|[code](./model_inference_low_vram/Qwen-Image-Edit.py)|[code](./model_training/full/Qwen-Image-Edit.sh)|[code](./model_training/validate_full/Qwen-Image-Edit.py)|[code](./model_training/lora/Qwen-Image-Edit.sh)|[code](./model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||||
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](./model_inference/Qwen-Image-Edit-2509.py)|[code](./model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](./model_training/full/Qwen-Image-Edit-2509.sh)|[code](./model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](./model_training/lora/Qwen-Image-Edit-2509.sh)|[code](./model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
|
||||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./model_inference/Qwen-Image-Distill-LoRA.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](./model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./model_inference/Qwen-Image-Distill-LoRA.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](./model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./model_inference/Qwen-Image-EliGen.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./model_inference/Qwen-Image-EliGen.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](./model_inference/Qwen-Image-EliGen-V2.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](./model_inference/Qwen-Image-EliGen-V2.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](./model_inference/Qwen-Image-EliGen-Poster.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
|
||||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
||||||
|
|||||||
@@ -47,12 +47,10 @@ image.save("image.jpg")
|
|||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./model_inference/Qwen-Image.py)|[code](./model_inference_low_vram/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)|
|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./model_inference/Qwen-Image.py)|[code](./model_inference_low_vram/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)|
|
||||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](./model_inference/Qwen-Image-Edit.py)|[code](./model_inference_low_vram/Qwen-Image-Edit.py)|[code](./model_training/full/Qwen-Image-Edit.sh)|[code](./model_training/validate_full/Qwen-Image-Edit.py)|[code](./model_training/lora/Qwen-Image-Edit.sh)|[code](./model_training/validate_lora/Qwen-Image-Edit.py)|
|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](./model_inference/Qwen-Image-Edit.py)|[code](./model_inference_low_vram/Qwen-Image-Edit.py)|[code](./model_training/full/Qwen-Image-Edit.sh)|[code](./model_training/validate_full/Qwen-Image-Edit.py)|[code](./model_training/lora/Qwen-Image-Edit.sh)|[code](./model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||||
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](./model_inference/Qwen-Image-Edit-2509.py)|[code](./model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](./model_training/full/Qwen-Image-Edit-2509.sh)|[code](./model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](./model_training/lora/Qwen-Image-Edit-2509.sh)|[code](./model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
|
||||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./model_inference/Qwen-Image-Distill-LoRA.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](./model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./model_inference/Qwen-Image-Distill-LoRA.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](./model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./model_inference/Qwen-Image-EliGen.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./model_inference/Qwen-Image-EliGen.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](./model_inference/Qwen-Image-EliGen-V2.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](./model_inference/Qwen-Image-EliGen-V2.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](./model_inference/Qwen-Image-EliGen-Poster.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
|
||||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2509", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
|
||||||
)
|
|
||||||
|
|
||||||
image_1 = pipe(prompt="一位少女", seed=0, num_inference_steps=40, height=1328, width=1024)
|
|
||||||
image_1.save("image1.jpg")
|
|
||||||
|
|
||||||
image_2 = pipe(prompt="一位老人", seed=0, num_inference_steps=40, height=1328, width=1024)
|
|
||||||
image_2.save("image2.jpg")
|
|
||||||
|
|
||||||
prompt = "生成这两个人的合影"
|
|
||||||
edit_image = [Image.open("image1.jpg"), Image.open("image2.jpg")]
|
|
||||||
image_3 = pipe(prompt, edit_image=edit_image, seed=1, num_inference_steps=40, height=1328, width=1024, edit_image_auto_resize=True)
|
|
||||||
image_3.save("image3.jpg")
|
|
||||||
@@ -10,6 +10,7 @@ pipe = QwenImagePipeline.from_pretrained(
|
|||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
],
|
],
|
||||||
|
tokenizer_config=None,
|
||||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
||||||
)
|
)
|
||||||
snapshot_download("DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix", local_dir="models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix", allow_file_pattern="model.safetensors")
|
snapshot_download("DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix", local_dir="models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix", allow_file_pattern="model.safetensors")
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ pipe = QwenImagePipeline.from_pretrained(
|
|||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
],
|
],
|
||||||
|
tokenizer_config=None,
|
||||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
||||||
)
|
)
|
||||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||||
|
|||||||
@@ -1,114 +0,0 @@
|
|||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
|
||||||
import torch
|
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
|
||||||
from modelscope import dataset_snapshot_download, snapshot_download
|
|
||||||
import random
|
|
||||||
|
|
||||||
|
|
||||||
def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):
|
|
||||||
# Create a blank image for overlays
|
|
||||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
|
||||||
|
|
||||||
colors = [
|
|
||||||
(165, 238, 173, 80),
|
|
||||||
(76, 102, 221, 80),
|
|
||||||
(221, 160, 77, 80),
|
|
||||||
(204, 93, 71, 80),
|
|
||||||
(145, 187, 149, 80),
|
|
||||||
(134, 141, 172, 80),
|
|
||||||
(157, 137, 109, 80),
|
|
||||||
(153, 104, 95, 80),
|
|
||||||
(165, 238, 173, 80),
|
|
||||||
(76, 102, 221, 80),
|
|
||||||
(221, 160, 77, 80),
|
|
||||||
(204, 93, 71, 80),
|
|
||||||
(145, 187, 149, 80),
|
|
||||||
(134, 141, 172, 80),
|
|
||||||
(157, 137, 109, 80),
|
|
||||||
(153, 104, 95, 80),
|
|
||||||
]
|
|
||||||
# Generate random colors for each mask
|
|
||||||
if use_random_colors:
|
|
||||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
|
||||||
|
|
||||||
# Font settings
|
|
||||||
try:
|
|
||||||
font = ImageFont.truetype("wqy-zenhei.ttc", font_size) # Adjust as needed
|
|
||||||
except IOError:
|
|
||||||
font = ImageFont.load_default(font_size)
|
|
||||||
|
|
||||||
# Overlay each mask onto the overlay image
|
|
||||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
|
||||||
# Convert mask to RGBA mode
|
|
||||||
mask_rgba = mask.convert('RGBA')
|
|
||||||
mask_data = mask_rgba.getdata()
|
|
||||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
|
||||||
mask_rgba.putdata(new_data)
|
|
||||||
|
|
||||||
# Draw the mask prompt text on the mask
|
|
||||||
draw = ImageDraw.Draw(mask_rgba)
|
|
||||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
|
||||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
|
||||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
|
||||||
|
|
||||||
# Alpha composite the overlay with this mask
|
|
||||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
|
||||||
|
|
||||||
# Composite the overlay onto the original image
|
|
||||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
|
||||||
|
|
||||||
# Save or display the resulting image
|
|
||||||
result.save(output_path)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def example(pipe, seeds, example_id, global_prompt, entity_prompts, height=784, width=1280):
|
|
||||||
dataset_snapshot_download(
|
|
||||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
|
||||||
local_dir="./",
|
|
||||||
allow_file_pattern=f"data/examples/eligen/poster/example_{example_id}/*.png"
|
|
||||||
)
|
|
||||||
masks = [
|
|
||||||
Image.open(f"./data/examples/eligen/poster/example_{example_id}/{i}.png").convert('RGB').resize((width, height))
|
|
||||||
for i in range(len(entity_prompts))
|
|
||||||
]
|
|
||||||
negative_prompt = "网格化,规则的网格,模糊, 低分辨率, 低质量, 变形, 畸形, 错误的解剖学, 变形的手, 变形的身体, 变形的脸, 变形的头发, 变形的眼睛, 变形的嘴巴"
|
|
||||||
for seed in seeds:
|
|
||||||
# generate image
|
|
||||||
image = pipe(
|
|
||||||
prompt=global_prompt,
|
|
||||||
cfg_scale=4.0,
|
|
||||||
negative_prompt=negative_prompt,
|
|
||||||
num_inference_steps=40,
|
|
||||||
seed=seed,
|
|
||||||
height=height,
|
|
||||||
width=width,
|
|
||||||
eligen_entity_prompts=entity_prompts,
|
|
||||||
eligen_entity_masks=masks,
|
|
||||||
)
|
|
||||||
image.save(f"eligen_poster_example_{example_id}_{seed}.png")
|
|
||||||
image = Image.new("RGB", (width, height), (0, 0, 0))
|
|
||||||
visualize_masks(image, masks, entity_prompts, f"eligen_poster_example_{example_id}_mask_{seed}.png")
|
|
||||||
|
|
||||||
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
snapshot_download(
|
|
||||||
"DiffSynth-Studio/Qwen-Image-EliGen-Poster",
|
|
||||||
local_dir="models/DiffSynth-Studio/Qwen-Image-EliGen-Poster",
|
|
||||||
allow_file_pattern="model.safetensors",
|
|
||||||
)
|
|
||||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-EliGen-Poster/model.safetensors")
|
|
||||||
global_prompt = "一张以柔粉紫为背景的海报,左侧有大号粉紫色文字“Qwen-Image EliGen-Poster”,粉紫色椭圆框内白色小字:“图像精确分区控制模型”。右侧有一只小兔子在拆礼物,旁边站着一只头顶迷你烟花发射器的小龙(卡通Q版)。背景有一些白云点缀。整体风格卡通可爱,传达节日惊喜的主题。"
|
|
||||||
entity_prompts = ["粉紫色文字“Qwen-Image EliGen-Poster”", "粉紫色椭圆框内白色小字:“图像精确分区控制模型”", "一只小兔子在拆礼物,小兔子旁边站着一只头顶迷你烟花发射器的小龙(卡通Q版)"]
|
|
||||||
seed = [42]
|
|
||||||
example(pipe, seed, 1, global_prompt, entity_prompts)
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2509", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
|
||||||
],
|
|
||||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
|
||||||
)
|
|
||||||
pipe.enable_vram_management()
|
|
||||||
|
|
||||||
image_1 = pipe(prompt="一位少女", seed=0, num_inference_steps=40, height=1328, width=1024)
|
|
||||||
image_1.save("image1.jpg")
|
|
||||||
|
|
||||||
image_2 = pipe(prompt="一位老人", seed=0, num_inference_steps=40, height=1328, width=1024)
|
|
||||||
image_2.save("image2.jpg")
|
|
||||||
|
|
||||||
prompt = "生成这两个人的合影"
|
|
||||||
edit_image = [Image.open("image1.jpg"), Image.open("image2.jpg")]
|
|
||||||
image_3 = pipe(prompt, edit_image=edit_image, seed=1, num_inference_steps=40, height=1328, width=1024, edit_image_auto_resize=True)
|
|
||||||
image_3.save("image3.jpg")
|
|
||||||
@@ -10,6 +10,7 @@ pipe = QwenImagePipeline.from_pretrained(
|
|||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||||
],
|
],
|
||||||
|
tokenizer_config=None,
|
||||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
||||||
)
|
)
|
||||||
pipe.enable_vram_management()
|
pipe.enable_vram_management()
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ pipe = QwenImagePipeline.from_pretrained(
|
|||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||||
],
|
],
|
||||||
|
tokenizer_config=None,
|
||||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
||||||
)
|
)
|
||||||
pipe.enable_vram_management()
|
pipe.enable_vram_management()
|
||||||
|
|||||||
@@ -1,115 +0,0 @@
|
|||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
|
||||||
import torch
|
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
|
||||||
from modelscope import dataset_snapshot_download, snapshot_download
|
|
||||||
import random
|
|
||||||
|
|
||||||
|
|
||||||
def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):
|
|
||||||
# Create a blank image for overlays
|
|
||||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
|
||||||
|
|
||||||
colors = [
|
|
||||||
(165, 238, 173, 80),
|
|
||||||
(76, 102, 221, 80),
|
|
||||||
(221, 160, 77, 80),
|
|
||||||
(204, 93, 71, 80),
|
|
||||||
(145, 187, 149, 80),
|
|
||||||
(134, 141, 172, 80),
|
|
||||||
(157, 137, 109, 80),
|
|
||||||
(153, 104, 95, 80),
|
|
||||||
(165, 238, 173, 80),
|
|
||||||
(76, 102, 221, 80),
|
|
||||||
(221, 160, 77, 80),
|
|
||||||
(204, 93, 71, 80),
|
|
||||||
(145, 187, 149, 80),
|
|
||||||
(134, 141, 172, 80),
|
|
||||||
(157, 137, 109, 80),
|
|
||||||
(153, 104, 95, 80),
|
|
||||||
]
|
|
||||||
# Generate random colors for each mask
|
|
||||||
if use_random_colors:
|
|
||||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
|
||||||
|
|
||||||
# Font settings
|
|
||||||
try:
|
|
||||||
font = ImageFont.truetype("wqy-zenhei.ttc", font_size) # Adjust as needed
|
|
||||||
except IOError:
|
|
||||||
font = ImageFont.load_default(font_size)
|
|
||||||
|
|
||||||
# Overlay each mask onto the overlay image
|
|
||||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
|
||||||
# Convert mask to RGBA mode
|
|
||||||
mask_rgba = mask.convert('RGBA')
|
|
||||||
mask_data = mask_rgba.getdata()
|
|
||||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
|
||||||
mask_rgba.putdata(new_data)
|
|
||||||
|
|
||||||
# Draw the mask prompt text on the mask
|
|
||||||
draw = ImageDraw.Draw(mask_rgba)
|
|
||||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
|
||||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
|
||||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
|
||||||
|
|
||||||
# Alpha composite the overlay with this mask
|
|
||||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
|
||||||
|
|
||||||
# Composite the overlay onto the original image
|
|
||||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
|
||||||
|
|
||||||
# Save or display the resulting image
|
|
||||||
result.save(output_path)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def example(pipe, seeds, example_id, global_prompt, entity_prompts, height=784, width=1280):
|
|
||||||
dataset_snapshot_download(
|
|
||||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
|
||||||
local_dir="./",
|
|
||||||
allow_file_pattern=f"data/examples/eligen/poster/example_{example_id}/*.png"
|
|
||||||
)
|
|
||||||
masks = [
|
|
||||||
Image.open(f"./data/examples/eligen/poster/example_{example_id}/{i}.png").convert('RGB').resize((width, height))
|
|
||||||
for i in range(len(entity_prompts))
|
|
||||||
]
|
|
||||||
negative_prompt = "网格化,规则的网格,模糊, 低分辨率, 低质量, 变形, 畸形, 错误的解剖学, 变形的手, 变形的身体, 变形的脸, 变形的头发, 变形的眼睛, 变形的嘴巴"
|
|
||||||
for seed in seeds:
|
|
||||||
# generate image
|
|
||||||
image = pipe(
|
|
||||||
prompt=global_prompt,
|
|
||||||
cfg_scale=4.0,
|
|
||||||
negative_prompt=negative_prompt,
|
|
||||||
num_inference_steps=40,
|
|
||||||
seed=seed,
|
|
||||||
height=height,
|
|
||||||
width=width,
|
|
||||||
eligen_entity_prompts=entity_prompts,
|
|
||||||
eligen_entity_masks=masks,
|
|
||||||
)
|
|
||||||
image.save(f"eligen_poster_example_{example_id}_{seed}.png")
|
|
||||||
image = Image.new("RGB", (width, height), (0, 0, 0))
|
|
||||||
visualize_masks(image, masks, entity_prompts, f"eligen_poster_example_{example_id}_mask_{seed}.png")
|
|
||||||
|
|
||||||
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
pipe.enable_vram_management()
|
|
||||||
snapshot_download(
|
|
||||||
"DiffSynth-Studio/Qwen-Image-EliGen-Poster",
|
|
||||||
local_dir="models/DiffSynth-Studio/Qwen-Image-EliGen-Poster",
|
|
||||||
allow_file_pattern="model.safetensors",
|
|
||||||
)
|
|
||||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-EliGen-Poster/model.safetensors")
|
|
||||||
global_prompt = "一张以柔粉紫为背景的海报,左侧有大号粉紫色文字“Qwen-Image EliGen-Poster”,粉紫色椭圆框内白色小字:“图像精确分区控制模型”。右侧有一只小兔子在拆礼物,旁边站着一只头顶迷你烟花发射器的小龙(卡通Q版)。背景有一些白云点缀。整体风格卡通可爱,传达节日惊喜的主题。"
|
|
||||||
entity_prompts = ["粉紫色文字“Qwen-Image EliGen-Poster”", "粉紫色椭圆框内白色小字:“图像精确分区控制模型”", "一只小兔子在拆礼物,小兔子旁边站着一只头顶迷你烟花发射器的小龙(卡通Q版)"]
|
|
||||||
seed = [42]
|
|
||||||
example(pipe, seed, 1, global_prompt, entity_prompts)
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \
|
|
||||||
--dataset_base_path data/example_image_dataset \
|
|
||||||
--dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
|
||||||
--data_file_keys "image,edit_image" \
|
|
||||||
--extra_inputs "edit_image" \
|
|
||||||
--max_pixels 1048576 \
|
|
||||||
--dataset_repeat 50 \
|
|
||||||
--model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2509:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
|
||||||
--learning_rate 1e-5 \
|
|
||||||
--num_epochs 2 \
|
|
||||||
--remove_prefix_in_ckpt "pipe.dit." \
|
|
||||||
--output_path "./models/train/Qwen-Image-Edit-2509_full" \
|
|
||||||
--trainable_models "dit" \
|
|
||||||
--use_gradient_checkpointing \
|
|
||||||
--find_unused_parameters
|
|
||||||
@@ -1,18 +1,25 @@
|
|||||||
|
# dataset format:
|
||||||
|
# {
|
||||||
|
# "image": "path/to/win_image.png", # win image
|
||||||
|
# "lose_image": "path/to/lose_image.png", # lose image
|
||||||
|
# "prompt": "a photo of ...",
|
||||||
|
# }
|
||||||
accelerate launch examples/qwen_image/model_training/train.py \
|
accelerate launch examples/qwen_image/model_training/train.py \
|
||||||
--dataset_base_path "data/example_image_dataset" \
|
--dataset_base_path data/example_image_dataset \
|
||||||
--dataset_metadata_path data/example_image_dataset/metadata_eligen.json \
|
--dataset_metadata_path data/example_image_dataset/dpo.jsonl \
|
||||||
--data_file_keys "image,eligen_entity_masks" \
|
--data_file_keys "image,lose_image" \
|
||||||
--max_pixels 1048576 \
|
--max_pixels 1048576 \
|
||||||
--dataset_repeat 50 \
|
--dataset_repeat 400 \
|
||||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||||
--learning_rate 1e-4 \
|
--learning_rate 1e-4 \
|
||||||
--num_epochs 5 \
|
--num_epochs 5 \
|
||||||
--remove_prefix_in_ckpt "pipe.dit." \
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
--output_path "./models/train/Qwen-Image-EliGen-Poster_lora" \
|
--output_path "./models/train/Qwen-Image_DPO_lora" \
|
||||||
--lora_base_model "dit" \
|
--lora_base_model "dit" \
|
||||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||||
--lora_rank 32 \
|
--lora_rank 32 \
|
||||||
--extra_inputs "eligen_entity_masks,eligen_entity_prompts" \
|
|
||||||
--use_gradient_checkpointing \
|
--use_gradient_checkpointing \
|
||||||
--find_unused_parameters \
|
--dataset_num_workers 8 \
|
||||||
--lora_checkpoint "models/DiffSynth-Studio/Qwen-Image-EliGen-V2/model.safetensors"
|
--task dpo \
|
||||||
|
--beta_dpo 2500 \
|
||||||
|
--find_unused_parameters
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
accelerate launch examples/qwen_image/model_training/train.py \
|
|
||||||
--dataset_base_path data/example_image_dataset \
|
|
||||||
--dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
|
||||||
--data_file_keys "image,edit_image" \
|
|
||||||
--extra_inputs "edit_image" \
|
|
||||||
--max_pixels 1048576 \
|
|
||||||
--dataset_repeat 50 \
|
|
||||||
--model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2509:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
|
||||||
--learning_rate 1e-4 \
|
|
||||||
--num_epochs 5 \
|
|
||||||
--remove_prefix_in_ckpt "pipe.dit." \
|
|
||||||
--output_path "./models/train/Qwen-Image-Edit-2509_lora" \
|
|
||||||
--lora_base_model "dit" \
|
|
||||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
|
||||||
--lora_rank 32 \
|
|
||||||
--use_gradient_checkpointing \
|
|
||||||
--dataset_num_workers 8 \
|
|
||||||
--find_unused_parameters
|
|
||||||
@@ -1,13 +1,11 @@
|
|||||||
import torch, os, json
|
import torch, os
|
||||||
from diffsynth import load_state_dict
|
|
||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||||
from diffsynth.pipelines.flux_image_new import ControlNetInput
|
from diffsynth.pipelines.flux_image_new import ControlNetInput
|
||||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task, DPOLoss
|
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task
|
||||||
from diffsynth.trainers.unified_dataset import UnifiedDataset
|
from diffsynth.trainers.unified_dataset import UnifiedDataset
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class QwenImageTrainingModule(DiffusionTrainingModule):
|
class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -20,6 +18,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
|||||||
extra_inputs=None,
|
extra_inputs=None,
|
||||||
enable_fp8_training=False,
|
enable_fp8_training=False,
|
||||||
task="sft",
|
task="sft",
|
||||||
|
beta_dpo=1000.,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Load models
|
# Load models
|
||||||
@@ -40,8 +39,9 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
|||||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||||
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||||
self.task = task
|
self.task = task
|
||||||
|
self.lora_base_model = lora_base_model
|
||||||
|
self.beta_dpo = beta_dpo
|
||||||
|
|
||||||
|
|
||||||
def forward_preprocess(self, data):
|
def forward_preprocess(self, data):
|
||||||
# CFG-sensitive parameters
|
# CFG-sensitive parameters
|
||||||
inputs_posi = {"prompt": data["prompt"]}
|
inputs_posi = {"prompt": data["prompt"]}
|
||||||
@@ -81,32 +81,62 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
|||||||
for unit in self.pipe.units:
|
for unit in self.pipe.units:
|
||||||
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
||||||
return {**inputs_shared, **inputs_posi}
|
return {**inputs_shared, **inputs_posi}
|
||||||
|
|
||||||
|
def forward_dpo(self, data, accelerator=None):
|
||||||
def forward(self, data, inputs=None, return_inputs=False):
|
# Loss DPO: -logσ(−β(diff_policy − diff_ref))
|
||||||
# DPO (DPO requires a special training loss)
|
# Prepare inputs
|
||||||
if self.task == "dpo":
|
win_data = {key: data[key] for key in ["prompt", "image"]}
|
||||||
loss = DPOLoss().loss(self, data)
|
lose_data = {"prompt": None, "image": data["lose_image"]}
|
||||||
return loss
|
inputs_win = self.forward_preprocess(win_data)
|
||||||
|
inputs_lose = self.forward_preprocess(lose_data)
|
||||||
|
inputs_lose.update({key: inputs_win[key] for key in ["prompt", "prompt_emb", "prompt_emb_mask"]})
|
||||||
|
inputs_win.pop('noise')
|
||||||
|
inputs_lose.pop('noise')
|
||||||
|
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||||
|
# sample timestep and noise
|
||||||
|
timestep = self.pipe.sample_timestep()
|
||||||
|
noise = torch.rand_like(inputs_win["latents"])
|
||||||
|
# compute diff_policy = loss_win - loss_lose
|
||||||
|
loss_win = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_win)
|
||||||
|
loss_lose = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_lose)
|
||||||
|
diff_policy = loss_win - loss_lose
|
||||||
|
# compute diff_ref
|
||||||
|
if self.lora_base_model is not None:
|
||||||
|
self.disable_all_lora_layers(accelerator.unwrap_model(self).pipe.dit)
|
||||||
|
# load the original model weights
|
||||||
|
with torch.no_grad():
|
||||||
|
loss_win_ref = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_win)
|
||||||
|
loss_lose_ref = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_lose)
|
||||||
|
diff_ref = loss_win_ref - loss_lose_ref
|
||||||
|
self.enable_all_lora_layers(accelerator.unwrap_model(self).pipe.dit)
|
||||||
else:
|
else:
|
||||||
# Inputs
|
# TODO: may support full model training
|
||||||
if inputs is None:
|
raise NotImplementedError("DPO with full model training is not supported yet.")
|
||||||
inputs = self.forward_preprocess(data)
|
# compute loss
|
||||||
else:
|
loss = -1. * torch.nn.functional.logsigmoid(self.beta_dpo * (diff_ref - diff_policy)).mean()
|
||||||
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
return loss
|
||||||
if return_inputs: return inputs
|
|
||||||
|
def forward(self, data, inputs=None, return_inputs=False, accelerator=None, **kwargs):
|
||||||
# Loss
|
if self.task == "dpo":
|
||||||
if self.task == "sft":
|
return self.forward_dpo(data, accelerator=accelerator)
|
||||||
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
# Inputs
|
||||||
loss = self.pipe.training_loss(**models, **inputs)
|
if inputs is None:
|
||||||
elif self.task == "data_process":
|
inputs = self.forward_preprocess(data)
|
||||||
loss = inputs
|
else:
|
||||||
elif self.task == "direct_distill":
|
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||||
loss = self.pipe.direct_distill_loss(**inputs)
|
if return_inputs: return inputs
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unsupported task: {self.task}.")
|
# Loss
|
||||||
return loss
|
if self.task == "sft":
|
||||||
|
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||||
|
loss = self.pipe.training_loss(**models, **inputs)
|
||||||
|
elif self.task == "data_process":
|
||||||
|
loss = inputs
|
||||||
|
elif self.task == "direct_distill":
|
||||||
|
loss = self.pipe.direct_distill_loss(**inputs)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported task: {self.task}.")
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -142,6 +172,7 @@ if __name__ == "__main__":
|
|||||||
extra_inputs=args.extra_inputs,
|
extra_inputs=args.extra_inputs,
|
||||||
enable_fp8_training=args.enable_fp8_training,
|
enable_fp8_training=args.enable_fp8_training,
|
||||||
task=args.task,
|
task=args.task,
|
||||||
|
beta_dpo=args.beta_dpo,
|
||||||
)
|
)
|
||||||
model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt)
|
model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt)
|
||||||
launcher_map = {
|
launcher_map = {
|
||||||
|
|||||||
@@ -1,26 +0,0 @@
|
|||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
|
||||||
from diffsynth import load_state_dict
|
|
||||||
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2509", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=None,
|
|
||||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
|
||||||
)
|
|
||||||
state_dict = load_state_dict("models/train/Qwen-Image-Edit-2509_full/epoch-1.safetensors")
|
|
||||||
pipe.dit.load_state_dict(state_dict)
|
|
||||||
|
|
||||||
prompt = "Change the color of the dress in Figure 1 to the color shown in Figure 2."
|
|
||||||
images = [
|
|
||||||
Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)),
|
|
||||||
Image.open("data/example_image_dataset/edit/image_color.jpg").resize((1024, 1024)),
|
|
||||||
]
|
|
||||||
image = pipe(prompt, edit_image=images, seed=123, num_inference_steps=40, height=1024, width=1024)
|
|
||||||
image.save("image.jpg")
|
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
pipe = QwenImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||||
|
)
|
||||||
|
pipe.load_lora(pipe.dit, "models/train/Qwen-Image_DPO_lora/epoch-4.safetensors")
|
||||||
|
prompt = "黑板上写着“群起效尤,心灵手巧”,字的颜色分别是 “群”: 橙色、“起”: 黑色、“效”: 蓝色、“尤”: 绿色、“心”: 紫色、“灵”: 粉色、“手”: 红色、“巧”: 白色"
|
||||||
|
for seed in range(0, 5):
|
||||||
|
image = pipe(prompt, seed=seed)
|
||||||
|
image.save(f"image_dpo_{seed}.jpg")
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
|
||||||
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2509", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=None,
|
|
||||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
|
||||||
)
|
|
||||||
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Edit-2509_lora/epoch-4.safetensors")
|
|
||||||
|
|
||||||
prompt = "Change the color of the dress in Figure 1 to the color shown in Figure 2."
|
|
||||||
images = [
|
|
||||||
Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)),
|
|
||||||
Image.open("data/example_image_dataset/edit/image_color.jpg").resize((1024, 1024)),
|
|
||||||
]
|
|
||||||
image = pipe(prompt, edit_image=images, seed=123, num_inference_steps=40, height=1024, width=1024)
|
|
||||||
image.save("image.jpg")
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-EliGen-Poster_lora/epoch-4.safetensors")
|
|
||||||
|
|
||||||
|
|
||||||
entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"]
|
|
||||||
global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'"
|
|
||||||
masks = [Image.open(f"data/example_image_dataset/eligen/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
|
||||||
|
|
||||||
image = pipe(global_prompt,
|
|
||||||
seed=0,
|
|
||||||
height=1024,
|
|
||||||
width=1024,
|
|
||||||
eligen_entity_prompts=entity_prompts,
|
|
||||||
eligen_entity_masks=masks)
|
|
||||||
image.save("Qwen-Image-EliGen-Poster.jpg")
|
|
||||||
@@ -48,12 +48,10 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
|||||||
|
|
||||||
| Model ID | Extra Parameters | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
| Model ID | Extra Parameters | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./model_inference/Wan2.2-Animate-14B.py)|[code](./model_training/full/Wan2.2-Animate-14B.sh)|[code](./model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./model_training/lora/Wan2.2-Animate-14B.sh)|[code](./model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
|
||||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
||||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||||
|[Wan-AI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](./model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](./model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](./model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|
|
||||||
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.2-Fun-A14B-InP.py)|[code](./model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](./model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](./model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](./model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.2-Fun-A14B-InP.py)|[code](./model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](./model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](./model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](./model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
||||||
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](./model_inference/Wan2.2-Fun-A14B-Control.py)|[code](./model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](./model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](./model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](./model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](./model_inference/Wan2.2-Fun-A14B-Control.py)|[code](./model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](./model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](./model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](./model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
||||||
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
||||||
|
|||||||
@@ -48,12 +48,10 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
|||||||
|
|
||||||
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./model_inference/Wan2.2-Animate-14B.py)|[code](./model_training/full/Wan2.2-Animate-14B.sh)|[code](./model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./model_training/lora/Wan2.2-Animate-14B.sh)|[code](./model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
|
||||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
||||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||||
|[Wan-AI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](./model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](./model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](./model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|
|
||||||
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.2-Fun-A14B-InP.py)|[code](./model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](./model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](./model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](./model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.2-Fun-A14B-InP.py)|[code](./model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](./model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](./model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](./model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
||||||
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](./model_inference/Wan2.2-Fun-A14B-Control.py)|[code](./model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](./model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](./model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](./model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](./model_inference/Wan2.2-Fun-A14B-Control.py)|[code](./model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](./model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](./model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](./model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
||||||
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
||||||
|
|||||||
@@ -1,62 +0,0 @@
|
|||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from diffsynth import save_video, VideoData, load_state_dict
|
|
||||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
|
||||||
from modelscope import dataset_snapshot_download, snapshot_download
|
|
||||||
|
|
||||||
|
|
||||||
pipe = WanVideoPipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
|
||||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
|
||||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
|
||||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
pipe.enable_vram_management()
|
|
||||||
|
|
||||||
dataset_snapshot_download(
|
|
||||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
|
||||||
local_dir="./",
|
|
||||||
allow_file_pattern="data/examples/wan/animate/*",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Animate
|
|
||||||
input_image = Image.open("data/examples/wan/animate/animate_input_image.png")
|
|
||||||
animate_pose_video = VideoData("data/examples/wan/animate/animate_pose_video.mp4").raw_data()[:81-4]
|
|
||||||
animate_face_video = VideoData("data/examples/wan/animate/animate_face_video.mp4").raw_data()[:81-4]
|
|
||||||
video = pipe(
|
|
||||||
prompt="视频中的人在做动作",
|
|
||||||
seed=0, tiled=True,
|
|
||||||
input_image=input_image,
|
|
||||||
animate_pose_video=animate_pose_video,
|
|
||||||
animate_face_video=animate_face_video,
|
|
||||||
num_frames=81, height=720, width=1280,
|
|
||||||
num_inference_steps=20, cfg_scale=1,
|
|
||||||
)
|
|
||||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
|
||||||
|
|
||||||
# Replace
|
|
||||||
snapshot_download("Wan-AI/Wan2.2-Animate-14B", allow_file_pattern="relighting_lora.ckpt", local_dir="models/Wan-AI/Wan2.2-Animate-14B")
|
|
||||||
lora_state_dict = load_state_dict("models/Wan-AI/Wan2.2-Animate-14B/relighting_lora.ckpt", torch_dtype=torch.float32, device="cuda")["state_dict"]
|
|
||||||
pipe.load_lora(pipe.dit, state_dict=lora_state_dict)
|
|
||||||
input_image = Image.open("data/examples/wan/animate/replace_input_image.png")
|
|
||||||
animate_pose_video = VideoData("data/examples/wan/animate/replace_pose_video.mp4").raw_data()[:81-4]
|
|
||||||
animate_face_video = VideoData("data/examples/wan/animate/replace_face_video.mp4").raw_data()[:81-4]
|
|
||||||
animate_inpaint_video = VideoData("data/examples/wan/animate/replace_inpaint_video.mp4").raw_data()[:81-4]
|
|
||||||
animate_mask_video = VideoData("data/examples/wan/animate/replace_mask_video.mp4").raw_data()[:81-4]
|
|
||||||
video = pipe(
|
|
||||||
prompt="视频中的人在做动作",
|
|
||||||
seed=0, tiled=True,
|
|
||||||
input_image=input_image,
|
|
||||||
animate_pose_video=animate_pose_video,
|
|
||||||
animate_face_video=animate_face_video,
|
|
||||||
animate_inpaint_video=animate_inpaint_video,
|
|
||||||
animate_mask_video=animate_mask_video,
|
|
||||||
num_frames=81, height=720, width=1280,
|
|
||||||
num_inference_steps=20, cfg_scale=1,
|
|
||||||
)
|
|
||||||
save_video(video, "video2.mp4", fps=15, quality=5)
|
|
||||||
|
|
||||||
@@ -1,54 +0,0 @@
|
|||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from diffsynth import save_video, VideoData
|
|
||||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
|
||||||
from modelscope import dataset_snapshot_download
|
|
||||||
|
|
||||||
|
|
||||||
pipe = WanVideoPipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
|
||||||
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
|
||||||
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
|
||||||
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
pipe.enable_vram_management()
|
|
||||||
|
|
||||||
|
|
||||||
dataset_snapshot_download(
|
|
||||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
|
||||||
local_dir="./",
|
|
||||||
allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Depth video -> Video
|
|
||||||
control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832)
|
|
||||||
video = pipe(
|
|
||||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
|
||||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
|
||||||
vace_video=control_video,
|
|
||||||
seed=1, tiled=True
|
|
||||||
)
|
|
||||||
save_video(video, "video1_14b.mp4", fps=15, quality=5)
|
|
||||||
|
|
||||||
# Reference image -> Video
|
|
||||||
video = pipe(
|
|
||||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
|
||||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
|
||||||
vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)),
|
|
||||||
seed=1, tiled=True
|
|
||||||
)
|
|
||||||
save_video(video, "video2_14b.mp4", fps=15, quality=5)
|
|
||||||
|
|
||||||
# Depth video + Reference image -> Video
|
|
||||||
video = pipe(
|
|
||||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
|
||||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
|
||||||
vace_video=control_video,
|
|
||||||
vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)),
|
|
||||||
seed=1, tiled=True
|
|
||||||
)
|
|
||||||
save_video(video, "video3_14b.mp4", fps=15, quality=5)
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
|
||||||
--dataset_base_path data/example_video_dataset \
|
|
||||||
--dataset_metadata_path data/example_video_dataset/metadata_animate.csv \
|
|
||||||
--data_file_keys "video,animate_pose_video,animate_face_video" \
|
|
||||||
--height 480 \
|
|
||||||
--width 832 \
|
|
||||||
--num_frames 81 \
|
|
||||||
--dataset_repeat 100 \
|
|
||||||
--model_id_with_origin_paths "Wan-AI/Wan2.2-Animate-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-Animate-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-Animate-14B:Wan2.1_VAE.pth,Wan-AI/Wan2.2-Animate-14B:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
|
|
||||||
--learning_rate 1e-5 \
|
|
||||||
--num_epochs 2 \
|
|
||||||
--remove_prefix_in_ckpt "pipe.animate_adapter." \
|
|
||||||
--output_path "./models/train/Wan2.2-Animate-14B_full" \
|
|
||||||
--trainable_models "animate_adapter" \
|
|
||||||
--extra_inputs "input_image,animate_pose_video,animate_face_video" \
|
|
||||||
--use_gradient_checkpointing_offload
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
|
||||||
--dataset_base_path data/example_video_dataset \
|
|
||||||
--dataset_metadata_path data/example_video_dataset/metadata_vace.csv \
|
|
||||||
--data_file_keys "video,vace_video,vace_reference_image" \
|
|
||||||
--height 480 \
|
|
||||||
--width 832 \
|
|
||||||
--num_frames 17 \
|
|
||||||
--dataset_repeat 100 \
|
|
||||||
--model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \
|
|
||||||
--learning_rate 1e-4 \
|
|
||||||
--num_epochs 2 \
|
|
||||||
--remove_prefix_in_ckpt "pipe.vace." \
|
|
||||||
--output_path "./models/train/Wan2.2-VACE-Fun-A14B_high_noise_full" \
|
|
||||||
--trainable_models "vace" \
|
|
||||||
--extra_inputs "vace_video,vace_reference_image" \
|
|
||||||
--use_gradient_checkpointing_offload \
|
|
||||||
--max_timestep_boundary 0.358 \
|
|
||||||
--min_timestep_boundary 0
|
|
||||||
# boundary corresponds to timesteps [900, 1000]
|
|
||||||
|
|
||||||
|
|
||||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
|
||||||
--dataset_base_path data/example_video_dataset \
|
|
||||||
--dataset_metadata_path data/example_video_dataset/metadata_vace.csv \
|
|
||||||
--data_file_keys "video,vace_video,vace_reference_image" \
|
|
||||||
--height 480 \
|
|
||||||
--width 832 \
|
|
||||||
--num_frames 17 \
|
|
||||||
--dataset_repeat 100 \
|
|
||||||
--model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \
|
|
||||||
--learning_rate 1e-4 \
|
|
||||||
--num_epochs 2 \
|
|
||||||
--remove_prefix_in_ckpt "pipe.vace." \
|
|
||||||
--output_path "./models/train/Wan2.2-VACE-Fun-A14B_low_noise_full" \
|
|
||||||
--trainable_models "vace" \
|
|
||||||
--extra_inputs "vace_video,vace_reference_image" \
|
|
||||||
--use_gradient_checkpointing_offload \
|
|
||||||
--max_timestep_boundary 1 \
|
|
||||||
--min_timestep_boundary 0.358
|
|
||||||
# boundary corresponds to timesteps [0, 900]
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
# 1*80G GPU cannot train Wan2.2-Animate-14B LoRA
|
|
||||||
# We tested on 8*80G GPUs
|
|
||||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
|
||||||
--dataset_base_path data/example_video_dataset \
|
|
||||||
--dataset_metadata_path data/example_video_dataset/metadata_animate.csv \
|
|
||||||
--data_file_keys "video,animate_pose_video,animate_face_video" \
|
|
||||||
--height 480 \
|
|
||||||
--width 832 \
|
|
||||||
--num_frames 81 \
|
|
||||||
--dataset_repeat 100 \
|
|
||||||
--model_id_with_origin_paths "Wan-AI/Wan2.2-Animate-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-Animate-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-Animate-14B:Wan2.1_VAE.pth,Wan-AI/Wan2.2-Animate-14B:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
|
|
||||||
--learning_rate 1e-4 \
|
|
||||||
--num_epochs 5 \
|
|
||||||
--remove_prefix_in_ckpt "pipe.dit." \
|
|
||||||
--output_path "./models/train/Wan2.2-Animate-14B_lora" \
|
|
||||||
--lora_base_model "dit" \
|
|
||||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
|
||||||
--lora_rank 32 \
|
|
||||||
--extra_inputs "input_image,animate_pose_video,animate_face_video" \
|
|
||||||
--use_gradient_checkpointing_offload
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
accelerate launch examples/wanvideo/model_training/train.py \
|
|
||||||
--dataset_base_path data/example_video_dataset \
|
|
||||||
--dataset_metadata_path data/example_video_dataset/metadata_vace.csv \
|
|
||||||
--data_file_keys "video,vace_video,vace_reference_image" \
|
|
||||||
--height 480 \
|
|
||||||
--width 832 \
|
|
||||||
--num_frames 17 \
|
|
||||||
--dataset_repeat 100 \
|
|
||||||
--model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \
|
|
||||||
--learning_rate 1e-4 \
|
|
||||||
--num_epochs 5 \
|
|
||||||
--remove_prefix_in_ckpt "pipe.vace." \
|
|
||||||
--output_path "./models/train/Wan2.2-VACE-Fun-A14B_high_noise_lora" \
|
|
||||||
--lora_base_model "vace" \
|
|
||||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
|
||||||
--lora_rank 32 \
|
|
||||||
--extra_inputs "vace_video,vace_reference_image" \
|
|
||||||
--use_gradient_checkpointing_offload \
|
|
||||||
--max_timestep_boundary 0.358 \
|
|
||||||
--min_timestep_boundary 0
|
|
||||||
# boundary corresponds to timesteps [900, 1000]
|
|
||||||
|
|
||||||
accelerate launch examples/wanvideo/model_training/train.py \
|
|
||||||
--dataset_base_path data/example_video_dataset \
|
|
||||||
--dataset_metadata_path data/example_video_dataset/metadata_vace.csv \
|
|
||||||
--data_file_keys "video,vace_video,vace_reference_image" \
|
|
||||||
--height 480 \
|
|
||||||
--width 832 \
|
|
||||||
--num_frames 17 \
|
|
||||||
--dataset_repeat 100 \
|
|
||||||
--model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \
|
|
||||||
--learning_rate 1e-4 \
|
|
||||||
--num_epochs 5 \
|
|
||||||
--remove_prefix_in_ckpt "pipe.vace." \
|
|
||||||
--output_path "./models/train/Wan2.2-VACE-Fun-A14B_low_noise_lora" \
|
|
||||||
--lora_base_model "vace" \
|
|
||||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
|
||||||
--lora_rank 32 \
|
|
||||||
--extra_inputs "vace_video,vace_reference_image" \
|
|
||||||
--use_gradient_checkpointing_offload \
|
|
||||||
--max_timestep_boundary 1 \
|
|
||||||
--min_timestep_boundary 0.358
|
|
||||||
# boundary corresponds to timesteps [0, 900]
|
|
||||||
@@ -2,7 +2,7 @@ import torch, os, json
|
|||||||
from diffsynth import load_state_dict
|
from diffsynth import load_state_dict
|
||||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, wan_parser
|
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, wan_parser
|
||||||
from diffsynth.trainers.unified_dataset import UnifiedDataset, LoadVideo, ImageCropAndResize, ToAbsolutePath
|
from diffsynth.trainers.unified_dataset import UnifiedDataset
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
@@ -82,7 +82,7 @@ class WanTrainingModule(DiffusionTrainingModule):
|
|||||||
return {**inputs_shared, **inputs_posi}
|
return {**inputs_shared, **inputs_posi}
|
||||||
|
|
||||||
|
|
||||||
def forward(self, data, inputs=None):
|
def forward(self, data, inputs=None, **kwargs):
|
||||||
if inputs is None: inputs = self.forward_preprocess(data)
|
if inputs is None: inputs = self.forward_preprocess(data)
|
||||||
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||||
loss = self.pipe.training_loss(**models, **inputs)
|
loss = self.pipe.training_loss(**models, **inputs)
|
||||||
@@ -108,9 +108,6 @@ if __name__ == "__main__":
|
|||||||
time_division_factor=4,
|
time_division_factor=4,
|
||||||
time_division_remainder=1,
|
time_division_remainder=1,
|
||||||
),
|
),
|
||||||
special_operator_map={
|
|
||||||
"animate_face_video": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16))
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
model = WanTrainingModule(
|
model = WanTrainingModule(
|
||||||
model_paths=args.model_paths,
|
model_paths=args.model_paths,
|
||||||
|
|||||||
@@ -1,33 +0,0 @@
|
|||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from diffsynth import save_video, VideoData, load_state_dict
|
|
||||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
|
||||||
|
|
||||||
|
|
||||||
pipe = WanVideoPipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
|
||||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
|
||||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
|
||||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
state_dict = load_state_dict("models/train/Wan2.2-Animate-14B_full/epoch-1.safetensors")
|
|
||||||
pipe.animate_adapter.load_state_dict(state_dict, strict=False)
|
|
||||||
pipe.enable_vram_management()
|
|
||||||
|
|
||||||
input_image = VideoData("data/example_video_dataset/animate/animate_output.mp4", height=480, width=832)[0]
|
|
||||||
animate_pose_video = VideoData("data/examples/wan/animate/animate_pose_video.mp4", height=480, width=832).raw_data()[:81-4]
|
|
||||||
animate_face_video = VideoData("data/examples/wan/animate/animate_face_video.mp4", height=512, width=512).raw_data()[:81-4]
|
|
||||||
video = pipe(
|
|
||||||
prompt="视频中的人在做动作",
|
|
||||||
seed=0, tiled=True,
|
|
||||||
input_image=input_image,
|
|
||||||
animate_pose_video=animate_pose_video,
|
|
||||||
animate_face_video=animate_face_video,
|
|
||||||
num_frames=81, height=480, width=832,
|
|
||||||
num_inference_steps=20, cfg_scale=1,
|
|
||||||
)
|
|
||||||
save_video(video, "video_Wan2.2-Animate-14B.mp4", fps=15, quality=5)
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from diffsynth import save_video, VideoData, load_state_dict
|
|
||||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
|
||||||
|
|
||||||
|
|
||||||
pipe = WanVideoPipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
|
||||||
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
|
||||||
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
|
||||||
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
state_dict = load_state_dict("models/train/Wan2.2-VACE-Fun-A14B_high_noise_full/epoch-1.safetensors")
|
|
||||||
pipe.vace.load_state_dict(state_dict)
|
|
||||||
state_dict = load_state_dict("models/train/Wan2.2-VACE-Fun-A14B_low_noise_full/epoch-1.safetensors")
|
|
||||||
pipe.vace2.load_state_dict(state_dict)
|
|
||||||
pipe.enable_vram_management()
|
|
||||||
|
|
||||||
video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832)
|
|
||||||
video = [video[i] for i in range(17)]
|
|
||||||
reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
|
|
||||||
|
|
||||||
video = pipe(
|
|
||||||
prompt="from sunset to night, a small town, light, house, river",
|
|
||||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
|
||||||
vace_video=video, vace_reference_image=reference_image, num_frames=17,
|
|
||||||
seed=1, tiled=True
|
|
||||||
)
|
|
||||||
save_video(video, "video_Wan2.2-VACE-A14B.mp4", fps=15, quality=5)
|
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from diffsynth import save_video, VideoData, load_state_dict
|
|
||||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
|
||||||
|
|
||||||
|
|
||||||
pipe = WanVideoPipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
|
||||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
|
||||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
|
||||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
pipe.load_lora(pipe.dit, "models/train/Wan2.2-Animate-14B_lora/epoch-4.safetensors", alpha=1)
|
|
||||||
pipe.enable_vram_management()
|
|
||||||
|
|
||||||
input_image = VideoData("data/example_video_dataset/animate/animate_output.mp4", height=480, width=832)[0]
|
|
||||||
animate_pose_video = VideoData("data/examples/wan/animate/animate_pose_video.mp4", height=480, width=832).raw_data()[:81-4]
|
|
||||||
animate_face_video = VideoData("data/examples/wan/animate/animate_face_video.mp4", height=512, width=512).raw_data()[:81-4]
|
|
||||||
video = pipe(
|
|
||||||
prompt="视频中的人在做动作",
|
|
||||||
seed=0, tiled=True,
|
|
||||||
input_image=input_image,
|
|
||||||
animate_pose_video=animate_pose_video,
|
|
||||||
animate_face_video=animate_face_video,
|
|
||||||
num_frames=81, height=480, width=832,
|
|
||||||
num_inference_steps=20, cfg_scale=1,
|
|
||||||
)
|
|
||||||
save_video(video, "video_Wan2.2-Animate-14B.mp4", fps=15, quality=5)
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from diffsynth import save_video, VideoData
|
|
||||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
|
||||||
|
|
||||||
|
|
||||||
pipe = WanVideoPipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
|
||||||
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
|
||||||
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
|
||||||
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
pipe.load_lora(pipe.vace, "models/train/Wan2.2-VACE-Fun-A14B_high_noise_lora/epoch-4.safetensors", alpha=1)
|
|
||||||
pipe.load_lora(pipe.vace2, "models/train/Wan2.2-VACE-Fun-A14B_low_noise_lora/epoch-4.safetensors", alpha=1)
|
|
||||||
pipe.enable_vram_management()
|
|
||||||
|
|
||||||
video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832)
|
|
||||||
video = [video[i] for i in range(17)]
|
|
||||||
reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
|
|
||||||
|
|
||||||
video = pipe(
|
|
||||||
prompt="from sunset to night, a small town, light, house, river",
|
|
||||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
|
||||||
vace_video=video, vace_reference_image=reference_image, num_frames=17,
|
|
||||||
seed=1, tiled=True
|
|
||||||
)
|
|
||||||
save_video(video, "video_Wan2.2-VACE-Fun-A14B.mp4", fps=15, quality=5)
|
|
||||||
Reference in New Issue
Block a user