mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
Compare commits
164 Commits
qwen-image
...
dpo-refine
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6737dbfc9f | ||
|
|
0a1c172a00 | ||
|
|
77fac2a03f | ||
|
|
084bc2fc78 | ||
|
|
c63d474b60 | ||
|
|
7540568156 | ||
|
|
c5d426c254 | ||
|
|
a36f2f6032 | ||
|
|
ed256ef8be | ||
|
|
15079a6cb8 | ||
|
|
c084d6377b | ||
|
|
e9bc42f233 | ||
|
|
0d6de58af9 | ||
|
|
acbf932974 | ||
|
|
9d64ed7042 | ||
|
|
0b4b337e9a | ||
|
|
99908d9a1c | ||
|
|
73ced7a46d | ||
|
|
32b8b9b51e | ||
|
|
f6534a5b63 | ||
|
|
034c9b6c60 | ||
|
|
76335e0fe5 | ||
|
|
c0b589d934 | ||
|
|
833ba1e1fa | ||
|
|
7a5974d964 | ||
|
|
b0abdaffb4 | ||
|
|
e9f29bc402 | ||
|
|
1a7f482fbd | ||
|
|
3a0d51d100 | ||
|
|
bffdb901ed | ||
|
|
d93e8738cd | ||
|
|
7e5ce5d5c9 | ||
|
|
7aef554d83 | ||
|
|
090074e395 | ||
|
|
2dcdeefca8 | ||
|
|
452a6ca5cf | ||
|
|
d6cf20ef33 | ||
|
|
efdd6a59b6 | ||
|
|
42ec7b08eb | ||
|
|
d049fb6d1d | ||
|
|
144365b07d | ||
|
|
cb8de6be1b | ||
|
|
8c13362dcf | ||
|
|
c13fd7e0ee | ||
|
|
958ebf1352 | ||
|
|
b6da77e468 | ||
|
|
260e32217f | ||
|
|
5cee326f92 | ||
|
|
1d240994e7 | ||
|
|
a0bae07825 | ||
|
|
ff71720297 | ||
|
|
dea85643e6 | ||
|
|
6a46f32afe | ||
|
|
4641d0f360 | ||
|
|
826bab5962 | ||
|
|
5b6d112c15 | ||
|
|
febdaf6067 | ||
|
|
0a78bb9d38 | ||
|
|
9cea10cc69 | ||
|
|
caa17da5b9 | ||
|
|
fdeb363fa2 | ||
|
|
4147473c81 | ||
|
|
8a0bd7c377 | ||
|
|
b541b9bed2 | ||
|
|
419d47c195 | ||
|
|
ac2e859960 | ||
|
|
6663dca015 | ||
|
|
86e509ad31 | ||
|
|
8fcfa1dd2d | ||
|
|
2b7a2548b4 | ||
|
|
f0916e6bae | ||
|
|
822e80ec2f | ||
|
|
04e39f7de5 | ||
|
|
ce0b948655 | ||
|
|
c795e35142 | ||
|
|
f7c01f1367 | ||
|
|
cb49f0283f | ||
|
|
6a45815b23 | ||
|
|
8dae8d7bc8 | ||
|
|
f6418004bb | ||
|
|
c4b97cd591 | ||
|
|
b6d1ff01e0 | ||
|
|
0d81626fe7 | ||
|
|
e3f47a799b | ||
|
|
e014cad820 | ||
|
|
89bf3ce5cf | ||
|
|
3ebe118f23 | ||
|
|
7f719cefe6 | ||
|
|
46bd05b54d | ||
|
|
613dafbd09 | ||
|
|
952933eeb1 | ||
|
|
c0172e70b1 | ||
|
|
6ab426e641 | ||
|
|
d0467a7e8d | ||
|
|
36838a05ee | ||
|
|
5e6f9f89f1 | ||
|
|
2dad9a319c | ||
|
|
9ec0652339 | ||
|
|
7e348083ae | ||
|
|
29b12b2f4e | ||
|
|
b3f57ed920 | ||
|
|
c9fea729d8 | ||
|
|
9d0683df25 | ||
|
|
838b8109b1 | ||
|
|
3a9621f6da | ||
|
|
fff2c89360 | ||
|
|
ce61bef2b0 | ||
|
|
123f6dbadb | ||
|
|
f9ce261a0e | ||
|
|
d93de98a21 | ||
|
|
ad1da43476 | ||
|
|
398b1dbd7a | ||
|
|
9f6922bba9 | ||
|
|
f11a91e610 | ||
|
|
7ed09bb78d | ||
|
|
ac931856d5 | ||
|
|
2d09318236 | ||
|
|
7dc49bd036 | ||
|
|
4d16bdf853 | ||
|
|
01a1f48f70 | ||
|
|
6a9d875d65 | ||
|
|
f1c96d31b4 | ||
|
|
aafcca8d77 | ||
|
|
bf369cad4d | ||
|
|
024fdad76d | ||
|
|
e1c2eda5f5 | ||
|
|
0b574cc0c2 | ||
|
|
3212c83398 | ||
|
|
49f9a11eb3 | ||
|
|
fa36739f01 | ||
|
|
42e9764b60 | ||
|
|
f7f5c07570 | ||
|
|
ec1a936624 | ||
|
|
6e6136586c | ||
|
|
34766863f8 | ||
|
|
1d76d5e828 | ||
|
|
250540a398 | ||
|
|
46f3c38c37 | ||
|
|
9a8982efb1 | ||
|
|
3c815cce4b | ||
|
|
39d199c8bb | ||
|
|
f5506d1e13 | ||
|
|
166a8734fe | ||
|
|
b2273ec568 | ||
|
|
89c4e3bdb6 | ||
|
|
051ebf3439 | ||
|
|
7cfadc2ca8 | ||
|
|
32cf5d32ce | ||
|
|
4f7c3b6a1e | ||
|
|
57128dc89f | ||
|
|
d20680baae | ||
|
|
970403f78e | ||
|
|
bee2a969e5 | ||
|
|
2803ffcb38 | ||
|
|
d3224e1fdc | ||
|
|
3c2f85606f | ||
|
|
1f25ad416b | ||
|
|
d0b9b25db7 | ||
|
|
ef09db69cd | ||
|
|
a3b67436a6 | ||
|
|
3915bc3ee6 | ||
|
|
4299c999b5 | ||
|
|
6bae70eee0 | ||
|
|
6452edb738 |
57
README.md
57
README.md
@@ -64,6 +64,7 @@ Details: [./examples/qwen_image/](./examples/qwen_image/)
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
@@ -77,7 +78,10 @@ pipe = QwenImagePipeline.from_pretrained(
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "A detailed portrait of a girl underwater, wearing a blue flowing dress, hair gently floating, clear light and shadow, surrounded by bubbles, calm expression, fine details, dreamy and beautiful."
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image = pipe(
|
||||
prompt, seed=0, num_inference_steps=40,
|
||||
# edit_image=Image.open("xxx.jpg").resize((1328, 1328)) # For Qwen-Image-Edit
|
||||
)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
@@ -87,10 +91,21 @@ image.save("image.jpg")
|
||||
|
||||
<summary>Model Overview</summary>
|
||||
|
||||
|Model ID|Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|
|
||||
|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|Model ID|Inference|Low VRAM Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|
||||
|
||||
</details>
|
||||
|
||||
@@ -192,9 +207,15 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|
||||
| Model ID | Extra Parameters | Inference | Full Training | Validate After Full Training | LoRA Training | Validate After LoRA Training |
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||
|[Wan-AI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](./examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||
@@ -364,6 +385,32 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
|
||||
|
||||
## Update History
|
||||
|
||||
- **September 23, 2025** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) is released! This model is jointly developed and open-sourced by us and the Taobao Design Team. The model is built upon Qwen-Image, specifically designed for e-commerce poster scenarios, and supports precise partition layout control. Please refer to [our example code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py).
|
||||
|
||||
- **September 9, 2025**: Our training framework now supports multiple training modes and has been adapted for Qwen-Image. In addition to the standard SFT training mode, Direct Distill is now also supported; please refer to [our example code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh). This feature is experimental, and we will continue to improve it to support comprehensive model training capabilities.
|
||||
|
||||
- **August 28, 2025** We support Wan2.2-S2V, an audio-driven cinematic video generation model open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
|
||||
|
||||
- **August 21, 2025**: [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) is released! Compared to the V1 version, the training dataset has been updated to the [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset), enabling generated images to better align with the inherent image distribution and style of Qwen-Image. Please refer to [our sample code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py).
|
||||
|
||||
- **August 21, 2025**: We open-sourced the [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) structure control LoRA model. Following "In Context" routine, it supports various types of structural control conditions, including canny, depth, lineart, softedge, normal, and openpose. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py).
|
||||
|
||||
- **August 20, 2025** We open-sourced [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix), which improves the editing performance of Qwen-Image-Edit on low-resolution image inputs. Please refer to [our example code](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py).
|
||||
|
||||
- **August 19, 2025** 🔥 Qwen-Image-Edit is now open source. Welcome the new member to the image editing model family!
|
||||
|
||||
- **August 18, 2025** We trained and open-sourced the Inpaint ControlNet model for Qwen-Image, [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint), which adopts a lightweight architectural design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py).
|
||||
|
||||
- **August 15, 2025** We open-sourced the [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset). This is an image dataset generated using the Qwen-Image model, with a total of 160,000 `1024 x 1024` images. It includes the general, English text rendering, and Chinese text rendering subsets. We provide caption, entity and control images annotations for each image. Developers can use this dataset to train models such as ControlNet and EliGen for the Qwen-Image model. We aim to promote technological development through open-source contributions!
|
||||
|
||||
- **August 13, 2025** We trained and open-sourced the ControlNet model for Qwen-Image, [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth), which adopts a lightweight architectural design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py).
|
||||
|
||||
- **August 12, 2025** We trained and open-sourced the ControlNet model for Qwen-Image, [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny), which adopts a lightweight architectural design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py).
|
||||
|
||||
- **August 11, 2025** We released another distilled acceleration model for Qwen-Image, [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA). It uses the same training process as [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full), but the model structure is changed to LoRA. This makes it work better with other open-source models.
|
||||
|
||||
- **August 7, 2025** We open-sourced the entity control LoRA of Qwen-Image, [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen). Qwen-Image-EliGen is able to achieve entity-level controlled text-to-image generation. See the [paper](https://arxiv.org/abs/2501.01097) for technical details. Training dataset: [EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet).
|
||||
|
||||
- **August 5, 2025** We open-sourced the distilled acceleration model of Qwen-Image, [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full), achieving approximately 5x speedup.
|
||||
|
||||
- **August 4, 2025** 🔥 Qwen-Image is now open source. Welcome the new member to the image generation model family!
|
||||
|
||||
57
README_zh.md
57
README_zh.md
@@ -66,6 +66,7 @@ DiffSynth-Studio 为主流 Diffusion 模型(包括 FLUX、Wan 等)重新设
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
@@ -79,7 +80,10 @@ pipe = QwenImagePipeline.from_pretrained(
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image = pipe(
|
||||
prompt, seed=0, num_inference_steps=40,
|
||||
# edit_image=Image.open("xxx.jpg").resize((1328, 1328)) # For Qwen-Image-Edit
|
||||
)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
@@ -89,10 +93,21 @@ image.save("image.jpg")
|
||||
|
||||
<summary>模型总览</summary>
|
||||
|
||||
|模型 ID|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|
||||
|
||||
</details>
|
||||
|
||||
@@ -192,9 +207,15 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|
||||
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||
|[Wan-AI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](./examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||
@@ -380,6 +401,32 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
|
||||
|
||||
## 更新历史
|
||||
|
||||
- **2025年9月23日** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) 发布!本模型由我们与淘天体验设计团队联合研发并开源。模型基于 Qwen-Image 构建,专为电商海报场景设计,支持精确的分区布局控制。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)。
|
||||
|
||||
- **2025年9月9日** 我们的训练框架支持了多种训练模式,目前已适配 Qwen-Image,除标准 SFT 训练模式外,已支持 Direct Distill,请参考[我们的示例代码](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)。这项功能是实验性的,我们将会继续完善已支持更全面的模型训练功能。
|
||||
|
||||
- **2025年8月28日** 我们支持了Wan2.2-S2V,一个音频驱动的电影级视频生成模型。请参见[./examples/wanvideo/](./examples/wanvideo/)。
|
||||
|
||||
- **2025年8月21日** [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) 发布!相比于 V1 版本,训练数据集变为 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset),因此,生成的图像更符合 Qwen-Image 本身的图像分布和风格。 请参考[我们的示例代码](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)。
|
||||
|
||||
- **2025年8月21日** 我们开源了 [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) 结构控制 LoRA 模型,采用 In Context 的技术路线,支持多种类别的结构控制条件,包括 canny, depth, lineart, softedge, normal, openpose。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)。
|
||||
|
||||
- **2025年8月20日** 我们开源了 [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) 模型,提升了 Qwen-Image-Edit 对低分辨率图像输入的编辑效果。请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)
|
||||
|
||||
- **2025年8月19日** 🔥 Qwen-Image-Edit 开源,欢迎图像编辑模型新成员!
|
||||
|
||||
- **2025年8月18日** 我们训练并开源了 Qwen-Image 的图像重绘 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)。
|
||||
|
||||
- **2025年8月15日** 我们开源了 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset) 数据集。这是一个使用 Qwen-Image 模型生成的图像数据集,共包含 160,000 张`1024 x 1024`图像。它包括通用、英文文本渲染和中文文本渲染子集。我们为每张图像提供了图像描述、实体和结构控制图像的标注。开发者可以使用这个数据集来训练 Qwen-Image 模型的 ControlNet 和 EliGen 等模型,我们旨在通过开源推动技术发展!
|
||||
|
||||
- **2025年8月13日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)。
|
||||
|
||||
- **2025年8月12日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)。
|
||||
|
||||
- **2025年8月11日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA),沿用了与 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) 相同的训练流程,但模型结构修改为了 LoRA,因此能够更好地与其他开源生态模型兼容。
|
||||
|
||||
- **2025年8月7日** 我们开源了 Qwen-Image 的实体控制 LoRA 模型 [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)。Qwen-Image-EliGen 能够实现实体级可控的文生图。技术细节请参见[论文](https://arxiv.org/abs/2501.01097)。训练数据集:[EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)。
|
||||
|
||||
- **2025年8月5日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full),实现了约 5 倍加速。
|
||||
|
||||
- **2025年8月4日** 🔥 Qwen-Image 开源,欢迎图像生成模型家族新成员!
|
||||
|
||||
382
apps/gradio/qwen_image_eligen.py
Normal file
382
apps/gradio/qwen_image_eligen.py
Normal file
@@ -0,0 +1,382 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import random
|
||||
import json
|
||||
import gradio as gr
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download, snapshot_download
|
||||
|
||||
# pip install pydantic==2.10.6
|
||||
# pip install gradio==5.4.0
|
||||
|
||||
snapshot_download("DiffSynth-Studio/Qwen-Image-EliGen", local_dir="models/DiffSynth-Studio/Qwen-Image-EliGen", allow_file_pattern="model.safetensors")
|
||||
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/qwen-image/*")
|
||||
example_json = 'data/examples/eligen/qwen-image/ui_examples.json'
|
||||
with open(example_json, 'r') as f:
|
||||
examples = json.load(f)['examples']
|
||||
|
||||
for idx in range(len(examples)):
|
||||
example_id = examples[idx]['example_id']
|
||||
entity_prompts = examples[idx]['local_prompt_list']
|
||||
examples[idx]['mask_lists'] = [Image.open(f"data/examples/eligen/qwen-image/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||
|
||||
def create_canvas_data(background, masks):
|
||||
if background.shape[-1] == 3:
|
||||
background = np.dstack([background, np.full(background.shape[:2], 255, dtype=np.uint8)])
|
||||
layers = []
|
||||
for mask in masks:
|
||||
if mask is not None:
|
||||
mask_single_channel = mask if mask.ndim == 2 else mask[..., 0]
|
||||
layer = np.zeros((mask_single_channel.shape[0], mask_single_channel.shape[1], 4), dtype=np.uint8)
|
||||
layer[..., -1] = mask_single_channel
|
||||
layers.append(layer)
|
||||
else:
|
||||
layers.append(np.zeros_like(background))
|
||||
|
||||
composite = background.copy()
|
||||
for layer in layers:
|
||||
if layer.size > 0:
|
||||
composite = np.where(layer[..., -1:] > 0, layer, composite)
|
||||
return {
|
||||
"background": background,
|
||||
"layers": layers,
|
||||
"composite": composite,
|
||||
}
|
||||
|
||||
def load_example(load_example_button):
|
||||
example_idx = int(load_example_button.split()[-1]) - 1
|
||||
example = examples[example_idx]
|
||||
result = [
|
||||
50,
|
||||
example["global_prompt"],
|
||||
example["negative_prompt"],
|
||||
example["seed"],
|
||||
*example["local_prompt_list"],
|
||||
]
|
||||
num_entities = len(example["local_prompt_list"])
|
||||
result += [""] * (config["max_num_painter_layers"] - num_entities)
|
||||
masks = []
|
||||
for mask in example["mask_lists"]:
|
||||
mask_single_channel = np.array(mask.convert("L"))
|
||||
masks.append(mask_single_channel)
|
||||
for _ in range(config["max_num_painter_layers"] - len(masks)):
|
||||
blank_mask = np.zeros_like(masks[0]) if masks else np.zeros((512, 512), dtype=np.uint8)
|
||||
masks.append(blank_mask)
|
||||
background = np.ones((masks[0].shape[0], masks[0].shape[1], 4), dtype=np.uint8) * 255
|
||||
canvas_data_list = []
|
||||
for mask in masks:
|
||||
canvas_data = create_canvas_data(background, [mask])
|
||||
canvas_data_list.append(canvas_data)
|
||||
result.extend(canvas_data_list)
|
||||
return result
|
||||
|
||||
def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'):
|
||||
save_dir = os.path.join('workdirs/tmp_mask', random_dir)
|
||||
print(f'save to {save_dir}')
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
for i, mask in enumerate(masks):
|
||||
save_path = os.path.join(save_dir, f'{i}.png')
|
||||
mask.save(save_path)
|
||||
sample = {
|
||||
"global_prompt": global_prompt,
|
||||
"mask_prompts": mask_prompts,
|
||||
"seed": seed,
|
||||
}
|
||||
with open(os.path.join(save_dir, f"prompts.json"), 'w', encoding='utf-8') as f:
|
||||
json.dump(sample, f, ensure_ascii=False, indent=4)
|
||||
|
||||
def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False):
|
||||
# Create a blank image for overlays
|
||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||
colors = [
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
]
|
||||
# Generate random colors for each mask
|
||||
if use_random_colors:
|
||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
||||
# Font settings
|
||||
try:
|
||||
font = ImageFont.truetype("wqy-zenhei.ttc", font_size) # Adjust as needed
|
||||
except IOError:
|
||||
font = ImageFont.load_default(font_size)
|
||||
# Overlay each mask onto the overlay image
|
||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
||||
if mask is None:
|
||||
continue
|
||||
# Convert mask to RGBA mode
|
||||
mask_rgba = mask.convert('RGBA')
|
||||
mask_data = mask_rgba.getdata()
|
||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
||||
mask_rgba.putdata(new_data)
|
||||
# Draw the mask prompt text on the mask
|
||||
draw = ImageDraw.Draw(mask_rgba)
|
||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
||||
if mask_bbox is None:
|
||||
continue
|
||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
||||
# Alpha composite the overlay with this mask
|
||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
||||
# Composite the overlay onto the original image
|
||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
||||
return result
|
||||
|
||||
config = {
|
||||
"max_num_painter_layers": 8,
|
||||
"max_num_model_cache": 1,
|
||||
}
|
||||
|
||||
model_dict = {}
|
||||
|
||||
def load_model(model_type='qwen-image'):
|
||||
global model_dict
|
||||
model_key = f"{model_type}"
|
||||
if model_key in model_dict:
|
||||
return model_dict[model_key]
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-EliGen/model.safetensors")
|
||||
model_dict[model_key] = pipe
|
||||
return pipe
|
||||
|
||||
load_model('qwen-image')
|
||||
|
||||
with gr.Blocks() as app:
|
||||
gr.Markdown(
|
||||
"""## EliGen: Entity-Level Controllable Text-to-Image Model
|
||||
1. On the left, input the **global prompt** for the overall image, such as "a person stands by the river."
|
||||
2. On the right, input the **local prompt** for each entity, such as "person," and draw the corresponding mask in the **Entity Mask Painter**. Generally, solid rectangular masks yield better results.
|
||||
3. Click the **Generate** button to create the image. By selecting different **random seeds**, you can generate diverse images.
|
||||
4. **You can directly click the "Load Example" button on any sample at the bottom to load example inputs.**
|
||||
"""
|
||||
)
|
||||
|
||||
loading_status = gr.Textbox(label="Loading Model...", value="Loading model... Please wait...", visible=True)
|
||||
main_interface = gr.Column(visible=False)
|
||||
|
||||
def initialize_model():
|
||||
try:
|
||||
load_model('qwen-image')
|
||||
return {
|
||||
loading_status: gr.update(value="Model loaded successfully!", visible=False),
|
||||
main_interface: gr.update(visible=True),
|
||||
}
|
||||
except Exception as e:
|
||||
print(f'Failed to load model with error: {e}')
|
||||
return {
|
||||
loading_status: gr.update(value=f"Failed to load model: {str(e)}", visible=True),
|
||||
main_interface: gr.update(visible=True),
|
||||
}
|
||||
|
||||
app.load(initialize_model, inputs=None, outputs=[loading_status, main_interface])
|
||||
|
||||
with main_interface:
|
||||
with gr.Row():
|
||||
local_prompt_list = []
|
||||
canvas_list = []
|
||||
random_mask_dir = gr.State(f'{random.randint(0, 1000000):08d}')
|
||||
with gr.Column(scale=382, min_width=100):
|
||||
model_type = gr.State('qwen-image')
|
||||
with gr.Accordion(label="Global prompt"):
|
||||
prompt = gr.Textbox(label="Global Prompt", lines=3)
|
||||
negative_prompt = gr.Textbox(label="Negative prompt", value="", lines=3)
|
||||
with gr.Accordion(label="Inference Options", open=True):
|
||||
seed = gr.Number(minimum=0, maximum=10**9, value=42, interactive=True, label="Random seed", show_label=True)
|
||||
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, interactive=True, label="Inference steps")
|
||||
cfg_scale = gr.Slider(minimum=2.0, maximum=10.0, value=4.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
|
||||
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
|
||||
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
|
||||
with gr.Accordion(label="Inpaint Input Image", open=False, visible=False):
|
||||
input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil")
|
||||
background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight", visible=False)
|
||||
|
||||
with gr.Column():
|
||||
reset_input_button = gr.Button(value="Reset Inpaint Input")
|
||||
send_input_to_painter = gr.Button(value="Set as painter's background")
|
||||
@gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click)
|
||||
def reset_input_image(input_image):
|
||||
return None
|
||||
|
||||
with gr.Column(scale=618, min_width=100):
|
||||
with gr.Accordion(label="Entity Painter"):
|
||||
for painter_layer_id in range(config["max_num_painter_layers"]):
|
||||
with gr.Tab(label=f"Entity {painter_layer_id}"):
|
||||
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
|
||||
canvas = gr.ImageEditor(
|
||||
canvas_size=(1024, 1024),
|
||||
sources=None,
|
||||
layers=False,
|
||||
interactive=True,
|
||||
image_mode="RGBA",
|
||||
brush=gr.Brush(
|
||||
default_size=50,
|
||||
default_color="#000000",
|
||||
colors=["#000000"],
|
||||
),
|
||||
label="Entity Mask Painter",
|
||||
key=f"canvas_{painter_layer_id}",
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear], show_progress="hidden")
|
||||
def resize_canvas(height, width, canvas):
|
||||
if canvas is None or canvas["background"] is None:
|
||||
return np.ones((height, width, 3), dtype=np.uint8) * 255
|
||||
h, w = canvas["background"].shape[:2]
|
||||
if h != height or width != w:
|
||||
return np.ones((height, width, 3), dtype=np.uint8) * 255
|
||||
else:
|
||||
return canvas
|
||||
local_prompt_list.append(local_prompt)
|
||||
canvas_list.append(canvas)
|
||||
with gr.Accordion(label="Results"):
|
||||
run_button = gr.Button(value="Generate", variant="primary")
|
||||
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
output_to_painter_button = gr.Button(value="Set as painter's background")
|
||||
with gr.Column():
|
||||
return_with_mask = gr.Checkbox(value=False, interactive=True, label="show result with mask painting")
|
||||
output_to_input_button = gr.Button(value="Set as input image", visible=False, interactive=False)
|
||||
real_output = gr.State(None)
|
||||
mask_out = gr.State(None)
|
||||
|
||||
@gr.on(
|
||||
inputs=[model_type, prompt, negative_prompt, cfg_scale, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + local_prompt_list + canvas_list,
|
||||
outputs=[output_image, real_output, mask_out],
|
||||
triggers=run_button.click
|
||||
)
|
||||
def generate_image(model_type, prompt, negative_prompt, cfg_scale, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()):
|
||||
pipe = load_model(model_type)
|
||||
input_params = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"cfg_scale": cfg_scale,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"progress_bar_cmd": progress.tqdm,
|
||||
}
|
||||
# if input_image is not None:
|
||||
# input_params["input_image"] = input_image.resize((width, height)).convert("RGB")
|
||||
# input_params["enable_eligen_inpaint"] = True
|
||||
|
||||
local_prompt_list, canvas_list = (
|
||||
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
|
||||
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
|
||||
)
|
||||
local_prompts, masks = [], []
|
||||
for local_prompt, canvas in zip(local_prompt_list, canvas_list):
|
||||
if isinstance(local_prompt, str) and len(local_prompt) > 0:
|
||||
local_prompts.append(local_prompt)
|
||||
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
|
||||
entity_prompts = None if len(local_prompts) == 0 else local_prompts
|
||||
entity_masks = None if len(masks) == 0 or entity_prompts is None else masks
|
||||
input_params.update({
|
||||
"eligen_entity_prompts": entity_prompts,
|
||||
"eligen_entity_masks": entity_masks,
|
||||
})
|
||||
torch.manual_seed(seed)
|
||||
save_mask_prompts(masks, local_prompts, prompt, seed, random_mask_dir)
|
||||
image = pipe(**input_params)
|
||||
masks = [mask.resize(image.size) for mask in masks]
|
||||
image_with_mask = visualize_masks(image, masks, local_prompts)
|
||||
|
||||
real_output = gr.State(image)
|
||||
mask_out = gr.State(image_with_mask)
|
||||
|
||||
if return_with_mask:
|
||||
return image_with_mask, real_output, mask_out
|
||||
return image, real_output, mask_out
|
||||
|
||||
@gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click)
|
||||
def send_input_to_painter_background(input_image, *canvas_list):
|
||||
if input_image is None:
|
||||
return tuple(canvas_list)
|
||||
for canvas in canvas_list:
|
||||
h, w = canvas["background"].shape[:2]
|
||||
canvas["background"] = input_image.resize((w, h))
|
||||
return tuple(canvas_list)
|
||||
@gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
|
||||
def send_output_to_painter_background(real_output, *canvas_list):
|
||||
if real_output is None:
|
||||
return tuple(canvas_list)
|
||||
for canvas in canvas_list:
|
||||
h, w = canvas["background"].shape[:2]
|
||||
canvas["background"] = real_output.value.resize((w, h))
|
||||
return tuple(canvas_list)
|
||||
@gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden")
|
||||
def show_output(return_with_mask, real_output, mask_out):
|
||||
if return_with_mask:
|
||||
return mask_out.value
|
||||
else:
|
||||
return real_output.value
|
||||
@gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click)
|
||||
def send_output_to_pipe_input(real_output):
|
||||
return real_output.value
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("## Examples")
|
||||
for i in range(0, len(examples), 2):
|
||||
with gr.Row():
|
||||
if i < len(examples):
|
||||
example = examples[i]
|
||||
with gr.Column():
|
||||
example_image = gr.Image(
|
||||
value=f"data/examples/eligen/qwen-image/example_{example['example_id']}/example_image.png",
|
||||
label=example["description"],
|
||||
interactive=False,
|
||||
width=1024,
|
||||
height=512
|
||||
)
|
||||
load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
|
||||
load_example_button.click(
|
||||
load_example,
|
||||
inputs=[load_example_button],
|
||||
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
|
||||
)
|
||||
|
||||
if i + 1 < len(examples):
|
||||
example = examples[i + 1]
|
||||
with gr.Column():
|
||||
example_image = gr.Image(
|
||||
value=f"data/examples/eligen/qwen-image/example_{example['example_id']}/example_image.png",
|
||||
label=example["description"],
|
||||
interactive=False,
|
||||
width=1024,
|
||||
height=512
|
||||
)
|
||||
load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
|
||||
load_example_button.click(
|
||||
load_example,
|
||||
inputs=[load_example_button],
|
||||
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
|
||||
)
|
||||
app.config["show_progress"] = "hidden"
|
||||
app.launch(share=False)
|
||||
@@ -56,11 +56,14 @@ from ..models.stepvideo_vae import StepVideoVAE
|
||||
from ..models.stepvideo_dit import StepVideoModel
|
||||
|
||||
from ..models.wan_video_dit import WanModel
|
||||
from ..models.wan_video_dit_s2v import WanS2VModel
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
from ..models.wan_video_vae import WanVideoVAE, WanVideoVAE38
|
||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||
from ..models.wan_video_vace import VaceWanModel
|
||||
from ..models.wav2vec import WanS2VAudioEncoder
|
||||
from ..models.wan_video_animate_adapter import WanAnimateAdapter
|
||||
|
||||
from ..models.step1x_connector import Qwen2Connector
|
||||
|
||||
@@ -75,6 +78,7 @@ from ..models.nexus_gen import NexusGenAutoregressiveModel
|
||||
from ..models.qwen_image_dit import QwenImageDiT
|
||||
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
|
||||
from ..models.qwen_image_vae import QwenImageVAE
|
||||
from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet
|
||||
|
||||
model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
@@ -139,7 +143,6 @@ model_loader_configs = [
|
||||
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
@@ -149,9 +152,12 @@ model_loader_configs = [
|
||||
(None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "1f5ab7703c6fc803fdded85ff040c316", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "5b013604280dd715f8457c6ed6d6a626", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "2267d489f0ceb9f21836532952852ee5", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "47dbeab5e560db3180adf51dc0232fb1", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||
(None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
||||
(None, "966cffdcc52f9c46c391768b27637614", ["wan_video_dit"], [WanS2VModel], "civitai"),
|
||||
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
||||
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
||||
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||
@@ -167,6 +173,10 @@ model_loader_configs = [
|
||||
(None, "0319a1cb19835fb510907dd3367c95ff", ["qwen_image_dit"], [QwenImageDiT], "civitai"),
|
||||
(None, "8004730443f55db63092006dd9f7110e", ["qwen_image_text_encoder"], [QwenImageTextEncoder], "diffusers"),
|
||||
(None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"),
|
||||
(None, "073bce9cf969e317e5662cd570c3e79c", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
|
||||
(None, "a9e54e480a628f0b956a688a81c33bab", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
|
||||
(None, "06be60f3a4526586d8431cd038a71486", ["wans2v_audio_encoder"], [WanS2VAudioEncoder], "civitai"),
|
||||
(None, "31fa352acb8a1b1d33cd8764273d80a2", ["wan_video_dit", "wan_video_animate_adapter"], [WanModel, WanAnimateAdapter], "civitai"),
|
||||
]
|
||||
huggingface_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
|
||||
@@ -1 +1 @@
|
||||
from .video import VideoData, save_video, save_frames
|
||||
from .video import VideoData, save_video, save_frames, merge_video_audio, save_video_with_audio
|
||||
|
||||
@@ -2,6 +2,8 @@ import imageio, os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import subprocess
|
||||
import shutil
|
||||
|
||||
|
||||
class LowMemoryVideo:
|
||||
@@ -146,3 +148,70 @@ def save_frames(frames, save_path):
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
for i, frame in enumerate(tqdm(frames, desc="Saving images")):
|
||||
frame.save(os.path.join(save_path, f"{i}.png"))
|
||||
|
||||
|
||||
def merge_video_audio(video_path: str, audio_path: str):
|
||||
# TODO: may need a in-python implementation to avoid subprocess dependency
|
||||
"""
|
||||
Merge the video and audio into a new video, with the duration set to the shorter of the two,
|
||||
and overwrite the original video file.
|
||||
|
||||
Parameters:
|
||||
video_path (str): Path to the original video file
|
||||
audio_path (str): Path to the audio file
|
||||
"""
|
||||
|
||||
# check
|
||||
if not os.path.exists(video_path):
|
||||
raise FileNotFoundError(f"video file {video_path} does not exist")
|
||||
if not os.path.exists(audio_path):
|
||||
raise FileNotFoundError(f"audio file {audio_path} does not exist")
|
||||
|
||||
base, ext = os.path.splitext(video_path)
|
||||
temp_output = f"{base}_temp{ext}"
|
||||
|
||||
try:
|
||||
# create ffmpeg command
|
||||
command = [
|
||||
'ffmpeg',
|
||||
'-y', # overwrite
|
||||
'-i',
|
||||
video_path,
|
||||
'-i',
|
||||
audio_path,
|
||||
'-c:v',
|
||||
'copy', # copy video stream
|
||||
'-c:a',
|
||||
'aac', # use AAC audio encoder
|
||||
'-b:a',
|
||||
'192k', # set audio bitrate (optional)
|
||||
'-map',
|
||||
'0:v:0', # select the first video stream
|
||||
'-map',
|
||||
'1:a:0', # select the first audio stream
|
||||
'-shortest', # choose the shortest duration
|
||||
temp_output
|
||||
]
|
||||
|
||||
# execute the command
|
||||
result = subprocess.run(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
|
||||
# check result
|
||||
if result.returncode != 0:
|
||||
error_msg = f"FFmpeg execute failed: {result.stderr}"
|
||||
print(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
shutil.move(temp_output, video_path)
|
||||
print(f"Merge completed, saved to {video_path}")
|
||||
|
||||
except Exception as e:
|
||||
if os.path.exists(temp_output):
|
||||
os.remove(temp_output)
|
||||
print(f"merge_video_audio failed with error: {e}")
|
||||
|
||||
|
||||
def save_video_with_audio(frames, save_path, audio_path, fps=16, quality=9, ffmpeg_params=None):
|
||||
save_video(frames, save_path, fps, quality, ffmpeg_params)
|
||||
merge_video_audio(save_path, audio_path)
|
||||
|
||||
@@ -2,7 +2,8 @@ from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_e
|
||||
import numpy as np
|
||||
import cupy as cp
|
||||
import cv2
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class PatchMatcher:
|
||||
def __init__(
|
||||
@@ -233,13 +234,11 @@ class PyramidPatchMatcher:
|
||||
|
||||
def resample_image(self, images, level):
|
||||
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
||||
images = images.get()
|
||||
images_resample = []
|
||||
for image in images:
|
||||
image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
|
||||
images_resample.append(image_resample)
|
||||
images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
|
||||
return images_resample
|
||||
images_torch = torch.as_tensor(images, device='cuda', dtype=torch.float32)
|
||||
images_torch = images_torch.permute(0, 3, 1, 2)
|
||||
images_resample = F.interpolate(images_torch, size=(height, width), mode='area', align_corners=None)
|
||||
images_resample = images_resample.permute(0, 2, 3, 1).contiguous()
|
||||
return cp.asarray(images_resample)
|
||||
|
||||
def initialize_nnf(self, batch_size):
|
||||
if self.initialize == "random":
|
||||
@@ -262,14 +261,16 @@ class PyramidPatchMatcher:
|
||||
def update_nnf(self, nnf, level):
|
||||
# upscale
|
||||
nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
|
||||
nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
|
||||
nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
|
||||
nnf[:, 1::2, :, 0] += 1
|
||||
nnf[:, :, 1::2, 1] += 1
|
||||
# check if scale is 2
|
||||
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
||||
if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
|
||||
nnf = nnf.get().astype(np.float32)
|
||||
nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
|
||||
nnf = cp.array(np.stack(nnf), dtype=cp.int32)
|
||||
nnf_torch = torch.as_tensor(nnf, device='cuda', dtype=torch.float32)
|
||||
nnf_torch = nnf_torch.permute(0, 3, 1, 2)
|
||||
nnf_resized = F.interpolate(nnf_torch, size=(height, width), mode='bilinear', align_corners=False)
|
||||
nnf_resized = nnf_resized.permute(0, 2, 3, 1)
|
||||
nnf = cp.asarray(nnf_resized).astype(cp.int32)
|
||||
nnf = self.patch_matchers[level].clamp_bound(nnf)
|
||||
return nnf
|
||||
|
||||
|
||||
@@ -375,8 +375,7 @@ class FluxDiT(torch.nn.Module):
|
||||
return attention_mask
|
||||
|
||||
|
||||
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids):
|
||||
repeat_dim = hidden_states.shape[1]
|
||||
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, repeat_dim):
|
||||
max_masks = 0
|
||||
attention_mask = None
|
||||
prompt_embs = [prompt_emb]
|
||||
|
||||
74
diffsynth/models/qwen_image_controlnet.py
Normal file
74
diffsynth/models/qwen_image_controlnet.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .sd3_dit import RMSNorm
|
||||
from .utils import hash_state_dict_keys
|
||||
|
||||
|
||||
class BlockWiseControlBlock(torch.nn.Module):
|
||||
# [linear, gelu, linear]
|
||||
def __init__(self, dim: int = 3072):
|
||||
super().__init__()
|
||||
self.x_rms = RMSNorm(dim, eps=1e-6)
|
||||
self.y_rms = RMSNorm(dim, eps=1e-6)
|
||||
self.input_proj = nn.Linear(dim, dim)
|
||||
self.act = nn.GELU()
|
||||
self.output_proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x, y):
|
||||
x, y = self.x_rms(x), self.y_rms(y)
|
||||
x = self.input_proj(x + y)
|
||||
x = self.act(x)
|
||||
x = self.output_proj(x)
|
||||
return x
|
||||
|
||||
def init_weights(self):
|
||||
# zero initialize output_proj
|
||||
nn.init.zeros_(self.output_proj.weight)
|
||||
nn.init.zeros_(self.output_proj.bias)
|
||||
|
||||
|
||||
class QwenImageBlockWiseControlNet(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 60,
|
||||
in_dim: int = 64,
|
||||
additional_in_dim: int = 0,
|
||||
dim: int = 3072,
|
||||
):
|
||||
super().__init__()
|
||||
self.img_in = nn.Linear(in_dim + additional_in_dim, dim)
|
||||
self.controlnet_blocks = nn.ModuleList(
|
||||
[
|
||||
BlockWiseControlBlock(dim)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def init_weight(self):
|
||||
nn.init.zeros_(self.img_in.weight)
|
||||
nn.init.zeros_(self.img_in.bias)
|
||||
for block in self.controlnet_blocks:
|
||||
block.init_weights()
|
||||
|
||||
def process_controlnet_conditioning(self, controlnet_conditioning):
|
||||
return self.img_in(controlnet_conditioning)
|
||||
|
||||
def blockwise_forward(self, img, controlnet_conditioning, block_id):
|
||||
return self.controlnet_blocks[block_id](img, controlnet_conditioning)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return QwenImageBlockWiseControlNetStateDictConverter()
|
||||
|
||||
|
||||
class QwenImageBlockWiseControlNetStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
hash_value = hash_state_dict_keys(state_dict)
|
||||
extra_kwargs = {}
|
||||
if hash_value == "a9e54e480a628f0b956a688a81c33bab":
|
||||
# inpaint controlnet
|
||||
extra_kwargs = {"additional_in_dim": 4}
|
||||
return state_dict, extra_kwargs
|
||||
@@ -1,10 +1,44 @@
|
||||
import torch
|
||||
import torch, math
|
||||
import torch.nn as nn
|
||||
from typing import Tuple, Optional, Union, List
|
||||
from einops import rearrange
|
||||
from .sd3_dit import TimestepEmbeddings, RMSNorm
|
||||
from .flux_dit import AdaLayerNorm
|
||||
|
||||
try:
|
||||
import flash_attn_interface
|
||||
FLASH_ATTN_3_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
FLASH_ATTN_3_AVAILABLE = False
|
||||
|
||||
|
||||
def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, attention_mask = None, enable_fp8_attention: bool = False):
|
||||
if FLASH_ATTN_3_AVAILABLE and attention_mask is None:
|
||||
if not enable_fp8_attention:
|
||||
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
|
||||
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
|
||||
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
|
||||
x = flash_attn_interface.flash_attn_func(q, k, v)
|
||||
if isinstance(x, tuple):
|
||||
x = x[0]
|
||||
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
||||
else:
|
||||
origin_dtype = q.dtype
|
||||
q_std, k_std, v_std = q.std(), k.std(), v.std()
|
||||
q, k, v = (q / q_std).to(torch.float8_e4m3fn), (k / k_std).to(torch.float8_e4m3fn), (v / v_std).to(torch.float8_e4m3fn)
|
||||
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
|
||||
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
|
||||
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
|
||||
x = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=q_std * k_std / math.sqrt(q.size(-1)))
|
||||
if isinstance(x, tuple):
|
||||
x = x[0]
|
||||
x = x.to(origin_dtype) * v_std
|
||||
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
||||
else:
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
|
||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||
return x
|
||||
|
||||
|
||||
class ApproximateGELU(nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
||||
@@ -29,8 +63,8 @@ class QwenEmbedRope(nn.Module):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
pos_index = torch.arange(1024)
|
||||
neg_index = torch.arange(1024).flip(0) * -1 - 1
|
||||
pos_index = torch.arange(4096)
|
||||
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
||||
self.pos_freqs = torch.cat([
|
||||
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||
@@ -56,55 +90,139 @@ class QwenEmbedRope(nn.Module):
|
||||
)
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
def forward(self, video_fhw, txt_seq_lens, device):
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
|
||||
|
||||
def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens):
|
||||
if isinstance(video_fhw, list):
|
||||
video_fhw = video_fhw[0]
|
||||
frame, height, width = video_fhw
|
||||
rope_key = f"{frame}_{height}_{width}"
|
||||
|
||||
if rope_key not in self.rope_cache:
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat(
|
||||
[
|
||||
freqs_neg[1][-(height - height//2):],
|
||||
freqs_pos[1][:height//2]
|
||||
],
|
||||
dim=0
|
||||
)
|
||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat(
|
||||
[
|
||||
freqs_neg[2][-(width - width//2):],
|
||||
freqs_pos[2][:width//2]
|
||||
],
|
||||
dim=0
|
||||
)
|
||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
self.rope_cache[rope_key] = freqs.clone().contiguous()
|
||||
vid_freqs = self.rope_cache[rope_key]
|
||||
|
||||
video_fhw = tuple(max([i[j] for i in video_fhw]) for j in range(3))
|
||||
_, height, width = video_fhw
|
||||
if self.scale_rope:
|
||||
max_vid_index = max(height // 2, width // 2)
|
||||
else:
|
||||
max_vid_index = max(height, width)
|
||||
required_len = max_vid_index + max(txt_seq_lens)
|
||||
cur_max_len = self.pos_freqs.shape[0]
|
||||
if required_len <= cur_max_len:
|
||||
return
|
||||
|
||||
new_max_len = math.ceil(required_len / 512) * 512
|
||||
pos_index = torch.arange(new_max_len)
|
||||
neg_index = torch.arange(new_max_len).flip(0) * -1 - 1
|
||||
self.pos_freqs = torch.cat([
|
||||
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
||||
], dim=1)
|
||||
self.neg_freqs = torch.cat([
|
||||
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
||||
], dim=1)
|
||||
return
|
||||
|
||||
|
||||
def forward(self, video_fhw, txt_seq_lens, device):
|
||||
self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
|
||||
vid_freqs = []
|
||||
max_vid_index = 0
|
||||
for idx, fhw in enumerate(video_fhw):
|
||||
frame, height, width = fhw
|
||||
rope_key = f"{idx}_{height}_{width}"
|
||||
|
||||
if rope_key not in self.rope_cache:
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat(
|
||||
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
|
||||
)
|
||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
self.rope_cache[rope_key] = freqs.clone().contiguous()
|
||||
vid_freqs.append(self.rope_cache[rope_key])
|
||||
|
||||
if self.scale_rope:
|
||||
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
||||
else:
|
||||
max_vid_index = max(height, width, max_vid_index)
|
||||
|
||||
max_len = max(txt_seq_lens)
|
||||
txt_freqs = self.pos_freqs[max_vid_index: max_vid_index + max_len, ...]
|
||||
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
|
||||
def forward_sampling(self, video_fhw, txt_seq_lens, device):
|
||||
self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
|
||||
vid_freqs = []
|
||||
max_vid_index = 0
|
||||
for idx, fhw in enumerate(video_fhw):
|
||||
frame, height, width = fhw
|
||||
rope_key = f"{idx}_{height}_{width}"
|
||||
if idx > 0 and f"{0}_{height}_{width}" not in self.rope_cache:
|
||||
frame_0, height_0, width_0 = video_fhw[0]
|
||||
|
||||
rope_key_0 = f"0_{height_0}_{width_0}"
|
||||
spatial_freqs_0 = self.rope_cache[rope_key_0].reshape(frame_0, height_0, width_0, -1)
|
||||
h_indices = torch.linspace(0, height_0 - 1, height).long()
|
||||
w_indices = torch.linspace(0, width_0 - 1, width).long()
|
||||
h_grid, w_grid = torch.meshgrid(h_indices, w_indices, indexing='ij')
|
||||
sampled_rope = spatial_freqs_0[:, h_grid, w_grid, :]
|
||||
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
sampled_rope[:, :, :, :freqs_frame.shape[-1]] = freqs_frame
|
||||
|
||||
seq_lens = frame * height * width
|
||||
self.rope_cache[rope_key] = sampled_rope.reshape(seq_lens, -1).clone()
|
||||
if rope_key not in self.rope_cache:
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat(
|
||||
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
|
||||
)
|
||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
self.rope_cache[rope_key] = freqs.clone()
|
||||
vid_freqs.append(self.rope_cache[rope_key].contiguous())
|
||||
|
||||
if self.scale_rope:
|
||||
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
||||
else:
|
||||
max_vid_index = max(height, width, max_vid_index)
|
||||
|
||||
max_len = max(txt_seq_lens)
|
||||
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
|
||||
@@ -158,7 +276,9 @@ class QwenDoubleStreamAttention(nn.Module):
|
||||
self,
|
||||
image: torch.FloatTensor,
|
||||
text: torch.FloatTensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
enable_fp8_attention: bool = False,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
|
||||
txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
|
||||
@@ -186,9 +306,7 @@ class QwenDoubleStreamAttention(nn.Module):
|
||||
joint_k = torch.cat([txt_k, img_k], dim=2)
|
||||
joint_v = torch.cat([txt_v, img_v], dim=2)
|
||||
|
||||
joint_attn_out = torch.nn.functional.scaled_dot_product_attention(joint_q, joint_k, joint_v)
|
||||
|
||||
joint_attn_out = rearrange(joint_attn_out, 'b h s d -> b s (h d)').to(joint_q.dtype)
|
||||
joint_attn_out = qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], attention_mask=attention_mask, enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype)
|
||||
|
||||
txt_attn_output = joint_attn_out[:, :seq_txt, :]
|
||||
img_attn_output = joint_attn_out[:, seq_txt:, :]
|
||||
@@ -245,6 +363,8 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
text: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
enable_fp8_attention = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
||||
@@ -260,6 +380,8 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
image=img_modulated,
|
||||
text=txt_modulated,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
enable_fp8_attention=enable_fp8_attention,
|
||||
)
|
||||
|
||||
image = image + img_gate * img_attn_out
|
||||
@@ -309,6 +431,74 @@ class QwenImageDiT(torch.nn.Module):
|
||||
self.proj_out = nn.Linear(3072, 64)
|
||||
|
||||
|
||||
def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, entity_masks, height, width, image, img_shapes):
|
||||
# prompt_emb
|
||||
all_prompt_emb = entity_prompt_emb + [prompt_emb]
|
||||
all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb]
|
||||
all_prompt_emb = torch.cat(all_prompt_emb, dim=1)
|
||||
|
||||
# image_rotary_emb
|
||||
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||
entity_seq_lens = [emb_mask.sum(dim=1).tolist() for emb_mask in entity_prompt_emb_mask]
|
||||
entity_rotary_emb = [self.pos_embed(img_shapes, entity_seq_len, device=latents.device)[1] for entity_seq_len in entity_seq_lens]
|
||||
txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)
|
||||
image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)
|
||||
|
||||
# attention_mask
|
||||
repeat_dim = latents.shape[1]
|
||||
max_masks = entity_masks.shape[1]
|
||||
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
||||
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||
global_mask = torch.ones_like(entity_masks[0]).to(device=latents.device, dtype=latents.dtype)
|
||||
entity_masks = entity_masks + [global_mask]
|
||||
|
||||
N = len(entity_masks)
|
||||
batch_size = entity_masks[0].shape[0]
|
||||
seq_lens = [mask_.sum(dim=1).item() for mask_ in entity_prompt_emb_mask] + [prompt_emb_mask.sum(dim=1).item()]
|
||||
total_seq_len = sum(seq_lens) + image.shape[1]
|
||||
patched_masks = []
|
||||
for i in range(N):
|
||||
patched_mask = rearrange(entity_masks[i], "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
patched_masks.append(patched_mask)
|
||||
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
|
||||
|
||||
# prompt-image attention mask
|
||||
image_start = sum(seq_lens)
|
||||
image_end = total_seq_len
|
||||
cumsum = [0]
|
||||
single_image_seq = image_end - image_start
|
||||
for length in seq_lens:
|
||||
cumsum.append(cumsum[-1] + length)
|
||||
for i in range(N):
|
||||
prompt_start = cumsum[i]
|
||||
prompt_end = cumsum[i+1]
|
||||
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
|
||||
image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1)
|
||||
# repeat image mask to match the single image sequence length
|
||||
repeat_time = single_image_seq // image_mask.shape[-1]
|
||||
image_mask = image_mask.repeat(1, 1, repeat_time)
|
||||
# prompt update with image
|
||||
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
|
||||
# image update with prompt
|
||||
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
|
||||
# prompt-prompt attention mask, let the prompt tokens not attend to each other
|
||||
for i in range(N):
|
||||
for j in range(N):
|
||||
if i == j:
|
||||
continue
|
||||
start_i, end_i = cumsum[i], cumsum[i+1]
|
||||
start_j, end_j = cumsum[j], cumsum[j+1]
|
||||
attention_mask[:, start_i:end_i, start_j:end_j] = False
|
||||
|
||||
attention_mask = attention_mask.float()
|
||||
attention_mask[attention_mask == 0] = float('-inf')
|
||||
attention_mask[attention_mask == 1] = 0
|
||||
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1)
|
||||
|
||||
return all_prompt_emb, image_rotary_emb, attention_mask
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents=None,
|
||||
@@ -321,7 +511,7 @@ class QwenImageDiT(torch.nn.Module):
|
||||
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
|
||||
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
||||
|
||||
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (P Q C)", H=height//16, W=width//16, P=2, Q=2)
|
||||
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
image = self.img_in(image)
|
||||
text = self.txt_in(self.txt_norm(prompt_emb))
|
||||
|
||||
@@ -340,7 +530,7 @@ class QwenImageDiT(torch.nn.Module):
|
||||
image = self.norm_out(image, conditioning)
|
||||
image = self.proj_out(image)
|
||||
|
||||
latents = rearrange(image, "B (H W) (P Q C) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
|
||||
670
diffsynth/models/wan_video_animate_adapter.py
Normal file
670
diffsynth/models/wan_video_animate_adapter.py
Normal file
@@ -0,0 +1,670 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import math
|
||||
from typing import Tuple, Optional, List
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
|
||||
MEMORY_LAYOUT = {
|
||||
"flash": (
|
||||
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
|
||||
lambda x: x,
|
||||
),
|
||||
"torch": (
|
||||
lambda x: x.transpose(1, 2),
|
||||
lambda x: x.transpose(1, 2),
|
||||
),
|
||||
"vanilla": (
|
||||
lambda x: x.transpose(1, 2),
|
||||
lambda x: x.transpose(1, 2),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
mode="torch",
|
||||
drop_rate=0,
|
||||
attn_mask=None,
|
||||
causal=False,
|
||||
max_seqlen_q=None,
|
||||
batch_size=1,
|
||||
):
|
||||
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
||||
|
||||
if mode == "torch":
|
||||
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
||||
attn_mask = attn_mask.to(q.dtype)
|
||||
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
|
||||
|
||||
x = post_attn_layout(x)
|
||||
b, s, a, d = x.shape
|
||||
out = x.reshape(b, s, -1)
|
||||
return out
|
||||
|
||||
|
||||
class CausalConv1d(nn.Module):
|
||||
|
||||
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.pad_mode = pad_mode
|
||||
padding = (kernel_size - 1, 0) # T
|
||||
self.time_causal_padding = padding
|
||||
|
||||
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
|
||||
class FaceEncoder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None):
|
||||
factory_kwargs = {"dtype": dtype, "device": device}
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)
|
||||
self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.act = nn.SiLU()
|
||||
self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)
|
||||
self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)
|
||||
|
||||
self.out_proj = nn.Linear(1024, hidden_dim)
|
||||
self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = rearrange(x, "b t c -> b c t")
|
||||
b, c, t = x.shape
|
||||
|
||||
x = self.conv1_local(x)
|
||||
x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
|
||||
|
||||
x = self.norm1(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, "b t c -> b c t")
|
||||
x = self.conv2(x)
|
||||
x = rearrange(x, "b c t -> b t c")
|
||||
x = self.norm2(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, "b t c -> b c t")
|
||||
x = self.conv3(x)
|
||||
x = rearrange(x, "b c t -> b t c")
|
||||
x = self.norm3(x)
|
||||
x = self.act(x)
|
||||
x = self.out_proj(x)
|
||||
x = rearrange(x, "(b n) t c -> b t n c", b=b)
|
||||
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
|
||||
x = torch.cat([x, padding], dim=-2)
|
||||
x_local = x.clone()
|
||||
|
||||
return x_local
|
||||
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
elementwise_affine=True,
|
||||
eps: float = 1e-6,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
Initialize the RMSNorm normalization layer.
|
||||
|
||||
Args:
|
||||
dim (int): The dimension of the input tensor.
|
||||
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||
|
||||
Attributes:
|
||||
eps (float): A small value added to the denominator for numerical stability.
|
||||
weight (nn.Parameter): Learnable scaling parameter.
|
||||
|
||||
"""
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
||||
|
||||
def _norm(self, x):
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the RMSNorm layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tensor after applying RMSNorm.
|
||||
|
||||
"""
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
if hasattr(self, "weight"):
|
||||
output = output * self.weight
|
||||
return output
|
||||
|
||||
|
||||
def get_norm_layer(norm_layer):
|
||||
"""
|
||||
Get the normalization layer.
|
||||
|
||||
Args:
|
||||
norm_layer (str): The type of normalization layer.
|
||||
|
||||
Returns:
|
||||
norm_layer (nn.Module): The normalization layer.
|
||||
"""
|
||||
if norm_layer == "layer":
|
||||
return nn.LayerNorm
|
||||
elif norm_layer == "rms":
|
||||
return RMSNorm
|
||||
else:
|
||||
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
||||
|
||||
|
||||
class FaceAdapter(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int,
|
||||
heads_num: int,
|
||||
qk_norm: bool = True,
|
||||
qk_norm_type: str = "rms",
|
||||
num_adapter_layers: int = 1,
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
|
||||
factory_kwargs = {"dtype": dtype, "device": device}
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_dim
|
||||
self.heads_num = heads_num
|
||||
self.fuser_blocks = nn.ModuleList(
|
||||
[
|
||||
FaceBlock(
|
||||
self.hidden_size,
|
||||
self.heads_num,
|
||||
qk_norm=qk_norm,
|
||||
qk_norm_type=qk_norm_type,
|
||||
**factory_kwargs,
|
||||
)
|
||||
for _ in range(num_adapter_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
motion_embed: torch.Tensor,
|
||||
idx: int,
|
||||
freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
|
||||
freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
|
||||
|
||||
|
||||
|
||||
class FaceBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
heads_num: int,
|
||||
qk_norm: bool = True,
|
||||
qk_norm_type: str = "rms",
|
||||
qk_scale: float = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
|
||||
self.deterministic = False
|
||||
self.hidden_size = hidden_size
|
||||
self.heads_num = heads_num
|
||||
head_dim = hidden_size // heads_num
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
|
||||
self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
|
||||
|
||||
self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
|
||||
|
||||
qk_norm_layer = get_norm_layer(qk_norm_type)
|
||||
self.q_norm = (
|
||||
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
||||
)
|
||||
self.k_norm = (
|
||||
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
||||
)
|
||||
|
||||
self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
motion_vec: torch.Tensor,
|
||||
motion_mask: Optional[torch.Tensor] = None,
|
||||
use_context_parallel=False,
|
||||
) -> torch.Tensor:
|
||||
|
||||
B, T, N, C = motion_vec.shape
|
||||
T_comp = T
|
||||
|
||||
x_motion = self.pre_norm_motion(motion_vec)
|
||||
x_feat = self.pre_norm_feat(x)
|
||||
|
||||
kv = self.linear1_kv(x_motion)
|
||||
q = self.linear1_q(x_feat)
|
||||
|
||||
k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
|
||||
q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
|
||||
|
||||
# Apply QK-Norm if needed.
|
||||
q = self.q_norm(q).to(v)
|
||||
k = self.k_norm(k).to(v)
|
||||
|
||||
k = rearrange(k, "B L N H D -> (B L) H N D")
|
||||
v = rearrange(v, "B L N H D -> (B L) H N D")
|
||||
|
||||
q = rearrange(q, "B (L S) H D -> (B L) H S D", L=T_comp)
|
||||
# Compute attention.
|
||||
attn = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
attn = rearrange(attn, "(B L) H S D -> B (L S) (H D)", L=T_comp)
|
||||
|
||||
output = self.linear2(attn)
|
||||
|
||||
if motion_mask is not None:
|
||||
output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
||||
def custom_qr(input_tensor):
|
||||
original_dtype = input_tensor.dtype
|
||||
if original_dtype == torch.bfloat16:
|
||||
q, r = torch.linalg.qr(input_tensor.to(torch.float32))
|
||||
return q.to(original_dtype), r.to(original_dtype)
|
||||
return torch.linalg.qr(input_tensor)
|
||||
|
||||
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
||||
return F.leaky_relu(input + bias, negative_slope) * scale
|
||||
|
||||
|
||||
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
|
||||
_, minor, in_h, in_w = input.shape
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
|
||||
out = input.view(-1, minor, in_h, 1, in_w, 1)
|
||||
out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
|
||||
out = out.view(-1, minor, in_h * up_y, in_w * up_x)
|
||||
|
||||
out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
||||
out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
|
||||
max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
|
||||
|
||||
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
||||
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
||||
out = F.conv2d(out, w)
|
||||
out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
||||
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
|
||||
return out[:, :, ::down_y, ::down_x]
|
||||
|
||||
|
||||
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
||||
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
|
||||
|
||||
|
||||
def make_kernel(k):
|
||||
k = torch.tensor(k, dtype=torch.float32)
|
||||
if k.ndim == 1:
|
||||
k = k[None, :] * k[:, None]
|
||||
k /= k.sum()
|
||||
return k
|
||||
|
||||
|
||||
class FusedLeakyReLU(nn.Module):
|
||||
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
|
||||
super().__init__()
|
||||
self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
|
||||
self.negative_slope = negative_slope
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, input):
|
||||
out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
||||
return out
|
||||
|
||||
|
||||
class Blur(nn.Module):
|
||||
def __init__(self, kernel, pad, upsample_factor=1):
|
||||
super().__init__()
|
||||
|
||||
kernel = make_kernel(kernel)
|
||||
|
||||
if upsample_factor > 1:
|
||||
kernel = kernel * (upsample_factor ** 2)
|
||||
|
||||
self.register_buffer('kernel', kernel)
|
||||
|
||||
self.pad = pad
|
||||
|
||||
def forward(self, input):
|
||||
return upfirdn2d(input, self.kernel, pad=self.pad)
|
||||
|
||||
|
||||
class ScaledLeakyReLU(nn.Module):
|
||||
def __init__(self, negative_slope=0.2):
|
||||
super().__init__()
|
||||
|
||||
self.negative_slope = negative_slope
|
||||
|
||||
def forward(self, input):
|
||||
return F.leaky_relu(input, negative_slope=self.negative_slope)
|
||||
|
||||
|
||||
class EqualConv2d(nn.Module):
|
||||
def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
|
||||
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
||||
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_channel))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
|
||||
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
||||
)
|
||||
|
||||
|
||||
class EqualLinear(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.activation = activation
|
||||
|
||||
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
||||
self.lr_mul = lr_mul
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
if self.activation:
|
||||
out = F.linear(input, self.weight * self.scale)
|
||||
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
||||
else:
|
||||
out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
|
||||
|
||||
|
||||
class ConvLayer(nn.Sequential):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
downsample=False,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
bias=True,
|
||||
activate=True,
|
||||
):
|
||||
layers = []
|
||||
|
||||
if downsample:
|
||||
factor = 2
|
||||
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
||||
pad0 = (p + 1) // 2
|
||||
pad1 = p // 2
|
||||
|
||||
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
||||
|
||||
stride = 2
|
||||
self.padding = 0
|
||||
|
||||
else:
|
||||
stride = 1
|
||||
self.padding = kernel_size // 2
|
||||
|
||||
layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
|
||||
bias=bias and not activate))
|
||||
|
||||
if activate:
|
||||
if bias:
|
||||
layers.append(FusedLeakyReLU(out_channel))
|
||||
else:
|
||||
layers.append(ScaledLeakyReLU(0.2))
|
||||
|
||||
super().__init__(*layers)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
||||
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
||||
|
||||
self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
|
||||
|
||||
def forward(self, input):
|
||||
out = self.conv1(input)
|
||||
out = self.conv2(out)
|
||||
|
||||
skip = self.skip(input)
|
||||
out = (out + skip) / math.sqrt(2)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class EncoderApp(nn.Module):
|
||||
def __init__(self, size, w_dim=512):
|
||||
super(EncoderApp, self).__init__()
|
||||
|
||||
channels = {
|
||||
4: 512,
|
||||
8: 512,
|
||||
16: 512,
|
||||
32: 512,
|
||||
64: 256,
|
||||
128: 128,
|
||||
256: 64,
|
||||
512: 32,
|
||||
1024: 16
|
||||
}
|
||||
|
||||
self.w_dim = w_dim
|
||||
log_size = int(math.log(size, 2))
|
||||
|
||||
self.convs = nn.ModuleList()
|
||||
self.convs.append(ConvLayer(3, channels[size], 1))
|
||||
|
||||
in_channel = channels[size]
|
||||
for i in range(log_size, 2, -1):
|
||||
out_channel = channels[2 ** (i - 1)]
|
||||
self.convs.append(ResBlock(in_channel, out_channel))
|
||||
in_channel = out_channel
|
||||
|
||||
self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
res = []
|
||||
h = x
|
||||
for conv in self.convs:
|
||||
h = conv(h)
|
||||
res.append(h)
|
||||
|
||||
return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:]
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, size, dim=512, dim_motion=20):
|
||||
super(Encoder, self).__init__()
|
||||
|
||||
# appearance netmork
|
||||
self.net_app = EncoderApp(size, dim)
|
||||
|
||||
# motion network
|
||||
fc = [EqualLinear(dim, dim)]
|
||||
for i in range(3):
|
||||
fc.append(EqualLinear(dim, dim))
|
||||
|
||||
fc.append(EqualLinear(dim, dim_motion))
|
||||
self.fc = nn.Sequential(*fc)
|
||||
|
||||
def enc_app(self, x):
|
||||
h_source = self.net_app(x)
|
||||
return h_source
|
||||
|
||||
def enc_motion(self, x):
|
||||
h, _ = self.net_app(x)
|
||||
h_motion = self.fc(h)
|
||||
return h_motion
|
||||
|
||||
|
||||
class Direction(nn.Module):
|
||||
def __init__(self, motion_dim):
|
||||
super(Direction, self).__init__()
|
||||
self.weight = nn.Parameter(torch.randn(512, motion_dim))
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
weight = self.weight + 1e-8
|
||||
Q, R = custom_qr(weight)
|
||||
if input is None:
|
||||
return Q
|
||||
else:
|
||||
input_diag = torch.diag_embed(input) # alpha, diagonal matrix
|
||||
out = torch.matmul(input_diag, Q.T)
|
||||
out = torch.sum(out, dim=1)
|
||||
return out
|
||||
|
||||
|
||||
class Synthesis(nn.Module):
|
||||
def __init__(self, motion_dim):
|
||||
super(Synthesis, self).__init__()
|
||||
self.direction = Direction(motion_dim)
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, size, style_dim=512, motion_dim=20):
|
||||
super().__init__()
|
||||
|
||||
self.enc = Encoder(size, style_dim, motion_dim)
|
||||
self.dec = Synthesis(motion_dim)
|
||||
|
||||
def get_motion(self, img):
|
||||
#motion_feat = self.enc.enc_motion(img)
|
||||
motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
|
||||
motion = self.dec.direction(motion_feat)
|
||||
return motion
|
||||
|
||||
|
||||
class WanAnimateAdapter(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pose_patch_embedding = torch.nn.Conv3d(16, 5120, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
||||
self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20)
|
||||
self.face_adapter = FaceAdapter(heads_num=40, hidden_dim=5120, num_adapter_layers=40 // 5)
|
||||
self.face_encoder = FaceEncoder(in_dim=512, hidden_dim=5120, num_heads=4)
|
||||
|
||||
def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):
|
||||
pose_latents = self.pose_patch_embedding(pose_latents)
|
||||
x[:, :, 1:] += pose_latents
|
||||
|
||||
b,c,T,h,w = face_pixel_values.shape
|
||||
face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
|
||||
|
||||
encode_bs = 8
|
||||
face_pixel_values_tmp = []
|
||||
for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):
|
||||
face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))
|
||||
|
||||
motion_vec = torch.cat(face_pixel_values_tmp)
|
||||
|
||||
motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
|
||||
motion_vec = self.face_encoder(motion_vec)
|
||||
|
||||
B, L, H, C = motion_vec.shape
|
||||
pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
|
||||
motion_vec = torch.cat([pad_face, motion_vec], dim=1)
|
||||
return x, motion_vec
|
||||
|
||||
def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None):
|
||||
if block_idx % 5 == 0:
|
||||
adapter_args = [x, motion_vec, motion_masks, False]
|
||||
residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args)
|
||||
x = residual_out + x
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanAnimateAdapterStateDictConverter()
|
||||
|
||||
|
||||
class WanAnimateAdapterStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith("pose_patch_embedding.") or name.startswith("face_adapter") or name.startswith("face_encoder") or name.startswith("motion_encoder"):
|
||||
state_dict_[name] = param
|
||||
return state_dict_
|
||||
|
||||
@@ -182,7 +182,7 @@ def process_pose_file(cam_params, width=672, height=384, original_pose_width=128
|
||||
|
||||
|
||||
def generate_camera_coordinates(
|
||||
direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"],
|
||||
direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown", "In", "Out"],
|
||||
length: int,
|
||||
speed: float = 1/54,
|
||||
origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
|
||||
@@ -198,5 +198,9 @@ def generate_camera_coordinates(
|
||||
coor[13] += speed
|
||||
if "Down" in direction:
|
||||
coor[13] -= speed
|
||||
if "In" in direction:
|
||||
coor[18] -= speed
|
||||
if "Out" in direction:
|
||||
coor[18] += speed
|
||||
coordinates.append(coor)
|
||||
return coordinates
|
||||
|
||||
@@ -294,6 +294,7 @@ class WanModel(torch.nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.in_dim = in_dim
|
||||
self.freq_dim = freq_dim
|
||||
self.has_image_input = has_image_input
|
||||
self.patch_size = patch_size
|
||||
@@ -341,9 +342,7 @@ class WanModel(torch.nn.Module):
|
||||
y_camera = self.control_adapter(control_camera_latents_input)
|
||||
x = [u + v for u, v in zip(x, y_camera)]
|
||||
x = x[0].unsqueeze(0)
|
||||
grid_size = x.shape[2:]
|
||||
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
||||
return x, grid_size # x, grid_size: (f, h, w)
|
||||
return x
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
|
||||
return rearrange(
|
||||
@@ -495,6 +494,7 @@ class WanModelStateDictConverter:
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
|
||||
state_dict = {name: param for name, param in state_dict.items() if name.split(".")[0] not in ["pose_patch_embedding", "face_adapter", "face_encoder", "motion_encoder"]}
|
||||
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
@@ -551,20 +551,6 @@ class WanModelStateDictConverter:
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
|
||||
# 1.3B PAI control
|
||||
config = {
|
||||
@@ -713,6 +699,42 @@ class WanModelStateDictConverter:
|
||||
"eps": 1e-6,
|
||||
"require_clip_embedding": False,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "2267d489f0ceb9f21836532952852ee5":
|
||||
# Wan2.2-Fun-A14B-Control
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 52,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": True,
|
||||
"require_clip_embedding": False,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "47dbeab5e560db3180adf51dc0232fb1":
|
||||
# Wan2.2-Fun-A14B-Control-Camera
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": False,
|
||||
"add_control_adapter": True,
|
||||
"in_dim_control_adapter": 24,
|
||||
"require_clip_embedding": False,
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict, config
|
||||
|
||||
625
diffsynth/models/wan_video_dit_s2v.py
Normal file
625
diffsynth/models/wan_video_dit_s2v.py
Normal file
@@ -0,0 +1,625 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
from .utils import hash_state_dict_keys
|
||||
from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d
|
||||
|
||||
|
||||
def torch_dfs(model: nn.Module, parent_name='root'):
|
||||
module_names, modules = [], []
|
||||
current_name = parent_name if parent_name else 'root'
|
||||
module_names.append(current_name)
|
||||
modules.append(model)
|
||||
|
||||
for name, child in model.named_children():
|
||||
if parent_name:
|
||||
child_name = f'{parent_name}.{name}'
|
||||
else:
|
||||
child_name = name
|
||||
child_modules, child_names = torch_dfs(child, child_name)
|
||||
module_names += child_names
|
||||
modules += child_modules
|
||||
return modules, module_names
|
||||
|
||||
|
||||
def rope_precompute(x, grid_sizes, freqs, start=None):
|
||||
b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
|
||||
|
||||
# split freqs
|
||||
if type(freqs) is list:
|
||||
trainable_freqs = freqs[1]
|
||||
freqs = freqs[0]
|
||||
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
||||
|
||||
# loop over samples
|
||||
output = torch.view_as_complex(x.detach().reshape(b, s, n, -1, 2).to(torch.float64))
|
||||
seq_bucket = [0]
|
||||
if not type(grid_sizes) is list:
|
||||
grid_sizes = [grid_sizes]
|
||||
for g in grid_sizes:
|
||||
if not type(g) is list:
|
||||
g = [torch.zeros_like(g), g]
|
||||
batch_size = g[0].shape[0]
|
||||
for i in range(batch_size):
|
||||
if start is None:
|
||||
f_o, h_o, w_o = g[0][i]
|
||||
else:
|
||||
f_o, h_o, w_o = start[i]
|
||||
|
||||
f, h, w = g[1][i]
|
||||
t_f, t_h, t_w = g[2][i]
|
||||
seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o
|
||||
seq_len = int(seq_f * seq_h * seq_w)
|
||||
if seq_len > 0:
|
||||
if t_f > 0:
|
||||
factor_f, factor_h, factor_w = (t_f / seq_f).item(), (t_h / seq_h).item(), (t_w / seq_w).item()
|
||||
# Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item())
|
||||
if f_o >= 0:
|
||||
f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1, seq_f).astype(int).tolist()
|
||||
else:
|
||||
f_sam = np.linspace(-f_o.item(), (-t_f - f_o).item() + 1, seq_f).astype(int).tolist()
|
||||
h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1, seq_h).astype(int).tolist()
|
||||
w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1, seq_w).astype(int).tolist()
|
||||
|
||||
assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0
|
||||
freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][f_sam].conj()
|
||||
freqs_0 = freqs_0.view(seq_f, 1, 1, -1)
|
||||
|
||||
freqs_i = torch.cat(
|
||||
[
|
||||
freqs_0.expand(seq_f, seq_h, seq_w, -1),
|
||||
freqs[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1),
|
||||
freqs[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1),
|
||||
],
|
||||
dim=-1
|
||||
).reshape(seq_len, 1, -1)
|
||||
elif t_f < 0:
|
||||
freqs_i = trainable_freqs.unsqueeze(1)
|
||||
# apply rotary embedding
|
||||
output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = freqs_i
|
||||
seq_bucket.append(seq_bucket[-1] + seq_len)
|
||||
return output
|
||||
|
||||
|
||||
class CausalConv1d(nn.Module):
|
||||
|
||||
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode='replicate', **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.pad_mode = pad_mode
|
||||
padding = (kernel_size - 1, 0) # T
|
||||
self.time_causal_padding = padding
|
||||
|
||||
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MotionEncoder_tc(nn.Module):
|
||||
|
||||
def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, need_global=True, dtype=None, device=None):
|
||||
factory_kwargs = {"dtype": dtype, "device": device}
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.need_global = need_global
|
||||
self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_heads, 3, stride=1)
|
||||
if need_global:
|
||||
self.conv1_global = CausalConv1d(in_dim, hidden_dim // 4, 3, stride=1)
|
||||
self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.act = nn.SiLU()
|
||||
self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2)
|
||||
self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2)
|
||||
|
||||
if need_global:
|
||||
self.final_linear = nn.Linear(hidden_dim, hidden_dim, **factory_kwargs)
|
||||
|
||||
self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.norm2 = nn.LayerNorm(hidden_dim // 2, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
|
||||
|
||||
def forward(self, x):
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x_ori = x.clone()
|
||||
b, c, t = x.shape
|
||||
x = self.conv1_local(x)
|
||||
x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads)
|
||||
x = self.norm1(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x = self.conv2(x)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm2(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x = self.conv3(x)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm3(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, '(b n) t c -> b t n c', b=b)
|
||||
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1).to(device=x.device, dtype=x.dtype)
|
||||
x = torch.cat([x, padding], dim=-2)
|
||||
x_local = x.clone()
|
||||
|
||||
if not self.need_global:
|
||||
return x_local
|
||||
|
||||
x = self.conv1_global(x_ori)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm1(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x = self.conv2(x)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm2(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x = self.conv3(x)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm3(x)
|
||||
x = self.act(x)
|
||||
x = self.final_linear(x)
|
||||
x = rearrange(x, '(b n) t c -> b t n c', b=b)
|
||||
|
||||
return x, x_local
|
||||
|
||||
|
||||
class FramePackMotioner(nn.Module):
|
||||
|
||||
def __init__(self, inner_dim=1024, num_heads=16, zip_frame_buckets=[1, 2, 16], drop_mode="drop", *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
||||
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
|
||||
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
|
||||
self.zip_frame_buckets = torch.tensor(zip_frame_buckets, dtype=torch.long)
|
||||
|
||||
self.inner_dim = inner_dim
|
||||
self.num_heads = num_heads
|
||||
self.freqs = torch.cat(precompute_freqs_cis_3d(inner_dim // num_heads), dim=1)
|
||||
self.drop_mode = drop_mode
|
||||
|
||||
def forward(self, motion_latents, add_last_motion=2):
|
||||
motion_frames = motion_latents[0].shape[1]
|
||||
mot = []
|
||||
mot_remb = []
|
||||
for m in motion_latents:
|
||||
lat_height, lat_width = m.shape[2], m.shape[3]
|
||||
padd_lat = torch.zeros(16, self.zip_frame_buckets.sum(), lat_height, lat_width).to(device=m.device, dtype=m.dtype)
|
||||
overlap_frame = min(padd_lat.shape[1], m.shape[1])
|
||||
if overlap_frame > 0:
|
||||
padd_lat[:, -overlap_frame:] = m[:, -overlap_frame:]
|
||||
|
||||
if add_last_motion < 2 and self.drop_mode != "drop":
|
||||
zero_end_frame = self.zip_frame_buckets[:self.zip_frame_buckets.__len__() - add_last_motion - 1].sum()
|
||||
padd_lat[:, -zero_end_frame:] = 0
|
||||
|
||||
padd_lat = padd_lat.unsqueeze(0)
|
||||
clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -self.zip_frame_buckets.sum():, :, :].split(
|
||||
list(self.zip_frame_buckets)[::-1], dim=2
|
||||
) # 16, 2 ,1
|
||||
|
||||
# patchfy
|
||||
clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2)
|
||||
clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(2).transpose(1, 2)
|
||||
clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(2).transpose(1, 2)
|
||||
|
||||
if add_last_motion < 2 and self.drop_mode == "drop":
|
||||
clean_latents_post = clean_latents_post[:, :0] if add_last_motion < 2 else clean_latents_post
|
||||
clean_latents_2x = clean_latents_2x[:, :0] if add_last_motion < 1 else clean_latents_2x
|
||||
|
||||
motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1)
|
||||
|
||||
# rope
|
||||
start_time_id = -(self.zip_frame_buckets[:1].sum())
|
||||
end_time_id = start_time_id + self.zip_frame_buckets[0]
|
||||
grid_sizes = [] if add_last_motion < 2 and self.drop_mode == "drop" else \
|
||||
[
|
||||
[torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
|
||||
torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1),
|
||||
torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]
|
||||
]
|
||||
|
||||
start_time_id = -(self.zip_frame_buckets[:2].sum())
|
||||
end_time_id = start_time_id + self.zip_frame_buckets[1] // 2
|
||||
grid_sizes_2x = [] if add_last_motion < 1 and self.drop_mode == "drop" else \
|
||||
[
|
||||
[torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
|
||||
torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0).repeat(1, 1),
|
||||
torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]
|
||||
]
|
||||
|
||||
start_time_id = -(self.zip_frame_buckets[:3].sum())
|
||||
end_time_id = start_time_id + self.zip_frame_buckets[2] // 4
|
||||
grid_sizes_4x = [
|
||||
[
|
||||
torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
|
||||
torch.tensor([end_time_id, lat_height // 8, lat_width // 8]).unsqueeze(0).repeat(1, 1),
|
||||
torch.tensor([self.zip_frame_buckets[2], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1),
|
||||
]
|
||||
]
|
||||
|
||||
grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x
|
||||
|
||||
motion_rope_emb = rope_precompute(
|
||||
motion_lat.detach().view(1, motion_lat.shape[1], self.num_heads, self.inner_dim // self.num_heads),
|
||||
grid_sizes,
|
||||
self.freqs,
|
||||
start=None
|
||||
)
|
||||
|
||||
mot.append(motion_lat)
|
||||
mot_remb.append(motion_rope_emb)
|
||||
return mot, mot_remb
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
output_dim: int,
|
||||
norm_eps: float = 1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, elementwise_affine=False)
|
||||
|
||||
def forward(self, x, temb):
|
||||
temb = self.linear(F.silu(temb))
|
||||
shift, scale = temb.chunk(2, dim=1)
|
||||
shift = shift[:, None, :]
|
||||
scale = scale[:, None, :]
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
|
||||
|
||||
class AudioInjector_WAN(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
all_modules,
|
||||
all_modules_names,
|
||||
dim=2048,
|
||||
num_heads=32,
|
||||
inject_layer=[0, 27],
|
||||
enable_adain=False,
|
||||
adain_dim=2048,
|
||||
):
|
||||
super().__init__()
|
||||
self.injected_block_id = {}
|
||||
audio_injector_id = 0
|
||||
for mod_name, mod in zip(all_modules_names, all_modules):
|
||||
if isinstance(mod, DiTBlock):
|
||||
for inject_id in inject_layer:
|
||||
if f'transformer_blocks.{inject_id}' in mod_name:
|
||||
self.injected_block_id[inject_id] = audio_injector_id
|
||||
audio_injector_id += 1
|
||||
|
||||
self.injector = nn.ModuleList([CrossAttention(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
) for _ in range(audio_injector_id)])
|
||||
self.injector_pre_norm_feat = nn.ModuleList([nn.LayerNorm(
|
||||
dim,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
) for _ in range(audio_injector_id)])
|
||||
self.injector_pre_norm_vec = nn.ModuleList([nn.LayerNorm(
|
||||
dim,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
) for _ in range(audio_injector_id)])
|
||||
if enable_adain:
|
||||
self.injector_adain_layers = nn.ModuleList([AdaLayerNorm(output_dim=dim * 2, embedding_dim=adain_dim) for _ in range(audio_injector_id)])
|
||||
|
||||
|
||||
class CausalAudioEncoder(nn.Module):
|
||||
|
||||
def __init__(self, dim=5120, num_layers=25, out_dim=2048, num_token=4, need_global=False):
|
||||
super().__init__()
|
||||
self.encoder = MotionEncoder_tc(in_dim=dim, hidden_dim=out_dim, num_heads=num_token, need_global=need_global)
|
||||
weight = torch.ones((1, num_layers, 1, 1)) * 0.01
|
||||
|
||||
self.weights = torch.nn.Parameter(weight)
|
||||
self.act = torch.nn.SiLU()
|
||||
|
||||
def forward(self, features):
|
||||
# features B * num_layers * dim * video_length
|
||||
weights = self.act(self.weights.to(device=features.device, dtype=features.dtype))
|
||||
weights_sum = weights.sum(dim=1, keepdims=True)
|
||||
weighted_feat = ((features * weights) / weights_sum).sum(dim=1) # b dim f
|
||||
weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim
|
||||
res = self.encoder(weighted_feat) # b f n dim
|
||||
return res # b f n dim
|
||||
|
||||
|
||||
class WanS2VDiTBlock(DiTBlock):
|
||||
|
||||
def forward(self, x, context, t_mod, seq_len_x, freqs):
|
||||
t_mod = (self.modulation.unsqueeze(2).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
||||
# t_mod[:, :, 0] for x, t_mod[:, :, 1] for other like ref, motion, etc.
|
||||
t_mod = [
|
||||
torch.cat([element[:, :, 0].expand(1, seq_len_x, x.shape[-1]), element[:, :, 1].expand(1, x.shape[1] - seq_len_x, x.shape[-1])], dim=1)
|
||||
for element in t_mod
|
||||
]
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = t_mod
|
||||
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
||||
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
|
||||
x = x + self.cross_attn(self.norm3(x), context)
|
||||
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
x = self.gate(x, gate_mlp, self.ffn(input_x))
|
||||
return x
|
||||
|
||||
|
||||
class WanS2VModel(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
in_dim: int,
|
||||
ffn_dim: int,
|
||||
out_dim: int,
|
||||
text_dim: int,
|
||||
freq_dim: int,
|
||||
eps: float,
|
||||
patch_size: Tuple[int, int, int],
|
||||
num_heads: int,
|
||||
num_layers: int,
|
||||
cond_dim: int,
|
||||
audio_dim: int,
|
||||
num_audio_token: int,
|
||||
enable_adain: bool = True,
|
||||
audio_inject_layers: list = [0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39],
|
||||
zero_timestep: bool = True,
|
||||
add_last_motion: bool = True,
|
||||
framepack_drop_mode: str = "padd",
|
||||
fuse_vae_embedding_in_latents: bool = True,
|
||||
require_vae_embedding: bool = False,
|
||||
seperated_timestep: bool = False,
|
||||
require_clip_embedding: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.in_dim = in_dim
|
||||
self.freq_dim = freq_dim
|
||||
self.patch_size = patch_size
|
||||
self.num_heads = num_heads
|
||||
self.enbale_adain = enable_adain
|
||||
self.add_last_motion = add_last_motion
|
||||
self.zero_timestep = zero_timestep
|
||||
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
|
||||
self.require_vae_embedding = require_vae_embedding
|
||||
self.seperated_timestep = seperated_timestep
|
||||
self.require_clip_embedding = require_clip_embedding
|
||||
|
||||
self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), nn.Linear(dim, dim))
|
||||
self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
||||
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
||||
|
||||
self.blocks = nn.ModuleList([WanS2VDiTBlock(False, dim, num_heads, ffn_dim, eps) for _ in range(num_layers)])
|
||||
self.head = Head(dim, out_dim, patch_size, eps)
|
||||
self.freqs = torch.cat(precompute_freqs_cis_3d(dim // num_heads), dim=1)
|
||||
|
||||
self.cond_encoder = nn.Conv3d(cond_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.casual_audio_encoder = CausalAudioEncoder(dim=audio_dim, out_dim=dim, num_token=num_audio_token, need_global=enable_adain)
|
||||
all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks")
|
||||
self.audio_injector = AudioInjector_WAN(
|
||||
all_modules,
|
||||
all_modules_names,
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
inject_layer=audio_inject_layers,
|
||||
enable_adain=enable_adain,
|
||||
adain_dim=dim,
|
||||
)
|
||||
self.trainable_cond_mask = nn.Embedding(3, dim)
|
||||
self.frame_packer = FramePackMotioner(inner_dim=dim, num_heads=num_heads, zip_frame_buckets=[1, 2, 16], drop_mode=framepack_drop_mode)
|
||||
|
||||
def patchify(self, x: torch.Tensor):
|
||||
grid_size = x.shape[2:]
|
||||
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
||||
return x, grid_size # x, grid_size: (f, h, w)
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
|
||||
return rearrange(
|
||||
x,
|
||||
'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
|
||||
f=grid_size[0],
|
||||
h=grid_size[1],
|
||||
w=grid_size[2],
|
||||
x=self.patch_size[0],
|
||||
y=self.patch_size[1],
|
||||
z=self.patch_size[2]
|
||||
)
|
||||
|
||||
def process_motion_frame_pack(self, motion_latents, drop_motion_frames=False, add_last_motion=2):
|
||||
flattern_mot, mot_remb = self.frame_packer(motion_latents, add_last_motion)
|
||||
if drop_motion_frames:
|
||||
return [m[:, :0] for m in flattern_mot], [m[:, :0] for m in mot_remb]
|
||||
else:
|
||||
return flattern_mot, mot_remb
|
||||
|
||||
def inject_motion(self, x, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2):
|
||||
# inject the motion frames token to the hidden states
|
||||
mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=add_last_motion)
|
||||
if len(mot) > 0:
|
||||
x = torch.cat([x, mot[0]], dim=1)
|
||||
rope_embs = torch.cat([rope_embs, mot_remb[0]], dim=1)
|
||||
mask_input = torch.cat(
|
||||
[mask_input, 2 * torch.ones([1, x.shape[1] - mask_input.shape[1]], device=mask_input.device, dtype=mask_input.dtype)], dim=1
|
||||
)
|
||||
return x, rope_embs, mask_input
|
||||
|
||||
def after_transformer_block(self, block_idx, hidden_states, audio_emb_global, audio_emb, original_seq_len, use_unified_sequence_parallel=False):
|
||||
if block_idx in self.audio_injector.injected_block_id.keys():
|
||||
audio_attn_id = self.audio_injector.injected_block_id[block_idx]
|
||||
num_frames = audio_emb.shape[1]
|
||||
if use_unified_sequence_parallel:
|
||||
from xfuser.core.distributed import get_sp_group
|
||||
hidden_states = get_sp_group().all_gather(hidden_states, dim=1)
|
||||
|
||||
input_hidden_states = hidden_states[:, :original_seq_len].clone() # b (f h w) c
|
||||
input_hidden_states = rearrange(input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames)
|
||||
|
||||
audio_emb_global = rearrange(audio_emb_global, "b t n c -> (b t) n c")
|
||||
adain_hidden_states = self.audio_injector.injector_adain_layers[audio_attn_id](input_hidden_states, temb=audio_emb_global[:, 0])
|
||||
attn_hidden_states = adain_hidden_states
|
||||
|
||||
audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
|
||||
attn_audio_emb = audio_emb
|
||||
residual_out = self.audio_injector.injector[audio_attn_id](attn_hidden_states, attn_audio_emb)
|
||||
residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames)
|
||||
hidden_states[:, :original_seq_len] = hidden_states[:, :original_seq_len] + residual_out
|
||||
if use_unified_sequence_parallel:
|
||||
from xfuser.core.distributed import get_sequence_parallel_world_size, get_sequence_parallel_rank
|
||||
hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||
return hidden_states
|
||||
|
||||
def cal_audio_emb(self, audio_input, motion_frames=[73, 19]):
|
||||
audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1)
|
||||
audio_emb_global, audio_emb = self.casual_audio_encoder(audio_input)
|
||||
audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone()
|
||||
merged_audio_emb = audio_emb[:, motion_frames[1]:, :]
|
||||
return audio_emb_global, merged_audio_emb
|
||||
|
||||
def get_grid_sizes(self, grid_size_x, grid_size_ref):
|
||||
f, h, w = grid_size_x
|
||||
rf, rh, rw = grid_size_ref
|
||||
grid_sizes_x = torch.tensor([f, h, w], dtype=torch.long).unsqueeze(0)
|
||||
grid_sizes_x = [[torch.zeros_like(grid_sizes_x), grid_sizes_x, grid_sizes_x]]
|
||||
grid_sizes_ref = [[
|
||||
torch.tensor([30, 0, 0]).unsqueeze(0),
|
||||
torch.tensor([31, rh, rw]).unsqueeze(0),
|
||||
torch.tensor([1, rh, rw]).unsqueeze(0),
|
||||
]]
|
||||
return grid_sizes_x + grid_sizes_ref
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents,
|
||||
timestep,
|
||||
context,
|
||||
audio_input,
|
||||
motion_latents,
|
||||
pose_cond,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
use_gradient_checkpointing=False
|
||||
):
|
||||
origin_ref_latents = latents[:, :, 0:1]
|
||||
x = latents[:, :, 1:]
|
||||
|
||||
# context embedding
|
||||
context = self.text_embedding(context)
|
||||
|
||||
# audio encode
|
||||
audio_emb_global, merged_audio_emb = self.cal_audio_emb(audio_input)
|
||||
|
||||
# x and pose_cond
|
||||
pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond
|
||||
x, (f, h, w) = self.patchify(self.patch_embedding(x) + self.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120])
|
||||
seq_len_x = x.shape[1]
|
||||
|
||||
# reference image
|
||||
ref_latents, (rf, rh, rw) = self.patchify(self.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120])
|
||||
grid_sizes = self.get_grid_sizes((f, h, w), (rf, rh, rw))
|
||||
x = torch.cat([x, ref_latents], dim=1)
|
||||
# mask
|
||||
mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device)
|
||||
# freqs
|
||||
pre_compute_freqs = rope_precompute(
|
||||
x.detach().view(1, x.size(1), self.num_heads, self.dim // self.num_heads), grid_sizes, self.freqs, start=None
|
||||
)
|
||||
# motion
|
||||
x, pre_compute_freqs, mask = self.inject_motion(x, pre_compute_freqs, mask, motion_latents, add_last_motion=2)
|
||||
|
||||
x = x + self.trainable_cond_mask(mask).to(x.dtype)
|
||||
|
||||
# t_mod
|
||||
timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)])
|
||||
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||
t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block_id, block in enumerate(self.blocks):
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x,
|
||||
context,
|
||||
t_mod,
|
||||
seq_len_x,
|
||||
pre_compute_freqs[0],
|
||||
use_reentrant=False,
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
||||
x,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x,
|
||||
context,
|
||||
t_mod,
|
||||
seq_len_x,
|
||||
pre_compute_freqs[0],
|
||||
use_reentrant=False,
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
||||
x,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
|
||||
x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)
|
||||
|
||||
x = x[:, :seq_len_x]
|
||||
x = self.head(x, t[:-1])
|
||||
x = self.unpatchify(x, (f, h, w))
|
||||
# make compatible with wan video
|
||||
x = torch.cat([origin_ref_latents, x], dim=2)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanS2VModelStateDictConverter()
|
||||
|
||||
|
||||
class WanS2VModelStateDictConverter:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
config = {}
|
||||
if hash_state_dict_keys(state_dict) == "966cffdcc52f9c46c391768b27637614":
|
||||
config = {
|
||||
"dim": 5120,
|
||||
"in_dim": 16,
|
||||
"ffn_dim": 13824,
|
||||
"out_dim": 16,
|
||||
"text_dim": 4096,
|
||||
"freq_dim": 256,
|
||||
"eps": 1e-06,
|
||||
"patch_size": (1, 2, 2),
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"cond_dim": 16,
|
||||
"audio_dim": 1024,
|
||||
"num_audio_token": 4,
|
||||
}
|
||||
return state_dict, config
|
||||
@@ -1216,7 +1216,6 @@ class WanVideoVAE(nn.Module):
|
||||
|
||||
|
||||
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
|
||||
videos = [video.to("cpu") for video in videos]
|
||||
hidden_states = []
|
||||
for video in videos:
|
||||
@@ -1234,11 +1233,18 @@ class WanVideoVAE(nn.Module):
|
||||
|
||||
|
||||
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
if tiled:
|
||||
video = self.tiled_decode(hidden_states, device, tile_size, tile_stride)
|
||||
else:
|
||||
video = self.single_decode(hidden_states, device)
|
||||
return video
|
||||
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
|
||||
videos = []
|
||||
for hidden_state in hidden_states:
|
||||
hidden_state = hidden_state.unsqueeze(0)
|
||||
if tiled:
|
||||
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
|
||||
else:
|
||||
video = self.single_decode(hidden_state, device)
|
||||
video = video.squeeze(0)
|
||||
videos.append(video)
|
||||
videos = torch.stack(videos)
|
||||
return videos
|
||||
|
||||
|
||||
@staticmethod
|
||||
|
||||
204
diffsynth/models/wav2vec.py
Normal file
204
diffsynth/models/wav2vec.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def get_sample_indices(original_fps, total_frames, target_fps, num_sample, fixed_start=None):
|
||||
required_duration = num_sample / target_fps
|
||||
required_origin_frames = int(np.ceil(required_duration * original_fps))
|
||||
if required_duration > total_frames / original_fps:
|
||||
raise ValueError("required_duration must be less than video length")
|
||||
|
||||
if not fixed_start is None and fixed_start >= 0:
|
||||
start_frame = fixed_start
|
||||
else:
|
||||
max_start = total_frames - required_origin_frames
|
||||
if max_start < 0:
|
||||
raise ValueError("video length is too short")
|
||||
start_frame = np.random.randint(0, max_start + 1)
|
||||
start_time = start_frame / original_fps
|
||||
|
||||
end_time = start_time + required_duration
|
||||
time_points = np.linspace(start_time, end_time, num_sample, endpoint=False)
|
||||
|
||||
frame_indices = np.round(np.array(time_points) * original_fps).astype(int)
|
||||
frame_indices = np.clip(frame_indices, 0, total_frames - 1)
|
||||
return frame_indices
|
||||
|
||||
|
||||
def linear_interpolation(features, input_fps, output_fps, output_len=None):
|
||||
"""
|
||||
features: shape=[1, T, 512]
|
||||
input_fps: fps for audio, f_a
|
||||
output_fps: fps for video, f_m
|
||||
output_len: video length
|
||||
"""
|
||||
features = features.transpose(1, 2)
|
||||
seq_len = features.shape[2] / float(input_fps)
|
||||
if output_len is None:
|
||||
output_len = int(seq_len * output_fps)
|
||||
output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear') # [1, 512, output_len]
|
||||
return output_features.transpose(1, 2)
|
||||
|
||||
|
||||
class WanS2VAudioEncoder(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2Config
|
||||
config = {
|
||||
"_name_or_path": "facebook/wav2vec2-large-xlsr-53",
|
||||
"activation_dropout": 0.05,
|
||||
"apply_spec_augment": True,
|
||||
"architectures": ["Wav2Vec2ForCTC"],
|
||||
"attention_dropout": 0.1,
|
||||
"bos_token_id": 1,
|
||||
"conv_bias": True,
|
||||
"conv_dim": [512, 512, 512, 512, 512, 512, 512],
|
||||
"conv_kernel": [10, 3, 3, 3, 3, 2, 2],
|
||||
"conv_stride": [5, 2, 2, 2, 2, 2, 2],
|
||||
"ctc_loss_reduction": "mean",
|
||||
"ctc_zero_infinity": True,
|
||||
"do_stable_layer_norm": True,
|
||||
"eos_token_id": 2,
|
||||
"feat_extract_activation": "gelu",
|
||||
"feat_extract_dropout": 0.0,
|
||||
"feat_extract_norm": "layer",
|
||||
"feat_proj_dropout": 0.05,
|
||||
"final_dropout": 0.0,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout": 0.05,
|
||||
"hidden_size": 1024,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"layerdrop": 0.05,
|
||||
"mask_channel_length": 10,
|
||||
"mask_channel_min_space": 1,
|
||||
"mask_channel_other": 0.0,
|
||||
"mask_channel_prob": 0.0,
|
||||
"mask_channel_selection": "static",
|
||||
"mask_feature_length": 10,
|
||||
"mask_feature_prob": 0.0,
|
||||
"mask_time_length": 10,
|
||||
"mask_time_min_space": 1,
|
||||
"mask_time_other": 0.0,
|
||||
"mask_time_prob": 0.05,
|
||||
"mask_time_selection": "static",
|
||||
"model_type": "wav2vec2",
|
||||
"num_attention_heads": 16,
|
||||
"num_conv_pos_embedding_groups": 16,
|
||||
"num_conv_pos_embeddings": 128,
|
||||
"num_feat_extract_layers": 7,
|
||||
"num_hidden_layers": 24,
|
||||
"pad_token_id": 0,
|
||||
"transformers_version": "4.7.0.dev0",
|
||||
"vocab_size": 33
|
||||
}
|
||||
self.model = Wav2Vec2ForCTC(Wav2Vec2Config(**config))
|
||||
self.video_rate = 30
|
||||
|
||||
def extract_audio_feat(self, input_audio, sample_rate, processor, return_all_layers=False, dtype=torch.float32, device='cpu'):
|
||||
input_values = processor(input_audio, sampling_rate=sample_rate, return_tensors="pt").input_values.to(dtype=dtype, device=device)
|
||||
|
||||
# retrieve logits & take argmax
|
||||
res = self.model(input_values, output_hidden_states=True)
|
||||
if return_all_layers:
|
||||
feat = torch.cat(res.hidden_states)
|
||||
else:
|
||||
feat = res.hidden_states[-1]
|
||||
feat = linear_interpolation(feat, input_fps=50, output_fps=self.video_rate)
|
||||
return feat
|
||||
|
||||
def get_audio_embed_bucket(self, audio_embed, stride=2, batch_frames=12, m=2):
|
||||
num_layers, audio_frame_num, audio_dim = audio_embed.shape
|
||||
|
||||
if num_layers > 1:
|
||||
return_all_layers = True
|
||||
else:
|
||||
return_all_layers = False
|
||||
|
||||
min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1
|
||||
|
||||
bucket_num = min_batch_num * batch_frames
|
||||
batch_idx = [stride * i for i in range(bucket_num)]
|
||||
batch_audio_eb = []
|
||||
for bi in batch_idx:
|
||||
if bi < audio_frame_num:
|
||||
audio_sample_stride = 2
|
||||
chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride))
|
||||
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
|
||||
chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx]
|
||||
|
||||
if return_all_layers:
|
||||
frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1)
|
||||
else:
|
||||
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
|
||||
else:
|
||||
frame_audio_embed = \
|
||||
torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
|
||||
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
|
||||
batch_audio_eb.append(frame_audio_embed)
|
||||
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0)
|
||||
|
||||
return batch_audio_eb, min_batch_num
|
||||
|
||||
def get_audio_embed_bucket_fps(self, audio_embed, fps=16, batch_frames=81, m=0):
|
||||
num_layers, audio_frame_num, audio_dim = audio_embed.shape
|
||||
|
||||
if num_layers > 1:
|
||||
return_all_layers = True
|
||||
else:
|
||||
return_all_layers = False
|
||||
|
||||
scale = self.video_rate / fps
|
||||
|
||||
min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1
|
||||
|
||||
bucket_num = min_batch_num * batch_frames
|
||||
padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * self.video_rate) - audio_frame_num
|
||||
batch_idx = get_sample_indices(
|
||||
original_fps=self.video_rate, total_frames=audio_frame_num + padd_audio_num, target_fps=fps, num_sample=bucket_num, fixed_start=0
|
||||
)
|
||||
batch_audio_eb = []
|
||||
audio_sample_stride = int(self.video_rate / fps)
|
||||
for bi in batch_idx:
|
||||
if bi < audio_frame_num:
|
||||
|
||||
chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride))
|
||||
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
|
||||
chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx]
|
||||
|
||||
if return_all_layers:
|
||||
frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1)
|
||||
else:
|
||||
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
|
||||
else:
|
||||
frame_audio_embed = \
|
||||
torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
|
||||
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
|
||||
batch_audio_eb.append(frame_audio_embed)
|
||||
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0)
|
||||
|
||||
return batch_audio_eb, min_batch_num
|
||||
|
||||
def get_audio_feats_per_inference(self, input_audio, sample_rate, processor, fps=16, batch_frames=80, m=0, dtype=torch.float32, device='cpu'):
|
||||
audio_feat = self.extract_audio_feat(input_audio, sample_rate, processor, return_all_layers=True, dtype=dtype, device=device)
|
||||
audio_embed_bucket, min_batch_num = self.get_audio_embed_bucket_fps(audio_feat, fps=fps, batch_frames=batch_frames, m=m)
|
||||
audio_embed_bucket = audio_embed_bucket.unsqueeze(0).permute(0, 2, 3, 1).to(device, dtype)
|
||||
audio_embeds = [audio_embed_bucket[..., i * batch_frames:(i + 1) * batch_frames] for i in range(min_batch_num)]
|
||||
return audio_embeds
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanS2VAudioEncoderStateDictConverter()
|
||||
|
||||
|
||||
class WanS2VAudioEncoderStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict = {'model.' + k: v for k, v in state_dict.items()}
|
||||
return state_dict
|
||||
@@ -762,7 +762,7 @@ def lets_dance_flux(
|
||||
hidden_states = dit.x_embedder(hidden_states)
|
||||
|
||||
if entity_prompt_emb is not None and entity_masks is not None:
|
||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, 16)
|
||||
else:
|
||||
prompt_emb = dit.context_embedder(prompt_emb)
|
||||
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
|
||||
@@ -1233,7 +1233,7 @@ def model_fn_flux_image(
|
||||
|
||||
# EliGen
|
||||
if entity_prompt_emb is not None and entity_masks is not None:
|
||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, latents.shape[1])
|
||||
else:
|
||||
prompt_emb = dit.context_embedder(prompt_emb)
|
||||
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
|
||||
@@ -4,18 +4,46 @@ from typing import Union
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
|
||||
from ..models import ModelManager, load_state_dict
|
||||
from ..models.qwen_image_dit import QwenImageDiT
|
||||
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
|
||||
from ..models.qwen_image_vae import QwenImageVAE
|
||||
from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet
|
||||
from ..schedulers import FlowMatchScheduler
|
||||
from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||
from ..lora import GeneralLoRALoader
|
||||
from .flux_image_new import ControlNetInput
|
||||
|
||||
from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
|
||||
|
||||
class QwenImageBlockwiseMultiControlNet(torch.nn.Module):
|
||||
def __init__(self, models: list[QwenImageBlockWiseControlNet]):
|
||||
super().__init__()
|
||||
if not isinstance(models, list):
|
||||
models = [models]
|
||||
self.models = torch.nn.ModuleList(models)
|
||||
|
||||
def preprocess(self, controlnet_inputs: list[ControlNetInput], conditionings: list[torch.Tensor], **kwargs):
|
||||
processed_conditionings = []
|
||||
for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
|
||||
conditioning = rearrange(conditioning, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
||||
model_output = self.models[controlnet_input.controlnet_id].process_controlnet_conditioning(conditioning)
|
||||
processed_conditionings.append(model_output)
|
||||
return processed_conditionings
|
||||
|
||||
def blockwise_forward(self, image, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, block_id, **kwargs):
|
||||
res = 0
|
||||
for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
|
||||
progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1)
|
||||
if progress > controlnet_input.start + (1e-4) or progress < controlnet_input.end - (1e-4):
|
||||
continue
|
||||
model_output = self.models[controlnet_input.controlnet_id].blockwise_forward(image, conditioning, block_id)
|
||||
res = res + model_output * controlnet_input.scale
|
||||
return res
|
||||
|
||||
|
||||
class QwenImagePipeline(BasePipeline):
|
||||
|
||||
@@ -24,36 +52,97 @@ class QwenImagePipeline(BasePipeline):
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
)
|
||||
from transformers import Qwen2Tokenizer
|
||||
from transformers import Qwen2Tokenizer, Qwen2VLProcessor
|
||||
|
||||
self.scheduler = FlowMatchScheduler(sigma_min=0, sigma_max=1, extra_one_step=True, exponential_shift=True, exponential_shift_mu=0.8, shift_terminal=0.02)
|
||||
self.text_encoder: QwenImageTextEncoder = None
|
||||
self.dit: QwenImageDiT = None
|
||||
self.vae: QwenImageVAE = None
|
||||
self.blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None
|
||||
self.tokenizer: Qwen2Tokenizer = None
|
||||
self.processor: Qwen2VLProcessor = None
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.in_iteration_models = ("dit", "blockwise_controlnet")
|
||||
self.units = [
|
||||
QwenImageUnit_ShapeChecker(),
|
||||
QwenImageUnit_NoiseInitializer(),
|
||||
QwenImageUnit_InputImageEmbedder(),
|
||||
QwenImageUnit_Inpaint(),
|
||||
QwenImageUnit_EditImageEmbedder(),
|
||||
QwenImageUnit_ContextImageEmbedder(),
|
||||
QwenImageUnit_PromptEmbedder(),
|
||||
QwenImageUnit_EntityControl(),
|
||||
QwenImageUnit_BlockwiseControlNet(),
|
||||
]
|
||||
self.model_fn = model_fn_qwen_image
|
||||
|
||||
|
||||
def load_lora(self, module, path, alpha=1):
|
||||
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
||||
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
loader.load(module, lora, alpha=alpha)
|
||||
def load_lora(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
lora_config: Union[ModelConfig, str] = None,
|
||||
alpha=1,
|
||||
hotload=False,
|
||||
state_dict=None,
|
||||
):
|
||||
if state_dict is None:
|
||||
if isinstance(lora_config, str):
|
||||
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
lora_config.download_if_necessary()
|
||||
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
lora = state_dict
|
||||
if hotload:
|
||||
for name, module in module.named_modules():
|
||||
if isinstance(module, AutoWrappedLinear):
|
||||
lora_a_name = f'{name}.lora_A.default.weight'
|
||||
lora_b_name = f'{name}.lora_B.default.weight'
|
||||
if lora_a_name in lora and lora_b_name in lora:
|
||||
module.lora_A_weights.append(lora[lora_a_name] * alpha)
|
||||
module.lora_B_weights.append(lora[lora_b_name])
|
||||
else:
|
||||
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
||||
loader.load(module, lora, alpha=alpha)
|
||||
|
||||
|
||||
def clear_lora(self):
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, AutoWrappedLinear):
|
||||
if hasattr(module, "lora_A_weights"):
|
||||
module.lora_A_weights.clear()
|
||||
if hasattr(module, "lora_B_weights"):
|
||||
module.lora_B_weights.clear()
|
||||
|
||||
|
||||
def enable_lora_magic(self):
|
||||
if self.dit is not None:
|
||||
if not (hasattr(self.dit, "vram_management_enabled") and self.dit.vram_management_enabled):
|
||||
dtype = next(iter(self.dit.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device=self.device,
|
||||
onload_dtype=dtype,
|
||||
onload_device=self.device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=None,
|
||||
)
|
||||
|
||||
|
||||
def training_loss(self, **inputs):
|
||||
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
|
||||
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
|
||||
training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep)
|
||||
noise = torch.randn_like(inputs["input_latents"])
|
||||
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||
training_target = self.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||
|
||||
noise_pred = self.model_fn(**inputs, timestep=timestep)
|
||||
|
||||
@@ -62,16 +151,58 @@ class QwenImagePipeline(BasePipeline):
|
||||
return loss
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
|
||||
self.vram_management_enabled = True
|
||||
if num_persistent_param_in_dit is not None:
|
||||
vram_limit = None
|
||||
else:
|
||||
if vram_limit is None:
|
||||
vram_limit = self.get_vram()
|
||||
vram_limit = vram_limit - vram_buffer
|
||||
def direct_distill_loss(self, **inputs):
|
||||
self.scheduler.set_timesteps(inputs["num_inference_steps"])
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(self.scheduler.timesteps):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred = self.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)
|
||||
inputs["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs)
|
||||
loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float())
|
||||
return loss
|
||||
|
||||
|
||||
def _enable_fp8_lora_training(self, dtype):
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionRotaryEmbedding
|
||||
from ..models.qwen_image_dit import RMSNorm
|
||||
from ..models.qwen_image_vae import QwenImageRMS_norm
|
||||
module_map = {
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv3d: AutoWrappedModule,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
Qwen2_5_VLRotaryEmbedding: AutoWrappedModule,
|
||||
Qwen2RMSNorm: AutoWrappedModule,
|
||||
Qwen2_5_VisionPatchEmbed: AutoWrappedModule,
|
||||
Qwen2_5_VisionRotaryEmbedding: AutoWrappedModule,
|
||||
QwenImageRMS_norm: AutoWrappedModule,
|
||||
}
|
||||
model_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cuda",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cuda",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device="cuda",
|
||||
)
|
||||
if self.text_encoder is not None:
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm
|
||||
enable_vram_management(self.text_encoder, module_map=module_map, module_config=model_config)
|
||||
if self.dit is not None:
|
||||
enable_vram_management(self.dit, module_map=module_map, module_config=model_config)
|
||||
if self.vae is not None:
|
||||
enable_vram_management(self.vae, module_map=module_map, module_config=model_config)
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, auto_offload=True, enable_dit_fp8_computation=False):
|
||||
self.vram_management_enabled = True
|
||||
if vram_limit is None and auto_offload:
|
||||
vram_limit = self.get_vram()
|
||||
if vram_limit is not None:
|
||||
vram_limit = vram_limit - vram_buffer
|
||||
|
||||
if self.text_encoder is not None:
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionRotaryEmbedding
|
||||
dtype = next(iter(self.text_encoder.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.text_encoder,
|
||||
@@ -80,6 +211,8 @@ class QwenImagePipeline(BasePipeline):
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
Qwen2_5_VLRotaryEmbedding: AutoWrappedModule,
|
||||
Qwen2RMSNorm: AutoWrappedModule,
|
||||
Qwen2_5_VisionPatchEmbed: AutoWrappedModule,
|
||||
Qwen2_5_VisionRotaryEmbedding: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
@@ -95,31 +228,54 @@ class QwenImagePipeline(BasePipeline):
|
||||
from ..models.qwen_image_dit import RMSNorm
|
||||
dtype = next(iter(self.dit.parameters())).dtype
|
||||
device = "cpu" if vram_limit is not None else self.device
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
max_num_param=num_persistent_param_in_dit,
|
||||
overflow_module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if not enable_dit_fp8_computation:
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
else:
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=device,
|
||||
computation_dtype=dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.vae is not None:
|
||||
from ..models.qwen_image_vae import QwenImageRMS_norm
|
||||
dtype = next(iter(self.vae.parameters())).dtype
|
||||
@@ -141,6 +297,23 @@ class QwenImagePipeline(BasePipeline):
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.blockwise_controlnet is not None:
|
||||
enable_vram_management(
|
||||
self.blockwise_controlnet,
|
||||
module_map = {
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@@ -149,6 +322,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
processor_config: ModelConfig = None,
|
||||
):
|
||||
# Download and load models
|
||||
model_manager = ModelManager()
|
||||
@@ -165,10 +339,15 @@ class QwenImagePipeline(BasePipeline):
|
||||
pipe.text_encoder = model_manager.fetch_model("qwen_image_text_encoder")
|
||||
pipe.dit = model_manager.fetch_model("qwen_image_dit")
|
||||
pipe.vae = model_manager.fetch_model("qwen_image_vae")
|
||||
pipe.blockwise_controlnet = QwenImageBlockwiseMultiControlNet(model_manager.fetch_model("qwen_image_blockwise_controlnet", index="all"))
|
||||
if tokenizer_config is not None and pipe.text_encoder is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
from transformers import Qwen2Tokenizer
|
||||
pipe.tokenizer = Qwen2Tokenizer.from_pretrained(tokenizer_config.path)
|
||||
if processor_config is not None:
|
||||
processor_config.download_if_necessary()
|
||||
from transformers import Qwen2VLProcessor
|
||||
pipe.processor = Qwen2VLProcessor.from_pretrained(processor_config.path)
|
||||
return pipe
|
||||
|
||||
|
||||
@@ -182,6 +361,10 @@ class QwenImagePipeline(BasePipeline):
|
||||
# Image
|
||||
input_image: Image.Image = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Inpaint
|
||||
inpaint_mask: Image.Image = None,
|
||||
inpaint_blur_size: int = None,
|
||||
inpaint_blur_sigma: float = None,
|
||||
# Shape
|
||||
height: int = 1328,
|
||||
width: int = 1328,
|
||||
@@ -190,6 +373,21 @@ class QwenImagePipeline(BasePipeline):
|
||||
rand_device: str = "cpu",
|
||||
# Steps
|
||||
num_inference_steps: int = 30,
|
||||
exponential_shift_mu: float = None,
|
||||
# Blockwise ControlNet
|
||||
blockwise_controlnet_inputs: list[ControlNetInput] = None,
|
||||
# EliGen
|
||||
eligen_entity_prompts: list[str] = None,
|
||||
eligen_entity_masks: list[Image.Image] = None,
|
||||
eligen_enable_on_negative: bool = False,
|
||||
# Qwen-Image-Edit
|
||||
edit_image: Image.Image = None,
|
||||
edit_image_auto_resize: bool = True,
|
||||
edit_rope_interpolation: bool = False,
|
||||
# In-context control
|
||||
context_image: Image.Image = None,
|
||||
# FP8
|
||||
enable_fp8_attention: bool = False,
|
||||
# Tile
|
||||
tiled: bool = False,
|
||||
tile_size: int = 128,
|
||||
@@ -198,7 +396,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16))
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16), exponential_shift_mu=exponential_shift_mu)
|
||||
|
||||
# Parameters
|
||||
inputs_posi = {
|
||||
@@ -210,9 +408,16 @@ class QwenImagePipeline(BasePipeline):
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||
"inpaint_mask": inpaint_mask, "inpaint_blur_size": inpaint_blur_size, "inpaint_blur_sigma": inpaint_blur_sigma,
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"enable_fp8_attention": enable_fp8_attention,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"blockwise_controlnet_inputs": blockwise_controlnet_inputs,
|
||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
|
||||
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "edit_rope_interpolation": edit_rope_interpolation,
|
||||
"context_image": context_image,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
@@ -232,7 +437,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# Scheduler
|
||||
inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"])
|
||||
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
@@ -281,16 +486,35 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit):
|
||||
return {"latents": noise, "input_latents": input_latents}
|
||||
else:
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents, "input_latents": None}
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
|
||||
|
||||
|
||||
class QwenImageUnit_Inpaint(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("inpaint_mask", "height", "width", "inpaint_blur_size", "inpaint_blur_sigma"),
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, inpaint_mask, height, width, inpaint_blur_size, inpaint_blur_sigma):
|
||||
if inpaint_mask is None:
|
||||
return {}
|
||||
inpaint_mask = pipe.preprocess_image(inpaint_mask.convert("RGB").resize((width // 8, height // 8)), min_value=0, max_value=1)
|
||||
inpaint_mask = inpaint_mask.mean(dim=1, keepdim=True)
|
||||
if inpaint_blur_size is not None and inpaint_blur_sigma is not None:
|
||||
from torchvision.transforms import GaussianBlur
|
||||
blur = GaussianBlur(kernel_size=inpaint_blur_size * 2 + 1, sigma=inpaint_blur_sigma)
|
||||
inpaint_mask = blur(inpaint_mask)
|
||||
return {"inpaint_mask": inpaint_mask}
|
||||
|
||||
|
||||
class QwenImageUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
input_params=("edit_image",),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
|
||||
@@ -300,8 +524,88 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit):
|
||||
selected = hidden_states[bool_mask]
|
||||
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
||||
return split_result
|
||||
|
||||
def calculate_dimensions(self, target_area, ratio):
|
||||
import math
|
||||
width = math.sqrt(target_area * ratio)
|
||||
height = width / ratio
|
||||
width = round(width / 32) * 32
|
||||
height = round(height / 32) * 32
|
||||
return width, height
|
||||
|
||||
def resize_image(self, image, target_area=384*384):
|
||||
width, height = self.calculate_dimensions(target_area, image.size[0] / image.size[1])
|
||||
return image.resize((width, height))
|
||||
|
||||
def encode_prompt(self, pipe: QwenImagePipeline, prompt):
|
||||
template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
drop_idx = 34
|
||||
txt = [template.format(e) for e in prompt]
|
||||
model_inputs = pipe.tokenizer(txt, max_length=4096+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device)
|
||||
if model_inputs.input_ids.shape[1] >= 1024:
|
||||
print(f"Warning!!! QwenImage model was trained on prompts up to 512 tokens. Current prompt requires {model_inputs['input_ids'].shape[1] - drop_idx} tokens, which may lead to unpredictable behavior.")
|
||||
hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, output_hidden_states=True,)[-1]
|
||||
split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)
|
||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||
return split_hidden_states
|
||||
|
||||
def encode_prompt_edit(self, pipe: QwenImagePipeline, prompt, edit_image):
|
||||
template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
drop_idx = 64
|
||||
txt = [template.format(e) for e in prompt]
|
||||
model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device)
|
||||
hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1]
|
||||
split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)
|
||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||
return split_hidden_states
|
||||
|
||||
def encode_prompt_edit_multi(self, pipe: QwenImagePipeline, prompt, edit_image):
|
||||
template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
drop_idx = 64
|
||||
img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
|
||||
base_img_prompt = "".join([img_prompt_template.format(i + 1) for i in range(len(edit_image))])
|
||||
txt = [template.format(base_img_prompt + e) for e in prompt]
|
||||
edit_image = [self.resize_image(image) for image in edit_image]
|
||||
model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device)
|
||||
hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1]
|
||||
split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)
|
||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||
return split_hidden_states
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, prompt) -> dict:
|
||||
def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict:
|
||||
if pipe.text_encoder is not None:
|
||||
prompt = [prompt]
|
||||
if edit_image is None:
|
||||
split_hidden_states = self.encode_prompt(pipe, prompt)
|
||||
elif isinstance(edit_image, Image.Image):
|
||||
split_hidden_states = self.encode_prompt_edit(pipe, prompt, edit_image)
|
||||
else:
|
||||
split_hidden_states = self.encode_prompt_edit_multi(pipe, prompt, edit_image)
|
||||
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
||||
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
||||
prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
|
||||
encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list])
|
||||
prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask}
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
class QwenImageUnit_EntityControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
|
||||
def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
|
||||
bool_mask = mask.bool()
|
||||
valid_lengths = bool_mask.sum(dim=1)
|
||||
selected = hidden_states[bool_mask]
|
||||
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
||||
return split_result
|
||||
|
||||
def get_prompt_emb(self, pipe: QwenImagePipeline, prompt) -> dict:
|
||||
if pipe.text_encoder is not None:
|
||||
prompt = [prompt]
|
||||
template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
@@ -321,18 +625,174 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit):
|
||||
else:
|
||||
return {}
|
||||
|
||||
def preprocess_masks(self, pipe, masks, height, width, dim):
|
||||
out_masks = []
|
||||
for mask in masks:
|
||||
mask = pipe.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0
|
||||
mask = mask.repeat(1, dim, 1, 1).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
out_masks.append(mask)
|
||||
return out_masks
|
||||
|
||||
def prepare_entity_inputs(self, pipe, entity_prompts, entity_masks, width, height):
|
||||
entity_masks = self.preprocess_masks(pipe, entity_masks, height//8, width//8, 1)
|
||||
entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w
|
||||
prompt_embs, prompt_emb_masks = [], []
|
||||
for entity_prompt in entity_prompts:
|
||||
prompt_emb_dict = self.get_prompt_emb(pipe, entity_prompt)
|
||||
prompt_embs.append(prompt_emb_dict['prompt_emb'])
|
||||
prompt_emb_masks.append(prompt_emb_dict['prompt_emb_mask'])
|
||||
return prompt_embs, prompt_emb_masks, entity_masks
|
||||
|
||||
def prepare_eligen(self, pipe, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, enable_eligen_on_negative, cfg_scale):
|
||||
entity_prompt_emb_posi, entity_prompt_emb_posi_mask, entity_masks_posi = self.prepare_entity_inputs(pipe, eligen_entity_prompts, eligen_entity_masks, width, height)
|
||||
if enable_eligen_on_negative and cfg_scale != 1.0:
|
||||
entity_prompt_emb_nega = [prompt_emb_nega['prompt_emb']] * len(entity_prompt_emb_posi)
|
||||
entity_prompt_emb_nega_mask = [prompt_emb_nega['prompt_emb_mask']] * len(entity_prompt_emb_posi)
|
||||
entity_masks_nega = entity_masks_posi
|
||||
else:
|
||||
entity_prompt_emb_nega, entity_prompt_emb_nega_mask, entity_masks_nega = None, None, None
|
||||
eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi, "entity_prompt_emb_mask": entity_prompt_emb_posi_mask}
|
||||
eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega, "entity_prompt_emb_mask": entity_prompt_emb_nega_mask}
|
||||
return eligen_kwargs_posi, eligen_kwargs_nega
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
eligen_entity_prompts, eligen_entity_masks = inputs_shared.get("eligen_entity_prompts", None), inputs_shared.get("eligen_entity_masks", None)
|
||||
if eligen_entity_prompts is None or eligen_entity_masks is None or len(eligen_entity_prompts) == 0 or len(eligen_entity_masks) == 0:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
eligen_enable_on_negative = inputs_shared.get("eligen_enable_on_negative", False)
|
||||
eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega,
|
||||
eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"],
|
||||
eligen_enable_on_negative, inputs_shared["cfg_scale"])
|
||||
inputs_posi.update(eligen_kwargs_posi)
|
||||
if inputs_shared.get("cfg_scale", 1.0) != 1.0:
|
||||
inputs_nega.update(eligen_kwargs_nega)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
|
||||
class QwenImageUnit_BlockwiseControlNet(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("blockwise_controlnet_inputs", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def apply_controlnet_mask_on_latents(self, pipe, latents, mask):
|
||||
mask = (pipe.preprocess_image(mask) + 1) / 2
|
||||
mask = mask.mean(dim=1, keepdim=True)
|
||||
mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])
|
||||
latents = torch.concat([latents, mask], dim=1)
|
||||
return latents
|
||||
|
||||
def apply_controlnet_mask_on_image(self, pipe, image, mask):
|
||||
mask = mask.resize(image.size)
|
||||
mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu()
|
||||
image = np.array(image)
|
||||
image[mask > 0] = 0
|
||||
image = Image.fromarray(image)
|
||||
return image
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, blockwise_controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride):
|
||||
if blockwise_controlnet_inputs is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
conditionings = []
|
||||
for controlnet_input in blockwise_controlnet_inputs:
|
||||
image = controlnet_input.image
|
||||
if controlnet_input.inpaint_mask is not None:
|
||||
image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask)
|
||||
|
||||
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
image = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
if controlnet_input.inpaint_mask is not None:
|
||||
image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask)
|
||||
conditionings.append(image)
|
||||
|
||||
return {"blockwise_controlnet_conditioning": conditionings}
|
||||
|
||||
|
||||
class QwenImageUnit_EditImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("edit_image", "tiled", "tile_size", "tile_stride", "edit_image_auto_resize"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
|
||||
def calculate_dimensions(self, target_area, ratio):
|
||||
import math
|
||||
width = math.sqrt(target_area * ratio)
|
||||
height = width / ratio
|
||||
width = round(width / 32) * 32
|
||||
height = round(height / 32) * 32
|
||||
return width, height
|
||||
|
||||
|
||||
def edit_image_auto_resize(self, edit_image):
|
||||
calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1])
|
||||
return edit_image.resize((calculated_width, calculated_height))
|
||||
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, edit_image, tiled, tile_size, tile_stride, edit_image_auto_resize=False):
|
||||
if edit_image is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(['vae'])
|
||||
if isinstance(edit_image, Image.Image):
|
||||
resized_edit_image = self.edit_image_auto_resize(edit_image) if edit_image_auto_resize else edit_image
|
||||
edit_image = pipe.preprocess_image(resized_edit_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
edit_latents = pipe.vae.encode(edit_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
else:
|
||||
resized_edit_image, edit_latents = [], []
|
||||
for image in edit_image:
|
||||
if edit_image_auto_resize:
|
||||
image = self.edit_image_auto_resize(image)
|
||||
resized_edit_image.append(image)
|
||||
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
edit_latents.append(latents)
|
||||
return {"edit_latents": edit_latents, "edit_image": resized_edit_image}
|
||||
|
||||
|
||||
class QwenImageUnit_ContextImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("context_image", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, context_image, height, width, tiled, tile_size, tile_stride):
|
||||
if context_image is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(['vae'])
|
||||
context_image = pipe.preprocess_image(context_image.resize((width, height))).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
context_latents = pipe.vae.encode(context_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return {"context_latents": context_latents}
|
||||
|
||||
|
||||
def model_fn_qwen_image(
|
||||
dit: QwenImageDiT = None,
|
||||
blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_emb=None,
|
||||
prompt_emb_mask=None,
|
||||
height=None,
|
||||
width=None,
|
||||
blockwise_controlnet_conditioning=None,
|
||||
blockwise_controlnet_inputs=None,
|
||||
progress_id=0,
|
||||
num_inference_steps=1,
|
||||
entity_prompt_emb=None,
|
||||
entity_prompt_emb_mask=None,
|
||||
entity_masks=None,
|
||||
edit_latents=None,
|
||||
context_latents=None,
|
||||
enable_fp8_attention=False,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
edit_rope_interpolation=False,
|
||||
**kwargs
|
||||
):
|
||||
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
|
||||
@@ -340,13 +800,39 @@ def model_fn_qwen_image(
|
||||
timestep = timestep / 1000
|
||||
|
||||
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
|
||||
image = dit.img_in(image)
|
||||
text = dit.txt_in(dit.txt_norm(prompt_emb))
|
||||
conditioning = dit.time_text_embed(timestep, image.dtype)
|
||||
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||
image_seq_len = image.shape[1]
|
||||
|
||||
for block in dit.transformer_blocks:
|
||||
if context_latents is not None:
|
||||
img_shapes += [(context_latents.shape[0], context_latents.shape[2]//2, context_latents.shape[3]//2)]
|
||||
context_image = rearrange(context_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=context_latents.shape[2]//2, W=context_latents.shape[3]//2, P=2, Q=2)
|
||||
image = torch.cat([image, context_image], dim=1)
|
||||
if edit_latents is not None:
|
||||
edit_latents_list = edit_latents if isinstance(edit_latents, list) else [edit_latents]
|
||||
img_shapes += [(e.shape[0], e.shape[2]//2, e.shape[3]//2) for e in edit_latents_list]
|
||||
edit_image = [rearrange(e, "B C (H P) (W Q) -> B (H W) (C P Q)", H=e.shape[2]//2, W=e.shape[3]//2, P=2, Q=2) for e in edit_latents_list]
|
||||
image = torch.cat([image] + edit_image, dim=1)
|
||||
|
||||
image = dit.img_in(image)
|
||||
conditioning = dit.time_text_embed(timestep, image.dtype)
|
||||
|
||||
if entity_prompt_emb is not None:
|
||||
text, image_rotary_emb, attention_mask = dit.process_entity_masks(
|
||||
latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask,
|
||||
entity_masks, height, width, image, img_shapes,
|
||||
)
|
||||
else:
|
||||
text = dit.txt_in(dit.txt_norm(prompt_emb))
|
||||
if edit_rope_interpolation:
|
||||
image_rotary_emb = dit.pos_embed.forward_sampling(img_shapes, txt_seq_lens, device=latents.device)
|
||||
else:
|
||||
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||
attention_mask = None
|
||||
|
||||
if blockwise_controlnet_conditioning is not None:
|
||||
blockwise_controlnet_conditioning = blockwise_controlnet.preprocess(
|
||||
blockwise_controlnet_inputs, blockwise_controlnet_conditioning)
|
||||
|
||||
for block_id, block in enumerate(dit.transformer_blocks):
|
||||
text, image = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
@@ -355,10 +841,21 @@ def model_fn_qwen_image(
|
||||
text=text,
|
||||
temb=conditioning,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
enable_fp8_attention=enable_fp8_attention,
|
||||
)
|
||||
if blockwise_controlnet_conditioning is not None:
|
||||
image_slice = image[:, :image_seq_len].clone()
|
||||
controlnet_output = blockwise_controlnet.blockwise_forward(
|
||||
image=image_slice, conditionings=blockwise_controlnet_conditioning,
|
||||
controlnet_inputs=blockwise_controlnet_inputs, block_id=block_id,
|
||||
progress_id=progress_id, num_inference_steps=num_inference_steps,
|
||||
)
|
||||
image[:, :image_seq_len] = image_slice + controlnet_output
|
||||
|
||||
image = dit.norm_out(image, conditioning)
|
||||
image = dit.proj_out(image)
|
||||
image = image[:, :image_seq_len]
|
||||
|
||||
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
return latents
|
||||
|
||||
@@ -15,11 +15,13 @@ from typing_extensions import Literal
|
||||
from ..utils import BasePipeline, ModelConfig, PipelineUnit, PipelineUnitRunner
|
||||
from ..models import ModelManager, load_state_dict
|
||||
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
||||
from ..models.wan_video_dit_s2v import rope_precompute
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm
|
||||
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
from ..models.wan_video_vace import VaceWanModel
|
||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||
from ..models.wan_video_animate_adapter import WanAnimateAdapter
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from ..prompters import WanPrompter
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
|
||||
@@ -43,14 +45,17 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.vae: WanVideoVAE = None
|
||||
self.motion_controller: WanMotionControllerModel = None
|
||||
self.vace: VaceWanModel = None
|
||||
self.in_iteration_models = ("dit", "motion_controller", "vace")
|
||||
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace")
|
||||
self.vace2: VaceWanModel = None
|
||||
self.animate_adapter: WanAnimateAdapter = None
|
||||
self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter")
|
||||
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2", "animate_adapter")
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.units = [
|
||||
WanVideoUnit_ShapeChecker(),
|
||||
WanVideoUnit_NoiseInitializer(),
|
||||
WanVideoUnit_InputVideoEmbedder(),
|
||||
WanVideoUnit_PromptEmbedder(),
|
||||
WanVideoUnit_S2V(),
|
||||
WanVideoUnit_InputVideoEmbedder(),
|
||||
WanVideoUnit_ImageEmbedderVAE(),
|
||||
WanVideoUnit_ImageEmbedderCLIP(),
|
||||
WanVideoUnit_ImageEmbedderFused(),
|
||||
@@ -59,18 +64,46 @@ class WanVideoPipeline(BasePipeline):
|
||||
WanVideoUnit_FunCameraControl(),
|
||||
WanVideoUnit_SpeedControl(),
|
||||
WanVideoUnit_VACE(),
|
||||
WanVideoPostUnit_AnimateVideoSplit(),
|
||||
WanVideoPostUnit_AnimatePoseLatents(),
|
||||
WanVideoPostUnit_AnimateFacePixelValues(),
|
||||
WanVideoPostUnit_AnimateInpaint(),
|
||||
WanVideoUnit_UnifiedSequenceParallel(),
|
||||
WanVideoUnit_TeaCache(),
|
||||
WanVideoUnit_CfgMerger(),
|
||||
]
|
||||
self.post_units = [
|
||||
WanVideoPostUnit_S2V(),
|
||||
]
|
||||
self.model_fn = model_fn_wan_video
|
||||
|
||||
|
||||
def load_lora(self, module, path, alpha=1):
|
||||
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
||||
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
loader.load(module, lora, alpha=alpha)
|
||||
|
||||
def load_lora(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
lora_config: Union[ModelConfig, str] = None,
|
||||
alpha=1,
|
||||
hotload=False,
|
||||
state_dict=None,
|
||||
):
|
||||
if state_dict is None:
|
||||
if isinstance(lora_config, str):
|
||||
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
lora_config.download_if_necessary()
|
||||
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
lora = state_dict
|
||||
if hotload:
|
||||
for name, module in module.named_modules():
|
||||
if isinstance(module, AutoWrappedLinear):
|
||||
lora_a_name = f'{name}.lora_A.default.weight'
|
||||
lora_b_name = f'{name}.lora_B.default.weight'
|
||||
if lora_a_name in lora and lora_b_name in lora:
|
||||
module.lora_A_weights.append(lora[lora_a_name] * alpha)
|
||||
module.lora_B_weights.append(lora[lora_b_name])
|
||||
else:
|
||||
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
||||
loader.load(module, lora, alpha=alpha)
|
||||
|
||||
def training_loss(self, **inputs):
|
||||
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps)
|
||||
@@ -127,6 +160,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
torch.nn.LayerNorm: WanAutoCastLayerNorm,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.Conv1d: AutoWrappedModule,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
@@ -254,6 +289,25 @@ class WanVideoPipeline(BasePipeline):
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.audio_encoder is not None:
|
||||
# TODO: need check
|
||||
dtype = next(iter(self.audio_encoder.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.audio_encoder,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
torch.nn.Conv1d: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def initialize_usp(self):
|
||||
@@ -290,6 +344,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
|
||||
audio_processor_config: ModelConfig = None,
|
||||
redirect_common_files: bool = True,
|
||||
use_usp=False,
|
||||
):
|
||||
@@ -331,8 +386,14 @@ class WanVideoPipeline(BasePipeline):
|
||||
pipe.vae = model_manager.fetch_model("wan_video_vae")
|
||||
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
||||
pipe.vace = model_manager.fetch_model("wan_video_vace")
|
||||
|
||||
vace = model_manager.fetch_model("wan_video_vace", index=2)
|
||||
if isinstance(vace, list):
|
||||
pipe.vace, pipe.vace2 = vace
|
||||
else:
|
||||
pipe.vace = vace
|
||||
pipe.audio_encoder = model_manager.fetch_model("wans2v_audio_encoder")
|
||||
pipe.animate_adapter = model_manager.fetch_model("wan_video_animate_adapter")
|
||||
|
||||
# Size division factor
|
||||
if pipe.vae is not None:
|
||||
pipe.height_division_factor = pipe.vae.upsampling_factor * 2
|
||||
@@ -342,7 +403,11 @@ class WanVideoPipeline(BasePipeline):
|
||||
tokenizer_config.download_if_necessary(use_usp=use_usp)
|
||||
pipe.prompter.fetch_models(pipe.text_encoder)
|
||||
pipe.prompter.fetch_tokenizer(tokenizer_config.path)
|
||||
|
||||
|
||||
if audio_processor_config is not None:
|
||||
audio_processor_config.download_if_necessary(use_usp=use_usp)
|
||||
from transformers import Wav2Vec2Processor
|
||||
pipe.audio_processor = Wav2Vec2Processor.from_pretrained(audio_processor_config.path)
|
||||
# Unified Sequence Parallel
|
||||
if use_usp: pipe.enable_usp()
|
||||
return pipe
|
||||
@@ -361,6 +426,13 @@ class WanVideoPipeline(BasePipeline):
|
||||
# Video-to-video
|
||||
input_video: Optional[list[Image.Image]] = None,
|
||||
denoising_strength: Optional[float] = 1.0,
|
||||
# Speech-to-video
|
||||
input_audio: Optional[np.array] = None,
|
||||
audio_embeds: Optional[torch.Tensor] = None,
|
||||
audio_sample_rate: Optional[int] = 16000,
|
||||
s2v_pose_video: Optional[list[Image.Image]] = None,
|
||||
s2v_pose_latents: Optional[torch.Tensor] = None,
|
||||
motion_video: Optional[list[Image.Image]] = None,
|
||||
# ControlNet
|
||||
control_video: Optional[list[Image.Image]] = None,
|
||||
reference_image: Optional[Image.Image] = None,
|
||||
@@ -373,6 +445,11 @@ class WanVideoPipeline(BasePipeline):
|
||||
vace_video_mask: Optional[Image.Image] = None,
|
||||
vace_reference_image: Optional[Image.Image] = None,
|
||||
vace_scale: Optional[float] = 1.0,
|
||||
# Animate
|
||||
animate_pose_video: Optional[list[Image.Image]] = None,
|
||||
animate_face_video: Optional[list[Image.Image]] = None,
|
||||
animate_inpaint_video: Optional[list[Image.Image]] = None,
|
||||
animate_mask_video: Optional[list[Image.Image]] = None,
|
||||
# Randomness
|
||||
seed: Optional[int] = None,
|
||||
rand_device: Optional[str] = "cpu",
|
||||
@@ -429,6 +506,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
"motion_bucket_id": motion_bucket_id,
|
||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
|
||||
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video,
|
||||
"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
@@ -441,6 +520,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2:
|
||||
self.load_models_to_device(self.in_iteration_models_2)
|
||||
models["dit"] = self.dit2
|
||||
models["vace"] = self.vace2
|
||||
|
||||
# Timestep
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
@@ -462,9 +542,15 @@ class WanVideoPipeline(BasePipeline):
|
||||
inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"]
|
||||
|
||||
# VACE (TODO: remove it)
|
||||
if vace_reference_image is not None:
|
||||
inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:]
|
||||
|
||||
if vace_reference_image is not None or (animate_pose_video is not None and animate_face_video is not None):
|
||||
if vace_reference_image is not None and isinstance(vace_reference_image, list):
|
||||
f = len(vace_reference_image)
|
||||
else:
|
||||
f = 1
|
||||
inputs_shared["latents"] = inputs_shared["latents"][:, :, f:]
|
||||
# post-denoising, pre-decoding processing logic
|
||||
for unit in self.post_units:
|
||||
inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
@@ -492,11 +578,12 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit):
|
||||
def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image):
|
||||
length = (num_frames - 1) // 4 + 1
|
||||
if vace_reference_image is not None:
|
||||
length += 1
|
||||
f = len(vace_reference_image) if isinstance(vace_reference_image, list) else 1
|
||||
length += f
|
||||
shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor)
|
||||
noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device)
|
||||
if vace_reference_image is not None:
|
||||
noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2)
|
||||
noise = torch.concat((noise[:, :, -f:], noise[:, :, :-f]), dim=2)
|
||||
return {"noise": noise}
|
||||
|
||||
|
||||
@@ -515,7 +602,9 @@ class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
|
||||
input_video = pipe.preprocess_video(input_video)
|
||||
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
if vace_reference_image is not None:
|
||||
vace_reference_image = pipe.preprocess_video([vace_reference_image])
|
||||
if not isinstance(vace_reference_image, list):
|
||||
vace_reference_image = [vace_reference_image]
|
||||
vace_reference_image = pipe.preprocess_video(vace_reference_image)
|
||||
vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
input_latents = torch.concat([vace_reference_latents, input_latents], dim=2)
|
||||
if pipe.scheduler.training:
|
||||
@@ -663,22 +752,23 @@ class WanVideoUnit_ImageEmbedderFused(PipelineUnit):
|
||||
class WanVideoUnit_FunControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"),
|
||||
input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y", "latents"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y):
|
||||
def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y, latents):
|
||||
if control_video is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
control_video = pipe.preprocess_video(control_video)
|
||||
control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
y_dim = pipe.dit.in_dim-control_latents.shape[1]-latents.shape[1]
|
||||
if clip_feature is None or y is None:
|
||||
clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device)
|
||||
y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device)
|
||||
y = torch.zeros((1, y_dim, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device)
|
||||
else:
|
||||
y = y[:, -16:]
|
||||
y = y[:, -y_dim:]
|
||||
y = torch.concat([control_latents, y], dim=1)
|
||||
return {"clip_feature": clip_feature, "y": y}
|
||||
|
||||
@@ -698,6 +788,8 @@ class WanVideoUnit_FunReference(PipelineUnit):
|
||||
reference_image = reference_image.resize((width, height))
|
||||
reference_latents = pipe.preprocess_video([reference_image])
|
||||
reference_latents = pipe.vae.encode(reference_latents, device=pipe.device)
|
||||
if pipe.image_encoder is None:
|
||||
return {"reference_latents": reference_latents}
|
||||
clip_feature = pipe.preprocess_image(reference_image)
|
||||
clip_feature = pipe.image_encoder.encode_image([clip_feature])
|
||||
return {"reference_latents": reference_latents, "clip_feature": clip_feature}
|
||||
@@ -707,13 +799,14 @@ class WanVideoUnit_FunReference(PipelineUnit):
|
||||
class WanVideoUnit_FunCameraControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image"),
|
||||
input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image):
|
||||
def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image, tiled, tile_size, tile_stride):
|
||||
if camera_control_direction is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
camera_control_plucker_embedding = pipe.dit.control_adapter.process_camera_coordinates(
|
||||
camera_control_direction, num_frames, height, width, camera_control_speed, camera_control_origin)
|
||||
|
||||
@@ -728,14 +821,27 @@ class WanVideoUnit_FunCameraControl(PipelineUnit):
|
||||
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
|
||||
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
|
||||
control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
|
||||
|
||||
input_image = input_image.resize((width, height))
|
||||
input_latents = pipe.preprocess_video([input_image])
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_latents = pipe.vae.encode(input_latents, device=pipe.device)
|
||||
y = torch.zeros_like(latents).to(pipe.device)
|
||||
y[:, :, :1] = input_latents
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
if y.shape[1] != pipe.dit.in_dim - latents.shape[1]:
|
||||
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
||||
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
||||
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
|
||||
msk[:, 1:] = 0
|
||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
y = torch.cat([msk,y])
|
||||
y = y.unsqueeze(0)
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"control_camera_latents_input": control_camera_latents_input, "y": y}
|
||||
|
||||
|
||||
@@ -790,11 +896,23 @@ class WanVideoUnit_VACE(PipelineUnit):
|
||||
if vace_reference_image is None:
|
||||
pass
|
||||
else:
|
||||
vace_reference_image = pipe.preprocess_video([vace_reference_image])
|
||||
if not isinstance(vace_reference_image,list):
|
||||
vace_reference_image = [vace_reference_image]
|
||||
|
||||
vace_reference_image = pipe.preprocess_video(vace_reference_image)
|
||||
|
||||
bs, c, f, h, w = vace_reference_image.shape
|
||||
new_vace_ref_images = []
|
||||
for j in range(f):
|
||||
new_vace_ref_images.append(vace_reference_image[0, :, j:j+1])
|
||||
vace_reference_image = new_vace_ref_images
|
||||
|
||||
vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
|
||||
vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2)
|
||||
vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2)
|
||||
vace_reference_latents = [u.unsqueeze(0) for u in vace_reference_latents]
|
||||
|
||||
vace_video_latents = torch.concat((*vace_reference_latents, vace_video_latents), dim=2)
|
||||
vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :f]), vace_mask_latents), dim=2)
|
||||
|
||||
vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)
|
||||
return {"vace_context": vace_context, "vace_scale": vace_scale}
|
||||
@@ -851,6 +969,187 @@ class WanVideoUnit_CfgMerger(PipelineUnit):
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class WanVideoUnit_S2V(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
onload_model_names=("audio_encoder", "vae",)
|
||||
)
|
||||
|
||||
def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames, fps=16, audio_embeds=None, return_all=False):
|
||||
if audio_embeds is not None:
|
||||
return {"audio_embeds": audio_embeds}
|
||||
pipe.load_models_to_device(["audio_encoder"])
|
||||
audio_embeds = pipe.audio_encoder.get_audio_feats_per_inference(input_audio, audio_sample_rate, pipe.audio_processor, fps=fps, batch_frames=num_frames-1, dtype=pipe.torch_dtype, device=pipe.device)
|
||||
if return_all:
|
||||
return audio_embeds
|
||||
else:
|
||||
return {"audio_embeds": audio_embeds[0]}
|
||||
|
||||
def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride, motion_video=None):
|
||||
pipe.load_models_to_device(["vae"])
|
||||
motion_frames = 73
|
||||
kwargs = {}
|
||||
if motion_video is not None and len(motion_video) > 0:
|
||||
assert len(motion_video) == motion_frames, f"motion video must have {motion_frames} frames, but got {len(motion_video)}"
|
||||
motion_latents = pipe.preprocess_video(motion_video)
|
||||
kwargs["drop_motion_frames"] = False
|
||||
else:
|
||||
motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device)
|
||||
kwargs["drop_motion_frames"] = True
|
||||
motion_latents = pipe.vae.encode(motion_latents, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
kwargs.update({"motion_latents": motion_latents})
|
||||
return kwargs
|
||||
|
||||
def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=None, num_repeats=1, return_all=False):
|
||||
if s2v_pose_latents is not None:
|
||||
return {"s2v_pose_latents": s2v_pose_latents}
|
||||
if s2v_pose_video is None:
|
||||
return {"s2v_pose_latents": None}
|
||||
pipe.load_models_to_device(["vae"])
|
||||
infer_frames = num_frames - 1
|
||||
input_video = pipe.preprocess_video(s2v_pose_video)[:, :, :infer_frames * num_repeats]
|
||||
# pad if not enough frames
|
||||
padding_frames = infer_frames * num_repeats - input_video.shape[2]
|
||||
input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2)
|
||||
input_videos = input_video.chunk(num_repeats, dim=2)
|
||||
pose_conds = []
|
||||
for r in range(num_repeats):
|
||||
cond = input_videos[r]
|
||||
cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond], dim=2)
|
||||
cond_latents = pipe.vae.encode(cond, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
pose_conds.append(cond_latents[:,:,1:])
|
||||
if return_all:
|
||||
return pose_conds
|
||||
else:
|
||||
return {"s2v_pose_latents": pose_conds[0]}
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
if (inputs_shared.get("input_audio") is None and inputs_shared.get("audio_embeds") is None) or pipe.audio_encoder is None or pipe.audio_processor is None:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
num_frames, height, width, tiled, tile_size, tile_stride = inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride")
|
||||
input_audio, audio_embeds, audio_sample_rate = inputs_shared.pop("input_audio"), inputs_shared.pop("audio_embeds"), inputs_shared.get("audio_sample_rate")
|
||||
s2v_pose_video, s2v_pose_latents, motion_video = inputs_shared.pop("s2v_pose_video"), inputs_shared.pop("s2v_pose_latents"), inputs_shared.pop("motion_video")
|
||||
|
||||
audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames, audio_embeds=audio_embeds)
|
||||
inputs_posi.update(audio_input_positive)
|
||||
inputs_nega.update({"audio_embeds": 0.0 * audio_input_positive["audio_embeds"]})
|
||||
|
||||
inputs_shared.update(self.process_motion_latents(pipe, height, width, tiled, tile_size, tile_stride, motion_video))
|
||||
inputs_shared.update(self.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=s2v_pose_latents))
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
@staticmethod
|
||||
def pre_calculate_audio_pose(pipe: WanVideoPipeline, input_audio=None, audio_sample_rate=16000, s2v_pose_video=None, num_frames=81, height=448, width=832, fps=16, tiled=True, tile_size=(30, 52), tile_stride=(15, 26)):
|
||||
assert pipe.audio_encoder is not None and pipe.audio_processor is not None, "Please load audio encoder and audio processor first."
|
||||
shapes = WanVideoUnit_ShapeChecker().process(pipe, height, width, num_frames)
|
||||
height, width, num_frames = shapes["height"], shapes["width"], shapes["num_frames"]
|
||||
unit = WanVideoUnit_S2V()
|
||||
audio_embeds = unit.process_audio(pipe, input_audio, audio_sample_rate, num_frames, fps, return_all=True)
|
||||
pose_latents = unit.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, num_repeats=len(audio_embeds), return_all=True, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
pose_latents = None if s2v_pose_video is None else pose_latents
|
||||
return audio_embeds, pose_latents, len(audio_embeds)
|
||||
|
||||
|
||||
class WanVideoPostUnit_S2V(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=("latents", "motion_latents", "drop_motion_frames"))
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, latents, motion_latents, drop_motion_frames):
|
||||
if pipe.audio_encoder is None or motion_latents is None or drop_motion_frames:
|
||||
return {}
|
||||
latents = torch.cat([motion_latents, latents[:,:,1:]], dim=2)
|
||||
return {"latents": latents}
|
||||
|
||||
|
||||
class WanVideoPostUnit_AnimateVideoSplit(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=("input_video", "animate_pose_video", "animate_face_video", "animate_inpaint_video", "animate_mask_video"))
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, input_video, animate_pose_video, animate_face_video, animate_inpaint_video, animate_mask_video):
|
||||
if input_video is None:
|
||||
return {}
|
||||
if animate_pose_video is not None:
|
||||
animate_pose_video = animate_pose_video[:len(input_video) - 4]
|
||||
if animate_face_video is not None:
|
||||
animate_face_video = animate_face_video[:len(input_video) - 4]
|
||||
if animate_inpaint_video is not None:
|
||||
animate_inpaint_video = animate_inpaint_video[:len(input_video) - 4]
|
||||
if animate_mask_video is not None:
|
||||
animate_mask_video = animate_mask_video[:len(input_video) - 4]
|
||||
return {"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video}
|
||||
|
||||
|
||||
class WanVideoPostUnit_AnimatePoseLatents(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("animate_pose_video", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, animate_pose_video, tiled, tile_size, tile_stride):
|
||||
if animate_pose_video is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
animate_pose_video = pipe.preprocess_video(animate_pose_video)
|
||||
pose_latents = pipe.vae.encode(animate_pose_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"pose_latents": pose_latents}
|
||||
|
||||
|
||||
class WanVideoPostUnit_AnimateFacePixelValues(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(take_over=True)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
if inputs_shared.get("animate_face_video", None) is None:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
inputs_posi["face_pixel_values"] = pipe.preprocess_video(inputs_shared["animate_face_video"])
|
||||
inputs_nega["face_pixel_values"] = torch.zeros_like(inputs_posi["face_pixel_values"]) - 1
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class WanVideoPostUnit_AnimateInpaint(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("animate_inpaint_video", "animate_mask_video", "input_image", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"):
|
||||
if mask_pixel_values is None:
|
||||
msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
|
||||
else:
|
||||
msk = mask_pixel_values.clone()
|
||||
msk[:, :mask_len] = 1
|
||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
return msk
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, animate_inpaint_video, animate_mask_video, input_image, tiled, tile_size, tile_stride):
|
||||
if animate_inpaint_video is None or animate_mask_video is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
|
||||
bg_pixel_values = pipe.preprocess_video(animate_inpaint_video)
|
||||
y_reft = pipe.vae.encode(bg_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0].to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
_, lat_t, lat_h, lat_w = y_reft.shape
|
||||
|
||||
ref_pixel_values = pipe.preprocess_video([input_image])
|
||||
ref_latents = pipe.vae.encode(ref_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
mask_ref = self.get_i2v_mask(1, lat_h, lat_w, 1, device=pipe.device)
|
||||
y_ref = torch.concat([mask_ref, ref_latents[0]]).to(dtype=torch.bfloat16, device=pipe.device)
|
||||
|
||||
mask_pixel_values = 1 - pipe.preprocess_video(animate_mask_video, max_value=1, min_value=0)
|
||||
mask_pixel_values = rearrange(mask_pixel_values, "b c t h w -> (b t) c h w")
|
||||
mask_pixel_values = torch.nn.functional.interpolate(mask_pixel_values, size=(lat_h, lat_w), mode='nearest')
|
||||
mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b t c h w", b=1)[:,:,0]
|
||||
msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, 0, mask_pixel_values=mask_pixel_values, device=pipe.device)
|
||||
|
||||
y_reft = torch.concat([msk_reft, y_reft]).to(dtype=torch.bfloat16, device=pipe.device)
|
||||
y = torch.concat([y_ref, y_reft], dim=1).unsqueeze(0)
|
||||
return {"y": y}
|
||||
|
||||
|
||||
class TeaCache:
|
||||
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
||||
@@ -962,6 +1261,7 @@ def model_fn_wan_video(
|
||||
dit: WanModel,
|
||||
motion_controller: WanMotionControllerModel = None,
|
||||
vace: VaceWanModel = None,
|
||||
animate_adapter: WanAnimateAdapter = None,
|
||||
latents: torch.Tensor = None,
|
||||
timestep: torch.Tensor = None,
|
||||
context: torch.Tensor = None,
|
||||
@@ -970,9 +1270,15 @@ def model_fn_wan_video(
|
||||
reference_latents = None,
|
||||
vace_context = None,
|
||||
vace_scale = 1.0,
|
||||
audio_embeds: Optional[torch.Tensor] = None,
|
||||
motion_latents: Optional[torch.Tensor] = None,
|
||||
s2v_pose_latents: Optional[torch.Tensor] = None,
|
||||
drop_motion_frames: bool = True,
|
||||
tea_cache: TeaCache = None,
|
||||
use_unified_sequence_parallel: bool = False,
|
||||
motion_bucket_id: Optional[torch.Tensor] = None,
|
||||
pose_latents=None,
|
||||
face_pixel_values=None,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
sliding_window_stride: Optional[int] = None,
|
||||
cfg_merge: bool = False,
|
||||
@@ -1007,7 +1313,22 @@ def model_fn_wan_video(
|
||||
tensor_names=["latents", "y"],
|
||||
batch_size=2 if cfg_merge else 1
|
||||
)
|
||||
|
||||
# wan2.2 s2v
|
||||
if audio_embeds is not None:
|
||||
return model_fn_wans2v(
|
||||
dit=dit,
|
||||
latents=latents,
|
||||
timestep=timestep,
|
||||
context=context,
|
||||
audio_embeds=audio_embeds,
|
||||
motion_latents=motion_latents,
|
||||
s2v_pose_latents=s2v_pose_latents,
|
||||
drop_motion_frames=drop_motion_frames,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_unified_sequence_parallel=use_unified_sequence_parallel,
|
||||
)
|
||||
|
||||
if use_unified_sequence_parallel:
|
||||
import torch.distributed as dist
|
||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||
@@ -1021,6 +1342,10 @@ def model_fn_wan_video(
|
||||
torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep
|
||||
]).flatten()
|
||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0))
|
||||
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
||||
t_chunks = torch.chunk(t, get_sequence_parallel_world_size(), dim=1)
|
||||
t_chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, t_chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in t_chunks]
|
||||
t = t_chunks[get_sequence_parallel_rank()]
|
||||
t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim))
|
||||
else:
|
||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||
@@ -1045,8 +1370,16 @@ def model_fn_wan_video(
|
||||
clip_embdding = dit.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
|
||||
# Add camera control
|
||||
x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
|
||||
# Camera control
|
||||
x = dit.patchify(x, control_camera_latents_input)
|
||||
|
||||
# Animate
|
||||
if pose_latents is not None and face_pixel_values is not None:
|
||||
x, motion_vec = animate_adapter.after_patch_embedding(x, pose_latents, face_pixel_values)
|
||||
|
||||
# Patchify
|
||||
f, h, w = x.shape[2:]
|
||||
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
||||
|
||||
# Reference image
|
||||
if reference_latents is not None:
|
||||
@@ -1069,7 +1402,11 @@ def model_fn_wan_video(
|
||||
tea_cache_update = False
|
||||
|
||||
if vace_context is not None:
|
||||
vace_hints = vace(x, vace_context, context, t_mod, freqs)
|
||||
vace_hints = vace(
|
||||
x, vace_context, context, t_mod, freqs,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload
|
||||
)
|
||||
|
||||
# blocks
|
||||
if use_unified_sequence_parallel:
|
||||
@@ -1087,6 +1424,7 @@ def model_fn_wan_video(
|
||||
return custom_forward
|
||||
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
# Block
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
@@ -1102,12 +1440,18 @@ def model_fn_wan_video(
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
|
||||
# VACE
|
||||
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
||||
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
|
||||
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
||||
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||
current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0)
|
||||
x = x + current_vace_hint * vace_scale
|
||||
|
||||
# Animate
|
||||
if pose_latents is not None and face_pixel_values is not None:
|
||||
x = animate_adapter.after_transformer_block(block_id, x, motion_vec)
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(x)
|
||||
|
||||
@@ -1122,3 +1466,105 @@ def model_fn_wan_video(
|
||||
f -= 1
|
||||
x = dit.unpatchify(x, (f, h, w))
|
||||
return x
|
||||
|
||||
|
||||
def model_fn_wans2v(
|
||||
dit,
|
||||
latents,
|
||||
timestep,
|
||||
context,
|
||||
audio_embeds,
|
||||
motion_latents,
|
||||
s2v_pose_latents,
|
||||
drop_motion_frames=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
use_gradient_checkpointing=False,
|
||||
use_unified_sequence_parallel=False,
|
||||
):
|
||||
if use_unified_sequence_parallel:
|
||||
import torch.distributed as dist
|
||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||
get_sequence_parallel_world_size,
|
||||
get_sp_group)
|
||||
origin_ref_latents = latents[:, :, 0:1]
|
||||
x = latents[:, :, 1:]
|
||||
|
||||
# context embedding
|
||||
context = dit.text_embedding(context)
|
||||
|
||||
# audio encode
|
||||
audio_emb_global, merged_audio_emb = dit.cal_audio_emb(audio_embeds)
|
||||
|
||||
# x and s2v_pose_latents
|
||||
s2v_pose_latents = torch.zeros_like(x) if s2v_pose_latents is None else s2v_pose_latents
|
||||
x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(s2v_pose_latents))
|
||||
seq_len_x = seq_len_x_global = x.shape[1] # global used for unified sequence parallel
|
||||
|
||||
# reference image
|
||||
ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents))
|
||||
grid_sizes = dit.get_grid_sizes((f, h, w), (rf, rh, rw))
|
||||
x = torch.cat([x, ref_latents], dim=1)
|
||||
# mask
|
||||
mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device)
|
||||
# freqs
|
||||
pre_compute_freqs = rope_precompute(x.detach().view(1, x.size(1), dit.num_heads, dit.dim // dit.num_heads), grid_sizes, dit.freqs, start=None)
|
||||
# motion
|
||||
x, pre_compute_freqs, mask = dit.inject_motion(x, pre_compute_freqs, mask, motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=2)
|
||||
|
||||
x = x + dit.trainable_cond_mask(mask).to(x.dtype)
|
||||
|
||||
# tmod
|
||||
timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)])
|
||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)).unsqueeze(2).transpose(0, 2)
|
||||
|
||||
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
||||
world_size, sp_rank = get_sequence_parallel_world_size(), get_sequence_parallel_rank()
|
||||
assert x.shape[1] % world_size == 0, f"the dimension after chunk must be divisible by world size, but got {x.shape[1]} and {get_sequence_parallel_world_size()}"
|
||||
x = torch.chunk(x, world_size, dim=1)[sp_rank]
|
||||
seg_idxs = [0] + list(torch.cumsum(torch.tensor([x.shape[1]] * world_size), dim=0).cpu().numpy())
|
||||
seq_len_x_list = [min(max(0, seq_len_x - seg_idxs[i]), x.shape[1]) for i in range(len(seg_idxs)-1)]
|
||||
seq_len_x = seq_len_x_list[sp_rank]
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, seq_len_x, pre_compute_freqs[0],
|
||||
use_reentrant=False,
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
||||
x,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, seq_len_x, pre_compute_freqs[0],
|
||||
use_reentrant=False,
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
||||
x,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
|
||||
x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel)
|
||||
|
||||
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
||||
x = get_sp_group().all_gather(x, dim=1)
|
||||
|
||||
x = x[:, :seq_len_x_global]
|
||||
x = dit.head(x, t[:-1])
|
||||
x = dit.unpatchify(x, (f, h, w))
|
||||
# make compatible with wan video
|
||||
x = torch.cat([origin_ref_latents, x], dim=2)
|
||||
return x
|
||||
|
||||
@@ -31,7 +31,7 @@ class FlowMatchScheduler():
|
||||
self.set_timesteps(num_inference_steps)
|
||||
|
||||
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None, dynamic_shift_len=None):
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None, dynamic_shift_len=None, exponential_shift_mu=None):
|
||||
if shift is not None:
|
||||
self.shift = shift
|
||||
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
||||
@@ -42,7 +42,12 @@ class FlowMatchScheduler():
|
||||
if self.inverse_timesteps:
|
||||
self.sigmas = torch.flip(self.sigmas, dims=[0])
|
||||
if self.exponential_shift:
|
||||
mu = self.calculate_shift(dynamic_shift_len) if dynamic_shift_len is not None else self.exponential_shift_mu
|
||||
if exponential_shift_mu is not None:
|
||||
mu = exponential_shift_mu
|
||||
elif dynamic_shift_len is not None:
|
||||
mu = self.calculate_shift(dynamic_shift_len)
|
||||
else:
|
||||
mu = self.exponential_shift_mu
|
||||
self.sigmas = math.exp(mu) / (math.exp(mu) + (1 / self.sigmas - 1))
|
||||
else:
|
||||
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
|
||||
|
||||
337
diffsynth/trainers/unified_dataset.py
Normal file
337
diffsynth/trainers/unified_dataset.py
Normal file
@@ -0,0 +1,337 @@
|
||||
import torch, torchvision, imageio, os, json, pandas
|
||||
import imageio.v3 as iio
|
||||
from PIL import Image
|
||||
|
||||
|
||||
|
||||
class DataProcessingPipeline:
|
||||
def __init__(self, operators=None):
|
||||
self.operators: list[DataProcessingOperator] = [] if operators is None else operators
|
||||
|
||||
def __call__(self, data):
|
||||
for operator in self.operators:
|
||||
data = operator(data)
|
||||
return data
|
||||
|
||||
def __rshift__(self, pipe):
|
||||
if isinstance(pipe, DataProcessingOperator):
|
||||
pipe = DataProcessingPipeline([pipe])
|
||||
return DataProcessingPipeline(self.operators + pipe.operators)
|
||||
|
||||
|
||||
|
||||
class DataProcessingOperator:
|
||||
def __call__(self, data):
|
||||
raise NotImplementedError("DataProcessingOperator cannot be called directly.")
|
||||
|
||||
def __rshift__(self, pipe):
|
||||
if isinstance(pipe, DataProcessingOperator):
|
||||
pipe = DataProcessingPipeline([pipe])
|
||||
return DataProcessingPipeline([self]).__rshift__(pipe)
|
||||
|
||||
|
||||
|
||||
class DataProcessingOperatorRaw(DataProcessingOperator):
|
||||
def __call__(self, data):
|
||||
return data
|
||||
|
||||
|
||||
|
||||
class ToInt(DataProcessingOperator):
|
||||
def __call__(self, data):
|
||||
return int(data)
|
||||
|
||||
|
||||
|
||||
class ToFloat(DataProcessingOperator):
|
||||
def __call__(self, data):
|
||||
return float(data)
|
||||
|
||||
|
||||
|
||||
class ToStr(DataProcessingOperator):
|
||||
def __init__(self, none_value=""):
|
||||
self.none_value = none_value
|
||||
|
||||
def __call__(self, data):
|
||||
if data is None: data = self.none_value
|
||||
return str(data)
|
||||
|
||||
|
||||
|
||||
class LoadImage(DataProcessingOperator):
|
||||
def __init__(self, convert_RGB=True):
|
||||
self.convert_RGB = convert_RGB
|
||||
|
||||
def __call__(self, data: str):
|
||||
image = Image.open(data)
|
||||
if self.convert_RGB: image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
|
||||
|
||||
class ImageCropAndResize(DataProcessingOperator):
|
||||
def __init__(self, height, width, max_pixels, height_division_factor, width_division_factor):
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.max_pixels = max_pixels
|
||||
self.height_division_factor = height_division_factor
|
||||
self.width_division_factor = width_division_factor
|
||||
|
||||
def crop_and_resize(self, image, target_height, target_width):
|
||||
width, height = image.size
|
||||
scale = max(target_width / width, target_height / height)
|
||||
image = torchvision.transforms.functional.resize(
|
||||
image,
|
||||
(round(height*scale), round(width*scale)),
|
||||
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||
)
|
||||
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
|
||||
return image
|
||||
|
||||
def get_height_width(self, image):
|
||||
if self.height is None or self.width is None:
|
||||
width, height = image.size
|
||||
if width * height > self.max_pixels:
|
||||
scale = (width * height / self.max_pixels) ** 0.5
|
||||
height, width = int(height / scale), int(width / scale)
|
||||
height = height // self.height_division_factor * self.height_division_factor
|
||||
width = width // self.width_division_factor * self.width_division_factor
|
||||
else:
|
||||
height, width = self.height, self.width
|
||||
return height, width
|
||||
|
||||
|
||||
def __call__(self, data: Image.Image):
|
||||
image = self.crop_and_resize(data, *self.get_height_width(data))
|
||||
return image
|
||||
|
||||
|
||||
|
||||
class ToList(DataProcessingOperator):
|
||||
def __call__(self, data):
|
||||
return [data]
|
||||
|
||||
|
||||
|
||||
class LoadVideo(DataProcessingOperator):
|
||||
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
|
||||
self.num_frames = num_frames
|
||||
self.time_division_factor = time_division_factor
|
||||
self.time_division_remainder = time_division_remainder
|
||||
# frame_processor is build in the video loader for high efficiency.
|
||||
self.frame_processor = frame_processor
|
||||
|
||||
def get_num_frames(self, reader):
|
||||
num_frames = self.num_frames
|
||||
if int(reader.count_frames()) < num_frames:
|
||||
num_frames = int(reader.count_frames())
|
||||
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
||||
num_frames -= 1
|
||||
return num_frames
|
||||
|
||||
def __call__(self, data: str):
|
||||
reader = imageio.get_reader(data)
|
||||
num_frames = self.get_num_frames(reader)
|
||||
frames = []
|
||||
for frame_id in range(num_frames):
|
||||
frame = reader.get_data(frame_id)
|
||||
frame = Image.fromarray(frame)
|
||||
frame = self.frame_processor(frame)
|
||||
frames.append(frame)
|
||||
reader.close()
|
||||
return frames
|
||||
|
||||
|
||||
|
||||
class SequencialProcess(DataProcessingOperator):
|
||||
def __init__(self, operator=lambda x: x):
|
||||
self.operator = operator
|
||||
|
||||
def __call__(self, data):
|
||||
return [self.operator(i) for i in data]
|
||||
|
||||
|
||||
|
||||
class LoadGIF(DataProcessingOperator):
|
||||
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
|
||||
self.num_frames = num_frames
|
||||
self.time_division_factor = time_division_factor
|
||||
self.time_division_remainder = time_division_remainder
|
||||
# frame_processor is build in the video loader for high efficiency.
|
||||
self.frame_processor = frame_processor
|
||||
|
||||
def get_num_frames(self, path):
|
||||
num_frames = self.num_frames
|
||||
images = iio.imread(path, mode="RGB")
|
||||
if len(images) < num_frames:
|
||||
num_frames = len(images)
|
||||
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
||||
num_frames -= 1
|
||||
return num_frames
|
||||
|
||||
def __call__(self, data: str):
|
||||
num_frames = self.get_num_frames(data)
|
||||
frames = []
|
||||
images = iio.imread(data, mode="RGB")
|
||||
for img in images:
|
||||
frame = Image.fromarray(img)
|
||||
frame = self.frame_processor(frame)
|
||||
frames.append(frame)
|
||||
if len(frames) >= num_frames:
|
||||
break
|
||||
return frames
|
||||
|
||||
|
||||
|
||||
class RouteByExtensionName(DataProcessingOperator):
|
||||
def __init__(self, operator_map):
|
||||
self.operator_map = operator_map
|
||||
|
||||
def __call__(self, data: str):
|
||||
file_ext_name = data.split(".")[-1].lower()
|
||||
for ext_names, operator in self.operator_map:
|
||||
if ext_names is None or file_ext_name in ext_names:
|
||||
return operator(data)
|
||||
raise ValueError(f"Unsupported file: {data}")
|
||||
|
||||
|
||||
|
||||
class RouteByType(DataProcessingOperator):
|
||||
def __init__(self, operator_map):
|
||||
self.operator_map = operator_map
|
||||
|
||||
def __call__(self, data):
|
||||
for dtype, operator in self.operator_map:
|
||||
if dtype is None or isinstance(data, dtype):
|
||||
return operator(data)
|
||||
raise ValueError(f"Unsupported data: {data}")
|
||||
|
||||
|
||||
|
||||
class LoadTorchPickle(DataProcessingOperator):
|
||||
def __init__(self, map_location="cpu"):
|
||||
self.map_location = map_location
|
||||
|
||||
def __call__(self, data):
|
||||
return torch.load(data, map_location=self.map_location, weights_only=False)
|
||||
|
||||
|
||||
|
||||
class ToAbsolutePath(DataProcessingOperator):
|
||||
def __init__(self, base_path=""):
|
||||
self.base_path = base_path
|
||||
|
||||
def __call__(self, data):
|
||||
return os.path.join(self.base_path, data)
|
||||
|
||||
|
||||
|
||||
class UnifiedDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
base_path=None, metadata_path=None,
|
||||
repeat=1,
|
||||
data_file_keys=tuple(),
|
||||
main_data_operator=lambda x: x,
|
||||
special_operator_map=None,
|
||||
):
|
||||
self.base_path = base_path
|
||||
self.metadata_path = metadata_path
|
||||
self.repeat = repeat
|
||||
self.data_file_keys = data_file_keys
|
||||
self.main_data_operator = main_data_operator
|
||||
self.cached_data_operator = LoadTorchPickle()
|
||||
self.special_operator_map = {} if special_operator_map is None else special_operator_map
|
||||
self.data = []
|
||||
self.cached_data = []
|
||||
self.load_from_cache = metadata_path is None
|
||||
self.load_metadata(metadata_path)
|
||||
|
||||
@staticmethod
|
||||
def default_image_operator(
|
||||
base_path="",
|
||||
max_pixels=1920*1080, height=None, width=None,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
):
|
||||
return RouteByType(operator_map=[
|
||||
(str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)),
|
||||
(list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))),
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def default_video_operator(
|
||||
base_path="",
|
||||
max_pixels=1920*1080, height=None, width=None,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
num_frames=81, time_division_factor=4, time_division_remainder=1,
|
||||
):
|
||||
return RouteByType(operator_map=[
|
||||
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
|
||||
(("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),
|
||||
(("gif",), LoadGIF(
|
||||
num_frames, time_division_factor, time_division_remainder,
|
||||
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
||||
)),
|
||||
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
|
||||
num_frames, time_division_factor, time_division_remainder,
|
||||
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
||||
)),
|
||||
])),
|
||||
])
|
||||
|
||||
def search_for_cached_data_files(self, path):
|
||||
for file_name in os.listdir(path):
|
||||
subpath = os.path.join(path, file_name)
|
||||
if os.path.isdir(subpath):
|
||||
self.search_for_cached_data_files(subpath)
|
||||
elif subpath.endswith(".pth"):
|
||||
self.cached_data.append(subpath)
|
||||
|
||||
def load_metadata(self, metadata_path):
|
||||
if metadata_path is None:
|
||||
print("No metadata_path. Searching for cached data files.")
|
||||
self.search_for_cached_data_files(self.base_path)
|
||||
print(f"{len(self.cached_data)} cached data files found.")
|
||||
elif metadata_path.endswith(".json"):
|
||||
with open(metadata_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
self.data = metadata
|
||||
elif metadata_path.endswith(".jsonl"):
|
||||
metadata = []
|
||||
with open(metadata_path, 'r') as f:
|
||||
for line in f:
|
||||
metadata.append(json.loads(line.strip()))
|
||||
self.data = metadata
|
||||
else:
|
||||
metadata = pandas.read_csv(metadata_path)
|
||||
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
||||
|
||||
def __getitem__(self, data_id):
|
||||
if self.load_from_cache:
|
||||
data = self.cached_data[data_id % len(self.cached_data)]
|
||||
data = self.cached_data_operator(data)
|
||||
else:
|
||||
data = self.data[data_id % len(self.data)].copy()
|
||||
for key in self.data_file_keys:
|
||||
if key in data:
|
||||
if key in self.special_operator_map:
|
||||
data[key] = self.special_operator_map[key](data[key])
|
||||
elif key in self.data_file_keys:
|
||||
data[key] = self.main_data_operator(data[key])
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
if self.load_from_cache:
|
||||
return len(self.cached_data) * self.repeat
|
||||
else:
|
||||
return len(self.data) * self.repeat
|
||||
|
||||
def check_data_equal(self, data1, data2):
|
||||
# Debug only
|
||||
if len(data1) != len(data2):
|
||||
return False
|
||||
for k in data1:
|
||||
if data1[k] != data2[k]:
|
||||
return False
|
||||
return True
|
||||
@@ -1,9 +1,12 @@
|
||||
import imageio, os, torch, warnings, torchvision, argparse, json
|
||||
from ..utils import ModelConfig
|
||||
from ..models.utils import load_state_dict
|
||||
from peft import LoraConfig, inject_adapter_in_model
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
|
||||
|
||||
@@ -153,7 +156,7 @@ class VideoDataset(torch.utils.data.Dataset):
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
data_file_keys=("video",),
|
||||
image_file_extension=("jpg", "jpeg", "png", "webp"),
|
||||
video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"),
|
||||
video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm", "gif"),
|
||||
repeat=1,
|
||||
args=None,
|
||||
):
|
||||
@@ -258,8 +261,53 @@ class VideoDataset(torch.utils.data.Dataset):
|
||||
num_frames -= 1
|
||||
return num_frames
|
||||
|
||||
|
||||
def _load_gif(self, file_path):
|
||||
gif_img = Image.open(file_path)
|
||||
frame_count = 0
|
||||
delays, frames = [], []
|
||||
while True:
|
||||
delay = gif_img.info.get('duration', 100) # ms
|
||||
delays.append(delay)
|
||||
rgb_frame = gif_img.convert("RGB")
|
||||
croped_frame = self.crop_and_resize(rgb_frame, *self.get_height_width(rgb_frame))
|
||||
frames.append(croped_frame)
|
||||
frame_count += 1
|
||||
try:
|
||||
gif_img.seek(frame_count)
|
||||
except:
|
||||
break
|
||||
# delays canbe used to calculate framerates
|
||||
# i guess it is better to sample images with stable interval,
|
||||
# and using minimal_interval as the interval,
|
||||
# and framerate = 1000 / minimal_interval
|
||||
if any((delays[0] != i) for i in delays):
|
||||
minimal_interval = min([i for i in delays if i > 0])
|
||||
# make a ((start,end),frameid) struct
|
||||
start_end_idx_map = [((sum(delays[:i]), sum(delays[:i+1])), i) for i in range(len(delays))]
|
||||
_frames = []
|
||||
# according gemini-code-assist, make it more efficient to locate
|
||||
# where to sample the frame
|
||||
last_match = 0
|
||||
for i in range(sum(delays) // minimal_interval):
|
||||
current_time = minimal_interval * i
|
||||
for idx, ((start, end), frame_idx) in enumerate(start_end_idx_map[last_match:]):
|
||||
if start <= current_time < end:
|
||||
_frames.append(frames[frame_idx])
|
||||
last_match = idx + last_match
|
||||
break
|
||||
frames = _frames
|
||||
num_frames = len(frames)
|
||||
if num_frames > self.num_frames:
|
||||
num_frames = self.num_frames
|
||||
else:
|
||||
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
||||
num_frames -= 1
|
||||
frames = frames[:num_frames]
|
||||
return frames
|
||||
|
||||
def load_video(self, file_path):
|
||||
if file_path.lower().endswith(".gif"):
|
||||
return self._load_gif(file_path)
|
||||
reader = imageio.get_reader(file_path)
|
||||
num_frames = self.get_num_frames(reader)
|
||||
frames = []
|
||||
@@ -337,14 +385,29 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
return trainable_param_names
|
||||
|
||||
|
||||
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None):
|
||||
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None):
|
||||
if lora_alpha is None:
|
||||
lora_alpha = lora_rank
|
||||
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
|
||||
model = inject_adapter_in_model(lora_config, model)
|
||||
if upcast_dtype is not None:
|
||||
for param in model.parameters():
|
||||
if param.requires_grad:
|
||||
param.data = param.to(upcast_dtype)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
|
||||
def mapping_lora_state_dict(self, state_dict):
|
||||
new_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if "lora_A.weight" in key or "lora_B.weight" in key:
|
||||
new_key = key.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight")
|
||||
new_state_dict[new_key] = value
|
||||
elif "lora_A.default.weight" in key or "lora_B.default.weight" in key:
|
||||
new_state_dict[key] = value
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def export_trainable_state_dict(self, state_dict, remove_prefix=None):
|
||||
trainable_param_names = self.trainable_param_names()
|
||||
state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names}
|
||||
@@ -356,7 +419,120 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
state_dict_[name] = param
|
||||
state_dict = state_dict_
|
||||
return state_dict
|
||||
|
||||
|
||||
def transfer_data_to_device(self, data, device, torch_float_dtype=None):
|
||||
for key in data:
|
||||
if isinstance(data[key], torch.Tensor):
|
||||
data[key] = data[key].to(device)
|
||||
if torch_float_dtype is not None and data[key].dtype in [torch.float, torch.float16, torch.bfloat16]:
|
||||
data[key] = data[key].to(torch_float_dtype)
|
||||
return data
|
||||
|
||||
|
||||
def parse_model_configs(self, model_paths, model_id_with_origin_paths, enable_fp8_training=False):
|
||||
offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None
|
||||
model_configs = []
|
||||
if model_paths is not None:
|
||||
model_paths = json.loads(model_paths)
|
||||
model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths]
|
||||
if model_id_with_origin_paths is not None:
|
||||
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
||||
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths]
|
||||
return model_configs
|
||||
|
||||
|
||||
def switch_pipe_to_training_mode(
|
||||
self,
|
||||
pipe,
|
||||
trainable_models,
|
||||
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=None,
|
||||
enable_fp8_training=False,
|
||||
):
|
||||
# Scheduler
|
||||
pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
# Freeze untrainable models
|
||||
pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
|
||||
|
||||
# Enable FP8 if pipeline supports
|
||||
if enable_fp8_training and hasattr(pipe, "_enable_fp8_lora_training"):
|
||||
pipe._enable_fp8_lora_training(torch.float8_e4m3fn)
|
||||
|
||||
# Add LoRA to the base models
|
||||
if lora_base_model is not None:
|
||||
model = self.add_lora_to_model(
|
||||
getattr(pipe, lora_base_model),
|
||||
target_modules=lora_target_modules.split(","),
|
||||
lora_rank=lora_rank,
|
||||
upcast_dtype=pipe.torch_dtype,
|
||||
)
|
||||
if lora_checkpoint is not None:
|
||||
state_dict = load_state_dict(lora_checkpoint)
|
||||
state_dict = self.mapping_lora_state_dict(state_dict)
|
||||
load_result = model.load_state_dict(state_dict, strict=False)
|
||||
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
|
||||
if len(load_result[1]) > 0:
|
||||
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
||||
setattr(pipe, lora_base_model, model)
|
||||
|
||||
def disable_all_lora_layers(self, model):
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, 'enable_adapters'):
|
||||
module.enable_adapters(False)
|
||||
|
||||
def enable_all_lora_layers(self, model):
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, 'enable_adapters'):
|
||||
module.enable_adapters(True)
|
||||
|
||||
|
||||
class DPOLoss:
|
||||
def __init__(self, beta=2500):
|
||||
self.beta = beta
|
||||
|
||||
def sample_timestep(self, pipe):
|
||||
timestep_id = torch.randint(0, pipe.scheduler.num_train_timesteps, (1,))
|
||||
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return timestep
|
||||
|
||||
def training_loss_minimum(self, pipe, noise, timestep, **inputs):
|
||||
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||
noise_pred = pipe.model_fn(**inputs, timestep=timestep)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * pipe.scheduler.training_weight(timestep)
|
||||
return loss
|
||||
|
||||
def loss(self, model, data):
|
||||
# Loss DPO: -logσ(−β(diff_policy − diff_ref))
|
||||
# Prepare inputs
|
||||
win_data = {key: data[key] for key in ["prompt", "image"]}
|
||||
lose_data = {"prompt": data["prompt"], "image": data["lose_image"]}
|
||||
inputs_win = model.forward_preprocess(win_data)
|
||||
inputs_lose = model.forward_preprocess(lose_data)
|
||||
inputs_win.pop('noise')
|
||||
inputs_lose.pop('noise')
|
||||
models = {name: getattr(model.pipe, name) for name in model.pipe.in_iteration_models}
|
||||
# sample timestep and noise
|
||||
timestep = self.sample_timestep(model.pipe)
|
||||
noise = torch.rand_like(inputs_win["latents"])
|
||||
# compute diff_policy = loss_win - loss_lose
|
||||
loss_win = self.training_loss_minimum(model.pipe, noise, timestep, **models, **inputs_win)
|
||||
loss_lose = self.training_loss_minimum(model.pipe, noise, timestep, **models, **inputs_lose)
|
||||
diff_policy = loss_win - loss_lose
|
||||
# compute diff_ref
|
||||
# TODO: may support full model training
|
||||
model.disable_all_lora_layers(model.pipe.dit)
|
||||
# load the original model weights
|
||||
with torch.no_grad():
|
||||
loss_win_ref = self.training_loss_minimum(model.pipe, noise, timestep, **models, **inputs_win)
|
||||
loss_lose_ref = self.training_loss_minimum(model.pipe, noise, timestep, **models, **inputs_lose)
|
||||
diff_ref = loss_win_ref - loss_lose_ref
|
||||
model.enable_all_lora_layers(model.pipe.dit)
|
||||
# compute loss
|
||||
loss = -1. * torch.nn.functional.logsigmoid(self.beta * (diff_ref - diff_policy)).mean()
|
||||
return loss
|
||||
|
||||
|
||||
class ModelLogger:
|
||||
@@ -364,12 +540,15 @@ class ModelLogger:
|
||||
self.output_path = output_path
|
||||
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
|
||||
self.state_dict_converter = state_dict_converter
|
||||
|
||||
|
||||
def on_step_end(self, loss):
|
||||
pass
|
||||
|
||||
|
||||
self.num_steps = 0
|
||||
|
||||
|
||||
def on_step_end(self, accelerator, model, save_steps=None):
|
||||
self.num_steps += 1
|
||||
if save_steps is not None and self.num_steps % save_steps == 0:
|
||||
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
||||
|
||||
|
||||
def on_epoch_end(self, accelerator, model, epoch_id):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
@@ -381,43 +560,92 @@ class ModelLogger:
|
||||
accelerator.save(state_dict, path, safe_serialization=True)
|
||||
|
||||
|
||||
def on_training_end(self, accelerator, model, save_steps=None):
|
||||
if save_steps is not None and self.num_steps % save_steps != 0:
|
||||
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
||||
|
||||
|
||||
def save_model(self, accelerator, model, file_name):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||
state_dict = self.state_dict_converter(state_dict)
|
||||
os.makedirs(self.output_path, exist_ok=True)
|
||||
path = os.path.join(self.output_path, file_name)
|
||||
accelerator.save(state_dict, path, safe_serialization=True)
|
||||
|
||||
|
||||
def launch_training_task(
|
||||
dataset: torch.utils.data.Dataset,
|
||||
model: DiffusionTrainingModule,
|
||||
model_logger: ModelLogger,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
||||
learning_rate: float = 1e-5,
|
||||
weight_decay: float = 1e-2,
|
||||
num_workers: int = 8,
|
||||
save_steps: int = None,
|
||||
num_epochs: int = 1,
|
||||
gradient_accumulation_steps: int = 1,
|
||||
find_unused_parameters: bool = False,
|
||||
args = None,
|
||||
):
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0])
|
||||
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
|
||||
if args is not None:
|
||||
learning_rate = args.learning_rate
|
||||
weight_decay = args.weight_decay
|
||||
num_workers = args.dataset_num_workers
|
||||
save_steps = args.save_steps
|
||||
num_epochs = args.num_epochs
|
||||
gradient_accumulation_steps = args.gradient_accumulation_steps
|
||||
find_unused_parameters = args.find_unused_parameters
|
||||
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=find_unused_parameters)],
|
||||
)
|
||||
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
||||
|
||||
for epoch_id in range(num_epochs):
|
||||
for data in tqdm(dataloader):
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
loss = model(data)
|
||||
if dataset.load_from_cache:
|
||||
loss = model({}, inputs=data)
|
||||
else:
|
||||
loss = model(data)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
model_logger.on_step_end(loss)
|
||||
model_logger.on_step_end(accelerator, model, save_steps)
|
||||
scheduler.step()
|
||||
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||
if save_steps is None:
|
||||
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||
model_logger.on_training_end(accelerator, model, save_steps)
|
||||
|
||||
|
||||
|
||||
def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"):
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0])
|
||||
def launch_data_process_task(
|
||||
dataset: torch.utils.data.Dataset,
|
||||
model: DiffusionTrainingModule,
|
||||
model_logger: ModelLogger,
|
||||
num_workers: int = 8,
|
||||
args = None,
|
||||
):
|
||||
if args is not None:
|
||||
num_workers = args.dataset_num_workers
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||
accelerator = Accelerator()
|
||||
model, dataloader = accelerator.prepare(model, dataloader)
|
||||
os.makedirs(os.path.join(output_path, "data_cache"), exist_ok=True)
|
||||
for data_id, data in enumerate(tqdm(dataloader)):
|
||||
with torch.no_grad():
|
||||
inputs = model.forward_preprocess(data)
|
||||
inputs = {key: inputs[key] for key in model.model_input_keys if key in inputs}
|
||||
torch.save(inputs, os.path.join(output_path, "data_cache", f"{data_id}.pth"))
|
||||
|
||||
for data_id, data in tqdm(enumerate(dataloader)):
|
||||
with accelerator.accumulate(model):
|
||||
with torch.no_grad():
|
||||
folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
|
||||
data = model(data, return_inputs=True)
|
||||
torch.save(data, save_path)
|
||||
|
||||
|
||||
|
||||
@@ -441,11 +669,16 @@ def wan_parser():
|
||||
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
|
||||
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
|
||||
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
|
||||
parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.")
|
||||
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
||||
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
||||
parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
|
||||
parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
|
||||
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
|
||||
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
|
||||
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
|
||||
return parser
|
||||
|
||||
|
||||
@@ -469,11 +702,16 @@ def flux_parser():
|
||||
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
|
||||
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
|
||||
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
|
||||
parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.")
|
||||
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
||||
parser.add_argument("--align_to_opensource_format", default=False, action="store_true", help="Whether to align the lora format to opensource format. Only for DiT's LoRA.")
|
||||
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
|
||||
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
||||
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
|
||||
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
|
||||
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
|
||||
return parser
|
||||
|
||||
|
||||
@@ -498,9 +736,16 @@ def qwen_image_parser():
|
||||
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
|
||||
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
|
||||
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
|
||||
parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.")
|
||||
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
||||
parser.add_argument("--align_to_opensource_format", default=False, action="store_true", help="Whether to align the lora format to opensource format. Only for DiT's LoRA.")
|
||||
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
|
||||
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
||||
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
|
||||
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
|
||||
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
|
||||
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
|
||||
parser.add_argument("--enable_fp8_training", default=False, action="store_true", help="Whether to enable FP8 training. Only available for LoRA training on a single GPU.")
|
||||
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
|
||||
return parser
|
||||
|
||||
@@ -139,6 +139,20 @@ class BasePipeline(torch.nn.Module):
|
||||
else:
|
||||
model.eval()
|
||||
model.requires_grad_(False)
|
||||
|
||||
|
||||
def blend_with_mask(self, base, addition, mask):
|
||||
return base * (1 - mask) + addition * mask
|
||||
|
||||
|
||||
def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs):
|
||||
timestep = scheduler.timesteps[progress_id]
|
||||
if inpaint_mask is not None:
|
||||
noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents)
|
||||
noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask)
|
||||
latents_next = scheduler.step(noise_pred, timestep, latents)
|
||||
return latents_next
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -116,7 +116,7 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
bias: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
device = input.device
|
||||
origin_dtype = input.dtype
|
||||
@@ -136,6 +136,7 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||
input = input / (scale_a + 1e-8)
|
||||
input = input.to(self.computation_dtype)
|
||||
weight = weight.to(self.computation_dtype)
|
||||
bias = bias.to(torch.bfloat16)
|
||||
|
||||
result = torch._scaled_mm(
|
||||
input,
|
||||
|
||||
@@ -249,19 +249,24 @@ The script includes the following parameters:
|
||||
* `--width`: Width of the image or video. Leave `height` and `width` empty to enable dynamic resolution.
|
||||
* `--data_file_keys`: Data file keys in the metadata. Separate with commas.
|
||||
* `--dataset_repeat`: Number of times the dataset repeats per epoch.
|
||||
* `--dataset_num_workers`: Number of workers for data loading.
|
||||
* Model
|
||||
* `--model_paths`: Paths to load models. In JSON format.
|
||||
* `--model_id_with_origin_paths`: Model ID with original paths, e.g., black-forest-labs/FLUX.1-dev:flux1-dev.safetensors. Separate with commas.
|
||||
* Training
|
||||
* `--learning_rate`: Learning rate.
|
||||
* `--weight_decay`: Weight decay.
|
||||
* `--num_epochs`: Number of epochs.
|
||||
* `--output_path`: Save path.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefix in checkpoint.
|
||||
* `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.
|
||||
* `--find_unused_parameters`: Whether to find unused parameters in DDP.
|
||||
* Trainable Modules
|
||||
* `--trainable_models`: Models that can be trained, e.g., dit, vae, text_encoder.
|
||||
* `--lora_base_model`: Which model to add LoRA to.
|
||||
* `--lora_target_modules`: Which layers to add LoRA to.
|
||||
* `--lora_rank`: Rank of LoRA.
|
||||
* `--lora_checkpoint`: Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.
|
||||
* Extra Model Inputs
|
||||
* `--extra_inputs`: Extra model inputs, separated by commas.
|
||||
* VRAM Management
|
||||
|
||||
@@ -249,19 +249,24 @@ FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_tra
|
||||
* `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。
|
||||
* `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。
|
||||
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
|
||||
* 模型
|
||||
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||
* `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 black-forest-labs/FLUX.1-dev:flux1-dev.safetensors。用逗号分隔。
|
||||
* 训练
|
||||
* `--learning_rate`: 学习率。
|
||||
* `--weight_decay`:权重衰减大小。
|
||||
* `--num_epochs`: 轮数(Epoch)。
|
||||
* `--output_path`: 保存路径。
|
||||
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
|
||||
* `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次
|
||||
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数
|
||||
* 可训练模块
|
||||
* `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。
|
||||
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||
* `--lora_target_modules`: LoRA 添加到哪一层上。
|
||||
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||
* `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。
|
||||
* 额外模型输入
|
||||
* `--extra_inputs`: 额外的模型输入,以逗号分隔。
|
||||
* 显存管理
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import torch, os, json
|
||||
from diffsynth import load_state_dict
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, flux_parser
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, flux_parser
|
||||
from diffsynth.models.lora import FluxLoRAConverter
|
||||
from diffsynth.trainers.unified_dataset import UnifiedDataset
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
@@ -11,37 +13,23 @@ class FluxTrainingModule(DiffusionTrainingModule):
|
||||
self,
|
||||
model_paths=None, model_id_with_origin_paths=None,
|
||||
trainable_models=None,
|
||||
lora_base_model=None, lora_target_modules="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp", lora_rank=32,
|
||||
lora_base_model=None, lora_target_modules="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp", lora_rank=32, lora_checkpoint=None,
|
||||
use_gradient_checkpointing=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
extra_inputs=None,
|
||||
):
|
||||
super().__init__()
|
||||
# Load models
|
||||
model_configs = []
|
||||
if model_paths is not None:
|
||||
model_paths = json.loads(model_paths)
|
||||
model_configs += [ModelConfig(path=path) for path in model_paths]
|
||||
if model_id_with_origin_paths is not None:
|
||||
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
||||
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths]
|
||||
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=False)
|
||||
self.pipe = FluxImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
|
||||
|
||||
# Reset training scheduler
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
# Training mode
|
||||
self.switch_pipe_to_training_mode(
|
||||
self.pipe, trainable_models,
|
||||
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
|
||||
enable_fp8_training=False,
|
||||
)
|
||||
|
||||
# Freeze untrainable models
|
||||
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
|
||||
|
||||
# Add LoRA to the base models
|
||||
if lora_base_model is not None:
|
||||
model = self.add_lora_to_model(
|
||||
getattr(self.pipe, lora_base_model),
|
||||
target_modules=lora_target_modules.split(","),
|
||||
lora_rank=lora_rank
|
||||
)
|
||||
setattr(self.pipe, lora_base_model, model)
|
||||
|
||||
# Store other configs
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
@@ -98,7 +86,20 @@ class FluxTrainingModule(DiffusionTrainingModule):
|
||||
if __name__ == "__main__":
|
||||
parser = flux_parser()
|
||||
args = parser.parse_args()
|
||||
dataset = ImageDataset(args=args)
|
||||
dataset = UnifiedDataset(
|
||||
base_path=args.dataset_base_path,
|
||||
metadata_path=args.dataset_metadata_path,
|
||||
repeat=args.dataset_repeat,
|
||||
data_file_keys=args.data_file_keys.split(","),
|
||||
main_data_operator=UnifiedDataset.default_image_operator(
|
||||
base_path=args.dataset_base_path,
|
||||
max_pixels=args.max_pixels,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
height_division_factor=16,
|
||||
width_division_factor=16,
|
||||
)
|
||||
)
|
||||
model = FluxTrainingModule(
|
||||
model_paths=args.model_paths,
|
||||
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||
@@ -106,6 +107,7 @@ if __name__ == "__main__":
|
||||
lora_base_model=args.lora_base_model,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
lora_rank=args.lora_rank,
|
||||
lora_checkpoint=args.lora_checkpoint,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
extra_inputs=args.extra_inputs,
|
||||
@@ -115,10 +117,4 @@ if __name__ == "__main__":
|
||||
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
|
||||
)
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
|
||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||
launch_training_task(
|
||||
dataset, model, model_logger, optimizer, scheduler,
|
||||
num_epochs=args.num_epochs,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
)
|
||||
launch_training_task(dataset, model, model_logger, args=args)
|
||||
|
||||
@@ -20,9 +20,9 @@ Run the following code to quickly load the [Qwen/Qwen-Image](https://www.modelsc
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
@@ -34,17 +34,30 @@ pipe = QwenImagePipeline.from_pretrained(
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "A detailed portrait of a girl underwater, wearing a blue flowing dress, hair gently floating, clear light and shadow, surrounded by bubbles, calm expression, fine details, dreamy and beautiful."
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image = pipe(
|
||||
prompt, seed=0, num_inference_steps=40,
|
||||
# edit_image=Image.open("xxx.jpg").resize((1328, 1328)) # For Qwen-Image-Edit
|
||||
)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
## Model Overview
|
||||
|
||||
|Model ID|Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|
|
||||
|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image )|[code](./model_inference/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|
||||
|Model ID|Inference|Low VRAM Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./model_inference/Qwen-Image.py)|[code](./model_inference_low_vram/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)|
|
||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](./model_inference/Qwen-Image-Edit.py)|[code](./model_inference_low_vram/Qwen-Image-Edit.py)|[code](./model_training/full/Qwen-Image-Edit.sh)|[code](./model_training/validate_full/Qwen-Image-Edit.py)|[code](./model_training/lora/Qwen-Image-Edit.sh)|[code](./model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](./model_inference/Qwen-Image-Edit-2509.py)|[code](./model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](./model_training/full/Qwen-Image-Edit-2509.sh)|[code](./model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](./model_training/lora/Qwen-Image-Edit-2509.sh)|[code](./model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./model_inference/Qwen-Image-Distill-LoRA.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](./model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./model_inference/Qwen-Image-EliGen.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](./model_inference/Qwen-Image-EliGen-V2.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](./model_inference/Qwen-Image-EliGen-Poster.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](./model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](./model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](./model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](./model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](./model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](./model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|
||||
|
||||
## Model Inference
|
||||
|
||||
@@ -164,6 +177,7 @@ After enabling VRAM management, the framework will automatically choose a memory
|
||||
* `vram_limit`: VRAM usage limit in GB. By default, it uses all free VRAM on the device. Note that this is not a strict limit. If the set limit is too low but actual free VRAM is enough, the model will run with minimal VRAM use. Set it to 0 for the smallest possible VRAM use.
|
||||
* `vram_buffer`: VRAM buffer size in GB. Default is 0.5GB. A buffer is needed because large network layers may use more VRAM than expected during loading. The best value is the VRAM size of the largest model layer.
|
||||
* `num_persistent_param_in_dit`: Number of parameters to keep in VRAM in the DiT model. Default is no limit. This option will be removed in the future. Do not rely on it.
|
||||
* `enable_dit_fp8_computation`: Whether to enable FP8 computation in the DiT model. This is only applicable to GPUs that support FP8 operations (e.g., H200, etc.). Disabled by default.
|
||||
|
||||
</details>
|
||||
|
||||
@@ -172,7 +186,14 @@ After enabling VRAM management, the framework will automatically choose a memory
|
||||
|
||||
<summary>Inference Acceleration</summary>
|
||||
|
||||
Inference acceleration for Qwen-Image is under development. Please stay tuned!
|
||||
* FP8 Quantization: Choose the appropriate quantization method based on your hardware and requirements.
|
||||
* GPUs that do not support FP8 computation (e.g., A100, 4090, etc.): FP8 quantization will only reduce VRAM usage without speeding up inference. Code: [./model_inference_low_vram/Qwen-Image.py](./model_inference_low_vram/Qwen-Image.py)
|
||||
* GPUs that support FP8 operations (e.g., H200, etc.): Please install [Flash Attention 3](https://github.com/Dao-AILab/flash-attention). Otherwise, FP8 acceleration will only apply to Linear layers.
|
||||
* Faster inference but higher VRAM usage: Use [./accelerate/Qwen-Image-FP8.py](./accelerate/Qwen-Image-FP8.py)
|
||||
* Slightly slower inference but lower VRAM usage: Use [./accelerate/Qwen-Image-FP8-offload.py](./accelerate/Qwen-Image-FP8-offload.py)
|
||||
* Distillation acceleration: We trained two distillation models for fast inference at `cfg_scale=1` and `num_inference_steps=15`.
|
||||
* [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full): Full distillation version. Better image quality but lower LoRA compatibility. Use [./model_inference/Qwen-Image-Distill-Full.py](./model_inference/Qwen-Image-Distill-Full.py).
|
||||
* [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA): LoRA distillation version. Slightly lower image quality but better LoRA compatibility. Use [./model_inference/Qwen-Image-Distill-LoRA.py](./model_inference/Qwen-Image-Distill-LoRA.py).
|
||||
|
||||
</details>
|
||||
|
||||
@@ -219,28 +240,32 @@ The script includes the following parameters:
|
||||
* `--width`: Width of image or video. Leave `height` and `width` empty to enable dynamic resolution.
|
||||
* `--data_file_keys`: Data file keys in metadata. Separate with commas.
|
||||
* `--dataset_repeat`: Number of times the dataset repeats per epoch.
|
||||
* `--dataset_num_workers`: Number of workers for data loading.
|
||||
* Model
|
||||
* `--model_paths`: Model paths to load. In JSON format.
|
||||
* `--model_id_with_origin_paths`: Model ID with original paths, e.g., Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors. Separate with commas.
|
||||
* `--tokenizer_path`: Tokenizer path. Leave empty to auto-download.
|
||||
* `--processor_path`: Path to the processor of Qwen-Image-Edit. Leave empty to auto-download.
|
||||
* Training
|
||||
* `--learning_rate`: Learning rate.
|
||||
* `--weight_decay`: Weight decay.
|
||||
* `--num_epochs`: Number of epochs.
|
||||
* `--output_path`: Save path.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefix in checkpoint.
|
||||
* `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.
|
||||
* `--find_unused_parameters`: Whether to find unused parameters in DDP.
|
||||
* Trainable Modules
|
||||
* `--trainable_models`: Models to train, e.g., dit, vae, text_encoder.
|
||||
* `--lora_base_model`: Which model to add LoRA to.
|
||||
* `--lora_target_modules`: Which layers to add LoRA to.
|
||||
* `--lora_rank`: Rank of LoRA.
|
||||
* `--lora_checkpoint`: Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.
|
||||
* Extra Model Inputs
|
||||
* `--extra_inputs`: Extra model inputs, separated by commas.
|
||||
* VRAM Management
|
||||
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
|
||||
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
|
||||
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
|
||||
* Others
|
||||
* `--align_to_opensource_format`: Whether to align DiT LoRA format with open-source version. Only works for LoRA training.
|
||||
|
||||
In addition, the training framework is built on [`accelerate`](https://huggingface.co/docs/accelerate/index). Run `accelerate config` before training to set GPU-related settings. For some training tasks (e.g., full training of 20B model), we provide suggested `accelerate` config files. Check the corresponding training script for details.
|
||||
|
||||
|
||||
@@ -20,9 +20,9 @@ pip install -e .
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
@@ -34,17 +34,30 @@ pipe = QwenImagePipeline.from_pretrained(
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image = pipe(
|
||||
prompt, seed=0, num_inference_steps=40,
|
||||
# edit_image=Image.open("xxx.jpg").resize((1328, 1328)) # For Qwen-Image-Edit
|
||||
)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
## 模型总览
|
||||
|
||||
|模型 ID|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./model_inference/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|
||||
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./model_inference/Qwen-Image.py)|[code](./model_inference_low_vram/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)|
|
||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](./model_inference/Qwen-Image-Edit.py)|[code](./model_inference_low_vram/Qwen-Image-Edit.py)|[code](./model_training/full/Qwen-Image-Edit.sh)|[code](./model_training/validate_full/Qwen-Image-Edit.py)|[code](./model_training/lora/Qwen-Image-Edit.sh)|[code](./model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](./model_inference/Qwen-Image-Edit-2509.py)|[code](./model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](./model_training/full/Qwen-Image-Edit-2509.sh)|[code](./model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](./model_training/lora/Qwen-Image-Edit-2509.sh)|[code](./model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./model_inference/Qwen-Image-Distill-LoRA.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](./model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./model_inference/Qwen-Image-EliGen.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](./model_inference/Qwen-Image-EliGen-V2.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](./model_inference/Qwen-Image-EliGen-Poster.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](./model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](./model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](./model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](./model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](./model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](./model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|
||||
|
||||
## 模型推理
|
||||
|
||||
@@ -164,6 +177,7 @@ FP8 量化能够大幅度减少显存占用,但不会加速,部分模型在
|
||||
* `vram_limit`: 显存占用量限制(GB),默认占用设备上的剩余显存。注意这不是一个绝对限制,当设置的显存不足以支持模型进行推理,但实际可用显存足够时,将会以最小化显存占用的形式进行推理。将其设置为0时,将会实现理论最小显存占用。
|
||||
* `vram_buffer`: 显存缓冲区大小(GB),默认为 0.5GB。由于部分较大的神经网络层在 onload 阶段会不可控地占用更多显存,因此一个显存缓冲区是必要的,理论上的最优值为模型中最大的层所占的显存。
|
||||
* `num_persistent_param_in_dit`: DiT 模型中常驻显存的参数数量(个),默认为无限制。我们将会在未来删除这个参数,请不要依赖这个参数。
|
||||
* `enable_dit_fp8_computation`: 是否启用 DiT 模型中的 FP8 计算,仅适用于支持 FP8 运算的 GPU(例如 H200 等),默认不启用。
|
||||
|
||||
</details>
|
||||
|
||||
@@ -172,7 +186,14 @@ FP8 量化能够大幅度减少显存占用,但不会加速,部分模型在
|
||||
|
||||
<summary>推理加速</summary>
|
||||
|
||||
Qwen-Image 的推理加速技术正在开发中,敬请期待!
|
||||
* FP8 量化:根据您的硬件与需求,请选择合适的量化方式
|
||||
* GPU 不支持 FP8 计算(例如 A100、4090 等):FP8 量化仅能降低显存占用,无法加速,代码:[./model_inference_low_vram/Qwen-Image.py](./model_inference_low_vram/Qwen-Image.py)
|
||||
* GPU 支持 FP8 运算(例如 H200 等):请安装 [Flash Attention 3](https://github.com/Dao-AILab/flash-attention),否则 FP8 加速仅对 Linear 层生效
|
||||
* 更快的速度,但更大的显存:请使用 [./accelerate/Qwen-Image-FP8.py](./accelerate/Qwen-Image-FP8.py)
|
||||
* 稍慢的速度,但更小的显存:请使用 [./accelerate/Qwen-Image-FP8-offload.py](./accelerate/Qwen-Image-FP8-offload.py)
|
||||
* 蒸馏加速:我们训练了两个蒸馏加速模型,可以在 `cfg_scale=1` 和 `num_inference_steps=15` 设置下进行快速推理
|
||||
* [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full):全量蒸馏训练版本,更好的生成效果,稍差的 LoRA 兼容性,请使用 [./model_inference/Qwen-Image-Distill-Full.py](./model_inference/Qwen-Image-Distill-Full.py)
|
||||
* [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA):LoRA 蒸馏训练版本,稍差的生成效果,更好的 LoRA 兼容性,请使用 [./model_inference/Qwen-Image-Distill-LoRA.py](./model_inference/Qwen-Image-Distill-LoRA.py)
|
||||
|
||||
</details>
|
||||
|
||||
@@ -219,28 +240,32 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod
|
||||
* `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。
|
||||
* `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。
|
||||
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
|
||||
* 模型
|
||||
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||
* `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors。用逗号分隔。
|
||||
* `--tokenizer_path`: tokenizer 路径,留空将会自动下载。
|
||||
* `--processor_path`:Qwen-Image-Edit 的 processor 路径。留空则自动下载。
|
||||
* 训练
|
||||
* `--learning_rate`: 学习率。
|
||||
* `--weight_decay`:权重衰减大小。
|
||||
* `--num_epochs`: 轮数(Epoch)。
|
||||
* `--output_path`: 保存路径。
|
||||
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
|
||||
* `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次
|
||||
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数
|
||||
* 可训练模块
|
||||
* `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。
|
||||
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||
* `--lora_target_modules`: LoRA 添加到哪一层上。
|
||||
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||
* `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。
|
||||
* 额外模型输入
|
||||
* `--extra_inputs`: 额外的模型输入,以逗号分隔。
|
||||
* 显存管理
|
||||
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
|
||||
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
|
||||
* `--gradient_accumulation_steps`: 梯度累积步数。
|
||||
* 其他
|
||||
* `--align_to_opensource_format`: 是否将 DiT LoRA 的格式与开源版本对齐,仅对 LoRA 训练生效。
|
||||
|
||||
此外,训练框架基于 [`accelerate`](https://huggingface.co/docs/accelerate/index) 构建,在开始训练前运行 `accelerate config` 可配置 GPU 的相关参数。对于部分模型训练(例如 20B 模型的全量训练)脚本,我们提供了建议的 `accelerate` 配置文件,可在对应的训练脚本中查看。
|
||||
|
||||
|
||||
18
examples/qwen_image/accelerate/Qwen-Image-FP8-offload.py
Normal file
18
examples/qwen_image/accelerate/Qwen-Image-FP8-offload.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.enable_vram_management(enable_dit_fp8_computation=True)
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40, enable_fp8_attention=True)
|
||||
image.save("image.jpg")
|
||||
51
examples/qwen_image/accelerate/Qwen-Image-FP8.py
Normal file
51
examples/qwen_image/accelerate/Qwen-Image-FP8.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from diffsynth.models.qwen_image_dit import RMSNorm
|
||||
from diffsynth.vram_management.layers import enable_vram_management, AutoWrappedLinear, AutoWrappedModule
|
||||
import torch
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
|
||||
enable_vram_management(
|
||||
pipe.dit,
|
||||
module_map = {
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=torch.bfloat16,
|
||||
offload_device="cuda",
|
||||
onload_dtype=torch.bfloat16,
|
||||
onload_device="cuda",
|
||||
computation_dtype=torch.bfloat16,
|
||||
computation_device="cuda",
|
||||
),
|
||||
vram_limit=None,
|
||||
)
|
||||
enable_vram_management(
|
||||
pipe.dit,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=torch.float8_e4m3fn,
|
||||
offload_device="cuda",
|
||||
onload_dtype=torch.float8_e4m3fn,
|
||||
onload_device="cuda",
|
||||
computation_dtype=torch.float8_e4m3fn,
|
||||
computation_device="cuda",
|
||||
),
|
||||
vram_limit=None,
|
||||
)
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40, enable_fp8_attention=True)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,31 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
from PIL import Image
|
||||
import torch
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny", origin_file_pattern="model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||
local_dir="./data/example_image_dataset",
|
||||
allow_file_pattern="canny/image_1.jpg"
|
||||
)
|
||||
controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1328, 1328))
|
||||
|
||||
prompt = "一只小狗,毛发光洁柔顺,眼神灵动,背景是樱花纷飞的春日庭院,唯美温馨。"
|
||||
image = pipe(
|
||||
prompt, seed=0,
|
||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,32 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
from PIL import Image
|
||||
import torch
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||
local_dir="./data/example_image_dataset",
|
||||
allow_file_pattern="depth/image_1.jpg"
|
||||
)
|
||||
|
||||
controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1328, 1328))
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(
|
||||
prompt, seed=0,
|
||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||
local_dir="./data/example_image_dataset",
|
||||
allow_file_pattern="inpaint/*.jpg"
|
||||
)
|
||||
prompt = "a cat with sunglasses"
|
||||
controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1328, 1328))
|
||||
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1328, 1328))
|
||||
image = pipe(
|
||||
prompt, seed=0,
|
||||
input_image=controlnet_image, inpaint_mask=inpaint_mask,
|
||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)],
|
||||
num_inference_steps=40,
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,24 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, load_state_dict
|
||||
from modelscope import snapshot_download
|
||||
import torch, math
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
|
||||
snapshot_download("MusePublic/Qwen-Image-Distill", allow_file_pattern="qwen_image_distill_3step.safetensors", cache_dir="models")
|
||||
lora_state_dict = load_state_dict("models/MusePublic/Qwen-Image-Distill/qwen_image_distill_3step.safetensors")
|
||||
lora_state_dict = {i.replace("base_model.model.", ""): j for i, j in lora_state_dict.items()}
|
||||
pipe.load_lora(pipe.dit, state_dict=lora_state_dict)
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=3, cfg_scale=1, exponential_shift_mu=math.log(2.5))
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,20 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from modelscope import snapshot_download
|
||||
import torch
|
||||
|
||||
snapshot_download("DiffSynth-Studio/Qwen-Image-Distill-LoRA", local_dir="models/DiffSynth-Studio/Qwen-Image-Distill-LoRA")
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-Distill-LoRA/model.safetensors")
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1)
|
||||
image.save("image.jpg")
|
||||
25
examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py
Normal file
25
examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2509", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
||||
)
|
||||
|
||||
image_1 = pipe(prompt="一位少女", seed=0, num_inference_steps=40, height=1328, width=1024)
|
||||
image_1.save("image1.jpg")
|
||||
|
||||
image_2 = pipe(prompt="一位老人", seed=0, num_inference_steps=40, height=1328, width=1024)
|
||||
image_2.save("image2.jpg")
|
||||
|
||||
prompt = "生成这两个人的合影"
|
||||
edit_image = [Image.open("image1.jpg"), Image.open("image2.jpg")]
|
||||
image_3 = pipe(prompt, edit_image=edit_image, seed=1, num_inference_steps=40, height=1328, width=1024, edit_image_auto_resize=True)
|
||||
image_3.save("image3.jpg")
|
||||
@@ -0,0 +1,25 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
from modelscope import snapshot_download
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
||||
)
|
||||
snapshot_download("DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix", local_dir="models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix", allow_file_pattern="model.safetensors")
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix/model.safetensors")
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, height=1024, width=768)
|
||||
image.save("image.jpg")
|
||||
|
||||
prompt = "将裙子变成粉色"
|
||||
image = image.resize((512, 384))
|
||||
image = pipe(prompt, edit_image=image, seed=1, num_inference_steps=40, height=1024, width=768, edit_rope_interpolation=True, edit_image_auto_resize=False)
|
||||
image.save(f"image2.jpg")
|
||||
25
examples/qwen_image/model_inference/Qwen-Image-Edit.py
Normal file
25
examples/qwen_image/model_inference/Qwen-Image-Edit.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
||||
)
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
input_image = pipe(prompt=prompt, seed=0, num_inference_steps=40, height=1328, width=1024)
|
||||
input_image.save("image1.jpg")
|
||||
|
||||
prompt = "将裙子改为粉色"
|
||||
# edit_image_auto_resize=True: auto resize input image to match the area of 1024*1024 with the original aspect ratio
|
||||
image = pipe(prompt, edit_image=input_image, seed=1, num_inference_steps=40, height=1328, width=1024, edit_image_auto_resize=True)
|
||||
image.save(f"image2.jpg")
|
||||
|
||||
# edit_image_auto_resize=False: do not resize input image
|
||||
image = pipe(prompt, edit_image=input_image, seed=1, num_inference_steps=40, height=1328, width=1024, edit_image_auto_resize=False)
|
||||
image.save(f"image3.jpg")
|
||||
114
examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py
Normal file
114
examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from modelscope import dataset_snapshot_download, snapshot_download
|
||||
import random
|
||||
|
||||
|
||||
def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):
|
||||
# Create a blank image for overlays
|
||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||
|
||||
colors = [
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
]
|
||||
# Generate random colors for each mask
|
||||
if use_random_colors:
|
||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
||||
|
||||
# Font settings
|
||||
try:
|
||||
font = ImageFont.truetype("wqy-zenhei.ttc", font_size) # Adjust as needed
|
||||
except IOError:
|
||||
font = ImageFont.load_default(font_size)
|
||||
|
||||
# Overlay each mask onto the overlay image
|
||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
||||
# Convert mask to RGBA mode
|
||||
mask_rgba = mask.convert('RGBA')
|
||||
mask_data = mask_rgba.getdata()
|
||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
||||
mask_rgba.putdata(new_data)
|
||||
|
||||
# Draw the mask prompt text on the mask
|
||||
draw = ImageDraw.Draw(mask_rgba)
|
||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
||||
|
||||
# Alpha composite the overlay with this mask
|
||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
||||
|
||||
# Composite the overlay onto the original image
|
||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
||||
|
||||
# Save or display the resulting image
|
||||
result.save(output_path)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def example(pipe, seeds, example_id, global_prompt, entity_prompts, height=784, width=1280):
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/eligen/poster/example_{example_id}/*.png"
|
||||
)
|
||||
masks = [
|
||||
Image.open(f"./data/examples/eligen/poster/example_{example_id}/{i}.png").convert('RGB').resize((width, height))
|
||||
for i in range(len(entity_prompts))
|
||||
]
|
||||
negative_prompt = "网格化,规则的网格,模糊, 低分辨率, 低质量, 变形, 畸形, 错误的解剖学, 变形的手, 变形的身体, 变形的脸, 变形的头发, 变形的眼睛, 变形的嘴巴"
|
||||
for seed in seeds:
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt=global_prompt,
|
||||
cfg_scale=4.0,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=40,
|
||||
seed=seed,
|
||||
height=height,
|
||||
width=width,
|
||||
eligen_entity_prompts=entity_prompts,
|
||||
eligen_entity_masks=masks,
|
||||
)
|
||||
image.save(f"eligen_poster_example_{example_id}_{seed}.png")
|
||||
image = Image.new("RGB", (width, height), (0, 0, 0))
|
||||
visualize_masks(image, masks, entity_prompts, f"eligen_poster_example_{example_id}_mask_{seed}.png")
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
snapshot_download(
|
||||
"DiffSynth-Studio/Qwen-Image-EliGen-Poster",
|
||||
local_dir="models/DiffSynth-Studio/Qwen-Image-EliGen-Poster",
|
||||
allow_file_pattern="model.safetensors",
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-EliGen-Poster/model.safetensors")
|
||||
global_prompt = "一张以柔粉紫为背景的海报,左侧有大号粉紫色文字“Qwen-Image EliGen-Poster”,粉紫色椭圆框内白色小字:“图像精确分区控制模型”。右侧有一只小兔子在拆礼物,旁边站着一只头顶迷你烟花发射器的小龙(卡通Q版)。背景有一些白云点缀。整体风格卡通可爱,传达节日惊喜的主题。"
|
||||
entity_prompts = ["粉紫色文字“Qwen-Image EliGen-Poster”", "粉紫色椭圆框内白色小字:“图像精确分区控制模型”", "一只小兔子在拆礼物,小兔子旁边站着一只头顶迷你烟花发射器的小龙(卡通Q版)"]
|
||||
seed = [42]
|
||||
example(pipe, seed, 1, global_prompt, entity_prompts)
|
||||
106
examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py
Normal file
106
examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import torch
|
||||
import random
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from modelscope import dataset_snapshot_download, snapshot_download
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
|
||||
def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):
|
||||
# Create a blank image for overlays
|
||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||
|
||||
colors = [
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
]
|
||||
# Generate random colors for each mask
|
||||
if use_random_colors:
|
||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
||||
|
||||
# Font settings
|
||||
try:
|
||||
font = ImageFont.truetype("wqy-zenhei.ttc", font_size) # Adjust as needed
|
||||
except IOError:
|
||||
font = ImageFont.load_default(font_size)
|
||||
|
||||
# Overlay each mask onto the overlay image
|
||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
||||
# Convert mask to RGBA mode
|
||||
mask_rgba = mask.convert('RGBA')
|
||||
mask_data = mask_rgba.getdata()
|
||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
||||
mask_rgba.putdata(new_data)
|
||||
|
||||
# Draw the mask prompt text on the mask
|
||||
draw = ImageDraw.Draw(mask_rgba)
|
||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
||||
|
||||
# Alpha composite the overlay with this mask
|
||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
||||
|
||||
# Composite the overlay onto the original image
|
||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
||||
|
||||
# Save or display the resulting image
|
||||
result.save(output_path)
|
||||
|
||||
return result
|
||||
|
||||
def example(pipe, seeds, example_id, global_prompt, entity_prompts):
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/qwen-image/example_{example_id}/*.png")
|
||||
masks = [Image.open(f"./data/examples/eligen/qwen-image/example_{example_id}/{i}.png").convert('RGB').resize((1024, 1024)) for i in range(len(entity_prompts))]
|
||||
negative_prompt = "网格化,规则的网格,模糊, 低分辨率, 低质量, 变形, 畸形, 错误的解剖学, 变形的手, 变形的身体, 变形的脸, 变形的头发, 变形的眼睛, 变形的嘴巴"
|
||||
for seed in seeds:
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt=global_prompt,
|
||||
cfg_scale=4.0,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=40,
|
||||
seed=seed,
|
||||
height=1024,
|
||||
width=1024,
|
||||
eligen_entity_prompts=entity_prompts,
|
||||
eligen_entity_masks=masks,
|
||||
)
|
||||
image.save(f"eligen_example_{example_id}_{seed}.png")
|
||||
visualize_masks(image, masks, entity_prompts, f"eligen_example_{example_id}_mask_{seed}.png")
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
snapshot_download("DiffSynth-Studio/Qwen-Image-EliGen-V2", local_dir="models/DiffSynth-Studio/Qwen-Image-EliGen-V2", allow_file_pattern="model.safetensors")
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-EliGen-V2/model.safetensors")
|
||||
|
||||
seeds = [0]
|
||||
|
||||
global_prompt = "写实摄影风格. A beautiful asia woman wearing white dress, she is holding a mirror with her right arm, with a beach background."
|
||||
entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"]
|
||||
example(pipe, seeds, 7, global_prompt, entity_prompts)
|
||||
|
||||
global_prompt = "写实摄影风格, 细节丰富。街头一位漂亮的女孩,穿着衬衫和短裤,手持写有“实体控制”的标牌,背景是繁忙的城市街道,阳光明媚,行人匆匆。"
|
||||
entity_prompts = ["一个漂亮的女孩", "标牌 '实体控制'", "短裤", "衬衫"]
|
||||
example(pipe, seeds, 4, global_prompt, entity_prompts)
|
||||
128
examples/qwen_image/model_inference/Qwen-Image-EliGen.py
Normal file
128
examples/qwen_image/model_inference/Qwen-Image-EliGen.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from modelscope import dataset_snapshot_download, snapshot_download
|
||||
import random
|
||||
|
||||
|
||||
def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):
|
||||
# Create a blank image for overlays
|
||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||
|
||||
colors = [
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
]
|
||||
# Generate random colors for each mask
|
||||
if use_random_colors:
|
||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
||||
|
||||
# Font settings
|
||||
try:
|
||||
font = ImageFont.truetype("wqy-zenhei.ttc", font_size) # Adjust as needed
|
||||
except IOError:
|
||||
font = ImageFont.load_default(font_size)
|
||||
|
||||
# Overlay each mask onto the overlay image
|
||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
||||
# Convert mask to RGBA mode
|
||||
mask_rgba = mask.convert('RGBA')
|
||||
mask_data = mask_rgba.getdata()
|
||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
||||
mask_rgba.putdata(new_data)
|
||||
|
||||
# Draw the mask prompt text on the mask
|
||||
draw = ImageDraw.Draw(mask_rgba)
|
||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
||||
|
||||
# Alpha composite the overlay with this mask
|
||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
||||
|
||||
# Composite the overlay onto the original image
|
||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
||||
|
||||
# Save or display the resulting image
|
||||
result.save(output_path)
|
||||
|
||||
return result
|
||||
|
||||
def example(pipe, seeds, example_id, global_prompt, entity_prompts):
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/qwen-image/example_{example_id}/*.png")
|
||||
masks = [Image.open(f"./data/examples/eligen/qwen-image/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||
negative_prompt = ""
|
||||
for seed in seeds:
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt=global_prompt,
|
||||
cfg_scale=4.0,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=30,
|
||||
seed=seed,
|
||||
height=1024,
|
||||
width=1024,
|
||||
eligen_entity_prompts=entity_prompts,
|
||||
eligen_entity_masks=masks,
|
||||
)
|
||||
image.save(f"eligen_example_{example_id}_{seed}.png")
|
||||
visualize_masks(image, masks, entity_prompts, f"eligen_example_{example_id}_mask_{seed}.png")
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
snapshot_download("DiffSynth-Studio/Qwen-Image-EliGen", local_dir="models/DiffSynth-Studio/Qwen-Image-EliGen", allow_file_pattern="model.safetensors")
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-EliGen/model.safetensors")
|
||||
|
||||
# example 1
|
||||
global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\n"
|
||||
entity_prompts = ["cliff", "sea", "moon", "sailing boat", "a seated beautiful woman", "pale blue long dress with soft glow"]
|
||||
example(pipe, [0], 1, global_prompt, entity_prompts)
|
||||
|
||||
# example 2
|
||||
global_prompt = "samurai girl wearing a kimono, she's holding a sword glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render."
|
||||
entity_prompts = ["flowing hair", "sword glowing with red flame", "A cute bird", "yellow belt"]
|
||||
example(pipe, [0], 2, global_prompt, entity_prompts)
|
||||
|
||||
# example 3
|
||||
global_prompt = "Image of a neverending staircase up to a mysterious palace in the sky, The ancient palace stood majestically atop a mist-shrouded mountain, sunrise, two traditional monk walk in the stair looking at the sunrise, fog,see-through, best quality, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, photorealistic, concept art, harmony, serenity, tranquility, ambient occlusion, halation, cozy ambient lighting, dynamic lighting,masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning,"
|
||||
entity_prompts = ["ancient palace", "stone staircase with railings", "a traditional monk", "a traditional monk"]
|
||||
example(pipe, [27], 3, global_prompt, entity_prompts)
|
||||
|
||||
# example 4
|
||||
global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'"
|
||||
entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"]
|
||||
example(pipe, [21], 4, global_prompt, entity_prompts)
|
||||
|
||||
# example 5
|
||||
global_prompt = "A captivating, dramatic scene in a painting that exudes mystery and foreboding. A white sky, swirling blue clouds, and a crescent yellow moon illuminate a solitary woman standing near the water's edge. Her long dress flows in the wind, silhouetted against the eerie glow. The water mirrors the fiery sky and moonlight, amplifying the uneasy atmosphere."
|
||||
entity_prompts = ["crescent yellow moon", "a solitary woman", "water", "swirling blue clouds"]
|
||||
example(pipe, [0], 5, global_prompt, entity_prompts)
|
||||
|
||||
# example 7, same prompt with different seeds
|
||||
seeds = range(5, 9)
|
||||
global_prompt = "A beautiful asia woman wearing white dress, holding a mirror, with a forest background."
|
||||
entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"]
|
||||
example(pipe, seeds, 7, global_prompt, entity_prompts)
|
||||
@@ -0,0 +1,35 @@
|
||||
from PIL import Image
|
||||
import torch
|
||||
from modelscope import dataset_snapshot_download, snapshot_download
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from diffsynth.controlnets.processors import Annotator
|
||||
|
||||
allow_file_pattern = ["sk_model.pth", "sk_model2.pth", "dpt_hybrid-midas-501f0c75.pt", "ControlNetHED.pth", "body_pose_model.pth", "hand_pose_model.pth", "facenet.pth", "scannet.pt"]
|
||||
snapshot_download("lllyasviel/Annotators", local_dir="models/Annotators", allow_file_pattern=allow_file_pattern)
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
snapshot_download("DiffSynth-Studio/Qwen-Image-In-Context-Control-Union", local_dir="models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union", allow_file_pattern="model.safetensors")
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union/model.safetensors")
|
||||
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/qwen-image-context-control/image.jpg")
|
||||
origin_image = Image.open("data/examples/qwen-image-context-control/image.jpg").resize((1024, 1024))
|
||||
annotator_ids = ['openpose', 'canny', 'depth', 'lineart', 'softedge', 'normal']
|
||||
for annotator_id in annotator_ids:
|
||||
annotator = Annotator(processor_id=annotator_id, device="cuda")
|
||||
control_image = annotator(origin_image)
|
||||
control_image.save(f"{annotator.processor_id}.png")
|
||||
|
||||
control_prompt = "Context_Control. "
|
||||
prompt = f"{control_prompt}一个穿着淡蓝色的漂亮女孩正在翩翩起舞,背景是梦幻的星空,光影交错,细节精致。"
|
||||
negative_prompt = "网格化,规则的网格,模糊, 低分辨率, 低质量, 变形, 畸形, 错误的解剖学, 变形的手, 变形的身体, 变形的脸, 变形的头发, 变形的眼睛, 变形的嘴巴"
|
||||
image = pipe(prompt, seed=1, negative_prompt=negative_prompt, context_image=control_image, height=1024, width=1024)
|
||||
image.save(f"image_{annotator.processor_id}.png")
|
||||
@@ -0,0 +1,32 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
from PIL import Image
|
||||
import torch
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny", origin_file_pattern="model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||
local_dir="./data/example_image_dataset",
|
||||
allow_file_pattern="canny/image_1.jpg"
|
||||
)
|
||||
controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1328, 1328))
|
||||
|
||||
prompt = "一只小狗,毛发光洁柔顺,眼神灵动,背景是樱花纷飞的春日庭院,唯美温馨。"
|
||||
image = pipe(
|
||||
prompt, seed=0,
|
||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,33 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
from PIL import Image
|
||||
import torch
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||
local_dir="./data/example_image_dataset",
|
||||
allow_file_pattern="depth/image_1.jpg"
|
||||
)
|
||||
|
||||
controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1328, 1328))
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(
|
||||
prompt, seed=0,
|
||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||
local_dir="./data/example_image_dataset",
|
||||
allow_file_pattern="inpaint/*.jpg"
|
||||
)
|
||||
prompt = "a cat with sunglasses"
|
||||
controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1328, 1328))
|
||||
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1328, 1328))
|
||||
image = pipe(
|
||||
prompt, seed=0,
|
||||
input_image=controlnet_image, inpaint_mask=inpaint_mask,
|
||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)],
|
||||
num_inference_steps=40,
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,22 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from modelscope import snapshot_download
|
||||
import torch
|
||||
|
||||
# Please do not use float8 on this model
|
||||
snapshot_download("DiffSynth-Studio/Qwen-Image-Distill-LoRA", local_dir="models/DiffSynth-Studio/Qwen-Image-Distill-LoRA")
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-Distill-LoRA/model.safetensors")
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,26 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2509", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
image_1 = pipe(prompt="一位少女", seed=0, num_inference_steps=40, height=1328, width=1024)
|
||||
image_1.save("image1.jpg")
|
||||
|
||||
image_2 = pipe(prompt="一位老人", seed=0, num_inference_steps=40, height=1328, width=1024)
|
||||
image_2.save("image2.jpg")
|
||||
|
||||
prompt = "生成这两个人的合影"
|
||||
edit_image = [Image.open("image1.jpg"), Image.open("image2.jpg")]
|
||||
image_3 = pipe(prompt, edit_image=edit_image, seed=1, num_inference_steps=40, height=1328, width=1024, edit_image_auto_resize=True)
|
||||
image_3.save("image3.jpg")
|
||||
@@ -0,0 +1,27 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
from modelscope import snapshot_download
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
snapshot_download("DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix", local_dir="models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix", allow_file_pattern="model.safetensors")
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix/model.safetensors")
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, height=1024, width=768)
|
||||
image.save("image.jpg")
|
||||
|
||||
prompt = "将裙子变成粉色"
|
||||
image = image.resize((512, 384))
|
||||
image = pipe(prompt, edit_image=image, seed=1, num_inference_steps=40, height=1024, width=768, edit_rope_interpolation=True, edit_image_auto_resize=False)
|
||||
image.save(f"image2.jpg")
|
||||
@@ -0,0 +1,22 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, height=1024, width=1024)
|
||||
image.save("image1.jpg")
|
||||
|
||||
prompt = "将裙子改为粉色"
|
||||
image = pipe(prompt, edit_image=image, seed=1, num_inference_steps=40, height=1024, width=1024)
|
||||
image.save(f"image2.jpg")
|
||||
@@ -0,0 +1,115 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from modelscope import dataset_snapshot_download, snapshot_download
|
||||
import random
|
||||
|
||||
|
||||
def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):
|
||||
# Create a blank image for overlays
|
||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||
|
||||
colors = [
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
]
|
||||
# Generate random colors for each mask
|
||||
if use_random_colors:
|
||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
||||
|
||||
# Font settings
|
||||
try:
|
||||
font = ImageFont.truetype("wqy-zenhei.ttc", font_size) # Adjust as needed
|
||||
except IOError:
|
||||
font = ImageFont.load_default(font_size)
|
||||
|
||||
# Overlay each mask onto the overlay image
|
||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
||||
# Convert mask to RGBA mode
|
||||
mask_rgba = mask.convert('RGBA')
|
||||
mask_data = mask_rgba.getdata()
|
||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
||||
mask_rgba.putdata(new_data)
|
||||
|
||||
# Draw the mask prompt text on the mask
|
||||
draw = ImageDraw.Draw(mask_rgba)
|
||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
||||
|
||||
# Alpha composite the overlay with this mask
|
||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
||||
|
||||
# Composite the overlay onto the original image
|
||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
||||
|
||||
# Save or display the resulting image
|
||||
result.save(output_path)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def example(pipe, seeds, example_id, global_prompt, entity_prompts, height=784, width=1280):
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/eligen/poster/example_{example_id}/*.png"
|
||||
)
|
||||
masks = [
|
||||
Image.open(f"./data/examples/eligen/poster/example_{example_id}/{i}.png").convert('RGB').resize((width, height))
|
||||
for i in range(len(entity_prompts))
|
||||
]
|
||||
negative_prompt = "网格化,规则的网格,模糊, 低分辨率, 低质量, 变形, 畸形, 错误的解剖学, 变形的手, 变形的身体, 变形的脸, 变形的头发, 变形的眼睛, 变形的嘴巴"
|
||||
for seed in seeds:
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt=global_prompt,
|
||||
cfg_scale=4.0,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=40,
|
||||
seed=seed,
|
||||
height=height,
|
||||
width=width,
|
||||
eligen_entity_prompts=entity_prompts,
|
||||
eligen_entity_masks=masks,
|
||||
)
|
||||
image.save(f"eligen_poster_example_{example_id}_{seed}.png")
|
||||
image = Image.new("RGB", (width, height), (0, 0, 0))
|
||||
visualize_masks(image, masks, entity_prompts, f"eligen_poster_example_{example_id}_mask_{seed}.png")
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
snapshot_download(
|
||||
"DiffSynth-Studio/Qwen-Image-EliGen-Poster",
|
||||
local_dir="models/DiffSynth-Studio/Qwen-Image-EliGen-Poster",
|
||||
allow_file_pattern="model.safetensors",
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-EliGen-Poster/model.safetensors")
|
||||
global_prompt = "一张以柔粉紫为背景的海报,左侧有大号粉紫色文字“Qwen-Image EliGen-Poster”,粉紫色椭圆框内白色小字:“图像精确分区控制模型”。右侧有一只小兔子在拆礼物,旁边站着一只头顶迷你烟花发射器的小龙(卡通Q版)。背景有一些白云点缀。整体风格卡通可爱,传达节日惊喜的主题。"
|
||||
entity_prompts = ["粉紫色文字“Qwen-Image EliGen-Poster”", "粉紫色椭圆框内白色小字:“图像精确分区控制模型”", "一只小兔子在拆礼物,小兔子旁边站着一只头顶迷你烟花发射器的小龙(卡通Q版)"]
|
||||
seed = [42]
|
||||
example(pipe, seed, 1, global_prompt, entity_prompts)
|
||||
@@ -0,0 +1,108 @@
|
||||
import torch
|
||||
import random
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from modelscope import dataset_snapshot_download, snapshot_download
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
|
||||
|
||||
def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):
|
||||
# Create a blank image for overlays
|
||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||
|
||||
colors = [
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
]
|
||||
# Generate random colors for each mask
|
||||
if use_random_colors:
|
||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
||||
|
||||
# Font settings
|
||||
try:
|
||||
font = ImageFont.truetype("wqy-zenhei.ttc", font_size) # Adjust as needed
|
||||
except IOError:
|
||||
font = ImageFont.load_default(font_size)
|
||||
|
||||
# Overlay each mask onto the overlay image
|
||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
||||
# Convert mask to RGBA mode
|
||||
mask_rgba = mask.convert('RGBA')
|
||||
mask_data = mask_rgba.getdata()
|
||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
||||
mask_rgba.putdata(new_data)
|
||||
|
||||
# Draw the mask prompt text on the mask
|
||||
draw = ImageDraw.Draw(mask_rgba)
|
||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
||||
|
||||
# Alpha composite the overlay with this mask
|
||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
||||
|
||||
# Composite the overlay onto the original image
|
||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
||||
|
||||
# Save or display the resulting image
|
||||
result.save(output_path)
|
||||
|
||||
return result
|
||||
|
||||
def example(pipe, seeds, example_id, global_prompt, entity_prompts):
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/qwen-image/example_{example_id}/*.png")
|
||||
masks = [Image.open(f"./data/examples/eligen/qwen-image/example_{example_id}/{i}.png").convert('RGB').resize((1024, 1024)) for i in range(len(entity_prompts))]
|
||||
negative_prompt = "网格化,规则的网格,模糊, 低分辨率, 低质量, 变形, 畸形, 错误的解剖学, 变形的手, 变形的身体, 变形的脸, 变形的头发, 变形的眼睛, 变形的嘴巴"
|
||||
for seed in seeds:
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt=global_prompt,
|
||||
cfg_scale=4.0,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=40,
|
||||
seed=seed,
|
||||
height=1024,
|
||||
width=1024,
|
||||
eligen_entity_prompts=entity_prompts,
|
||||
eligen_entity_masks=masks,
|
||||
)
|
||||
image.save(f"eligen_example_{example_id}_{seed}.png")
|
||||
visualize_masks(image, masks, entity_prompts, f"eligen_example_{example_id}_mask_{seed}.png")
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
snapshot_download("DiffSynth-Studio/Qwen-Image-EliGen-V2", local_dir="models/DiffSynth-Studio/Qwen-Image-EliGen-V2", allow_file_pattern="model.safetensors")
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-EliGen-V2/model.safetensors")
|
||||
|
||||
seeds = [0]
|
||||
|
||||
global_prompt = "写实摄影风格. A beautiful asia woman wearing white dress, she is holding a mirror with her right arm, with a beach background."
|
||||
entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"]
|
||||
example(pipe, seeds, 7, global_prompt, entity_prompts)
|
||||
|
||||
global_prompt = "写实摄影风格, 细节丰富。街头一位漂亮的女孩,穿着衬衫和短裤,手持写有“实体控制”的标牌,背景是繁忙的城市街道,阳光明媚,行人匆匆。"
|
||||
entity_prompts = ["一个漂亮的女孩", "标牌 '实体控制'", "短裤", "衬衫"]
|
||||
example(pipe, seeds, 4, global_prompt, entity_prompts)
|
||||
@@ -0,0 +1,129 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from modelscope import dataset_snapshot_download, snapshot_download
|
||||
import random
|
||||
|
||||
|
||||
def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):
|
||||
# Create a blank image for overlays
|
||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||
|
||||
colors = [
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
]
|
||||
# Generate random colors for each mask
|
||||
if use_random_colors:
|
||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
||||
|
||||
# Font settings
|
||||
try:
|
||||
font = ImageFont.truetype("wqy-zenhei.ttc", font_size) # Adjust as needed
|
||||
except IOError:
|
||||
font = ImageFont.load_default(font_size)
|
||||
|
||||
# Overlay each mask onto the overlay image
|
||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
||||
# Convert mask to RGBA mode
|
||||
mask_rgba = mask.convert('RGBA')
|
||||
mask_data = mask_rgba.getdata()
|
||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
||||
mask_rgba.putdata(new_data)
|
||||
|
||||
# Draw the mask prompt text on the mask
|
||||
draw = ImageDraw.Draw(mask_rgba)
|
||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
||||
|
||||
# Alpha composite the overlay with this mask
|
||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
||||
|
||||
# Composite the overlay onto the original image
|
||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
||||
|
||||
# Save or display the resulting image
|
||||
result.save(output_path)
|
||||
|
||||
return result
|
||||
|
||||
def example(pipe, seeds, example_id, global_prompt, entity_prompts):
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/qwen-image/example_{example_id}/*.png")
|
||||
masks = [Image.open(f"./data/examples/eligen/qwen-image/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||
negative_prompt = ""
|
||||
for seed in seeds:
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt=global_prompt,
|
||||
cfg_scale=4.0,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=30,
|
||||
seed=seed,
|
||||
height=1024,
|
||||
width=1024,
|
||||
eligen_entity_prompts=entity_prompts,
|
||||
eligen_entity_masks=masks,
|
||||
)
|
||||
image.save(f"eligen_example_{example_id}_{seed}.png")
|
||||
visualize_masks(image, masks, entity_prompts, f"eligen_example_{example_id}_mask_{seed}.png")
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
snapshot_download("DiffSynth-Studio/Qwen-Image-EliGen", local_dir="models/DiffSynth-Studio/Qwen-Image-EliGen", allow_file_pattern="model.safetensors")
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-EliGen/model.safetensors")
|
||||
|
||||
# example 1
|
||||
global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\n"
|
||||
entity_prompts = ["cliff", "sea", "moon", "sailing boat", "a seated beautiful woman", "pale blue long dress with soft glow"]
|
||||
example(pipe, [0], 1, global_prompt, entity_prompts)
|
||||
|
||||
# example 2
|
||||
global_prompt = "samurai girl wearing a kimono, she's holding a sword glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render."
|
||||
entity_prompts = ["flowing hair", "sword glowing with red flame", "A cute bird", "yellow belt"]
|
||||
example(pipe, [0], 2, global_prompt, entity_prompts)
|
||||
|
||||
# example 3
|
||||
global_prompt = "Image of a neverending staircase up to a mysterious palace in the sky, The ancient palace stood majestically atop a mist-shrouded mountain, sunrise, two traditional monk walk in the stair looking at the sunrise, fog,see-through, best quality, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, photorealistic, concept art, harmony, serenity, tranquility, ambient occlusion, halation, cozy ambient lighting, dynamic lighting,masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning,"
|
||||
entity_prompts = ["ancient palace", "stone staircase with railings", "a traditional monk", "a traditional monk"]
|
||||
example(pipe, [27], 3, global_prompt, entity_prompts)
|
||||
|
||||
# example 4
|
||||
global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'"
|
||||
entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"]
|
||||
example(pipe, [21], 4, global_prompt, entity_prompts)
|
||||
|
||||
# example 5
|
||||
global_prompt = "A captivating, dramatic scene in a painting that exudes mystery and foreboding. A white sky, swirling blue clouds, and a crescent yellow moon illuminate a solitary woman standing near the water's edge. Her long dress flows in the wind, silhouetted against the eerie glow. The water mirrors the fiery sky and moonlight, amplifying the uneasy atmosphere."
|
||||
entity_prompts = ["crescent yellow moon", "a solitary woman", "water", "swirling blue clouds"]
|
||||
example(pipe, [0], 5, global_prompt, entity_prompts)
|
||||
|
||||
# example 7, same prompt with different seeds
|
||||
seeds = range(5, 9)
|
||||
global_prompt = "A beautiful asia woman wearing white dress, holding a mirror, with a forest background."
|
||||
entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"]
|
||||
example(pipe, seeds, 7, global_prompt, entity_prompts)
|
||||
@@ -0,0 +1,36 @@
|
||||
from PIL import Image
|
||||
import torch
|
||||
from modelscope import dataset_snapshot_download, snapshot_download
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from diffsynth.controlnets.processors import Annotator
|
||||
|
||||
allow_file_pattern = ["sk_model.pth", "sk_model2.pth", "dpt_hybrid-midas-501f0c75.pt", "ControlNetHED.pth", "body_pose_model.pth", "hand_pose_model.pth", "facenet.pth", "scannet.pt"]
|
||||
snapshot_download("lllyasviel/Annotators", local_dir="models/Annotators", allow_file_pattern=allow_file_pattern)
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
snapshot_download("DiffSynth-Studio/Qwen-Image-In-Context-Control-Union", local_dir="models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union", allow_file_pattern="model.safetensors")
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union/model.safetensors")
|
||||
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/qwen-image-context-control/image.jpg")
|
||||
origin_image = Image.open("data/examples/qwen-image-context-control/image.jpg").resize((1024, 1024))
|
||||
annotator_ids = ['openpose', 'canny', 'depth', 'lineart', 'softedge', 'normal']
|
||||
for annotator_id in annotator_ids:
|
||||
annotator = Annotator(processor_id=annotator_id, device="cuda")
|
||||
control_image = annotator(origin_image)
|
||||
control_image.save(f"{annotator.processor_id}.png")
|
||||
|
||||
control_prompt = "Context_Control. "
|
||||
prompt = f"{control_prompt}一个穿着淡蓝色的漂亮女孩正在翩翩起舞,背景是梦幻的星空,光影交错,细节精致。"
|
||||
negative_prompt = "网格化,规则的网格,模糊, 低分辨率, 低质量, 变形, 畸形, 错误的解剖学, 变形的手, 变形的身体, 变形的脸, 变形的头发, 变形的眼睛, 变形的嘴巴"
|
||||
image = pipe(prompt, seed=1, negative_prompt=negative_prompt, context_image=control_image, height=1024, width=1024)
|
||||
image.save(f"image_{annotator.processor_id}.png")
|
||||
@@ -0,0 +1,38 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_canny.csv \
|
||||
--data_file_keys "image,blockwise_controlnet_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny:model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \
|
||||
--output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Canny_full" \
|
||||
--trainable_models "blockwise_controlnet" \
|
||||
--extra_inputs "blockwise_controlnet_image" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
|
||||
# If you want to pre-train a Blockwise ControlNet from scratch,
|
||||
# please run the following script to first generate the initialized model weights file,
|
||||
# and then start training with a high learning rate (1e-3).
|
||||
|
||||
# python examples/qwen_image/model_training/scripts/Qwen-Image-Blockwise-ControlNet-Initialize.py
|
||||
|
||||
# accelerate launch examples/qwen_image/model_training/train.py \
|
||||
# --dataset_base_path data/example_image_dataset \
|
||||
# --dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_canny.csv \
|
||||
# --data_file_keys "image,blockwise_controlnet_image" \
|
||||
# --max_pixels 1048576 \
|
||||
# --dataset_repeat 50 \
|
||||
# --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
# --model_paths '["models/blockwise_controlnet.safetensors"]' \
|
||||
# --learning_rate 1e-3 \
|
||||
# --num_epochs 2 \
|
||||
# --remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \
|
||||
# --output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Canny_full" \
|
||||
# --trainable_models "blockwise_controlnet" \
|
||||
# --extra_inputs "blockwise_controlnet_image" \
|
||||
# --use_gradient_checkpointing \
|
||||
# --find_unused_parameters
|
||||
@@ -0,0 +1,38 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_depth.csv \
|
||||
--data_file_keys "image,blockwise_controlnet_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth:model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \
|
||||
--output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Depth_full" \
|
||||
--trainable_models "blockwise_controlnet" \
|
||||
--extra_inputs "blockwise_controlnet_image" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
|
||||
# If you want to pre-train a Blockwise ControlNet from scratch,
|
||||
# please run the following script to first generate the initialized model weights file,
|
||||
# and then start training with a high learning rate (1e-3).
|
||||
|
||||
# python examples/qwen_image/model_training/scripts/Qwen-Image-Blockwise-ControlNet-Initialize.py
|
||||
|
||||
# accelerate launch examples/qwen_image/model_training/train.py \
|
||||
# --dataset_base_path data/example_image_dataset \
|
||||
# --dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_depth.csv \
|
||||
# --data_file_keys "image,blockwise_controlnet_image" \
|
||||
# --max_pixels 1048576 \
|
||||
# --dataset_repeat 50 \
|
||||
# --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
# --model_paths '["models/blockwise_controlnet.safetensors"]' \
|
||||
# --learning_rate 1e-3 \
|
||||
# --num_epochs 2 \
|
||||
# --remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \
|
||||
# --output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Depth_full" \
|
||||
# --trainable_models "blockwise_controlnet" \
|
||||
# --extra_inputs "blockwise_controlnet_image" \
|
||||
# --use_gradient_checkpointing \
|
||||
# --find_unused_parameters
|
||||
@@ -0,0 +1,38 @@
|
||||
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config.yaml examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_inpaint.csv \
|
||||
--data_file_keys "image,blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint:model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \
|
||||
--output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_full" \
|
||||
--trainable_models "blockwise_controlnet" \
|
||||
--extra_inputs "blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
|
||||
# If you want to pre-train a Inpaint Blockwise ControlNet from scratch,
|
||||
# please run the following script to first generate the initialized model weights file,
|
||||
# and then start training with a high learning rate (1e-3).
|
||||
|
||||
# python examples/qwen_image/model_training/scripts/Qwen-Image-Blockwise-ControlNet-Inpaint-Initialize.py
|
||||
|
||||
# accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config.yaml examples/qwen_image/model_training/train.py \
|
||||
# --dataset_base_path data/example_image_dataset \
|
||||
# --dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_inpaint.csv \
|
||||
# --data_file_keys "image,blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \
|
||||
# --max_pixels 1048576 \
|
||||
# --dataset_repeat 50 \
|
||||
# --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
# --model_paths '["models/blockwise_controlnet_inpaint.safetensors"]' \
|
||||
# --learning_rate 1e-3 \
|
||||
# --num_epochs 2 \
|
||||
# --remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \
|
||||
# --output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_full" \
|
||||
# --trainable_models "blockwise_controlnet" \
|
||||
# --extra_inputs "blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \
|
||||
# --use_gradient_checkpointing \
|
||||
# --find_unused_parameters
|
||||
@@ -9,4 +9,5 @@ accelerate launch --config_file examples/qwen_image/model_training/full/accelera
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-Distill-Full_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
||||
--data_file_keys "image,edit_image" \
|
||||
--extra_inputs "edit_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2509:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-Edit-2509_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
15
examples/qwen_image/model_training/full/Qwen-Image-Edit.sh
Normal file
15
examples/qwen_image/model_training/full/Qwen-Image-Edit.sh
Normal file
@@ -0,0 +1,15 @@
|
||||
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_edit.csv \
|
||||
--data_file_keys "image,edit_image" \
|
||||
--extra_inputs "edit_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image-Edit:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-Edit_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
@@ -9,4 +9,5 @@ accelerate launch --config_file examples/qwen_image/model_training/full/accelera
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 1
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -0,0 +1,17 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_canny.csv \
|
||||
--data_file_keys "image,blockwise_controlnet_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny:model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Canny_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "blockwise_controlnet_image" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
@@ -0,0 +1,17 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_depth.csv \
|
||||
--data_file_keys "image,blockwise_controlnet_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth:model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Depth_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "blockwise_controlnet_image" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
@@ -0,0 +1,17 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_inpaint.csv \
|
||||
--data_file_keys "image,blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint:model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
@@ -11,5 +11,5 @@ accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--align_to_opensource_format \
|
||||
--use_gradient_checkpointing
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_distill_qwen_image.csv \
|
||||
--data_file_keys "image" \
|
||||
--extra_inputs "seed,rand_device,num_inference_steps,cfg_scale" \
|
||||
--height 1328 \
|
||||
--width 1328 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-Distill-LoRA_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8 \
|
||||
--find_unused_parameters \
|
||||
--task direct_distill
|
||||
|
||||
# This is an experimental training feature designed to directly distill the model, enabling generation results with fewer steps to approximate those achieved with more steps.
|
||||
# The model (https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) is trained using this script.
|
||||
# The sample dataset is provided solely to demonstrate the dataset format. For actual usage, please construct a larger dataset using the base model.
|
||||
@@ -0,0 +1,18 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
||||
--data_file_keys "image,edit_image" \
|
||||
--extra_inputs "edit_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2509:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-Edit-2509_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8 \
|
||||
--find_unused_parameters
|
||||
18
examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh
Normal file
18
examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh
Normal file
@@ -0,0 +1,18 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_edit.csv \
|
||||
--data_file_keys "image,edit_image" \
|
||||
--extra_inputs "edit_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image-Edit:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-Edit_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8 \
|
||||
--find_unused_parameters
|
||||
@@ -0,0 +1,18 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path "data/example_image_dataset" \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_eligen.json \
|
||||
--data_file_keys "image,eligen_entity_masks" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-EliGen-Poster_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "eligen_entity_masks,eligen_entity_prompts" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters \
|
||||
--lora_checkpoint "models/DiffSynth-Studio/Qwen-Image-EliGen-V2/model.safetensors"
|
||||
17
examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh
Normal file
17
examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh
Normal file
@@ -0,0 +1,17 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path "data/example_image_dataset" \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_eligen.json \
|
||||
--data_file_keys "image,eligen_entity_masks" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-EliGen_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "eligen_entity_masks,eligen_entity_prompts" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
@@ -0,0 +1,20 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path "data/example_image_dataset" \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_qwenimage_context.csv \
|
||||
--data_file_keys "image,context_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-In-Context-Control-Union_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 64 \
|
||||
--lora_checkpoint "models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union/model.safetensors" \
|
||||
--extra_inputs "context_image" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
|
||||
# if you want to train from scratch, you can remove the --lora_checkpoint argument
|
||||
@@ -0,0 +1,26 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
--output_path "./models/train/Qwen-Image_lora_cache" \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8 \
|
||||
--task data_process
|
||||
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path models/train/Qwen-Image_lora_cache \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8 \
|
||||
--find_unused_parameters \
|
||||
--enable_fp8_training
|
||||
@@ -11,5 +11,6 @@ accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--align_to_opensource_format \
|
||||
--use_gradient_checkpointing
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8 \
|
||||
--find_unused_parameters
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
# This script is for initializing a Qwen-Image-Blockwise-ControlNet
|
||||
from diffsynth import hash_state_dict_keys
|
||||
from diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
controlnet = QwenImageBlockWiseControlNet().to(dtype=torch.bfloat16, device="cuda")
|
||||
controlnet.init_weight()
|
||||
state_dict_controlnet = controlnet.state_dict()
|
||||
|
||||
print(hash_state_dict_keys(state_dict_controlnet))
|
||||
save_file(state_dict_controlnet, "models/blockwise_controlnet.safetensors")
|
||||
@@ -0,0 +1,12 @@
|
||||
# This script is for initializing a Inpaint Qwen-Image-ControlNet
|
||||
import torch
|
||||
from diffsynth import hash_state_dict_keys
|
||||
from diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet
|
||||
from safetensors.torch import save_file
|
||||
|
||||
controlnet = QwenImageBlockWiseControlNet(additional_in_dim=4).to(dtype=torch.bfloat16, device="cuda")
|
||||
controlnet.init_weight()
|
||||
state_dict_controlnet = controlnet.state_dict()
|
||||
|
||||
print(hash_state_dict_keys(state_dict_controlnet))
|
||||
save_file(state_dict_controlnet, "models/blockwise_controlnet_inpaint.safetensors")
|
||||
@@ -1,7 +1,9 @@
|
||||
import torch, os, json
|
||||
from diffsynth import load_state_dict
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser
|
||||
from diffsynth.models.lora import QwenImageLoRAConverter
|
||||
from diffsynth.pipelines.flux_image_new import ControlNetInput
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task, DPOLoss
|
||||
from diffsynth.trainers.unified_dataset import UnifiedDataset
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
@@ -10,47 +12,35 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
def __init__(
|
||||
self,
|
||||
model_paths=None, model_id_with_origin_paths=None,
|
||||
tokenizer_path=None,
|
||||
tokenizer_path=None, processor_path=None,
|
||||
trainable_models=None,
|
||||
lora_base_model=None, lora_target_modules="", lora_rank=32,
|
||||
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||
use_gradient_checkpointing=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
extra_inputs=None,
|
||||
enable_fp8_training=False,
|
||||
task="sft",
|
||||
):
|
||||
super().__init__()
|
||||
# Load models
|
||||
model_configs = []
|
||||
if model_paths is not None:
|
||||
model_paths = json.loads(model_paths)
|
||||
model_configs += [ModelConfig(path=path) for path in model_paths]
|
||||
if model_id_with_origin_paths is not None:
|
||||
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
||||
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths]
|
||||
if tokenizer_path is not None:
|
||||
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=ModelConfig(tokenizer_path))
|
||||
else:
|
||||
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
|
||||
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=enable_fp8_training)
|
||||
tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
|
||||
processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path)
|
||||
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config)
|
||||
|
||||
# Training mode
|
||||
self.switch_pipe_to_training_mode(
|
||||
self.pipe, trainable_models,
|
||||
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
|
||||
enable_fp8_training=enable_fp8_training,
|
||||
)
|
||||
|
||||
# Reset training scheduler (do it in each training step)
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
# Freeze untrainable models
|
||||
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
|
||||
|
||||
# Add LoRA to the base models
|
||||
if lora_base_model is not None:
|
||||
model = self.add_lora_to_model(
|
||||
getattr(self.pipe, lora_base_model),
|
||||
target_modules=lora_target_modules.split(","),
|
||||
lora_rank=lora_rank
|
||||
)
|
||||
setattr(self.pipe, lora_base_model, model)
|
||||
|
||||
# Store other configs
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||
|
||||
self.task = task
|
||||
|
||||
|
||||
def forward_preprocess(self, data):
|
||||
# CFG-sensitive parameters
|
||||
@@ -70,11 +60,22 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
"rand_device": self.pipe.device,
|
||||
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||
"edit_image_auto_resize": True,
|
||||
}
|
||||
|
||||
# Extra inputs
|
||||
controlnet_input, blockwise_controlnet_input = {}, {}
|
||||
for extra_input in self.extra_inputs:
|
||||
inputs_shared[extra_input] = data[extra_input]
|
||||
if extra_input.startswith("blockwise_controlnet_"):
|
||||
blockwise_controlnet_input[extra_input.replace("blockwise_controlnet_", "")] = data[extra_input]
|
||||
elif extra_input.startswith("controlnet_"):
|
||||
controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input]
|
||||
else:
|
||||
inputs_shared[extra_input] = data[extra_input]
|
||||
if len(controlnet_input) > 0:
|
||||
inputs_shared["controlnet_inputs"] = [ControlNetInput(**controlnet_input)]
|
||||
if len(blockwise_controlnet_input) > 0:
|
||||
inputs_shared["blockwise_controlnet_inputs"] = [ControlNetInput(**blockwise_controlnet_input)]
|
||||
|
||||
# Pipeline units will automatically process the input parameters.
|
||||
for unit in self.pipe.units:
|
||||
@@ -82,39 +83,71 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
return {**inputs_shared, **inputs_posi}
|
||||
|
||||
|
||||
def forward(self, data, inputs=None):
|
||||
if inputs is None: inputs = self.forward_preprocess(data)
|
||||
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||
loss = self.pipe.training_loss(**models, **inputs)
|
||||
return loss
|
||||
def forward(self, data, inputs=None, return_inputs=False):
|
||||
# DPO (DPO requires a special training loss)
|
||||
if self.task == "dpo":
|
||||
loss = DPOLoss().loss(self, data)
|
||||
return loss
|
||||
else:
|
||||
# Inputs
|
||||
if inputs is None:
|
||||
inputs = self.forward_preprocess(data)
|
||||
else:
|
||||
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||
if return_inputs: return inputs
|
||||
|
||||
# Loss
|
||||
if self.task == "sft":
|
||||
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||
loss = self.pipe.training_loss(**models, **inputs)
|
||||
elif self.task == "data_process":
|
||||
loss = inputs
|
||||
elif self.task == "direct_distill":
|
||||
loss = self.pipe.direct_distill_loss(**inputs)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported task: {self.task}.")
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = qwen_image_parser()
|
||||
args = parser.parse_args()
|
||||
dataset = ImageDataset(args=args)
|
||||
dataset = UnifiedDataset(
|
||||
base_path=args.dataset_base_path,
|
||||
metadata_path=args.dataset_metadata_path,
|
||||
repeat=args.dataset_repeat,
|
||||
data_file_keys=args.data_file_keys.split(","),
|
||||
main_data_operator=UnifiedDataset.default_image_operator(
|
||||
base_path=args.dataset_base_path,
|
||||
max_pixels=args.max_pixels,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
height_division_factor=16,
|
||||
width_division_factor=16,
|
||||
)
|
||||
)
|
||||
model = QwenImageTrainingModule(
|
||||
model_paths=args.model_paths,
|
||||
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||
tokenizer_path=args.tokenizer_path,
|
||||
processor_path=args.processor_path,
|
||||
trainable_models=args.trainable_models,
|
||||
lora_base_model=args.lora_base_model,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
lora_rank=args.lora_rank,
|
||||
lora_checkpoint=args.lora_checkpoint,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
extra_inputs=args.extra_inputs,
|
||||
enable_fp8_training=args.enable_fp8_training,
|
||||
task=args.task,
|
||||
)
|
||||
model_logger = ModelLogger(
|
||||
args.output_path,
|
||||
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||
state_dict_converter=QwenImageLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
|
||||
)
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
|
||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||
launch_training_task(
|
||||
dataset, model, model_logger, optimizer, scheduler,
|
||||
num_epochs=args.num_epochs,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
)
|
||||
model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt)
|
||||
launcher_map = {
|
||||
"sft": launch_training_task,
|
||||
"data_process": launch_data_process_task,
|
||||
"direct_distill": launch_training_task,
|
||||
"dpo": launch_training_task,
|
||||
}
|
||||
launcher_map[args.task](dataset, model, model_logger, args=args)
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
from PIL import Image
|
||||
import torch
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig(path="models/train/Qwen-Image-Blockwise-ControlNet-Canny_full/epoch-1.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||
local_dir="./data/example_image_dataset",
|
||||
allow_file_pattern="canny/image_1.jpg"
|
||||
)
|
||||
controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1328, 1328))
|
||||
|
||||
prompt = "一只小狗,毛发光洁柔顺,眼神灵动,背景是樱花纷飞的春日庭院,唯美温馨。"
|
||||
image = pipe(
|
||||
prompt, seed=0,
|
||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,31 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
from PIL import Image
|
||||
import torch
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig(path="models/train/Qwen-Image-Blockwise-ControlNet-Depth_full/epoch-1.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||
local_dir="./data/example_image_dataset",
|
||||
allow_file_pattern="depth/image_1.jpg"
|
||||
)
|
||||
controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1328, 1328))
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(
|
||||
prompt, seed=0,
|
||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig(path="models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_full/epoch-1.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||
local_dir="./data/example_image_dataset",
|
||||
allow_file_pattern="inpaint/*.jpg"
|
||||
)
|
||||
prompt = "a cat with sunglasses"
|
||||
controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024))
|
||||
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024))
|
||||
image = pipe(
|
||||
prompt, seed=0,
|
||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)],
|
||||
height=1024, width=1024,
|
||||
num_inference_steps=40,
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,26 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from diffsynth import load_state_dict
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2509", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=None,
|
||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
||||
)
|
||||
state_dict = load_state_dict("models/train/Qwen-Image-Edit-2509_full/epoch-1.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
|
||||
prompt = "Change the color of the dress in Figure 1 to the color shown in Figure 2."
|
||||
images = [
|
||||
Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)),
|
||||
Image.open("data/example_image_dataset/edit/image_color.jpg").resize((1024, 1024)),
|
||||
]
|
||||
image = pipe(prompt, edit_image=images, seed=123, num_inference_steps=40, height=1024, width=1024)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from diffsynth import load_state_dict
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=None,
|
||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
||||
)
|
||||
state_dict = load_state_dict("models/train/Qwen-Image-Edit_full/epoch-1.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
|
||||
prompt = "将裙子改为粉色"
|
||||
image = Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024))
|
||||
image = pipe(prompt, edit_image=image, seed=0, num_inference_steps=40, height=1024, width=1024)
|
||||
image.save(f"image.jpg")
|
||||
@@ -0,0 +1,32 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
from PIL import Image
|
||||
import torch
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny", origin_file_pattern="model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Blockwise-ControlNet-Canny_lora/epoch-4.safetensors")
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||
local_dir="./data/example_image_dataset",
|
||||
allow_file_pattern="canny/image_1.jpg"
|
||||
)
|
||||
controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1328, 1328))
|
||||
|
||||
prompt = "一只小狗,毛发光洁柔顺,眼神灵动,背景是樱花纷飞的春日庭院,唯美温馨。"
|
||||
image = pipe(
|
||||
prompt, seed=0,
|
||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,33 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
from PIL import Image
|
||||
import torch
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Blockwise-ControlNet-Depth_lora/epoch-4.safetensors")
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||
local_dir="./data/example_image_dataset",
|
||||
allow_file_pattern="depth/image_1.jpg"
|
||||
)
|
||||
|
||||
controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1328, 1328))
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(
|
||||
prompt, seed=0,
|
||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_lora/epoch-4.safetensors")
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||
local_dir="./data/example_image_dataset",
|
||||
allow_file_pattern="inpaint/*.jpg"
|
||||
)
|
||||
prompt = "a cat with sunglasses"
|
||||
controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024))
|
||||
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024))
|
||||
image = pipe(
|
||||
prompt, seed=0,
|
||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)],
|
||||
height=1024, width=1024,
|
||||
num_inference_steps=40,
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,23 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Distill-LoRA_lora/epoch-4.safetensors")
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(
|
||||
prompt,
|
||||
seed=0,
|
||||
num_inference_steps=4,
|
||||
cfg_scale=1,
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2509", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=None,
|
||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Edit-2509_lora/epoch-4.safetensors")
|
||||
|
||||
prompt = "Change the color of the dress in Figure 1 to the color shown in Figure 2."
|
||||
images = [
|
||||
Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)),
|
||||
Image.open("data/example_image_dataset/edit/image_color.jpg").resize((1024, 1024)),
|
||||
]
|
||||
image = pipe(prompt, edit_image=images, seed=123, num_inference_steps=40, height=1024, width=1024)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,21 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=None,
|
||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Edit_lora/epoch-4.safetensors")
|
||||
|
||||
prompt = "将裙子改为粉色"
|
||||
image = Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024))
|
||||
image = pipe(prompt, edit_image=image, seed=0, num_inference_steps=40, height=1024, width=1024)
|
||||
image.save(f"image.jpg")
|
||||
@@ -0,0 +1,29 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-EliGen-Poster_lora/epoch-4.safetensors")
|
||||
|
||||
|
||||
entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"]
|
||||
global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'"
|
||||
masks = [Image.open(f"data/example_image_dataset/eligen/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||
|
||||
image = pipe(global_prompt,
|
||||
seed=0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
eligen_entity_prompts=entity_prompts,
|
||||
eligen_entity_masks=masks)
|
||||
image.save("Qwen-Image-EliGen-Poster.jpg")
|
||||
@@ -0,0 +1,29 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-EliGen_lora/epoch-4.safetensors")
|
||||
|
||||
|
||||
entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"]
|
||||
global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'"
|
||||
masks = [Image.open(f"data/example_image_dataset/eligen/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||
|
||||
image = pipe(global_prompt,
|
||||
seed=0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
eligen_entity_prompts=entity_prompts,
|
||||
eligen_entity_masks=masks)
|
||||
image.save("Qwen-Image_EliGen.jpg")
|
||||
@@ -0,0 +1,19 @@
|
||||
from PIL import Image
|
||||
import torch
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-In-Context-Control-Union_lora/epoch-4.safetensors")
|
||||
image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1024, 1024))
|
||||
prompt = "Context_Control. a dog"
|
||||
image = pipe(prompt=prompt, seed=0, context_image=image, height=1024, width=1024)
|
||||
image.save("image_context.jpg")
|
||||
@@ -48,9 +48,15 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|
||||
| Model ID | Extra Parameters | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./model_inference/Wan2.2-Animate-14B.py)|[code](./model_training/full/Wan2.2-Animate-14B.sh)|[code](./model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./model_training/lora/Wan2.2-Animate-14B.sh)|[code](./model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||
|[Wan-AI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](./model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](./model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](./model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.2-Fun-A14B-InP.py)|[code](./model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](./model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](./model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](./model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](./model_inference/Wan2.2-Fun-A14B-Control.py)|[code](./model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](./model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](./model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](./model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./model_inference/Wan2.1-T2V-1.3B.py)|[code](./model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](./model_inference/Wan2.1-T2V-14B.py)|[code](./model_training/full/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./model_training/lora/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](./model_inference/Wan2.1-I2V-14B-480P.py)|[code](./model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||
@@ -280,6 +286,7 @@ The script includes the following parameters:
|
||||
* `--num_frames`: Number of frames per video. Frames are sampled from the video prefix.
|
||||
* `--data_file_keys`: Data file keys in the metadata. Comma-separated.
|
||||
* `--dataset_repeat`: Number of times to repeat the dataset per epoch.
|
||||
* `--dataset_num_workers`: Number of workers for data loading.
|
||||
* Models
|
||||
* `--model_paths`: Paths to load models. In JSON format.
|
||||
* `--model_id_with_origin_paths`: Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.
|
||||
@@ -287,14 +294,18 @@ The script includes the following parameters:
|
||||
* `--min_timestep_boundary`: Minimum value of the timestep interval, ranging from 0 to 1. Default is 1. This needs to be manually set only when training mixed models with multiple DiTs, for example, [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B).
|
||||
* Training
|
||||
* `--learning_rate`: Learning rate.
|
||||
* `--weight_decay`: Weight decay.
|
||||
* `--num_epochs`: Number of epochs.
|
||||
* `--output_path`: Output save path.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefix in ckpt.
|
||||
* `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.
|
||||
* `--find_unused_parameters`: Whether to find unused parameters in DDP.
|
||||
* Trainable Modules
|
||||
* `--trainable_models`: Models to train, e.g., dit, vae, text_encoder.
|
||||
* `--lora_base_model`: Which model LoRA is added to.
|
||||
* `--lora_target_modules`: Which layers LoRA is added to.
|
||||
* `--lora_rank`: Rank of LoRA.
|
||||
* `--lora_checkpoint`: Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.
|
||||
* Extra Inputs
|
||||
* `--extra_inputs`: Additional model inputs, comma-separated.
|
||||
* VRAM Management
|
||||
|
||||
@@ -48,9 +48,15 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|
||||
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./model_inference/Wan2.2-Animate-14B.py)|[code](./model_training/full/Wan2.2-Animate-14B.sh)|[code](./model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./model_training/lora/Wan2.2-Animate-14B.sh)|[code](./model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||
|[Wan-AI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](./model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](./model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](./model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.2-Fun-A14B-InP.py)|[code](./model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](./model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](./model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](./model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](./model_inference/Wan2.2-Fun-A14B-Control.py)|[code](./model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](./model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](./model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](./model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./model_inference/Wan2.1-T2V-1.3B.py)|[code](./model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](./model_inference/Wan2.1-T2V-14B.py)|[code](./model_training/full/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./model_training/lora/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](./model_inference/Wan2.1-I2V-14B-480P.py)|[code](./model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||
@@ -282,6 +288,7 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai
|
||||
* `--num_frames`: 每个视频中的帧数。帧从视频前缀中采样。
|
||||
* `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。
|
||||
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
|
||||
* 模型
|
||||
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||
* `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors。用逗号分隔。
|
||||
@@ -289,14 +296,18 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai
|
||||
* `--min_timestep_boundary`: Timestep 区间最小值,范围为 0~1,默认为 1,仅在多 DiT 的混合模型训练中需要手动设置,例如 [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)。
|
||||
* 训练
|
||||
* `--learning_rate`: 学习率。
|
||||
* `--weight_decay`:权重衰减大小。
|
||||
* `--num_epochs`: 轮数(Epoch)。
|
||||
* `--output_path`: 保存路径。
|
||||
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
|
||||
* `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次
|
||||
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数
|
||||
* 可训练模块
|
||||
* `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。
|
||||
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||
* `--lora_target_modules`: LoRA 添加到哪一层上。
|
||||
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||
* `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。
|
||||
* 额外模型输入
|
||||
* `--extra_inputs`: 额外的模型输入,以逗号分隔。
|
||||
* 显存管理
|
||||
|
||||
62
examples/wanvideo/model_inference/Wan2.2-Animate-14B.py
Normal file
62
examples/wanvideo/model_inference/Wan2.2-Animate-14B.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData, load_state_dict
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download, snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern="data/examples/wan/animate/*",
|
||||
)
|
||||
|
||||
# Animate
|
||||
input_image = Image.open("data/examples/wan/animate/animate_input_image.png")
|
||||
animate_pose_video = VideoData("data/examples/wan/animate/animate_pose_video.mp4").raw_data()[:81-4]
|
||||
animate_face_video = VideoData("data/examples/wan/animate/animate_face_video.mp4").raw_data()[:81-4]
|
||||
video = pipe(
|
||||
prompt="视频中的人在做动作",
|
||||
seed=0, tiled=True,
|
||||
input_image=input_image,
|
||||
animate_pose_video=animate_pose_video,
|
||||
animate_face_video=animate_face_video,
|
||||
num_frames=81, height=720, width=1280,
|
||||
num_inference_steps=20, cfg_scale=1,
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|
||||
# Replace
|
||||
snapshot_download("Wan-AI/Wan2.2-Animate-14B", allow_file_pattern="relighting_lora.ckpt", local_dir="models/Wan-AI/Wan2.2-Animate-14B")
|
||||
lora_state_dict = load_state_dict("models/Wan-AI/Wan2.2-Animate-14B/relighting_lora.ckpt", torch_dtype=torch.float32, device="cuda")["state_dict"]
|
||||
pipe.load_lora(pipe.dit, state_dict=lora_state_dict)
|
||||
input_image = Image.open("data/examples/wan/animate/replace_input_image.png")
|
||||
animate_pose_video = VideoData("data/examples/wan/animate/replace_pose_video.mp4").raw_data()[:81-4]
|
||||
animate_face_video = VideoData("data/examples/wan/animate/replace_face_video.mp4").raw_data()[:81-4]
|
||||
animate_inpaint_video = VideoData("data/examples/wan/animate/replace_inpaint_video.mp4").raw_data()[:81-4]
|
||||
animate_mask_video = VideoData("data/examples/wan/animate/replace_mask_video.mp4").raw_data()[:81-4]
|
||||
video = pipe(
|
||||
prompt="视频中的人在做动作",
|
||||
seed=0, tiled=True,
|
||||
input_image=input_image,
|
||||
animate_pose_video=animate_pose_video,
|
||||
animate_face_video=animate_face_video,
|
||||
animate_inpaint_video=animate_inpaint_video,
|
||||
animate_mask_video=animate_mask_video,
|
||||
num_frames=81, height=720, width=1280,
|
||||
num_inference_steps=20, cfg_scale=1,
|
||||
)
|
||||
save_video(video, "video2.mp4", fps=15, quality=5)
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
import torch
|
||||
from diffsynth import save_video,VideoData
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/input_image.jpg"
|
||||
)
|
||||
input_image = Image.open("data/examples/wan/input_image.jpg")
|
||||
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
input_image=input_image,
|
||||
camera_control_direction="Left", camera_control_speed=0.01,
|
||||
)
|
||||
save_video(video, "video_left.mp4", fps=15, quality=5)
|
||||
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
input_image=input_image,
|
||||
camera_control_direction="Up", camera_control_speed=0.01,
|
||||
)
|
||||
save_video(video, "video_up.mp4", fps=15, quality=5)
|
||||
35
examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py
Normal file
35
examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
from diffsynth import save_video,VideoData
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"]
|
||||
)
|
||||
|
||||
# Control video
|
||||
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
|
||||
reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832))
|
||||
video = pipe(
|
||||
prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
control_video=control_video, reference_image=reference_image,
|
||||
height=832, width=576, num_frames=49,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video.mp4", fps=15, quality=5)
|
||||
35
examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py
Normal file
35
examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
from diffsynth import save_video
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/input_image.jpg"
|
||||
)
|
||||
image = Image.open("data/examples/wan/input_image.jpg")
|
||||
|
||||
# First and last frame to video
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=image,
|
||||
seed=0, tiled=True,
|
||||
# You can input `end_image=xxx` to control the last frame of the video.
|
||||
# The model will automatically generate the dynamic content between `input_image` and `end_image`.
|
||||
)
|
||||
save_video(video, "video.mp4", fps=15, quality=5)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user