mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
Compare commits
89 Commits
qwen-image
...
dpo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1e8c201d3b | ||
|
|
d96709fb6a | ||
|
|
bf7b339efb | ||
|
|
b0abdaffb4 | ||
|
|
e9f29bc402 | ||
|
|
1a7f482fbd | ||
|
|
d93e8738cd | ||
|
|
7e5ce5d5c9 | ||
|
|
7aef554d83 | ||
|
|
090074e395 | ||
|
|
2dcdeefca8 | ||
|
|
452a6ca5cf | ||
|
|
d6cf20ef33 | ||
|
|
efdd6a59b6 | ||
|
|
42ec7b08eb | ||
|
|
d049fb6d1d | ||
|
|
144365b07d | ||
|
|
cb8de6be1b | ||
|
|
8c13362dcf | ||
|
|
c13fd7e0ee | ||
|
|
958ebf1352 | ||
|
|
b6da77e468 | ||
|
|
260e32217f | ||
|
|
5cee326f92 | ||
|
|
1d240994e7 | ||
|
|
a0bae07825 | ||
|
|
ff71720297 | ||
|
|
dea85643e6 | ||
|
|
6a46f32afe | ||
|
|
4641d0f360 | ||
|
|
826bab5962 | ||
|
|
5b6d112c15 | ||
|
|
febdaf6067 | ||
|
|
0a78bb9d38 | ||
|
|
9cea10cc69 | ||
|
|
caa17da5b9 | ||
|
|
fdeb363fa2 | ||
|
|
4147473c81 | ||
|
|
8a0bd7c377 | ||
|
|
b541b9bed2 | ||
|
|
419d47c195 | ||
|
|
ac2e859960 | ||
|
|
6663dca015 | ||
|
|
86e509ad31 | ||
|
|
8fcfa1dd2d | ||
|
|
2b7a2548b4 | ||
|
|
f0916e6bae | ||
|
|
822e80ec2f | ||
|
|
04e39f7de5 | ||
|
|
ce0b948655 | ||
|
|
c795e35142 | ||
|
|
f7c01f1367 | ||
|
|
cb49f0283f | ||
|
|
6a45815b23 | ||
|
|
8dae8d7bc8 | ||
|
|
f6418004bb | ||
|
|
c4b97cd591 | ||
|
|
b6d1ff01e0 | ||
|
|
0d81626fe7 | ||
|
|
e3f47a799b | ||
|
|
e014cad820 | ||
|
|
89bf3ce5cf | ||
|
|
3ebe118f23 | ||
|
|
7f719cefe6 | ||
|
|
46bd05b54d | ||
|
|
613dafbd09 | ||
|
|
952933eeb1 | ||
|
|
c0172e70b1 | ||
|
|
6ab426e641 | ||
|
|
d0467a7e8d | ||
|
|
36838a05ee | ||
|
|
5e6f9f89f1 | ||
|
|
2dad9a319c | ||
|
|
9ec0652339 | ||
|
|
7e348083ae | ||
|
|
29b12b2f4e | ||
|
|
b3f57ed920 | ||
|
|
c9fea729d8 | ||
|
|
9d0683df25 | ||
|
|
838b8109b1 | ||
|
|
3a9621f6da | ||
|
|
fff2c89360 | ||
|
|
ce61bef2b0 | ||
|
|
123f6dbadb | ||
|
|
f9ce261a0e | ||
|
|
d93de98a21 | ||
|
|
ad1da43476 | ||
|
|
398b1dbd7a | ||
|
|
9f6922bba9 |
31
README.md
31
README.md
@@ -64,6 +64,7 @@ Details: [./examples/qwen_image/](./examples/qwen_image/)
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
@@ -77,7 +78,10 @@ pipe = QwenImagePipeline.from_pretrained(
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "A detailed portrait of a girl underwater, wearing a blue flowing dress, hair gently floating, clear light and shadow, surrounded by bubbles, calm expression, fine details, dreamy and beautiful."
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image = pipe(
|
||||
prompt, seed=0, num_inference_steps=40,
|
||||
# edit_image=Image.open("xxx.jpg").resize((1328, 1328)) # For Qwen-Image-Edit
|
||||
)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
@@ -90,12 +94,16 @@ image.save("image.jpg")
|
||||
|Model ID|Inference|Low VRAM Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|-|-|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|
||||
|
||||
</details>
|
||||
|
||||
@@ -197,9 +205,13 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|
||||
| Model ID | Extra Parameters | Inference | Full Training | Validate After Full Training | LoRA Training | Validate After LoRA Training |
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](./examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||
@@ -368,6 +380,21 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
|
||||
|
||||
|
||||
## Update History
|
||||
|
||||
- **September 22, 2025**: We have supported Direct Preference Optimization (DPO) training for Qwen-Image. Please refer to the [example code](examples/qwen_image/model_training/lora/Qwen-Image-DPO.sh) for the training script.
|
||||
|
||||
- **September 9, 2025**: Our training framework now supports multiple training modes and has been adapted for Qwen-Image. In addition to the standard SFT training mode, Direct Distill is now also supported; please refer to [our example code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh). This feature is experimental, and we will continue to improve it to support comprehensive model training capabilities.
|
||||
|
||||
- **August 28, 2025** We support Wan2.2-S2V, an audio-driven cinematic video generation model open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
|
||||
|
||||
- **August 21, 2025**: [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) is released! Compared to the V1 version, the training dataset has been updated to the [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset), enabling generated images to better align with the inherent image distribution and style of Qwen-Image. Please refer to [our sample code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py).
|
||||
|
||||
- **August 21, 2025**: We open-sourced the [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) structure control LoRA model. Following "In Context" routine, it supports various types of structural control conditions, including canny, depth, lineart, softedge, normal, and openpose. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py).
|
||||
|
||||
- **August 20, 2025** We open-sourced [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix), which improves the editing performance of Qwen-Image-Edit on low-resolution image inputs. Please refer to [our example code](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py).
|
||||
|
||||
- **August 19, 2025** 🔥 Qwen-Image-Edit is now open source. Welcome the new member to the image editing model family!
|
||||
|
||||
- **August 18, 2025** We trained and open-sourced the Inpaint ControlNet model for Qwen-Image, [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint), which adopts a lightweight architectural design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py).
|
||||
|
||||
- **August 15, 2025** We open-sourced the [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset). This is an image dataset generated using the Qwen-Image model, with a total of 160,000 `1024 x 1024` images. It includes the general, English text rendering, and Chinese text rendering subsets. We provide caption, entity and control images annotations for each image. Developers can use this dataset to train models such as ControlNet and EliGen for the Qwen-Image model. We aim to promote technological development through open-source contributions!
|
||||
|
||||
31
README_zh.md
31
README_zh.md
@@ -66,6 +66,7 @@ DiffSynth-Studio 为主流 Diffusion 模型(包括 FLUX、Wan 等)重新设
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
@@ -79,7 +80,10 @@ pipe = QwenImagePipeline.from_pretrained(
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image = pipe(
|
||||
prompt, seed=0, num_inference_steps=40,
|
||||
# edit_image=Image.open("xxx.jpg").resize((1328, 1328)) # For Qwen-Image-Edit
|
||||
)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
@@ -92,12 +96,16 @@ image.save("image.jpg")
|
||||
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|-|-|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|
||||
|
||||
</details>
|
||||
|
||||
@@ -197,9 +205,13 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|
||||
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](./examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||
@@ -384,6 +396,21 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
|
||||
|
||||
|
||||
## 更新历史
|
||||
|
||||
- **2025年9月22日** 我们支持了 Qwen-Image 的直接偏好对齐 (DPO) 训练,训练脚本请参考[示例代码](examples/qwen_image/model_training/lora/Qwen-Image-DPO.sh)。
|
||||
|
||||
- **2025年9月9日** 我们的训练框架支持了多种训练模式,目前已适配 Qwen-Image,除标准 SFT 训练模式外,已支持 Direct Distill,请参考[我们的示例代码](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)。这项功能是实验性的,我们将会继续完善已支持更全面的模型训练功能。
|
||||
|
||||
- **2025年8月28日** 我们支持了Wan2.2-S2V,一个音频驱动的电影级视频生成模型。请参见[./examples/wanvideo/](./examples/wanvideo/)。
|
||||
|
||||
- **2025年8月21日** [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) 发布!相比于 V1 版本,训练数据集变为 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset),因此,生成的图像更符合 Qwen-Image 本身的图像分布和风格。 请参考[我们的示例代码](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)。
|
||||
|
||||
- **2025年8月21日** 我们开源了 [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) 结构控制 LoRA 模型,采用 In Context 的技术路线,支持多种类别的结构控制条件,包括 canny, depth, lineart, softedge, normal, openpose。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)。
|
||||
|
||||
- **2025年8月20日** 我们开源了 [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) 模型,提升了 Qwen-Image-Edit 对低分辨率图像输入的编辑效果。请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)
|
||||
|
||||
- **2025年8月19日** 🔥 Qwen-Image-Edit 开源,欢迎图像编辑模型新成员!
|
||||
|
||||
- **2025年8月18日** 我们训练并开源了 Qwen-Image 的图像重绘 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)。
|
||||
|
||||
- **2025年8月15日** 我们开源了 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset) 数据集。这是一个使用 Qwen-Image 模型生成的图像数据集,共包含 160,000 张`1024 x 1024`图像。它包括通用、英文文本渲染和中文文本渲染子集。我们为每张图像提供了图像描述、实体和结构控制图像的标注。开发者可以使用这个数据集来训练 Qwen-Image 模型的 ControlNet 和 EliGen 等模型,我们旨在通过开源推动技术发展!
|
||||
|
||||
@@ -56,11 +56,13 @@ from ..models.stepvideo_vae import StepVideoVAE
|
||||
from ..models.stepvideo_dit import StepVideoModel
|
||||
|
||||
from ..models.wan_video_dit import WanModel
|
||||
from ..models.wan_video_dit_s2v import WanS2VModel
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
from ..models.wan_video_vae import WanVideoVAE, WanVideoVAE38
|
||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||
from ..models.wan_video_vace import VaceWanModel
|
||||
from ..models.wav2vec import WanS2VAudioEncoder
|
||||
|
||||
from ..models.step1x_connector import Qwen2Connector
|
||||
|
||||
@@ -150,9 +152,12 @@ model_loader_configs = [
|
||||
(None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "1f5ab7703c6fc803fdded85ff040c316", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "5b013604280dd715f8457c6ed6d6a626", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "2267d489f0ceb9f21836532952852ee5", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "47dbeab5e560db3180adf51dc0232fb1", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||
(None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
||||
(None, "966cffdcc52f9c46c391768b27637614", ["wan_video_dit"], [WanS2VModel], "civitai"),
|
||||
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
||||
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
||||
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||
@@ -170,6 +175,7 @@ model_loader_configs = [
|
||||
(None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"),
|
||||
(None, "073bce9cf969e317e5662cd570c3e79c", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
|
||||
(None, "a9e54e480a628f0b956a688a81c33bab", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
|
||||
(None, "06be60f3a4526586d8431cd038a71486", ["wans2v_audio_encoder"], [WanS2VAudioEncoder], "civitai"),
|
||||
]
|
||||
huggingface_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
|
||||
@@ -1 +1 @@
|
||||
from .video import VideoData, save_video, save_frames
|
||||
from .video import VideoData, save_video, save_frames, merge_video_audio, save_video_with_audio
|
||||
|
||||
@@ -2,6 +2,8 @@ import imageio, os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import subprocess
|
||||
import shutil
|
||||
|
||||
|
||||
class LowMemoryVideo:
|
||||
@@ -146,3 +148,70 @@ def save_frames(frames, save_path):
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
for i, frame in enumerate(tqdm(frames, desc="Saving images")):
|
||||
frame.save(os.path.join(save_path, f"{i}.png"))
|
||||
|
||||
|
||||
def merge_video_audio(video_path: str, audio_path: str):
|
||||
# TODO: may need a in-python implementation to avoid subprocess dependency
|
||||
"""
|
||||
Merge the video and audio into a new video, with the duration set to the shorter of the two,
|
||||
and overwrite the original video file.
|
||||
|
||||
Parameters:
|
||||
video_path (str): Path to the original video file
|
||||
audio_path (str): Path to the audio file
|
||||
"""
|
||||
|
||||
# check
|
||||
if not os.path.exists(video_path):
|
||||
raise FileNotFoundError(f"video file {video_path} does not exist")
|
||||
if not os.path.exists(audio_path):
|
||||
raise FileNotFoundError(f"audio file {audio_path} does not exist")
|
||||
|
||||
base, ext = os.path.splitext(video_path)
|
||||
temp_output = f"{base}_temp{ext}"
|
||||
|
||||
try:
|
||||
# create ffmpeg command
|
||||
command = [
|
||||
'ffmpeg',
|
||||
'-y', # overwrite
|
||||
'-i',
|
||||
video_path,
|
||||
'-i',
|
||||
audio_path,
|
||||
'-c:v',
|
||||
'copy', # copy video stream
|
||||
'-c:a',
|
||||
'aac', # use AAC audio encoder
|
||||
'-b:a',
|
||||
'192k', # set audio bitrate (optional)
|
||||
'-map',
|
||||
'0:v:0', # select the first video stream
|
||||
'-map',
|
||||
'1:a:0', # select the first audio stream
|
||||
'-shortest', # choose the shortest duration
|
||||
temp_output
|
||||
]
|
||||
|
||||
# execute the command
|
||||
result = subprocess.run(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
|
||||
# check result
|
||||
if result.returncode != 0:
|
||||
error_msg = f"FFmpeg execute failed: {result.stderr}"
|
||||
print(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
shutil.move(temp_output, video_path)
|
||||
print(f"Merge completed, saved to {video_path}")
|
||||
|
||||
except Exception as e:
|
||||
if os.path.exists(temp_output):
|
||||
os.remove(temp_output)
|
||||
print(f"merge_video_audio failed with error: {e}")
|
||||
|
||||
|
||||
def save_video_with_audio(frames, save_path, audio_path, fps=16, quality=9, ffmpeg_params=None):
|
||||
save_video(frames, save_path, fps, quality, ffmpeg_params)
|
||||
merge_video_audio(save_path, audio_path)
|
||||
|
||||
@@ -2,7 +2,8 @@ from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_e
|
||||
import numpy as np
|
||||
import cupy as cp
|
||||
import cv2
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class PatchMatcher:
|
||||
def __init__(
|
||||
@@ -233,13 +234,11 @@ class PyramidPatchMatcher:
|
||||
|
||||
def resample_image(self, images, level):
|
||||
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
||||
images = images.get()
|
||||
images_resample = []
|
||||
for image in images:
|
||||
image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
|
||||
images_resample.append(image_resample)
|
||||
images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
|
||||
return images_resample
|
||||
images_torch = torch.as_tensor(images, device='cuda', dtype=torch.float32)
|
||||
images_torch = images_torch.permute(0, 3, 1, 2)
|
||||
images_resample = F.interpolate(images_torch, size=(height, width), mode='area', align_corners=None)
|
||||
images_resample = images_resample.permute(0, 2, 3, 1).contiguous()
|
||||
return cp.asarray(images_resample)
|
||||
|
||||
def initialize_nnf(self, batch_size):
|
||||
if self.initialize == "random":
|
||||
@@ -262,14 +261,16 @@ class PyramidPatchMatcher:
|
||||
def update_nnf(self, nnf, level):
|
||||
# upscale
|
||||
nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
|
||||
nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
|
||||
nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
|
||||
nnf[:, 1::2, :, 0] += 1
|
||||
nnf[:, :, 1::2, 1] += 1
|
||||
# check if scale is 2
|
||||
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
||||
if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
|
||||
nnf = nnf.get().astype(np.float32)
|
||||
nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
|
||||
nnf = cp.array(np.stack(nnf), dtype=cp.int32)
|
||||
nnf_torch = torch.as_tensor(nnf, device='cuda', dtype=torch.float32)
|
||||
nnf_torch = nnf_torch.permute(0, 3, 1, 2)
|
||||
nnf_resized = F.interpolate(nnf_torch, size=(height, width), mode='bilinear', align_corners=False)
|
||||
nnf_resized = nnf_resized.permute(0, 2, 3, 1)
|
||||
nnf = cp.asarray(nnf_resized).astype(cp.int32)
|
||||
nnf = self.patch_matchers[level].clamp_bound(nnf)
|
||||
return nnf
|
||||
|
||||
|
||||
@@ -63,8 +63,8 @@ class QwenEmbedRope(nn.Module):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
pos_index = torch.arange(1024)
|
||||
neg_index = torch.arange(1024).flip(0) * -1 - 1
|
||||
pos_index = torch.arange(4096)
|
||||
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
||||
self.pos_freqs = torch.cat([
|
||||
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||
@@ -94,7 +94,7 @@ class QwenEmbedRope(nn.Module):
|
||||
|
||||
def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens):
|
||||
if isinstance(video_fhw, list):
|
||||
video_fhw = video_fhw[0]
|
||||
video_fhw = tuple(max([i[j] for i in video_fhw]) for j in range(3))
|
||||
_, height, width = video_fhw
|
||||
if self.scale_rope:
|
||||
max_vid_index = max(height // 2, width // 2)
|
||||
@@ -127,49 +127,102 @@ class QwenEmbedRope(nn.Module):
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
|
||||
if isinstance(video_fhw, list):
|
||||
video_fhw = video_fhw[0]
|
||||
frame, height, width = video_fhw
|
||||
rope_key = f"{frame}_{height}_{width}"
|
||||
vid_freqs = []
|
||||
max_vid_index = 0
|
||||
for idx, fhw in enumerate(video_fhw):
|
||||
frame, height, width = fhw
|
||||
rope_key = f"{idx}_{height}_{width}"
|
||||
|
||||
if rope_key not in self.rope_cache:
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat(
|
||||
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
|
||||
)
|
||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
self.rope_cache[rope_key] = freqs.clone().contiguous()
|
||||
vid_freqs.append(self.rope_cache[rope_key])
|
||||
|
||||
if rope_key not in self.rope_cache:
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat(
|
||||
[
|
||||
freqs_neg[1][-(height - height//2):],
|
||||
freqs_pos[1][:height//2]
|
||||
],
|
||||
dim=0
|
||||
)
|
||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat(
|
||||
[
|
||||
freqs_neg[2][-(width - width//2):],
|
||||
freqs_pos[2][:width//2]
|
||||
],
|
||||
dim=0
|
||||
)
|
||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
self.rope_cache[rope_key] = freqs.clone().contiguous()
|
||||
vid_freqs = self.rope_cache[rope_key]
|
||||
|
||||
if self.scale_rope:
|
||||
max_vid_index = max(height // 2, width // 2)
|
||||
else:
|
||||
max_vid_index = max(height, width)
|
||||
max_vid_index = max(height, width, max_vid_index)
|
||||
|
||||
max_len = max(txt_seq_lens)
|
||||
txt_freqs = self.pos_freqs[max_vid_index: max_vid_index + max_len, ...]
|
||||
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
|
||||
def forward_sampling(self, video_fhw, txt_seq_lens, device):
|
||||
self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
|
||||
vid_freqs = []
|
||||
max_vid_index = 0
|
||||
for idx, fhw in enumerate(video_fhw):
|
||||
frame, height, width = fhw
|
||||
rope_key = f"{idx}_{height}_{width}"
|
||||
if idx > 0 and f"{0}_{height}_{width}" not in self.rope_cache:
|
||||
frame_0, height_0, width_0 = video_fhw[0]
|
||||
|
||||
rope_key_0 = f"0_{height_0}_{width_0}"
|
||||
spatial_freqs_0 = self.rope_cache[rope_key_0].reshape(frame_0, height_0, width_0, -1)
|
||||
h_indices = torch.linspace(0, height_0 - 1, height).long()
|
||||
w_indices = torch.linspace(0, width_0 - 1, width).long()
|
||||
h_grid, w_grid = torch.meshgrid(h_indices, w_indices, indexing='ij')
|
||||
sampled_rope = spatial_freqs_0[:, h_grid, w_grid, :]
|
||||
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
sampled_rope[:, :, :, :freqs_frame.shape[-1]] = freqs_frame
|
||||
|
||||
seq_lens = frame * height * width
|
||||
self.rope_cache[rope_key] = sampled_rope.reshape(seq_lens, -1).clone()
|
||||
if rope_key not in self.rope_cache:
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat(
|
||||
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
|
||||
)
|
||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
self.rope_cache[rope_key] = freqs.clone()
|
||||
vid_freqs.append(self.rope_cache[rope_key].contiguous())
|
||||
|
||||
if self.scale_rope:
|
||||
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
||||
else:
|
||||
max_vid_index = max(height, width, max_vid_index)
|
||||
|
||||
max_len = max(txt_seq_lens)
|
||||
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
|
||||
@@ -414,6 +467,7 @@ class QwenImageDiT(torch.nn.Module):
|
||||
image_start = sum(seq_lens)
|
||||
image_end = total_seq_len
|
||||
cumsum = [0]
|
||||
single_image_seq = image_end - image_start
|
||||
for length in seq_lens:
|
||||
cumsum.append(cumsum[-1] + length)
|
||||
for i in range(N):
|
||||
@@ -421,6 +475,9 @@ class QwenImageDiT(torch.nn.Module):
|
||||
prompt_end = cumsum[i+1]
|
||||
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
|
||||
image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1)
|
||||
# repeat image mask to match the single image sequence length
|
||||
repeat_time = single_image_seq // image_mask.shape[-1]
|
||||
image_mask = image_mask.repeat(1, 1, repeat_time)
|
||||
# prompt update with image
|
||||
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
|
||||
# image update with prompt
|
||||
@@ -440,7 +497,8 @@ class QwenImageDiT(torch.nn.Module):
|
||||
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1)
|
||||
|
||||
return all_prompt_emb, image_rotary_emb, attention_mask
|
||||
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents=None,
|
||||
|
||||
@@ -182,7 +182,7 @@ def process_pose_file(cam_params, width=672, height=384, original_pose_width=128
|
||||
|
||||
|
||||
def generate_camera_coordinates(
|
||||
direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"],
|
||||
direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown", "In", "Out"],
|
||||
length: int,
|
||||
speed: float = 1/54,
|
||||
origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
|
||||
@@ -198,5 +198,9 @@ def generate_camera_coordinates(
|
||||
coor[13] += speed
|
||||
if "Down" in direction:
|
||||
coor[13] -= speed
|
||||
if "In" in direction:
|
||||
coor[18] -= speed
|
||||
if "Out" in direction:
|
||||
coor[18] += speed
|
||||
coordinates.append(coor)
|
||||
return coordinates
|
||||
|
||||
@@ -294,6 +294,7 @@ class WanModel(torch.nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.in_dim = in_dim
|
||||
self.freq_dim = freq_dim
|
||||
self.has_image_input = has_image_input
|
||||
self.patch_size = patch_size
|
||||
@@ -713,6 +714,42 @@ class WanModelStateDictConverter:
|
||||
"eps": 1e-6,
|
||||
"require_clip_embedding": False,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "2267d489f0ceb9f21836532952852ee5":
|
||||
# Wan2.2-Fun-A14B-Control
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 52,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": True,
|
||||
"require_clip_embedding": False,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "47dbeab5e560db3180adf51dc0232fb1":
|
||||
# Wan2.2-Fun-A14B-Control-Camera
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": False,
|
||||
"add_control_adapter": True,
|
||||
"in_dim_control_adapter": 24,
|
||||
"require_clip_embedding": False,
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict, config
|
||||
|
||||
625
diffsynth/models/wan_video_dit_s2v.py
Normal file
625
diffsynth/models/wan_video_dit_s2v.py
Normal file
@@ -0,0 +1,625 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
from .utils import hash_state_dict_keys
|
||||
from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d
|
||||
|
||||
|
||||
def torch_dfs(model: nn.Module, parent_name='root'):
|
||||
module_names, modules = [], []
|
||||
current_name = parent_name if parent_name else 'root'
|
||||
module_names.append(current_name)
|
||||
modules.append(model)
|
||||
|
||||
for name, child in model.named_children():
|
||||
if parent_name:
|
||||
child_name = f'{parent_name}.{name}'
|
||||
else:
|
||||
child_name = name
|
||||
child_modules, child_names = torch_dfs(child, child_name)
|
||||
module_names += child_names
|
||||
modules += child_modules
|
||||
return modules, module_names
|
||||
|
||||
|
||||
def rope_precompute(x, grid_sizes, freqs, start=None):
|
||||
b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
|
||||
|
||||
# split freqs
|
||||
if type(freqs) is list:
|
||||
trainable_freqs = freqs[1]
|
||||
freqs = freqs[0]
|
||||
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
||||
|
||||
# loop over samples
|
||||
output = torch.view_as_complex(x.detach().reshape(b, s, n, -1, 2).to(torch.float64))
|
||||
seq_bucket = [0]
|
||||
if not type(grid_sizes) is list:
|
||||
grid_sizes = [grid_sizes]
|
||||
for g in grid_sizes:
|
||||
if not type(g) is list:
|
||||
g = [torch.zeros_like(g), g]
|
||||
batch_size = g[0].shape[0]
|
||||
for i in range(batch_size):
|
||||
if start is None:
|
||||
f_o, h_o, w_o = g[0][i]
|
||||
else:
|
||||
f_o, h_o, w_o = start[i]
|
||||
|
||||
f, h, w = g[1][i]
|
||||
t_f, t_h, t_w = g[2][i]
|
||||
seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o
|
||||
seq_len = int(seq_f * seq_h * seq_w)
|
||||
if seq_len > 0:
|
||||
if t_f > 0:
|
||||
factor_f, factor_h, factor_w = (t_f / seq_f).item(), (t_h / seq_h).item(), (t_w / seq_w).item()
|
||||
# Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item())
|
||||
if f_o >= 0:
|
||||
f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1, seq_f).astype(int).tolist()
|
||||
else:
|
||||
f_sam = np.linspace(-f_o.item(), (-t_f - f_o).item() + 1, seq_f).astype(int).tolist()
|
||||
h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1, seq_h).astype(int).tolist()
|
||||
w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1, seq_w).astype(int).tolist()
|
||||
|
||||
assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0
|
||||
freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][f_sam].conj()
|
||||
freqs_0 = freqs_0.view(seq_f, 1, 1, -1)
|
||||
|
||||
freqs_i = torch.cat(
|
||||
[
|
||||
freqs_0.expand(seq_f, seq_h, seq_w, -1),
|
||||
freqs[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1),
|
||||
freqs[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1),
|
||||
],
|
||||
dim=-1
|
||||
).reshape(seq_len, 1, -1)
|
||||
elif t_f < 0:
|
||||
freqs_i = trainable_freqs.unsqueeze(1)
|
||||
# apply rotary embedding
|
||||
output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = freqs_i
|
||||
seq_bucket.append(seq_bucket[-1] + seq_len)
|
||||
return output
|
||||
|
||||
|
||||
class CausalConv1d(nn.Module):
|
||||
|
||||
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode='replicate', **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.pad_mode = pad_mode
|
||||
padding = (kernel_size - 1, 0) # T
|
||||
self.time_causal_padding = padding
|
||||
|
||||
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MotionEncoder_tc(nn.Module):
|
||||
|
||||
def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, need_global=True, dtype=None, device=None):
|
||||
factory_kwargs = {"dtype": dtype, "device": device}
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.need_global = need_global
|
||||
self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_heads, 3, stride=1)
|
||||
if need_global:
|
||||
self.conv1_global = CausalConv1d(in_dim, hidden_dim // 4, 3, stride=1)
|
||||
self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.act = nn.SiLU()
|
||||
self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2)
|
||||
self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2)
|
||||
|
||||
if need_global:
|
||||
self.final_linear = nn.Linear(hidden_dim, hidden_dim, **factory_kwargs)
|
||||
|
||||
self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.norm2 = nn.LayerNorm(hidden_dim // 2, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
|
||||
|
||||
def forward(self, x):
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x_ori = x.clone()
|
||||
b, c, t = x.shape
|
||||
x = self.conv1_local(x)
|
||||
x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads)
|
||||
x = self.norm1(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x = self.conv2(x)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm2(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x = self.conv3(x)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm3(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, '(b n) t c -> b t n c', b=b)
|
||||
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1).to(device=x.device, dtype=x.dtype)
|
||||
x = torch.cat([x, padding], dim=-2)
|
||||
x_local = x.clone()
|
||||
|
||||
if not self.need_global:
|
||||
return x_local
|
||||
|
||||
x = self.conv1_global(x_ori)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm1(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x = self.conv2(x)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm2(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x = self.conv3(x)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm3(x)
|
||||
x = self.act(x)
|
||||
x = self.final_linear(x)
|
||||
x = rearrange(x, '(b n) t c -> b t n c', b=b)
|
||||
|
||||
return x, x_local
|
||||
|
||||
|
||||
class FramePackMotioner(nn.Module):
|
||||
|
||||
def __init__(self, inner_dim=1024, num_heads=16, zip_frame_buckets=[1, 2, 16], drop_mode="drop", *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
||||
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
|
||||
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
|
||||
self.zip_frame_buckets = torch.tensor(zip_frame_buckets, dtype=torch.long)
|
||||
|
||||
self.inner_dim = inner_dim
|
||||
self.num_heads = num_heads
|
||||
self.freqs = torch.cat(precompute_freqs_cis_3d(inner_dim // num_heads), dim=1)
|
||||
self.drop_mode = drop_mode
|
||||
|
||||
def forward(self, motion_latents, add_last_motion=2):
|
||||
motion_frames = motion_latents[0].shape[1]
|
||||
mot = []
|
||||
mot_remb = []
|
||||
for m in motion_latents:
|
||||
lat_height, lat_width = m.shape[2], m.shape[3]
|
||||
padd_lat = torch.zeros(16, self.zip_frame_buckets.sum(), lat_height, lat_width).to(device=m.device, dtype=m.dtype)
|
||||
overlap_frame = min(padd_lat.shape[1], m.shape[1])
|
||||
if overlap_frame > 0:
|
||||
padd_lat[:, -overlap_frame:] = m[:, -overlap_frame:]
|
||||
|
||||
if add_last_motion < 2 and self.drop_mode != "drop":
|
||||
zero_end_frame = self.zip_frame_buckets[:self.zip_frame_buckets.__len__() - add_last_motion - 1].sum()
|
||||
padd_lat[:, -zero_end_frame:] = 0
|
||||
|
||||
padd_lat = padd_lat.unsqueeze(0)
|
||||
clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -self.zip_frame_buckets.sum():, :, :].split(
|
||||
list(self.zip_frame_buckets)[::-1], dim=2
|
||||
) # 16, 2 ,1
|
||||
|
||||
# patchfy
|
||||
clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2)
|
||||
clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(2).transpose(1, 2)
|
||||
clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(2).transpose(1, 2)
|
||||
|
||||
if add_last_motion < 2 and self.drop_mode == "drop":
|
||||
clean_latents_post = clean_latents_post[:, :0] if add_last_motion < 2 else clean_latents_post
|
||||
clean_latents_2x = clean_latents_2x[:, :0] if add_last_motion < 1 else clean_latents_2x
|
||||
|
||||
motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1)
|
||||
|
||||
# rope
|
||||
start_time_id = -(self.zip_frame_buckets[:1].sum())
|
||||
end_time_id = start_time_id + self.zip_frame_buckets[0]
|
||||
grid_sizes = [] if add_last_motion < 2 and self.drop_mode == "drop" else \
|
||||
[
|
||||
[torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
|
||||
torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1),
|
||||
torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]
|
||||
]
|
||||
|
||||
start_time_id = -(self.zip_frame_buckets[:2].sum())
|
||||
end_time_id = start_time_id + self.zip_frame_buckets[1] // 2
|
||||
grid_sizes_2x = [] if add_last_motion < 1 and self.drop_mode == "drop" else \
|
||||
[
|
||||
[torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
|
||||
torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0).repeat(1, 1),
|
||||
torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]
|
||||
]
|
||||
|
||||
start_time_id = -(self.zip_frame_buckets[:3].sum())
|
||||
end_time_id = start_time_id + self.zip_frame_buckets[2] // 4
|
||||
grid_sizes_4x = [
|
||||
[
|
||||
torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
|
||||
torch.tensor([end_time_id, lat_height // 8, lat_width // 8]).unsqueeze(0).repeat(1, 1),
|
||||
torch.tensor([self.zip_frame_buckets[2], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1),
|
||||
]
|
||||
]
|
||||
|
||||
grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x
|
||||
|
||||
motion_rope_emb = rope_precompute(
|
||||
motion_lat.detach().view(1, motion_lat.shape[1], self.num_heads, self.inner_dim // self.num_heads),
|
||||
grid_sizes,
|
||||
self.freqs,
|
||||
start=None
|
||||
)
|
||||
|
||||
mot.append(motion_lat)
|
||||
mot_remb.append(motion_rope_emb)
|
||||
return mot, mot_remb
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
output_dim: int,
|
||||
norm_eps: float = 1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, elementwise_affine=False)
|
||||
|
||||
def forward(self, x, temb):
|
||||
temb = self.linear(F.silu(temb))
|
||||
shift, scale = temb.chunk(2, dim=1)
|
||||
shift = shift[:, None, :]
|
||||
scale = scale[:, None, :]
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
|
||||
|
||||
class AudioInjector_WAN(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
all_modules,
|
||||
all_modules_names,
|
||||
dim=2048,
|
||||
num_heads=32,
|
||||
inject_layer=[0, 27],
|
||||
enable_adain=False,
|
||||
adain_dim=2048,
|
||||
):
|
||||
super().__init__()
|
||||
self.injected_block_id = {}
|
||||
audio_injector_id = 0
|
||||
for mod_name, mod in zip(all_modules_names, all_modules):
|
||||
if isinstance(mod, DiTBlock):
|
||||
for inject_id in inject_layer:
|
||||
if f'transformer_blocks.{inject_id}' in mod_name:
|
||||
self.injected_block_id[inject_id] = audio_injector_id
|
||||
audio_injector_id += 1
|
||||
|
||||
self.injector = nn.ModuleList([CrossAttention(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
) for _ in range(audio_injector_id)])
|
||||
self.injector_pre_norm_feat = nn.ModuleList([nn.LayerNorm(
|
||||
dim,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
) for _ in range(audio_injector_id)])
|
||||
self.injector_pre_norm_vec = nn.ModuleList([nn.LayerNorm(
|
||||
dim,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
) for _ in range(audio_injector_id)])
|
||||
if enable_adain:
|
||||
self.injector_adain_layers = nn.ModuleList([AdaLayerNorm(output_dim=dim * 2, embedding_dim=adain_dim) for _ in range(audio_injector_id)])
|
||||
|
||||
|
||||
class CausalAudioEncoder(nn.Module):
|
||||
|
||||
def __init__(self, dim=5120, num_layers=25, out_dim=2048, num_token=4, need_global=False):
|
||||
super().__init__()
|
||||
self.encoder = MotionEncoder_tc(in_dim=dim, hidden_dim=out_dim, num_heads=num_token, need_global=need_global)
|
||||
weight = torch.ones((1, num_layers, 1, 1)) * 0.01
|
||||
|
||||
self.weights = torch.nn.Parameter(weight)
|
||||
self.act = torch.nn.SiLU()
|
||||
|
||||
def forward(self, features):
|
||||
# features B * num_layers * dim * video_length
|
||||
weights = self.act(self.weights.to(device=features.device, dtype=features.dtype))
|
||||
weights_sum = weights.sum(dim=1, keepdims=True)
|
||||
weighted_feat = ((features * weights) / weights_sum).sum(dim=1) # b dim f
|
||||
weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim
|
||||
res = self.encoder(weighted_feat) # b f n dim
|
||||
return res # b f n dim
|
||||
|
||||
|
||||
class WanS2VDiTBlock(DiTBlock):
|
||||
|
||||
def forward(self, x, context, t_mod, seq_len_x, freqs):
|
||||
t_mod = (self.modulation.unsqueeze(2).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
||||
# t_mod[:, :, 0] for x, t_mod[:, :, 1] for other like ref, motion, etc.
|
||||
t_mod = [
|
||||
torch.cat([element[:, :, 0].expand(1, seq_len_x, x.shape[-1]), element[:, :, 1].expand(1, x.shape[1] - seq_len_x, x.shape[-1])], dim=1)
|
||||
for element in t_mod
|
||||
]
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = t_mod
|
||||
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
||||
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
|
||||
x = x + self.cross_attn(self.norm3(x), context)
|
||||
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
x = self.gate(x, gate_mlp, self.ffn(input_x))
|
||||
return x
|
||||
|
||||
|
||||
class WanS2VModel(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
in_dim: int,
|
||||
ffn_dim: int,
|
||||
out_dim: int,
|
||||
text_dim: int,
|
||||
freq_dim: int,
|
||||
eps: float,
|
||||
patch_size: Tuple[int, int, int],
|
||||
num_heads: int,
|
||||
num_layers: int,
|
||||
cond_dim: int,
|
||||
audio_dim: int,
|
||||
num_audio_token: int,
|
||||
enable_adain: bool = True,
|
||||
audio_inject_layers: list = [0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39],
|
||||
zero_timestep: bool = True,
|
||||
add_last_motion: bool = True,
|
||||
framepack_drop_mode: str = "padd",
|
||||
fuse_vae_embedding_in_latents: bool = True,
|
||||
require_vae_embedding: bool = False,
|
||||
seperated_timestep: bool = False,
|
||||
require_clip_embedding: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.in_dim = in_dim
|
||||
self.freq_dim = freq_dim
|
||||
self.patch_size = patch_size
|
||||
self.num_heads = num_heads
|
||||
self.enbale_adain = enable_adain
|
||||
self.add_last_motion = add_last_motion
|
||||
self.zero_timestep = zero_timestep
|
||||
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
|
||||
self.require_vae_embedding = require_vae_embedding
|
||||
self.seperated_timestep = seperated_timestep
|
||||
self.require_clip_embedding = require_clip_embedding
|
||||
|
||||
self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), nn.Linear(dim, dim))
|
||||
self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
||||
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
||||
|
||||
self.blocks = nn.ModuleList([WanS2VDiTBlock(False, dim, num_heads, ffn_dim, eps) for _ in range(num_layers)])
|
||||
self.head = Head(dim, out_dim, patch_size, eps)
|
||||
self.freqs = torch.cat(precompute_freqs_cis_3d(dim // num_heads), dim=1)
|
||||
|
||||
self.cond_encoder = nn.Conv3d(cond_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.casual_audio_encoder = CausalAudioEncoder(dim=audio_dim, out_dim=dim, num_token=num_audio_token, need_global=enable_adain)
|
||||
all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks")
|
||||
self.audio_injector = AudioInjector_WAN(
|
||||
all_modules,
|
||||
all_modules_names,
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
inject_layer=audio_inject_layers,
|
||||
enable_adain=enable_adain,
|
||||
adain_dim=dim,
|
||||
)
|
||||
self.trainable_cond_mask = nn.Embedding(3, dim)
|
||||
self.frame_packer = FramePackMotioner(inner_dim=dim, num_heads=num_heads, zip_frame_buckets=[1, 2, 16], drop_mode=framepack_drop_mode)
|
||||
|
||||
def patchify(self, x: torch.Tensor):
|
||||
grid_size = x.shape[2:]
|
||||
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
||||
return x, grid_size # x, grid_size: (f, h, w)
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
|
||||
return rearrange(
|
||||
x,
|
||||
'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
|
||||
f=grid_size[0],
|
||||
h=grid_size[1],
|
||||
w=grid_size[2],
|
||||
x=self.patch_size[0],
|
||||
y=self.patch_size[1],
|
||||
z=self.patch_size[2]
|
||||
)
|
||||
|
||||
def process_motion_frame_pack(self, motion_latents, drop_motion_frames=False, add_last_motion=2):
|
||||
flattern_mot, mot_remb = self.frame_packer(motion_latents, add_last_motion)
|
||||
if drop_motion_frames:
|
||||
return [m[:, :0] for m in flattern_mot], [m[:, :0] for m in mot_remb]
|
||||
else:
|
||||
return flattern_mot, mot_remb
|
||||
|
||||
def inject_motion(self, x, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2):
|
||||
# inject the motion frames token to the hidden states
|
||||
mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=add_last_motion)
|
||||
if len(mot) > 0:
|
||||
x = torch.cat([x, mot[0]], dim=1)
|
||||
rope_embs = torch.cat([rope_embs, mot_remb[0]], dim=1)
|
||||
mask_input = torch.cat(
|
||||
[mask_input, 2 * torch.ones([1, x.shape[1] - mask_input.shape[1]], device=mask_input.device, dtype=mask_input.dtype)], dim=1
|
||||
)
|
||||
return x, rope_embs, mask_input
|
||||
|
||||
def after_transformer_block(self, block_idx, hidden_states, audio_emb_global, audio_emb, original_seq_len, use_unified_sequence_parallel=False):
|
||||
if block_idx in self.audio_injector.injected_block_id.keys():
|
||||
audio_attn_id = self.audio_injector.injected_block_id[block_idx]
|
||||
num_frames = audio_emb.shape[1]
|
||||
if use_unified_sequence_parallel:
|
||||
from xfuser.core.distributed import get_sp_group
|
||||
hidden_states = get_sp_group().all_gather(hidden_states, dim=1)
|
||||
|
||||
input_hidden_states = hidden_states[:, :original_seq_len].clone() # b (f h w) c
|
||||
input_hidden_states = rearrange(input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames)
|
||||
|
||||
audio_emb_global = rearrange(audio_emb_global, "b t n c -> (b t) n c")
|
||||
adain_hidden_states = self.audio_injector.injector_adain_layers[audio_attn_id](input_hidden_states, temb=audio_emb_global[:, 0])
|
||||
attn_hidden_states = adain_hidden_states
|
||||
|
||||
audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
|
||||
attn_audio_emb = audio_emb
|
||||
residual_out = self.audio_injector.injector[audio_attn_id](attn_hidden_states, attn_audio_emb)
|
||||
residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames)
|
||||
hidden_states[:, :original_seq_len] = hidden_states[:, :original_seq_len] + residual_out
|
||||
if use_unified_sequence_parallel:
|
||||
from xfuser.core.distributed import get_sequence_parallel_world_size, get_sequence_parallel_rank
|
||||
hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||
return hidden_states
|
||||
|
||||
def cal_audio_emb(self, audio_input, motion_frames=[73, 19]):
|
||||
audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1)
|
||||
audio_emb_global, audio_emb = self.casual_audio_encoder(audio_input)
|
||||
audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone()
|
||||
merged_audio_emb = audio_emb[:, motion_frames[1]:, :]
|
||||
return audio_emb_global, merged_audio_emb
|
||||
|
||||
def get_grid_sizes(self, grid_size_x, grid_size_ref):
|
||||
f, h, w = grid_size_x
|
||||
rf, rh, rw = grid_size_ref
|
||||
grid_sizes_x = torch.tensor([f, h, w], dtype=torch.long).unsqueeze(0)
|
||||
grid_sizes_x = [[torch.zeros_like(grid_sizes_x), grid_sizes_x, grid_sizes_x]]
|
||||
grid_sizes_ref = [[
|
||||
torch.tensor([30, 0, 0]).unsqueeze(0),
|
||||
torch.tensor([31, rh, rw]).unsqueeze(0),
|
||||
torch.tensor([1, rh, rw]).unsqueeze(0),
|
||||
]]
|
||||
return grid_sizes_x + grid_sizes_ref
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents,
|
||||
timestep,
|
||||
context,
|
||||
audio_input,
|
||||
motion_latents,
|
||||
pose_cond,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
use_gradient_checkpointing=False
|
||||
):
|
||||
origin_ref_latents = latents[:, :, 0:1]
|
||||
x = latents[:, :, 1:]
|
||||
|
||||
# context embedding
|
||||
context = self.text_embedding(context)
|
||||
|
||||
# audio encode
|
||||
audio_emb_global, merged_audio_emb = self.cal_audio_emb(audio_input)
|
||||
|
||||
# x and pose_cond
|
||||
pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond
|
||||
x, (f, h, w) = self.patchify(self.patch_embedding(x) + self.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120])
|
||||
seq_len_x = x.shape[1]
|
||||
|
||||
# reference image
|
||||
ref_latents, (rf, rh, rw) = self.patchify(self.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120])
|
||||
grid_sizes = self.get_grid_sizes((f, h, w), (rf, rh, rw))
|
||||
x = torch.cat([x, ref_latents], dim=1)
|
||||
# mask
|
||||
mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device)
|
||||
# freqs
|
||||
pre_compute_freqs = rope_precompute(
|
||||
x.detach().view(1, x.size(1), self.num_heads, self.dim // self.num_heads), grid_sizes, self.freqs, start=None
|
||||
)
|
||||
# motion
|
||||
x, pre_compute_freqs, mask = self.inject_motion(x, pre_compute_freqs, mask, motion_latents, add_last_motion=2)
|
||||
|
||||
x = x + self.trainable_cond_mask(mask).to(x.dtype)
|
||||
|
||||
# t_mod
|
||||
timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)])
|
||||
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||
t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block_id, block in enumerate(self.blocks):
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x,
|
||||
context,
|
||||
t_mod,
|
||||
seq_len_x,
|
||||
pre_compute_freqs[0],
|
||||
use_reentrant=False,
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
||||
x,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x,
|
||||
context,
|
||||
t_mod,
|
||||
seq_len_x,
|
||||
pre_compute_freqs[0],
|
||||
use_reentrant=False,
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
||||
x,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
|
||||
x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)
|
||||
|
||||
x = x[:, :seq_len_x]
|
||||
x = self.head(x, t[:-1])
|
||||
x = self.unpatchify(x, (f, h, w))
|
||||
# make compatible with wan video
|
||||
x = torch.cat([origin_ref_latents, x], dim=2)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanS2VModelStateDictConverter()
|
||||
|
||||
|
||||
class WanS2VModelStateDictConverter:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
config = {}
|
||||
if hash_state_dict_keys(state_dict) == "966cffdcc52f9c46c391768b27637614":
|
||||
config = {
|
||||
"dim": 5120,
|
||||
"in_dim": 16,
|
||||
"ffn_dim": 13824,
|
||||
"out_dim": 16,
|
||||
"text_dim": 4096,
|
||||
"freq_dim": 256,
|
||||
"eps": 1e-06,
|
||||
"patch_size": (1, 2, 2),
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"cond_dim": 16,
|
||||
"audio_dim": 1024,
|
||||
"num_audio_token": 4,
|
||||
}
|
||||
return state_dict, config
|
||||
@@ -1216,7 +1216,6 @@ class WanVideoVAE(nn.Module):
|
||||
|
||||
|
||||
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
|
||||
videos = [video.to("cpu") for video in videos]
|
||||
hidden_states = []
|
||||
for video in videos:
|
||||
@@ -1234,11 +1233,18 @@ class WanVideoVAE(nn.Module):
|
||||
|
||||
|
||||
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
if tiled:
|
||||
video = self.tiled_decode(hidden_states, device, tile_size, tile_stride)
|
||||
else:
|
||||
video = self.single_decode(hidden_states, device)
|
||||
return video
|
||||
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
|
||||
videos = []
|
||||
for hidden_state in hidden_states:
|
||||
hidden_state = hidden_state.unsqueeze(0)
|
||||
if tiled:
|
||||
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
|
||||
else:
|
||||
video = self.single_decode(hidden_state, device)
|
||||
video = video.squeeze(0)
|
||||
videos.append(video)
|
||||
videos = torch.stack(videos)
|
||||
return videos
|
||||
|
||||
|
||||
@staticmethod
|
||||
|
||||
204
diffsynth/models/wav2vec.py
Normal file
204
diffsynth/models/wav2vec.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def get_sample_indices(original_fps, total_frames, target_fps, num_sample, fixed_start=None):
|
||||
required_duration = num_sample / target_fps
|
||||
required_origin_frames = int(np.ceil(required_duration * original_fps))
|
||||
if required_duration > total_frames / original_fps:
|
||||
raise ValueError("required_duration must be less than video length")
|
||||
|
||||
if not fixed_start is None and fixed_start >= 0:
|
||||
start_frame = fixed_start
|
||||
else:
|
||||
max_start = total_frames - required_origin_frames
|
||||
if max_start < 0:
|
||||
raise ValueError("video length is too short")
|
||||
start_frame = np.random.randint(0, max_start + 1)
|
||||
start_time = start_frame / original_fps
|
||||
|
||||
end_time = start_time + required_duration
|
||||
time_points = np.linspace(start_time, end_time, num_sample, endpoint=False)
|
||||
|
||||
frame_indices = np.round(np.array(time_points) * original_fps).astype(int)
|
||||
frame_indices = np.clip(frame_indices, 0, total_frames - 1)
|
||||
return frame_indices
|
||||
|
||||
|
||||
def linear_interpolation(features, input_fps, output_fps, output_len=None):
|
||||
"""
|
||||
features: shape=[1, T, 512]
|
||||
input_fps: fps for audio, f_a
|
||||
output_fps: fps for video, f_m
|
||||
output_len: video length
|
||||
"""
|
||||
features = features.transpose(1, 2)
|
||||
seq_len = features.shape[2] / float(input_fps)
|
||||
if output_len is None:
|
||||
output_len = int(seq_len * output_fps)
|
||||
output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear') # [1, 512, output_len]
|
||||
return output_features.transpose(1, 2)
|
||||
|
||||
|
||||
class WanS2VAudioEncoder(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2Config
|
||||
config = {
|
||||
"_name_or_path": "facebook/wav2vec2-large-xlsr-53",
|
||||
"activation_dropout": 0.05,
|
||||
"apply_spec_augment": True,
|
||||
"architectures": ["Wav2Vec2ForCTC"],
|
||||
"attention_dropout": 0.1,
|
||||
"bos_token_id": 1,
|
||||
"conv_bias": True,
|
||||
"conv_dim": [512, 512, 512, 512, 512, 512, 512],
|
||||
"conv_kernel": [10, 3, 3, 3, 3, 2, 2],
|
||||
"conv_stride": [5, 2, 2, 2, 2, 2, 2],
|
||||
"ctc_loss_reduction": "mean",
|
||||
"ctc_zero_infinity": True,
|
||||
"do_stable_layer_norm": True,
|
||||
"eos_token_id": 2,
|
||||
"feat_extract_activation": "gelu",
|
||||
"feat_extract_dropout": 0.0,
|
||||
"feat_extract_norm": "layer",
|
||||
"feat_proj_dropout": 0.05,
|
||||
"final_dropout": 0.0,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout": 0.05,
|
||||
"hidden_size": 1024,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"layerdrop": 0.05,
|
||||
"mask_channel_length": 10,
|
||||
"mask_channel_min_space": 1,
|
||||
"mask_channel_other": 0.0,
|
||||
"mask_channel_prob": 0.0,
|
||||
"mask_channel_selection": "static",
|
||||
"mask_feature_length": 10,
|
||||
"mask_feature_prob": 0.0,
|
||||
"mask_time_length": 10,
|
||||
"mask_time_min_space": 1,
|
||||
"mask_time_other": 0.0,
|
||||
"mask_time_prob": 0.05,
|
||||
"mask_time_selection": "static",
|
||||
"model_type": "wav2vec2",
|
||||
"num_attention_heads": 16,
|
||||
"num_conv_pos_embedding_groups": 16,
|
||||
"num_conv_pos_embeddings": 128,
|
||||
"num_feat_extract_layers": 7,
|
||||
"num_hidden_layers": 24,
|
||||
"pad_token_id": 0,
|
||||
"transformers_version": "4.7.0.dev0",
|
||||
"vocab_size": 33
|
||||
}
|
||||
self.model = Wav2Vec2ForCTC(Wav2Vec2Config(**config))
|
||||
self.video_rate = 30
|
||||
|
||||
def extract_audio_feat(self, input_audio, sample_rate, processor, return_all_layers=False, dtype=torch.float32, device='cpu'):
|
||||
input_values = processor(input_audio, sampling_rate=sample_rate, return_tensors="pt").input_values.to(dtype=dtype, device=device)
|
||||
|
||||
# retrieve logits & take argmax
|
||||
res = self.model(input_values, output_hidden_states=True)
|
||||
if return_all_layers:
|
||||
feat = torch.cat(res.hidden_states)
|
||||
else:
|
||||
feat = res.hidden_states[-1]
|
||||
feat = linear_interpolation(feat, input_fps=50, output_fps=self.video_rate)
|
||||
return feat
|
||||
|
||||
def get_audio_embed_bucket(self, audio_embed, stride=2, batch_frames=12, m=2):
|
||||
num_layers, audio_frame_num, audio_dim = audio_embed.shape
|
||||
|
||||
if num_layers > 1:
|
||||
return_all_layers = True
|
||||
else:
|
||||
return_all_layers = False
|
||||
|
||||
min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1
|
||||
|
||||
bucket_num = min_batch_num * batch_frames
|
||||
batch_idx = [stride * i for i in range(bucket_num)]
|
||||
batch_audio_eb = []
|
||||
for bi in batch_idx:
|
||||
if bi < audio_frame_num:
|
||||
audio_sample_stride = 2
|
||||
chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride))
|
||||
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
|
||||
chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx]
|
||||
|
||||
if return_all_layers:
|
||||
frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1)
|
||||
else:
|
||||
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
|
||||
else:
|
||||
frame_audio_embed = \
|
||||
torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
|
||||
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
|
||||
batch_audio_eb.append(frame_audio_embed)
|
||||
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0)
|
||||
|
||||
return batch_audio_eb, min_batch_num
|
||||
|
||||
def get_audio_embed_bucket_fps(self, audio_embed, fps=16, batch_frames=81, m=0):
|
||||
num_layers, audio_frame_num, audio_dim = audio_embed.shape
|
||||
|
||||
if num_layers > 1:
|
||||
return_all_layers = True
|
||||
else:
|
||||
return_all_layers = False
|
||||
|
||||
scale = self.video_rate / fps
|
||||
|
||||
min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1
|
||||
|
||||
bucket_num = min_batch_num * batch_frames
|
||||
padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * self.video_rate) - audio_frame_num
|
||||
batch_idx = get_sample_indices(
|
||||
original_fps=self.video_rate, total_frames=audio_frame_num + padd_audio_num, target_fps=fps, num_sample=bucket_num, fixed_start=0
|
||||
)
|
||||
batch_audio_eb = []
|
||||
audio_sample_stride = int(self.video_rate / fps)
|
||||
for bi in batch_idx:
|
||||
if bi < audio_frame_num:
|
||||
|
||||
chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride))
|
||||
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
|
||||
chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx]
|
||||
|
||||
if return_all_layers:
|
||||
frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1)
|
||||
else:
|
||||
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
|
||||
else:
|
||||
frame_audio_embed = \
|
||||
torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
|
||||
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
|
||||
batch_audio_eb.append(frame_audio_embed)
|
||||
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0)
|
||||
|
||||
return batch_audio_eb, min_batch_num
|
||||
|
||||
def get_audio_feats_per_inference(self, input_audio, sample_rate, processor, fps=16, batch_frames=80, m=0, dtype=torch.float32, device='cpu'):
|
||||
audio_feat = self.extract_audio_feat(input_audio, sample_rate, processor, return_all_layers=True, dtype=dtype, device=device)
|
||||
audio_embed_bucket, min_batch_num = self.get_audio_embed_bucket_fps(audio_feat, fps=fps, batch_frames=batch_frames, m=m)
|
||||
audio_embed_bucket = audio_embed_bucket.unsqueeze(0).permute(0, 2, 3, 1).to(device, dtype)
|
||||
audio_embeds = [audio_embed_bucket[..., i * batch_frames:(i + 1) * batch_frames] for i in range(min_batch_num)]
|
||||
return audio_embeds
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanS2VAudioEncoderStateDictConverter()
|
||||
|
||||
|
||||
class WanS2VAudioEncoderStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict = {'model.' + k: v for k, v in state_dict.items()}
|
||||
return state_dict
|
||||
@@ -52,7 +52,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
)
|
||||
from transformers import Qwen2Tokenizer
|
||||
from transformers import Qwen2Tokenizer, Qwen2VLProcessor
|
||||
|
||||
self.scheduler = FlowMatchScheduler(sigma_min=0, sigma_max=1, extra_one_step=True, exponential_shift=True, exponential_shift_mu=0.8, shift_terminal=0.02)
|
||||
self.text_encoder: QwenImageTextEncoder = None
|
||||
@@ -60,6 +60,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
self.vae: QwenImageVAE = None
|
||||
self.blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None
|
||||
self.tokenizer: Qwen2Tokenizer = None
|
||||
self.processor: Qwen2VLProcessor = None
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.in_iteration_models = ("dit", "blockwise_controlnet")
|
||||
self.units = [
|
||||
@@ -67,6 +68,8 @@ class QwenImagePipeline(BasePipeline):
|
||||
QwenImageUnit_NoiseInitializer(),
|
||||
QwenImageUnit_InputImageEmbedder(),
|
||||
QwenImageUnit_Inpaint(),
|
||||
QwenImageUnit_EditImageEmbedder(),
|
||||
QwenImageUnit_ContextImageEmbedder(),
|
||||
QwenImageUnit_PromptEmbedder(),
|
||||
QwenImageUnit_EntityControl(),
|
||||
QwenImageUnit_BlockwiseControlNet(),
|
||||
@@ -74,18 +77,72 @@ class QwenImagePipeline(BasePipeline):
|
||||
self.model_fn = model_fn_qwen_image
|
||||
|
||||
|
||||
def load_lora(self, module, path, alpha=1):
|
||||
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
||||
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
loader.load(module, lora, alpha=alpha)
|
||||
def load_lora(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
lora_config: Union[ModelConfig, str] = None,
|
||||
alpha=1,
|
||||
hotload=False,
|
||||
state_dict=None,
|
||||
):
|
||||
if state_dict is None:
|
||||
if isinstance(lora_config, str):
|
||||
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
lora_config.download_if_necessary()
|
||||
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
lora = state_dict
|
||||
if hotload:
|
||||
for name, module in module.named_modules():
|
||||
if isinstance(module, AutoWrappedLinear):
|
||||
lora_a_name = f'{name}.lora_A.default.weight'
|
||||
lora_b_name = f'{name}.lora_B.default.weight'
|
||||
if lora_a_name in lora and lora_b_name in lora:
|
||||
module.lora_A_weights.append(lora[lora_a_name] * alpha)
|
||||
module.lora_B_weights.append(lora[lora_b_name])
|
||||
else:
|
||||
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
||||
loader.load(module, lora, alpha=alpha)
|
||||
|
||||
|
||||
def clear_lora(self):
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, AutoWrappedLinear):
|
||||
if hasattr(module, "lora_A_weights"):
|
||||
module.lora_A_weights.clear()
|
||||
if hasattr(module, "lora_B_weights"):
|
||||
module.lora_B_weights.clear()
|
||||
|
||||
|
||||
def enable_lora_magic(self):
|
||||
if self.dit is not None:
|
||||
if not (hasattr(self.dit, "vram_management_enabled") and self.dit.vram_management_enabled):
|
||||
dtype = next(iter(self.dit.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device=self.device,
|
||||
onload_dtype=dtype,
|
||||
onload_device=self.device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=None,
|
||||
)
|
||||
|
||||
|
||||
def training_loss(self, **inputs):
|
||||
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
|
||||
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
|
||||
training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep)
|
||||
noise = torch.randn_like(inputs["input_latents"])
|
||||
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||
training_target = self.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||
|
||||
noise_pred = self.model_fn(**inputs, timestep=timestep)
|
||||
|
||||
@@ -94,6 +151,49 @@ class QwenImagePipeline(BasePipeline):
|
||||
return loss
|
||||
|
||||
|
||||
def direct_distill_loss(self, **inputs):
|
||||
self.scheduler.set_timesteps(inputs["num_inference_steps"])
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(self.scheduler.timesteps):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred = self.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)
|
||||
inputs["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs)
|
||||
loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float())
|
||||
return loss
|
||||
|
||||
|
||||
def _enable_fp8_lora_training(self, dtype):
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionRotaryEmbedding
|
||||
from ..models.qwen_image_dit import RMSNorm
|
||||
from ..models.qwen_image_vae import QwenImageRMS_norm
|
||||
module_map = {
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv3d: AutoWrappedModule,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
Qwen2_5_VLRotaryEmbedding: AutoWrappedModule,
|
||||
Qwen2RMSNorm: AutoWrappedModule,
|
||||
Qwen2_5_VisionPatchEmbed: AutoWrappedModule,
|
||||
Qwen2_5_VisionRotaryEmbedding: AutoWrappedModule,
|
||||
QwenImageRMS_norm: AutoWrappedModule,
|
||||
}
|
||||
model_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cuda",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cuda",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device="cuda",
|
||||
)
|
||||
if self.text_encoder is not None:
|
||||
enable_vram_management(self.text_encoder, module_map=module_map, module_config=model_config)
|
||||
if self.dit is not None:
|
||||
enable_vram_management(self.dit, module_map=module_map, module_config=model_config)
|
||||
if self.vae is not None:
|
||||
enable_vram_management(self.vae, module_map=module_map, module_config=model_config)
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False):
|
||||
self.vram_management_enabled = True
|
||||
if vram_limit is None:
|
||||
@@ -101,7 +201,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
vram_limit = vram_limit - vram_buffer
|
||||
|
||||
if self.text_encoder is not None:
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionRotaryEmbedding
|
||||
dtype = next(iter(self.text_encoder.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.text_encoder,
|
||||
@@ -110,6 +210,8 @@ class QwenImagePipeline(BasePipeline):
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
Qwen2_5_VLRotaryEmbedding: AutoWrappedModule,
|
||||
Qwen2RMSNorm: AutoWrappedModule,
|
||||
Qwen2_5_VisionPatchEmbed: AutoWrappedModule,
|
||||
Qwen2_5_VisionRotaryEmbedding: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
@@ -219,6 +321,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
processor_config: ModelConfig = None,
|
||||
):
|
||||
# Download and load models
|
||||
model_manager = ModelManager()
|
||||
@@ -240,6 +343,10 @@ class QwenImagePipeline(BasePipeline):
|
||||
tokenizer_config.download_if_necessary()
|
||||
from transformers import Qwen2Tokenizer
|
||||
pipe.tokenizer = Qwen2Tokenizer.from_pretrained(tokenizer_config.path)
|
||||
if processor_config is not None:
|
||||
processor_config.download_if_necessary()
|
||||
from transformers import Qwen2VLProcessor
|
||||
pipe.processor = Qwen2VLProcessor.from_pretrained(processor_config.path)
|
||||
return pipe
|
||||
|
||||
|
||||
@@ -265,12 +372,19 @@ class QwenImagePipeline(BasePipeline):
|
||||
rand_device: str = "cpu",
|
||||
# Steps
|
||||
num_inference_steps: int = 30,
|
||||
exponential_shift_mu: float = None,
|
||||
# Blockwise ControlNet
|
||||
blockwise_controlnet_inputs: list[ControlNetInput] = None,
|
||||
# EliGen
|
||||
eligen_entity_prompts: list[str] = None,
|
||||
eligen_entity_masks: list[Image.Image] = None,
|
||||
eligen_enable_on_negative: bool = False,
|
||||
# Qwen-Image-Edit
|
||||
edit_image: Image.Image = None,
|
||||
edit_image_auto_resize: bool = True,
|
||||
edit_rope_interpolation: bool = False,
|
||||
# In-context control
|
||||
context_image: Image.Image = None,
|
||||
# FP8
|
||||
enable_fp8_attention: bool = False,
|
||||
# Tile
|
||||
@@ -279,10 +393,9 @@ class QwenImagePipeline(BasePipeline):
|
||||
tile_stride: int = 64,
|
||||
# Progress bar
|
||||
progress_bar_cmd = tqdm,
|
||||
extra_prompt_emb = None,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16))
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16), exponential_shift_mu=exponential_shift_mu)
|
||||
|
||||
# Parameters
|
||||
inputs_posi = {
|
||||
@@ -302,12 +415,11 @@ class QwenImagePipeline(BasePipeline):
|
||||
"blockwise_controlnet_inputs": blockwise_controlnet_inputs,
|
||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
|
||||
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "edit_rope_interpolation": edit_rope_interpolation,
|
||||
"context_image": context_image,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
if extra_prompt_emb is not None:
|
||||
inputs_posi["prompt_emb"] = torch.concat([inputs_posi["prompt_emb"], extra_prompt_emb], dim=1)
|
||||
inputs_posi["prompt_emb_mask"] = torch.ones((1, inputs_posi["prompt_emb"].shape[1]), dtype=inputs_posi["prompt_emb_mask"].dtype, device=inputs_posi["prompt_emb_mask"].device)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
@@ -395,13 +507,13 @@ class QwenImageUnit_Inpaint(PipelineUnit):
|
||||
return {"inpaint_mask": inpaint_mask}
|
||||
|
||||
|
||||
|
||||
class QwenImageUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
input_params=("edit_image",),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
|
||||
@@ -412,18 +524,35 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit):
|
||||
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
||||
return split_result
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, prompt) -> dict:
|
||||
if pipe.text_encoder is not None:
|
||||
def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict:
|
||||
if pipe.text_encoder is not None and prompt is not None:
|
||||
prompt = [prompt]
|
||||
template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
drop_idx = 34
|
||||
# If edit_image is None, use the default template for Qwen-Image, otherwise use the template for Qwen-Image-Edit
|
||||
if edit_image is None:
|
||||
template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
drop_idx = 34
|
||||
else:
|
||||
template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
drop_idx = 64
|
||||
txt = [template.format(e) for e in prompt]
|
||||
txt_tokens = pipe.tokenizer(txt, max_length=4096+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device)
|
||||
if txt_tokens.input_ids.shape[1] >= 1024:
|
||||
print(f"Warning!!! QwenImage model was trained on prompts up to 512 tokens. Current prompt requires {txt_tokens['input_ids'].shape[1] - drop_idx} tokens, which may lead to unpredictable behavior.")
|
||||
hidden_states = pipe.text_encoder(input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True,)[-1]
|
||||
|
||||
split_hidden_states = self.extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
|
||||
|
||||
# Qwen-Image-Edit model
|
||||
if pipe.processor is not None:
|
||||
model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device)
|
||||
# Qwen-Image model
|
||||
elif pipe.tokenizer is not None:
|
||||
model_inputs = pipe.tokenizer(txt, max_length=4096+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device)
|
||||
if model_inputs.input_ids.shape[1] >= 1024:
|
||||
print(f"Warning!!! QwenImage model was trained on prompts up to 512 tokens. Current prompt requires {model_inputs['input_ids'].shape[1] - drop_idx} tokens, which may lead to unpredictable behavior.")
|
||||
else:
|
||||
assert False, "QwenImagePipeline requires either tokenizer or processor to be loaded."
|
||||
|
||||
if 'pixel_values' in model_inputs:
|
||||
hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1]
|
||||
else:
|
||||
hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, output_hidden_states=True,)[-1]
|
||||
|
||||
split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)
|
||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
||||
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
||||
@@ -557,6 +686,53 @@ class QwenImageUnit_BlockwiseControlNet(PipelineUnit):
|
||||
return {"blockwise_controlnet_conditioning": conditionings}
|
||||
|
||||
|
||||
class QwenImageUnit_EditImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("edit_image", "tiled", "tile_size", "tile_stride", "edit_image_auto_resize"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
|
||||
def calculate_dimensions(self, target_area, ratio):
|
||||
import math
|
||||
width = math.sqrt(target_area * ratio)
|
||||
height = width / ratio
|
||||
width = round(width / 32) * 32
|
||||
height = round(height / 32) * 32
|
||||
return width, height
|
||||
|
||||
|
||||
def edit_image_auto_resize(self, edit_image):
|
||||
calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1])
|
||||
return edit_image.resize((calculated_width, calculated_height))
|
||||
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, edit_image, tiled, tile_size, tile_stride, edit_image_auto_resize=False):
|
||||
if edit_image is None:
|
||||
return {}
|
||||
resized_edit_image = self.edit_image_auto_resize(edit_image) if edit_image_auto_resize else edit_image
|
||||
pipe.load_models_to_device(['vae'])
|
||||
edit_image = pipe.preprocess_image(resized_edit_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
edit_latents = pipe.vae.encode(edit_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return {"edit_latents": edit_latents, "edit_image": resized_edit_image}
|
||||
|
||||
class QwenImageUnit_ContextImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("context_image", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, context_image, height, width, tiled, tile_size, tile_stride):
|
||||
if context_image is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(['vae'])
|
||||
context_image = pipe.preprocess_image(context_image.resize((width, height))).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
context_latents = pipe.vae.encode(context_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return {"context_latents": context_latents}
|
||||
|
||||
|
||||
def model_fn_qwen_image(
|
||||
dit: QwenImageDiT = None,
|
||||
blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None,
|
||||
@@ -573,9 +749,12 @@ def model_fn_qwen_image(
|
||||
entity_prompt_emb=None,
|
||||
entity_prompt_emb_mask=None,
|
||||
entity_masks=None,
|
||||
edit_latents=None,
|
||||
context_latents=None,
|
||||
enable_fp8_attention=False,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
edit_rope_interpolation=False,
|
||||
**kwargs
|
||||
):
|
||||
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
|
||||
@@ -583,7 +762,17 @@ def model_fn_qwen_image(
|
||||
timestep = timestep / 1000
|
||||
|
||||
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
|
||||
image_seq_len = image.shape[1]
|
||||
|
||||
if context_latents is not None:
|
||||
img_shapes += [(context_latents.shape[0], context_latents.shape[2]//2, context_latents.shape[3]//2)]
|
||||
context_image = rearrange(context_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=context_latents.shape[2]//2, W=context_latents.shape[3]//2, P=2, Q=2)
|
||||
image = torch.cat([image, context_image], dim=1)
|
||||
if edit_latents is not None:
|
||||
img_shapes += [(edit_latents.shape[0], edit_latents.shape[2]//2, edit_latents.shape[3]//2)]
|
||||
edit_image = rearrange(edit_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=edit_latents.shape[2]//2, W=edit_latents.shape[3]//2, P=2, Q=2)
|
||||
image = torch.cat([image, edit_image], dim=1)
|
||||
|
||||
image = dit.img_in(image)
|
||||
conditioning = dit.time_text_embed(timestep, image.dtype)
|
||||
|
||||
@@ -594,7 +783,10 @@ def model_fn_qwen_image(
|
||||
)
|
||||
else:
|
||||
text = dit.txt_in(dit.txt_norm(prompt_emb))
|
||||
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||
if edit_rope_interpolation:
|
||||
image_rotary_emb = dit.pos_embed.forward_sampling(img_shapes, txt_seq_lens, device=latents.device)
|
||||
else:
|
||||
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||
attention_mask = None
|
||||
|
||||
if blockwise_controlnet_conditioning is not None:
|
||||
@@ -614,14 +806,17 @@ def model_fn_qwen_image(
|
||||
enable_fp8_attention=enable_fp8_attention,
|
||||
)
|
||||
if blockwise_controlnet_conditioning is not None:
|
||||
image = image + blockwise_controlnet.blockwise_forward(
|
||||
image=image, conditionings=blockwise_controlnet_conditioning,
|
||||
image_slice = image[:, :image_seq_len].clone()
|
||||
controlnet_output = blockwise_controlnet.blockwise_forward(
|
||||
image=image_slice, conditionings=blockwise_controlnet_conditioning,
|
||||
controlnet_inputs=blockwise_controlnet_inputs, block_id=block_id,
|
||||
progress_id=progress_id, num_inference_steps=num_inference_steps,
|
||||
)
|
||||
image[:, :image_seq_len] = image_slice + controlnet_output
|
||||
|
||||
image = dit.norm_out(image, conditioning)
|
||||
image = dit.proj_out(image)
|
||||
image = image[:, :image_seq_len]
|
||||
|
||||
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
return latents
|
||||
|
||||
@@ -15,6 +15,7 @@ from typing_extensions import Literal
|
||||
from ..utils import BasePipeline, ModelConfig, PipelineUnit, PipelineUnitRunner
|
||||
from ..models import ModelManager, load_state_dict
|
||||
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
||||
from ..models.wan_video_dit_s2v import rope_precompute
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm
|
||||
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
@@ -49,8 +50,9 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.units = [
|
||||
WanVideoUnit_ShapeChecker(),
|
||||
WanVideoUnit_NoiseInitializer(),
|
||||
WanVideoUnit_InputVideoEmbedder(),
|
||||
WanVideoUnit_PromptEmbedder(),
|
||||
WanVideoUnit_S2V(),
|
||||
WanVideoUnit_InputVideoEmbedder(),
|
||||
WanVideoUnit_ImageEmbedderVAE(),
|
||||
WanVideoUnit_ImageEmbedderCLIP(),
|
||||
WanVideoUnit_ImageEmbedderFused(),
|
||||
@@ -63,6 +65,9 @@ class WanVideoPipeline(BasePipeline):
|
||||
WanVideoUnit_TeaCache(),
|
||||
WanVideoUnit_CfgMerger(),
|
||||
]
|
||||
self.post_units = [
|
||||
WanVideoPostUnit_S2V(),
|
||||
]
|
||||
self.model_fn = model_fn_wan_video
|
||||
|
||||
|
||||
@@ -127,6 +132,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
torch.nn.LayerNorm: WanAutoCastLayerNorm,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.Conv1d: AutoWrappedModule,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
@@ -254,6 +261,25 @@ class WanVideoPipeline(BasePipeline):
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.audio_encoder is not None:
|
||||
# TODO: need check
|
||||
dtype = next(iter(self.audio_encoder.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.audio_encoder,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
torch.nn.Conv1d: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def initialize_usp(self):
|
||||
@@ -290,6 +316,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
|
||||
audio_processor_config: ModelConfig = None,
|
||||
redirect_common_files: bool = True,
|
||||
use_usp=False,
|
||||
):
|
||||
@@ -332,7 +359,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
||||
pipe.vace = model_manager.fetch_model("wan_video_vace")
|
||||
|
||||
pipe.audio_encoder = model_manager.fetch_model("wans2v_audio_encoder")
|
||||
|
||||
# Size division factor
|
||||
if pipe.vae is not None:
|
||||
pipe.height_division_factor = pipe.vae.upsampling_factor * 2
|
||||
@@ -342,7 +370,11 @@ class WanVideoPipeline(BasePipeline):
|
||||
tokenizer_config.download_if_necessary(use_usp=use_usp)
|
||||
pipe.prompter.fetch_models(pipe.text_encoder)
|
||||
pipe.prompter.fetch_tokenizer(tokenizer_config.path)
|
||||
|
||||
|
||||
if audio_processor_config is not None:
|
||||
audio_processor_config.download_if_necessary(use_usp=use_usp)
|
||||
from transformers import Wav2Vec2Processor
|
||||
pipe.audio_processor = Wav2Vec2Processor.from_pretrained(audio_processor_config.path)
|
||||
# Unified Sequence Parallel
|
||||
if use_usp: pipe.enable_usp()
|
||||
return pipe
|
||||
@@ -361,6 +393,13 @@ class WanVideoPipeline(BasePipeline):
|
||||
# Video-to-video
|
||||
input_video: Optional[list[Image.Image]] = None,
|
||||
denoising_strength: Optional[float] = 1.0,
|
||||
# Speech-to-video
|
||||
input_audio: Optional[np.array] = None,
|
||||
audio_embeds: Optional[torch.Tensor] = None,
|
||||
audio_sample_rate: Optional[int] = 16000,
|
||||
s2v_pose_video: Optional[list[Image.Image]] = None,
|
||||
s2v_pose_latents: Optional[torch.Tensor] = None,
|
||||
motion_video: Optional[list[Image.Image]] = None,
|
||||
# ControlNet
|
||||
control_video: Optional[list[Image.Image]] = None,
|
||||
reference_image: Optional[Image.Image] = None,
|
||||
@@ -429,6 +468,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
"motion_bucket_id": motion_bucket_id,
|
||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
|
||||
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
@@ -464,7 +504,9 @@ class WanVideoPipeline(BasePipeline):
|
||||
# VACE (TODO: remove it)
|
||||
if vace_reference_image is not None:
|
||||
inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:]
|
||||
|
||||
# post-denoising, pre-decoding processing logic
|
||||
for unit in self.post_units:
|
||||
inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
@@ -663,22 +705,23 @@ class WanVideoUnit_ImageEmbedderFused(PipelineUnit):
|
||||
class WanVideoUnit_FunControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"),
|
||||
input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y", "latents"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y):
|
||||
def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y, latents):
|
||||
if control_video is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
control_video = pipe.preprocess_video(control_video)
|
||||
control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
y_dim = pipe.dit.in_dim-control_latents.shape[1]-latents.shape[1]
|
||||
if clip_feature is None or y is None:
|
||||
clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device)
|
||||
y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device)
|
||||
y = torch.zeros((1, y_dim, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device)
|
||||
else:
|
||||
y = y[:, -16:]
|
||||
y = y[:, -y_dim:]
|
||||
y = torch.concat([control_latents, y], dim=1)
|
||||
return {"clip_feature": clip_feature, "y": y}
|
||||
|
||||
@@ -698,6 +741,8 @@ class WanVideoUnit_FunReference(PipelineUnit):
|
||||
reference_image = reference_image.resize((width, height))
|
||||
reference_latents = pipe.preprocess_video([reference_image])
|
||||
reference_latents = pipe.vae.encode(reference_latents, device=pipe.device)
|
||||
if pipe.image_encoder is None:
|
||||
return {"reference_latents": reference_latents}
|
||||
clip_feature = pipe.preprocess_image(reference_image)
|
||||
clip_feature = pipe.image_encoder.encode_image([clip_feature])
|
||||
return {"reference_latents": reference_latents, "clip_feature": clip_feature}
|
||||
@@ -707,13 +752,14 @@ class WanVideoUnit_FunReference(PipelineUnit):
|
||||
class WanVideoUnit_FunCameraControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image"),
|
||||
input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image):
|
||||
def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image, tiled, tile_size, tile_stride):
|
||||
if camera_control_direction is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
camera_control_plucker_embedding = pipe.dit.control_adapter.process_camera_coordinates(
|
||||
camera_control_direction, num_frames, height, width, camera_control_speed, camera_control_origin)
|
||||
|
||||
@@ -728,14 +774,27 @@ class WanVideoUnit_FunCameraControl(PipelineUnit):
|
||||
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
|
||||
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
|
||||
control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
|
||||
|
||||
input_image = input_image.resize((width, height))
|
||||
input_latents = pipe.preprocess_video([input_image])
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_latents = pipe.vae.encode(input_latents, device=pipe.device)
|
||||
y = torch.zeros_like(latents).to(pipe.device)
|
||||
y[:, :, :1] = input_latents
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
if y.shape[1] != pipe.dit.in_dim - latents.shape[1]:
|
||||
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
||||
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
||||
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
|
||||
msk[:, 1:] = 0
|
||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
y = torch.cat([msk,y])
|
||||
y = y.unsqueeze(0)
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"control_camera_latents_input": control_camera_latents_input, "y": y}
|
||||
|
||||
|
||||
@@ -851,6 +910,98 @@ class WanVideoUnit_CfgMerger(PipelineUnit):
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class WanVideoUnit_S2V(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
onload_model_names=("audio_encoder", "vae",)
|
||||
)
|
||||
|
||||
def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames, fps=16, audio_embeds=None, return_all=False):
|
||||
if audio_embeds is not None:
|
||||
return {"audio_embeds": audio_embeds}
|
||||
pipe.load_models_to_device(["audio_encoder"])
|
||||
audio_embeds = pipe.audio_encoder.get_audio_feats_per_inference(input_audio, audio_sample_rate, pipe.audio_processor, fps=fps, batch_frames=num_frames-1, dtype=pipe.torch_dtype, device=pipe.device)
|
||||
if return_all:
|
||||
return audio_embeds
|
||||
else:
|
||||
return {"audio_embeds": audio_embeds[0]}
|
||||
|
||||
def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride, motion_video=None):
|
||||
pipe.load_models_to_device(["vae"])
|
||||
motion_frames = 73
|
||||
kwargs = {}
|
||||
if motion_video is not None and len(motion_video) > 0:
|
||||
assert len(motion_video) == motion_frames, f"motion video must have {motion_frames} frames, but got {len(motion_video)}"
|
||||
motion_latents = pipe.preprocess_video(motion_video)
|
||||
kwargs["drop_motion_frames"] = False
|
||||
else:
|
||||
motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device)
|
||||
kwargs["drop_motion_frames"] = True
|
||||
motion_latents = pipe.vae.encode(motion_latents, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
kwargs.update({"motion_latents": motion_latents})
|
||||
return kwargs
|
||||
|
||||
def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=None, num_repeats=1, return_all=False):
|
||||
if s2v_pose_latents is not None:
|
||||
return {"s2v_pose_latents": s2v_pose_latents}
|
||||
if s2v_pose_video is None:
|
||||
return {"s2v_pose_latents": None}
|
||||
pipe.load_models_to_device(["vae"])
|
||||
infer_frames = num_frames - 1
|
||||
input_video = pipe.preprocess_video(s2v_pose_video)[:, :, :infer_frames * num_repeats]
|
||||
# pad if not enough frames
|
||||
padding_frames = infer_frames * num_repeats - input_video.shape[2]
|
||||
input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2)
|
||||
input_videos = input_video.chunk(num_repeats, dim=2)
|
||||
pose_conds = []
|
||||
for r in range(num_repeats):
|
||||
cond = input_videos[r]
|
||||
cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond], dim=2)
|
||||
cond_latents = pipe.vae.encode(cond, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
pose_conds.append(cond_latents[:,:,1:])
|
||||
if return_all:
|
||||
return pose_conds
|
||||
else:
|
||||
return {"s2v_pose_latents": pose_conds[0]}
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
if (inputs_shared.get("input_audio") is None and inputs_shared.get("audio_embeds") is None) or pipe.audio_encoder is None or pipe.audio_processor is None:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
num_frames, height, width, tiled, tile_size, tile_stride = inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride")
|
||||
input_audio, audio_embeds, audio_sample_rate = inputs_shared.pop("input_audio"), inputs_shared.pop("audio_embeds"), inputs_shared.get("audio_sample_rate")
|
||||
s2v_pose_video, s2v_pose_latents, motion_video = inputs_shared.pop("s2v_pose_video"), inputs_shared.pop("s2v_pose_latents"), inputs_shared.pop("motion_video")
|
||||
|
||||
audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames, audio_embeds=audio_embeds)
|
||||
inputs_posi.update(audio_input_positive)
|
||||
inputs_nega.update({"audio_embeds": 0.0 * audio_input_positive["audio_embeds"]})
|
||||
|
||||
inputs_shared.update(self.process_motion_latents(pipe, height, width, tiled, tile_size, tile_stride, motion_video))
|
||||
inputs_shared.update(self.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=s2v_pose_latents))
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
@staticmethod
|
||||
def pre_calculate_audio_pose(pipe: WanVideoPipeline, input_audio=None, audio_sample_rate=16000, s2v_pose_video=None, num_frames=81, height=448, width=832, fps=16, tiled=True, tile_size=(30, 52), tile_stride=(15, 26)):
|
||||
assert pipe.audio_encoder is not None and pipe.audio_processor is not None, "Please load audio encoder and audio processor first."
|
||||
shapes = WanVideoUnit_ShapeChecker().process(pipe, height, width, num_frames)
|
||||
height, width, num_frames = shapes["height"], shapes["width"], shapes["num_frames"]
|
||||
unit = WanVideoUnit_S2V()
|
||||
audio_embeds = unit.process_audio(pipe, input_audio, audio_sample_rate, num_frames, fps, return_all=True)
|
||||
pose_latents = unit.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, num_repeats=len(audio_embeds), return_all=True, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
pose_latents = None if s2v_pose_video is None else pose_latents
|
||||
return audio_embeds, pose_latents, len(audio_embeds)
|
||||
|
||||
|
||||
class WanVideoPostUnit_S2V(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=("latents", "motion_latents", "drop_motion_frames"))
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, latents, motion_latents, drop_motion_frames):
|
||||
if pipe.audio_encoder is None or motion_latents is None or drop_motion_frames:
|
||||
return {}
|
||||
latents = torch.cat([motion_latents, latents[:,:,1:]], dim=2)
|
||||
return {"latents": latents}
|
||||
|
||||
|
||||
class TeaCache:
|
||||
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
||||
@@ -970,6 +1121,10 @@ def model_fn_wan_video(
|
||||
reference_latents = None,
|
||||
vace_context = None,
|
||||
vace_scale = 1.0,
|
||||
audio_embeds: Optional[torch.Tensor] = None,
|
||||
motion_latents: Optional[torch.Tensor] = None,
|
||||
s2v_pose_latents: Optional[torch.Tensor] = None,
|
||||
drop_motion_frames: bool = True,
|
||||
tea_cache: TeaCache = None,
|
||||
use_unified_sequence_parallel: bool = False,
|
||||
motion_bucket_id: Optional[torch.Tensor] = None,
|
||||
@@ -1007,7 +1162,22 @@ def model_fn_wan_video(
|
||||
tensor_names=["latents", "y"],
|
||||
batch_size=2 if cfg_merge else 1
|
||||
)
|
||||
|
||||
# wan2.2 s2v
|
||||
if audio_embeds is not None:
|
||||
return model_fn_wans2v(
|
||||
dit=dit,
|
||||
latents=latents,
|
||||
timestep=timestep,
|
||||
context=context,
|
||||
audio_embeds=audio_embeds,
|
||||
motion_latents=motion_latents,
|
||||
s2v_pose_latents=s2v_pose_latents,
|
||||
drop_motion_frames=drop_motion_frames,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_unified_sequence_parallel=use_unified_sequence_parallel,
|
||||
)
|
||||
|
||||
if use_unified_sequence_parallel:
|
||||
import torch.distributed as dist
|
||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||
@@ -1048,7 +1218,7 @@ def model_fn_wan_video(
|
||||
if clip_feature is not None and dit.require_clip_embedding:
|
||||
clip_embdding = dit.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
|
||||
|
||||
# Add camera control
|
||||
x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
|
||||
|
||||
@@ -1126,3 +1296,105 @@ def model_fn_wan_video(
|
||||
f -= 1
|
||||
x = dit.unpatchify(x, (f, h, w))
|
||||
return x
|
||||
|
||||
|
||||
def model_fn_wans2v(
|
||||
dit,
|
||||
latents,
|
||||
timestep,
|
||||
context,
|
||||
audio_embeds,
|
||||
motion_latents,
|
||||
s2v_pose_latents,
|
||||
drop_motion_frames=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
use_gradient_checkpointing=False,
|
||||
use_unified_sequence_parallel=False,
|
||||
):
|
||||
if use_unified_sequence_parallel:
|
||||
import torch.distributed as dist
|
||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||
get_sequence_parallel_world_size,
|
||||
get_sp_group)
|
||||
origin_ref_latents = latents[:, :, 0:1]
|
||||
x = latents[:, :, 1:]
|
||||
|
||||
# context embedding
|
||||
context = dit.text_embedding(context)
|
||||
|
||||
# audio encode
|
||||
audio_emb_global, merged_audio_emb = dit.cal_audio_emb(audio_embeds)
|
||||
|
||||
# x and s2v_pose_latents
|
||||
s2v_pose_latents = torch.zeros_like(x) if s2v_pose_latents is None else s2v_pose_latents
|
||||
x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(s2v_pose_latents))
|
||||
seq_len_x = seq_len_x_global = x.shape[1] # global used for unified sequence parallel
|
||||
|
||||
# reference image
|
||||
ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents))
|
||||
grid_sizes = dit.get_grid_sizes((f, h, w), (rf, rh, rw))
|
||||
x = torch.cat([x, ref_latents], dim=1)
|
||||
# mask
|
||||
mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device)
|
||||
# freqs
|
||||
pre_compute_freqs = rope_precompute(x.detach().view(1, x.size(1), dit.num_heads, dit.dim // dit.num_heads), grid_sizes, dit.freqs, start=None)
|
||||
# motion
|
||||
x, pre_compute_freqs, mask = dit.inject_motion(x, pre_compute_freqs, mask, motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=2)
|
||||
|
||||
x = x + dit.trainable_cond_mask(mask).to(x.dtype)
|
||||
|
||||
# tmod
|
||||
timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)])
|
||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)).unsqueeze(2).transpose(0, 2)
|
||||
|
||||
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
||||
world_size, sp_rank = get_sequence_parallel_world_size(), get_sequence_parallel_rank()
|
||||
assert x.shape[1] % world_size == 0, f"the dimension after chunk must be divisible by world size, but got {x.shape[1]} and {get_sequence_parallel_world_size()}"
|
||||
x = torch.chunk(x, world_size, dim=1)[sp_rank]
|
||||
seg_idxs = [0] + list(torch.cumsum(torch.tensor([x.shape[1]] * world_size), dim=0).cpu().numpy())
|
||||
seq_len_x_list = [min(max(0, seq_len_x - seg_idxs[i]), x.shape[1]) for i in range(len(seg_idxs)-1)]
|
||||
seq_len_x = seq_len_x_list[sp_rank]
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, seq_len_x, pre_compute_freqs[0],
|
||||
use_reentrant=False,
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
||||
x,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, seq_len_x, pre_compute_freqs[0],
|
||||
use_reentrant=False,
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
||||
x,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
|
||||
x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel)
|
||||
|
||||
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
||||
x = get_sp_group().all_gather(x, dim=1)
|
||||
|
||||
x = x[:, :seq_len_x_global]
|
||||
x = dit.head(x, t[:-1])
|
||||
x = dit.unpatchify(x, (f, h, w))
|
||||
# make compatible with wan video
|
||||
x = torch.cat([origin_ref_latents, x], dim=2)
|
||||
return x
|
||||
|
||||
@@ -31,7 +31,7 @@ class FlowMatchScheduler():
|
||||
self.set_timesteps(num_inference_steps)
|
||||
|
||||
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None, dynamic_shift_len=None):
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None, dynamic_shift_len=None, exponential_shift_mu=None):
|
||||
if shift is not None:
|
||||
self.shift = shift
|
||||
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
||||
@@ -42,7 +42,12 @@ class FlowMatchScheduler():
|
||||
if self.inverse_timesteps:
|
||||
self.sigmas = torch.flip(self.sigmas, dims=[0])
|
||||
if self.exponential_shift:
|
||||
mu = self.calculate_shift(dynamic_shift_len) if dynamic_shift_len is not None else self.exponential_shift_mu
|
||||
if exponential_shift_mu is not None:
|
||||
mu = exponential_shift_mu
|
||||
elif dynamic_shift_len is not None:
|
||||
mu = self.calculate_shift(dynamic_shift_len)
|
||||
else:
|
||||
mu = self.exponential_shift_mu
|
||||
self.sigmas = math.exp(mu) / (math.exp(mu) + (1 / self.sigmas - 1))
|
||||
else:
|
||||
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
|
||||
|
||||
334
diffsynth/trainers/unified_dataset.py
Normal file
334
diffsynth/trainers/unified_dataset.py
Normal file
@@ -0,0 +1,334 @@
|
||||
import torch, torchvision, imageio, os, json, pandas
|
||||
import imageio.v3 as iio
|
||||
from PIL import Image
|
||||
|
||||
|
||||
|
||||
class DataProcessingPipeline:
|
||||
def __init__(self, operators=None):
|
||||
self.operators: list[DataProcessingOperator] = [] if operators is None else operators
|
||||
|
||||
def __call__(self, data):
|
||||
for operator in self.operators:
|
||||
data = operator(data)
|
||||
return data
|
||||
|
||||
def __rshift__(self, pipe):
|
||||
if isinstance(pipe, DataProcessingOperator):
|
||||
pipe = DataProcessingPipeline([pipe])
|
||||
return DataProcessingPipeline(self.operators + pipe.operators)
|
||||
|
||||
|
||||
|
||||
class DataProcessingOperator:
|
||||
def __call__(self, data):
|
||||
raise NotImplementedError("DataProcessingOperator cannot be called directly.")
|
||||
|
||||
def __rshift__(self, pipe):
|
||||
if isinstance(pipe, DataProcessingOperator):
|
||||
pipe = DataProcessingPipeline([pipe])
|
||||
return DataProcessingPipeline([self]).__rshift__(pipe)
|
||||
|
||||
|
||||
|
||||
class DataProcessingOperatorRaw(DataProcessingOperator):
|
||||
def __call__(self, data):
|
||||
return data
|
||||
|
||||
|
||||
|
||||
class ToInt(DataProcessingOperator):
|
||||
def __call__(self, data):
|
||||
return int(data)
|
||||
|
||||
|
||||
|
||||
class ToFloat(DataProcessingOperator):
|
||||
def __call__(self, data):
|
||||
return float(data)
|
||||
|
||||
|
||||
|
||||
class ToStr(DataProcessingOperator):
|
||||
def __init__(self, none_value=""):
|
||||
self.none_value = none_value
|
||||
|
||||
def __call__(self, data):
|
||||
if data is None: data = self.none_value
|
||||
return str(data)
|
||||
|
||||
|
||||
|
||||
class LoadImage(DataProcessingOperator):
|
||||
def __init__(self, convert_RGB=True):
|
||||
self.convert_RGB = convert_RGB
|
||||
|
||||
def __call__(self, data: str):
|
||||
image = Image.open(data)
|
||||
if self.convert_RGB: image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
|
||||
|
||||
class ImageCropAndResize(DataProcessingOperator):
|
||||
def __init__(self, height, width, max_pixels, height_division_factor, width_division_factor):
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.max_pixels = max_pixels
|
||||
self.height_division_factor = height_division_factor
|
||||
self.width_division_factor = width_division_factor
|
||||
|
||||
def crop_and_resize(self, image, target_height, target_width):
|
||||
width, height = image.size
|
||||
scale = max(target_width / width, target_height / height)
|
||||
image = torchvision.transforms.functional.resize(
|
||||
image,
|
||||
(round(height*scale), round(width*scale)),
|
||||
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||
)
|
||||
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
|
||||
return image
|
||||
|
||||
def get_height_width(self, image):
|
||||
if self.height is None or self.width is None:
|
||||
width, height = image.size
|
||||
if width * height > self.max_pixels:
|
||||
scale = (width * height / self.max_pixels) ** 0.5
|
||||
height, width = int(height / scale), int(width / scale)
|
||||
height = height // self.height_division_factor * self.height_division_factor
|
||||
width = width // self.width_division_factor * self.width_division_factor
|
||||
else:
|
||||
height, width = self.height, self.width
|
||||
return height, width
|
||||
|
||||
|
||||
def __call__(self, data: Image.Image):
|
||||
image = self.crop_and_resize(data, *self.get_height_width(data))
|
||||
return image
|
||||
|
||||
|
||||
|
||||
class ToList(DataProcessingOperator):
|
||||
def __call__(self, data):
|
||||
return [data]
|
||||
|
||||
|
||||
|
||||
class LoadVideo(DataProcessingOperator):
|
||||
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
|
||||
self.num_frames = num_frames
|
||||
self.time_division_factor = time_division_factor
|
||||
self.time_division_remainder = time_division_remainder
|
||||
# frame_processor is build in the video loader for high efficiency.
|
||||
self.frame_processor = frame_processor
|
||||
|
||||
def get_num_frames(self, reader):
|
||||
num_frames = self.num_frames
|
||||
if int(reader.count_frames()) < num_frames:
|
||||
num_frames = int(reader.count_frames())
|
||||
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
||||
num_frames -= 1
|
||||
return num_frames
|
||||
|
||||
def __call__(self, data: str):
|
||||
reader = imageio.get_reader(data)
|
||||
num_frames = self.get_num_frames(reader)
|
||||
frames = []
|
||||
for frame_id in range(num_frames):
|
||||
frame = reader.get_data(frame_id)
|
||||
frame = Image.fromarray(frame)
|
||||
frame = self.frame_processor(frame)
|
||||
frames.append(frame)
|
||||
reader.close()
|
||||
return frames
|
||||
|
||||
|
||||
|
||||
class SequencialProcess(DataProcessingOperator):
|
||||
def __init__(self, operator=lambda x: x):
|
||||
self.operator = operator
|
||||
|
||||
def __call__(self, data):
|
||||
return [self.operator(i) for i in data]
|
||||
|
||||
|
||||
|
||||
class LoadGIF(DataProcessingOperator):
|
||||
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
|
||||
self.num_frames = num_frames
|
||||
self.time_division_factor = time_division_factor
|
||||
self.time_division_remainder = time_division_remainder
|
||||
# frame_processor is build in the video loader for high efficiency.
|
||||
self.frame_processor = frame_processor
|
||||
|
||||
def get_num_frames(self, path):
|
||||
num_frames = self.num_frames
|
||||
images = iio.imread(path, mode="RGB")
|
||||
if len(images) < num_frames:
|
||||
num_frames = len(images)
|
||||
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
||||
num_frames -= 1
|
||||
return num_frames
|
||||
|
||||
def __call__(self, data: str):
|
||||
num_frames = self.get_num_frames(data)
|
||||
frames = []
|
||||
images = iio.imread(data, mode="RGB")
|
||||
for img in images:
|
||||
frame = Image.fromarray(img)
|
||||
frame = self.frame_processor(frame)
|
||||
frames.append(frame)
|
||||
if len(frames) >= num_frames:
|
||||
break
|
||||
return frames
|
||||
|
||||
|
||||
|
||||
class RouteByExtensionName(DataProcessingOperator):
|
||||
def __init__(self, operator_map):
|
||||
self.operator_map = operator_map
|
||||
|
||||
def __call__(self, data: str):
|
||||
file_ext_name = data.split(".")[-1].lower()
|
||||
for ext_names, operator in self.operator_map:
|
||||
if ext_names is None or file_ext_name in ext_names:
|
||||
return operator(data)
|
||||
raise ValueError(f"Unsupported file: {data}")
|
||||
|
||||
|
||||
|
||||
class RouteByType(DataProcessingOperator):
|
||||
def __init__(self, operator_map):
|
||||
self.operator_map = operator_map
|
||||
|
||||
def __call__(self, data):
|
||||
for dtype, operator in self.operator_map:
|
||||
if dtype is None or isinstance(data, dtype):
|
||||
return operator(data)
|
||||
raise ValueError(f"Unsupported data: {data}")
|
||||
|
||||
|
||||
|
||||
class LoadTorchPickle(DataProcessingOperator):
|
||||
def __init__(self, map_location="cpu"):
|
||||
self.map_location = map_location
|
||||
|
||||
def __call__(self, data):
|
||||
return torch.load(data, map_location=self.map_location, weights_only=False)
|
||||
|
||||
|
||||
|
||||
class ToAbsolutePath(DataProcessingOperator):
|
||||
def __init__(self, base_path=""):
|
||||
self.base_path = base_path
|
||||
|
||||
def __call__(self, data):
|
||||
return os.path.join(self.base_path, data)
|
||||
|
||||
|
||||
|
||||
class UnifiedDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
base_path=None, metadata_path=None,
|
||||
repeat=1,
|
||||
data_file_keys=tuple(),
|
||||
main_data_operator=lambda x: x,
|
||||
special_operator_map=None,
|
||||
):
|
||||
self.base_path = base_path
|
||||
self.metadata_path = metadata_path
|
||||
self.repeat = repeat
|
||||
self.data_file_keys = data_file_keys
|
||||
self.main_data_operator = main_data_operator
|
||||
self.cached_data_operator = LoadTorchPickle()
|
||||
self.special_operator_map = {} if special_operator_map is None else special_operator_map
|
||||
self.data = []
|
||||
self.cached_data = []
|
||||
self.load_from_cache = metadata_path is None
|
||||
self.load_metadata(metadata_path)
|
||||
|
||||
@staticmethod
|
||||
def default_image_operator(
|
||||
base_path="",
|
||||
max_pixels=1920*1080, height=None, width=None,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
):
|
||||
return RouteByType(operator_map=[
|
||||
(str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)),
|
||||
(list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))),
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def default_video_operator(
|
||||
base_path="",
|
||||
max_pixels=1920*1080, height=None, width=None,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
num_frames=81, time_division_factor=4, time_division_remainder=1,
|
||||
):
|
||||
return RouteByType(operator_map=[
|
||||
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
|
||||
(("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),
|
||||
(("gif",), LoadGIF(num_frames, time_division_factor, time_division_remainder) >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)),
|
||||
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
|
||||
num_frames, time_division_factor, time_division_remainder,
|
||||
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
||||
)),
|
||||
])),
|
||||
])
|
||||
|
||||
def search_for_cached_data_files(self, path):
|
||||
for file_name in os.listdir(path):
|
||||
subpath = os.path.join(path, file_name)
|
||||
if os.path.isdir(subpath):
|
||||
self.search_for_cached_data_files(subpath)
|
||||
elif subpath.endswith(".pth"):
|
||||
self.cached_data.append(subpath)
|
||||
|
||||
def load_metadata(self, metadata_path):
|
||||
if metadata_path is None:
|
||||
print("No metadata_path. Searching for cached data files.")
|
||||
self.search_for_cached_data_files(self.base_path)
|
||||
print(f"{len(self.cached_data)} cached data files found.")
|
||||
elif metadata_path.endswith(".json"):
|
||||
with open(metadata_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
self.data = metadata
|
||||
elif metadata_path.endswith(".jsonl"):
|
||||
metadata = []
|
||||
with open(metadata_path, 'r') as f:
|
||||
for line in f:
|
||||
metadata.append(json.loads(line.strip()))
|
||||
self.data = metadata
|
||||
else:
|
||||
metadata = pandas.read_csv(metadata_path)
|
||||
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
||||
|
||||
def __getitem__(self, data_id):
|
||||
if self.load_from_cache:
|
||||
data = self.cached_data[data_id % len(self.cached_data)]
|
||||
data = self.cached_data_operator(data)
|
||||
else:
|
||||
data = self.data[data_id % len(self.data)].copy()
|
||||
for key in self.data_file_keys:
|
||||
if key in data:
|
||||
if key in self.special_operator_map:
|
||||
data[key] = self.special_operator_map[key]
|
||||
elif key in self.data_file_keys:
|
||||
data[key] = self.main_data_operator(data[key])
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
if self.load_from_cache:
|
||||
return len(self.cached_data) * self.repeat
|
||||
else:
|
||||
return len(self.data) * self.repeat
|
||||
|
||||
def check_data_equal(self, data1, data2):
|
||||
# Debug only
|
||||
if len(data1) != len(data2):
|
||||
return False
|
||||
for k in data1:
|
||||
if data1[k] != data2[k]:
|
||||
return False
|
||||
return True
|
||||
@@ -1,4 +1,6 @@
|
||||
import imageio, os, torch, warnings, torchvision, argparse, json
|
||||
from ..utils import ModelConfig
|
||||
from ..models.utils import load_state_dict
|
||||
from peft import LoraConfig, inject_adapter_in_model
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
@@ -154,7 +156,7 @@ class VideoDataset(torch.utils.data.Dataset):
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
data_file_keys=("video",),
|
||||
image_file_extension=("jpg", "jpeg", "png", "webp"),
|
||||
video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"),
|
||||
video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm", "gif"),
|
||||
repeat=1,
|
||||
args=None,
|
||||
):
|
||||
@@ -259,8 +261,53 @@ class VideoDataset(torch.utils.data.Dataset):
|
||||
num_frames -= 1
|
||||
return num_frames
|
||||
|
||||
|
||||
def _load_gif(self, file_path):
|
||||
gif_img = Image.open(file_path)
|
||||
frame_count = 0
|
||||
delays, frames = [], []
|
||||
while True:
|
||||
delay = gif_img.info.get('duration', 100) # ms
|
||||
delays.append(delay)
|
||||
rgb_frame = gif_img.convert("RGB")
|
||||
croped_frame = self.crop_and_resize(rgb_frame, *self.get_height_width(rgb_frame))
|
||||
frames.append(croped_frame)
|
||||
frame_count += 1
|
||||
try:
|
||||
gif_img.seek(frame_count)
|
||||
except:
|
||||
break
|
||||
# delays canbe used to calculate framerates
|
||||
# i guess it is better to sample images with stable interval,
|
||||
# and using minimal_interval as the interval,
|
||||
# and framerate = 1000 / minimal_interval
|
||||
if any((delays[0] != i) for i in delays):
|
||||
minimal_interval = min([i for i in delays if i > 0])
|
||||
# make a ((start,end),frameid) struct
|
||||
start_end_idx_map = [((sum(delays[:i]), sum(delays[:i+1])), i) for i in range(len(delays))]
|
||||
_frames = []
|
||||
# according gemini-code-assist, make it more efficient to locate
|
||||
# where to sample the frame
|
||||
last_match = 0
|
||||
for i in range(sum(delays) // minimal_interval):
|
||||
current_time = minimal_interval * i
|
||||
for idx, ((start, end), frame_idx) in enumerate(start_end_idx_map[last_match:]):
|
||||
if start <= current_time < end:
|
||||
_frames.append(frames[frame_idx])
|
||||
last_match = idx + last_match
|
||||
break
|
||||
frames = _frames
|
||||
num_frames = len(frames)
|
||||
if num_frames > self.num_frames:
|
||||
num_frames = self.num_frames
|
||||
else:
|
||||
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
||||
num_frames -= 1
|
||||
frames = frames[:num_frames]
|
||||
return frames
|
||||
|
||||
def load_video(self, file_path):
|
||||
if file_path.lower().endswith(".gif"):
|
||||
return self._load_gif(file_path)
|
||||
reader = imageio.get_reader(file_path)
|
||||
num_frames = self.get_num_frames(reader)
|
||||
frames = []
|
||||
@@ -338,13 +385,26 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
return trainable_param_names
|
||||
|
||||
|
||||
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None):
|
||||
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None):
|
||||
if lora_alpha is None:
|
||||
lora_alpha = lora_rank
|
||||
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
|
||||
model = inject_adapter_in_model(lora_config, model)
|
||||
if upcast_dtype is not None:
|
||||
for param in model.parameters():
|
||||
if param.requires_grad:
|
||||
param.data = param.to(upcast_dtype)
|
||||
return model
|
||||
|
||||
def disable_all_lora_layers(self, model):
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, 'enable_adapters'):
|
||||
module.enable_adapters(False)
|
||||
|
||||
def enable_all_lora_layers(self, model):
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, 'enable_adapters'):
|
||||
module.enable_adapters(True)
|
||||
|
||||
def mapping_lora_state_dict(self, state_dict):
|
||||
new_state_dict = {}
|
||||
@@ -352,6 +412,8 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
if "lora_A.weight" in key or "lora_B.weight" in key:
|
||||
new_key = key.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight")
|
||||
new_state_dict[new_key] = value
|
||||
elif "lora_A.default.weight" in key or "lora_B.default.weight" in key:
|
||||
new_state_dict[key] = value
|
||||
return new_state_dict
|
||||
|
||||
|
||||
@@ -366,7 +428,62 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
state_dict_[name] = param
|
||||
state_dict = state_dict_
|
||||
return state_dict
|
||||
|
||||
|
||||
|
||||
def transfer_data_to_device(self, data, device, torch_float_dtype=None):
|
||||
for key in data:
|
||||
if isinstance(data[key], torch.Tensor):
|
||||
data[key] = data[key].to(device)
|
||||
if torch_float_dtype is not None and data[key].dtype in [torch.float, torch.float16, torch.bfloat16]:
|
||||
data[key] = data[key].to(torch_float_dtype)
|
||||
return data
|
||||
|
||||
|
||||
def parse_model_configs(self, model_paths, model_id_with_origin_paths, enable_fp8_training=False):
|
||||
offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None
|
||||
model_configs = []
|
||||
if model_paths is not None:
|
||||
model_paths = json.loads(model_paths)
|
||||
model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths]
|
||||
if model_id_with_origin_paths is not None:
|
||||
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
||||
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths]
|
||||
return model_configs
|
||||
|
||||
|
||||
def switch_pipe_to_training_mode(
|
||||
self,
|
||||
pipe,
|
||||
trainable_models,
|
||||
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=None,
|
||||
enable_fp8_training=False,
|
||||
):
|
||||
# Scheduler
|
||||
pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
# Freeze untrainable models
|
||||
pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
|
||||
|
||||
# Enable FP8 if pipeline supports
|
||||
if enable_fp8_training and hasattr(pipe, "_enable_fp8_lora_training"):
|
||||
pipe._enable_fp8_lora_training(torch.float8_e4m3fn)
|
||||
|
||||
# Add LoRA to the base models
|
||||
if lora_base_model is not None:
|
||||
model = self.add_lora_to_model(
|
||||
getattr(pipe, lora_base_model),
|
||||
target_modules=lora_target_modules.split(","),
|
||||
lora_rank=lora_rank,
|
||||
upcast_dtype=pipe.torch_dtype,
|
||||
)
|
||||
if lora_checkpoint is not None:
|
||||
state_dict = load_state_dict(lora_checkpoint)
|
||||
state_dict = self.mapping_lora_state_dict(state_dict)
|
||||
load_result = model.load_state_dict(state_dict, strict=False)
|
||||
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
|
||||
if len(load_result[1]) > 0:
|
||||
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
||||
setattr(pipe, lora_base_model, model)
|
||||
|
||||
|
||||
class ModelLogger:
|
||||
@@ -414,14 +531,26 @@ def launch_training_task(
|
||||
dataset: torch.utils.data.Dataset,
|
||||
model: DiffusionTrainingModule,
|
||||
model_logger: ModelLogger,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
||||
learning_rate: float = 1e-5,
|
||||
weight_decay: float = 1e-2,
|
||||
num_workers: int = 8,
|
||||
save_steps: int = None,
|
||||
num_epochs: int = 1,
|
||||
gradient_accumulation_steps: int = 1,
|
||||
find_unused_parameters: bool = False,
|
||||
args = None,
|
||||
):
|
||||
if args is not None:
|
||||
learning_rate = args.learning_rate
|
||||
weight_decay = args.weight_decay
|
||||
num_workers = args.dataset_num_workers
|
||||
save_steps = args.save_steps
|
||||
num_epochs = args.num_epochs
|
||||
gradient_accumulation_steps = args.gradient_accumulation_steps
|
||||
find_unused_parameters = args.find_unused_parameters
|
||||
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
@@ -433,7 +562,10 @@ def launch_training_task(
|
||||
for data in tqdm(dataloader):
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
loss = model(data)
|
||||
if dataset.load_from_cache:
|
||||
loss = model({}, inputs=data, accelerator=accelerator)
|
||||
else:
|
||||
loss = model(data, accelerator=accelerator)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
model_logger.on_step_end(accelerator, model, save_steps)
|
||||
@@ -443,16 +575,28 @@ def launch_training_task(
|
||||
model_logger.on_training_end(accelerator, model, save_steps)
|
||||
|
||||
|
||||
def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"):
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0])
|
||||
def launch_data_process_task(
|
||||
dataset: torch.utils.data.Dataset,
|
||||
model: DiffusionTrainingModule,
|
||||
model_logger: ModelLogger,
|
||||
num_workers: int = 8,
|
||||
args = None,
|
||||
):
|
||||
if args is not None:
|
||||
num_workers = args.dataset_num_workers
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||
accelerator = Accelerator()
|
||||
model, dataloader = accelerator.prepare(model, dataloader)
|
||||
os.makedirs(os.path.join(output_path, "data_cache"), exist_ok=True)
|
||||
for data_id, data in enumerate(tqdm(dataloader)):
|
||||
with torch.no_grad():
|
||||
inputs = model.forward_preprocess(data)
|
||||
inputs = {key: inputs[key] for key in model.model_input_keys if key in inputs}
|
||||
torch.save(inputs, os.path.join(output_path, "data_cache", f"{data_id}.pth"))
|
||||
|
||||
for data_id, data in tqdm(enumerate(dataloader)):
|
||||
with accelerator.accumulate(model):
|
||||
with torch.no_grad():
|
||||
folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
|
||||
data = model(data, return_inputs=True)
|
||||
torch.save(data, save_path)
|
||||
|
||||
|
||||
|
||||
@@ -552,4 +696,8 @@ def qwen_image_parser():
|
||||
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
|
||||
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
|
||||
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
|
||||
parser.add_argument("--enable_fp8_training", default=False, action="store_true", help="Whether to enable FP8 training. Only available for LoRA training on a single GPU.")
|
||||
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
|
||||
parser.add_argument("--beta_dpo", type=float, default=1000, help="hyperparameter beta for DPO loss, only used when task is dpo.")
|
||||
return parser
|
||||
|
||||
@@ -153,6 +153,19 @@ class BasePipeline(torch.nn.Module):
|
||||
latents_next = scheduler.step(noise_pred, timestep, latents)
|
||||
return latents_next
|
||||
|
||||
def sample_timestep(self):
|
||||
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
|
||||
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
|
||||
return timestep
|
||||
|
||||
def training_loss_minimum(self, noise, timestep, **inputs):
|
||||
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||
training_target = self.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||
noise_pred = self.model_fn(**inputs, timestep=timestep)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * self.scheduler.training_weight(timestep)
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import torch, os, json
|
||||
from diffsynth import load_state_dict
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, flux_parser
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, flux_parser
|
||||
from diffsynth.models.lora import FluxLoRAConverter
|
||||
from diffsynth.trainers.unified_dataset import UnifiedDataset
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
@@ -19,36 +20,16 @@ class FluxTrainingModule(DiffusionTrainingModule):
|
||||
):
|
||||
super().__init__()
|
||||
# Load models
|
||||
model_configs = []
|
||||
if model_paths is not None:
|
||||
model_paths = json.loads(model_paths)
|
||||
model_configs += [ModelConfig(path=path) for path in model_paths]
|
||||
if model_id_with_origin_paths is not None:
|
||||
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
||||
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths]
|
||||
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=False)
|
||||
self.pipe = FluxImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
|
||||
|
||||
# Reset training scheduler
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
# Training mode
|
||||
self.switch_pipe_to_training_mode(
|
||||
self.pipe, trainable_models,
|
||||
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
|
||||
enable_fp8_training=False,
|
||||
)
|
||||
|
||||
# Freeze untrainable models
|
||||
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
|
||||
|
||||
# Add LoRA to the base models
|
||||
if lora_base_model is not None:
|
||||
model = self.add_lora_to_model(
|
||||
getattr(self.pipe, lora_base_model),
|
||||
target_modules=lora_target_modules.split(","),
|
||||
lora_rank=lora_rank
|
||||
)
|
||||
if lora_checkpoint is not None:
|
||||
state_dict = load_state_dict(lora_checkpoint)
|
||||
state_dict = self.mapping_lora_state_dict(state_dict)
|
||||
load_result = model.load_state_dict(state_dict, strict=False)
|
||||
if len(load_result[1]) > 0:
|
||||
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
||||
setattr(self.pipe, lora_base_model, model)
|
||||
|
||||
# Store other configs
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
@@ -94,7 +75,7 @@ class FluxTrainingModule(DiffusionTrainingModule):
|
||||
return {**inputs_shared, **inputs_posi}
|
||||
|
||||
|
||||
def forward(self, data, inputs=None):
|
||||
def forward(self, data, inputs=None, **kwargs):
|
||||
if inputs is None: inputs = self.forward_preprocess(data)
|
||||
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||
loss = self.pipe.training_loss(**models, **inputs)
|
||||
@@ -105,7 +86,20 @@ class FluxTrainingModule(DiffusionTrainingModule):
|
||||
if __name__ == "__main__":
|
||||
parser = flux_parser()
|
||||
args = parser.parse_args()
|
||||
dataset = ImageDataset(args=args)
|
||||
dataset = UnifiedDataset(
|
||||
base_path=args.dataset_base_path,
|
||||
metadata_path=args.dataset_metadata_path,
|
||||
repeat=args.dataset_repeat,
|
||||
data_file_keys=args.data_file_keys.split(","),
|
||||
main_data_operator=UnifiedDataset.default_image_operator(
|
||||
base_path=args.dataset_base_path,
|
||||
max_pixels=args.max_pixels,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
height_division_factor=16,
|
||||
width_division_factor=16,
|
||||
)
|
||||
)
|
||||
model = FluxTrainingModule(
|
||||
model_paths=args.model_paths,
|
||||
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||
@@ -123,13 +117,4 @@ if __name__ == "__main__":
|
||||
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
|
||||
)
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||
launch_training_task(
|
||||
dataset, model, model_logger, optimizer, scheduler,
|
||||
num_epochs=args.num_epochs,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
save_steps=args.save_steps,
|
||||
find_unused_parameters=args.find_unused_parameters,
|
||||
num_workers=args.dataset_num_workers,
|
||||
)
|
||||
launch_training_task(dataset, model, model_logger, args=args)
|
||||
|
||||
@@ -20,9 +20,9 @@ Run the following code to quickly load the [Qwen/Qwen-Image](https://www.modelsc
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
@@ -34,7 +34,10 @@ pipe = QwenImagePipeline.from_pretrained(
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "A detailed portrait of a girl underwater, wearing a blue flowing dress, hair gently floating, clear light and shadow, surrounded by bubbles, calm expression, fine details, dreamy and beautiful."
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image = pipe(
|
||||
prompt, seed=0, num_inference_steps=40,
|
||||
# edit_image=Image.open("xxx.jpg").resize((1328, 1328)) # For Qwen-Image-Edit
|
||||
)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
@@ -43,12 +46,16 @@ image.save("image.jpg")
|
||||
|Model ID|Inference|Low VRAM Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./model_inference/Qwen-Image.py)|[code](./model_inference_low_vram/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)|
|
||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](./model_inference/Qwen-Image-Edit.py)|[code](./model_inference_low_vram/Qwen-Image-Edit.py)|[code](./model_training/full/Qwen-Image-Edit.sh)|[code](./model_training/validate_full/Qwen-Image-Edit.py)|[code](./model_training/lora/Qwen-Image-Edit.sh)|[code](./model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./model_inference/Qwen-Image-Distill-LoRA.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|-|-|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./model_inference/Qwen-Image-Distill-LoRA.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](./model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./model_inference/Qwen-Image-EliGen.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](./model_inference/Qwen-Image-EliGen-V2.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](./model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](./model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](./model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](./model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](./model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](./model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|
||||
|
||||
## Model Inference
|
||||
|
||||
@@ -236,6 +243,7 @@ The script includes the following parameters:
|
||||
* `--model_paths`: Model paths to load. In JSON format.
|
||||
* `--model_id_with_origin_paths`: Model ID with original paths, e.g., Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors. Separate with commas.
|
||||
* `--tokenizer_path`: Tokenizer path. Leave empty to auto-download.
|
||||
* `--processor_path`: Path to the processor of Qwen-Image-Edit. Leave empty to auto-download.
|
||||
* Training
|
||||
* `--learning_rate`: Learning rate.
|
||||
* `--weight_decay`: Weight decay.
|
||||
|
||||
@@ -20,9 +20,9 @@ pip install -e .
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
@@ -34,7 +34,10 @@ pipe = QwenImagePipeline.from_pretrained(
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image = pipe(
|
||||
prompt, seed=0, num_inference_steps=40,
|
||||
# edit_image=Image.open("xxx.jpg").resize((1328, 1328)) # For Qwen-Image-Edit
|
||||
)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
@@ -43,12 +46,16 @@ image.save("image.jpg")
|
||||
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./model_inference/Qwen-Image.py)|[code](./model_inference_low_vram/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)|
|
||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](./model_inference/Qwen-Image-Edit.py)|[code](./model_inference_low_vram/Qwen-Image-Edit.py)|[code](./model_training/full/Qwen-Image-Edit.sh)|[code](./model_training/validate_full/Qwen-Image-Edit.py)|[code](./model_training/lora/Qwen-Image-Edit.sh)|[code](./model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./model_inference/Qwen-Image-Distill-LoRA.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|-|-|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./model_inference/Qwen-Image-Distill-LoRA.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](./model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./model_inference/Qwen-Image-EliGen.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](./model_inference/Qwen-Image-EliGen-V2.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](./model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](./model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](./model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](./model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](./model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](./model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|
||||
|
||||
## 模型推理
|
||||
|
||||
@@ -236,6 +243,7 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod
|
||||
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||
* `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors。用逗号分隔。
|
||||
* `--tokenizer_path`: tokenizer 路径,留空将会自动下载。
|
||||
* `--processor_path`:Qwen-Image-Edit 的 processor 路径。留空则自动下载。
|
||||
* 训练
|
||||
* `--learning_rate`: 学习率。
|
||||
* `--weight_decay`:权重衰减大小。
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, load_state_dict
|
||||
from modelscope import snapshot_download
|
||||
import torch, math
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
|
||||
snapshot_download("MusePublic/Qwen-Image-Distill", allow_file_pattern="qwen_image_distill_3step.safetensors", cache_dir="models")
|
||||
lora_state_dict = load_state_dict("models/MusePublic/Qwen-Image-Distill/qwen_image_distill_3step.safetensors")
|
||||
lora_state_dict = {i.replace("base_model.model.", ""): j for i, j in lora_state_dict.items()}
|
||||
pipe.load_lora(pipe.dit, state_dict=lora_state_dict)
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=3, cfg_scale=1, exponential_shift_mu=math.log(2.5))
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,26 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
from modelscope import snapshot_download
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=None,
|
||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
||||
)
|
||||
snapshot_download("DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix", local_dir="models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix", allow_file_pattern="model.safetensors")
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix/model.safetensors")
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, height=1024, width=768)
|
||||
image.save("image.jpg")
|
||||
|
||||
prompt = "将裙子变成粉色"
|
||||
image = image.resize((512, 384))
|
||||
image = pipe(prompt, edit_image=image, seed=1, num_inference_steps=40, height=1024, width=768, edit_rope_interpolation=True, edit_image_auto_resize=False)
|
||||
image.save(f"image2.jpg")
|
||||
26
examples/qwen_image/model_inference/Qwen-Image-Edit.py
Normal file
26
examples/qwen_image/model_inference/Qwen-Image-Edit.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=None,
|
||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
||||
)
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
input_image = pipe(prompt=prompt, seed=0, num_inference_steps=40, height=1328, width=1024)
|
||||
input_image.save("image1.jpg")
|
||||
|
||||
prompt = "将裙子改为粉色"
|
||||
# edit_image_auto_resize=True: auto resize input image to match the area of 1024*1024 with the original aspect ratio
|
||||
image = pipe(prompt, edit_image=input_image, seed=1, num_inference_steps=40, height=1328, width=1024, edit_image_auto_resize=True)
|
||||
image.save(f"image2.jpg")
|
||||
|
||||
# edit_image_auto_resize=False: do not resize input image
|
||||
image = pipe(prompt, edit_image=input_image, seed=1, num_inference_steps=40, height=1328, width=1024, edit_image_auto_resize=False)
|
||||
image.save(f"image3.jpg")
|
||||
106
examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py
Normal file
106
examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import torch
|
||||
import random
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from modelscope import dataset_snapshot_download, snapshot_download
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
|
||||
def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):
|
||||
# Create a blank image for overlays
|
||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||
|
||||
colors = [
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
]
|
||||
# Generate random colors for each mask
|
||||
if use_random_colors:
|
||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
||||
|
||||
# Font settings
|
||||
try:
|
||||
font = ImageFont.truetype("wqy-zenhei.ttc", font_size) # Adjust as needed
|
||||
except IOError:
|
||||
font = ImageFont.load_default(font_size)
|
||||
|
||||
# Overlay each mask onto the overlay image
|
||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
||||
# Convert mask to RGBA mode
|
||||
mask_rgba = mask.convert('RGBA')
|
||||
mask_data = mask_rgba.getdata()
|
||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
||||
mask_rgba.putdata(new_data)
|
||||
|
||||
# Draw the mask prompt text on the mask
|
||||
draw = ImageDraw.Draw(mask_rgba)
|
||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
||||
|
||||
# Alpha composite the overlay with this mask
|
||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
||||
|
||||
# Composite the overlay onto the original image
|
||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
||||
|
||||
# Save or display the resulting image
|
||||
result.save(output_path)
|
||||
|
||||
return result
|
||||
|
||||
def example(pipe, seeds, example_id, global_prompt, entity_prompts):
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/qwen-image/example_{example_id}/*.png")
|
||||
masks = [Image.open(f"./data/examples/eligen/qwen-image/example_{example_id}/{i}.png").convert('RGB').resize((1024, 1024)) for i in range(len(entity_prompts))]
|
||||
negative_prompt = "网格化,规则的网格,模糊, 低分辨率, 低质量, 变形, 畸形, 错误的解剖学, 变形的手, 变形的身体, 变形的脸, 变形的头发, 变形的眼睛, 变形的嘴巴"
|
||||
for seed in seeds:
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt=global_prompt,
|
||||
cfg_scale=4.0,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=40,
|
||||
seed=seed,
|
||||
height=1024,
|
||||
width=1024,
|
||||
eligen_entity_prompts=entity_prompts,
|
||||
eligen_entity_masks=masks,
|
||||
)
|
||||
image.save(f"eligen_example_{example_id}_{seed}.png")
|
||||
visualize_masks(image, masks, entity_prompts, f"eligen_example_{example_id}_mask_{seed}.png")
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
snapshot_download("DiffSynth-Studio/Qwen-Image-EliGen-V2", local_dir="models/DiffSynth-Studio/Qwen-Image-EliGen-V2", allow_file_pattern="model.safetensors")
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-EliGen-V2/model.safetensors")
|
||||
|
||||
seeds = [0]
|
||||
|
||||
global_prompt = "写实摄影风格. A beautiful asia woman wearing white dress, she is holding a mirror with her right arm, with a beach background."
|
||||
entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"]
|
||||
example(pipe, seeds, 7, global_prompt, entity_prompts)
|
||||
|
||||
global_prompt = "写实摄影风格, 细节丰富。街头一位漂亮的女孩,穿着衬衫和短裤,手持写有“实体控制”的标牌,背景是繁忙的城市街道,阳光明媚,行人匆匆。"
|
||||
entity_prompts = ["一个漂亮的女孩", "标牌 '实体控制'", "短裤", "衬衫"]
|
||||
example(pipe, seeds, 4, global_prompt, entity_prompts)
|
||||
@@ -0,0 +1,35 @@
|
||||
from PIL import Image
|
||||
import torch
|
||||
from modelscope import dataset_snapshot_download, snapshot_download
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from diffsynth.controlnets.processors import Annotator
|
||||
|
||||
allow_file_pattern = ["sk_model.pth", "sk_model2.pth", "dpt_hybrid-midas-501f0c75.pt", "ControlNetHED.pth", "body_pose_model.pth", "hand_pose_model.pth", "facenet.pth", "scannet.pt"]
|
||||
snapshot_download("lllyasviel/Annotators", local_dir="models/Annotators", allow_file_pattern=allow_file_pattern)
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
snapshot_download("DiffSynth-Studio/Qwen-Image-In-Context-Control-Union", local_dir="models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union", allow_file_pattern="model.safetensors")
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union/model.safetensors")
|
||||
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/qwen-image-context-control/image.jpg")
|
||||
origin_image = Image.open("data/examples/qwen-image-context-control/image.jpg").resize((1024, 1024))
|
||||
annotator_ids = ['openpose', 'canny', 'depth', 'lineart', 'softedge', 'normal']
|
||||
for annotator_id in annotator_ids:
|
||||
annotator = Annotator(processor_id=annotator_id, device="cuda")
|
||||
control_image = annotator(origin_image)
|
||||
control_image.save(f"{annotator.processor_id}.png")
|
||||
|
||||
control_prompt = "Context_Control. "
|
||||
prompt = f"{control_prompt}一个穿着淡蓝色的漂亮女孩正在翩翩起舞,背景是梦幻的星空,光影交错,细节精致。"
|
||||
negative_prompt = "网格化,规则的网格,模糊, 低分辨率, 低质量, 变形, 畸形, 错误的解剖学, 变形的手, 变形的身体, 变形的脸, 变形的头发, 变形的眼睛, 变形的嘴巴"
|
||||
image = pipe(prompt, seed=1, negative_prompt=negative_prompt, context_image=control_image, height=1024, width=1024)
|
||||
image.save(f"image_{annotator.processor_id}.png")
|
||||
@@ -0,0 +1,28 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
from modelscope import snapshot_download
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
tokenizer_config=None,
|
||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
snapshot_download("DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix", local_dir="models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix", allow_file_pattern="model.safetensors")
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix/model.safetensors")
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, height=1024, width=768)
|
||||
image.save("image.jpg")
|
||||
|
||||
prompt = "将裙子变成粉色"
|
||||
image = image.resize((512, 384))
|
||||
image = pipe(prompt, edit_image=image, seed=1, num_inference_steps=40, height=1024, width=768, edit_rope_interpolation=True, edit_image_auto_resize=False)
|
||||
image.save(f"image2.jpg")
|
||||
@@ -0,0 +1,23 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
tokenizer_config=None,
|
||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, height=1024, width=1024)
|
||||
image.save("image1.jpg")
|
||||
|
||||
prompt = "将裙子改为粉色"
|
||||
image = pipe(prompt, edit_image=image, seed=1, num_inference_steps=40, height=1024, width=1024)
|
||||
image.save(f"image2.jpg")
|
||||
@@ -0,0 +1,108 @@
|
||||
import torch
|
||||
import random
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from modelscope import dataset_snapshot_download, snapshot_download
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
|
||||
|
||||
def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):
|
||||
# Create a blank image for overlays
|
||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||
|
||||
colors = [
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
]
|
||||
# Generate random colors for each mask
|
||||
if use_random_colors:
|
||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
||||
|
||||
# Font settings
|
||||
try:
|
||||
font = ImageFont.truetype("wqy-zenhei.ttc", font_size) # Adjust as needed
|
||||
except IOError:
|
||||
font = ImageFont.load_default(font_size)
|
||||
|
||||
# Overlay each mask onto the overlay image
|
||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
||||
# Convert mask to RGBA mode
|
||||
mask_rgba = mask.convert('RGBA')
|
||||
mask_data = mask_rgba.getdata()
|
||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
||||
mask_rgba.putdata(new_data)
|
||||
|
||||
# Draw the mask prompt text on the mask
|
||||
draw = ImageDraw.Draw(mask_rgba)
|
||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
||||
|
||||
# Alpha composite the overlay with this mask
|
||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
||||
|
||||
# Composite the overlay onto the original image
|
||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
||||
|
||||
# Save or display the resulting image
|
||||
result.save(output_path)
|
||||
|
||||
return result
|
||||
|
||||
def example(pipe, seeds, example_id, global_prompt, entity_prompts):
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/qwen-image/example_{example_id}/*.png")
|
||||
masks = [Image.open(f"./data/examples/eligen/qwen-image/example_{example_id}/{i}.png").convert('RGB').resize((1024, 1024)) for i in range(len(entity_prompts))]
|
||||
negative_prompt = "网格化,规则的网格,模糊, 低分辨率, 低质量, 变形, 畸形, 错误的解剖学, 变形的手, 变形的身体, 变形的脸, 变形的头发, 变形的眼睛, 变形的嘴巴"
|
||||
for seed in seeds:
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt=global_prompt,
|
||||
cfg_scale=4.0,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=40,
|
||||
seed=seed,
|
||||
height=1024,
|
||||
width=1024,
|
||||
eligen_entity_prompts=entity_prompts,
|
||||
eligen_entity_masks=masks,
|
||||
)
|
||||
image.save(f"eligen_example_{example_id}_{seed}.png")
|
||||
visualize_masks(image, masks, entity_prompts, f"eligen_example_{example_id}_mask_{seed}.png")
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
snapshot_download("DiffSynth-Studio/Qwen-Image-EliGen-V2", local_dir="models/DiffSynth-Studio/Qwen-Image-EliGen-V2", allow_file_pattern="model.safetensors")
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-EliGen-V2/model.safetensors")
|
||||
|
||||
seeds = [0]
|
||||
|
||||
global_prompt = "写实摄影风格. A beautiful asia woman wearing white dress, she is holding a mirror with her right arm, with a beach background."
|
||||
entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"]
|
||||
example(pipe, seeds, 7, global_prompt, entity_prompts)
|
||||
|
||||
global_prompt = "写实摄影风格, 细节丰富。街头一位漂亮的女孩,穿着衬衫和短裤,手持写有“实体控制”的标牌,背景是繁忙的城市街道,阳光明媚,行人匆匆。"
|
||||
entity_prompts = ["一个漂亮的女孩", "标牌 '实体控制'", "短裤", "衬衫"]
|
||||
example(pipe, seeds, 4, global_prompt, entity_prompts)
|
||||
@@ -0,0 +1,36 @@
|
||||
from PIL import Image
|
||||
import torch
|
||||
from modelscope import dataset_snapshot_download, snapshot_download
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from diffsynth.controlnets.processors import Annotator
|
||||
|
||||
allow_file_pattern = ["sk_model.pth", "sk_model2.pth", "dpt_hybrid-midas-501f0c75.pt", "ControlNetHED.pth", "body_pose_model.pth", "hand_pose_model.pth", "facenet.pth", "scannet.pt"]
|
||||
snapshot_download("lllyasviel/Annotators", local_dir="models/Annotators", allow_file_pattern=allow_file_pattern)
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
snapshot_download("DiffSynth-Studio/Qwen-Image-In-Context-Control-Union", local_dir="models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union", allow_file_pattern="model.safetensors")
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union/model.safetensors")
|
||||
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/qwen-image-context-control/image.jpg")
|
||||
origin_image = Image.open("data/examples/qwen-image-context-control/image.jpg").resize((1024, 1024))
|
||||
annotator_ids = ['openpose', 'canny', 'depth', 'lineart', 'softedge', 'normal']
|
||||
for annotator_id in annotator_ids:
|
||||
annotator = Annotator(processor_id=annotator_id, device="cuda")
|
||||
control_image = annotator(origin_image)
|
||||
control_image.save(f"{annotator.processor_id}.png")
|
||||
|
||||
control_prompt = "Context_Control. "
|
||||
prompt = f"{control_prompt}一个穿着淡蓝色的漂亮女孩正在翩翩起舞,背景是梦幻的星空,光影交错,细节精致。"
|
||||
negative_prompt = "网格化,规则的网格,模糊, 低分辨率, 低质量, 变形, 畸形, 错误的解剖学, 变形的手, 变形的身体, 变形的脸, 变形的头发, 变形的眼睛, 变形的嘴巴"
|
||||
image = pipe(prompt, seed=1, negative_prompt=negative_prompt, context_image=control_image, height=1024, width=1024)
|
||||
image.save(f"image_{annotator.processor_id}.png")
|
||||
15
examples/qwen_image/model_training/full/Qwen-Image-Edit.sh
Normal file
15
examples/qwen_image/model_training/full/Qwen-Image-Edit.sh
Normal file
@@ -0,0 +1,15 @@
|
||||
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_edit.csv \
|
||||
--data_file_keys "image,edit_image" \
|
||||
--extra_inputs "edit_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image-Edit:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-Edit_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
25
examples/qwen_image/model_training/lora/Qwen-Image-DPO.sh
Normal file
25
examples/qwen_image/model_training/lora/Qwen-Image-DPO.sh
Normal file
@@ -0,0 +1,25 @@
|
||||
# dataset format:
|
||||
# {
|
||||
# "image": "path/to/win_image.png", # win image
|
||||
# "lose_image": "path/to/lose_image.png", # lose image
|
||||
# "prompt": "a photo of ...",
|
||||
# }
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/dpo.jsonl \
|
||||
--data_file_keys "image,lose_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image_DPO_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8 \
|
||||
--task dpo \
|
||||
--beta_dpo 2500 \
|
||||
--find_unused_parameters
|
||||
@@ -0,0 +1,24 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_distill_qwen_image.csv \
|
||||
--data_file_keys "image" \
|
||||
--extra_inputs "seed,rand_device,num_inference_steps,cfg_scale" \
|
||||
--height 1328 \
|
||||
--width 1328 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-Distill-LoRA_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8 \
|
||||
--find_unused_parameters \
|
||||
--task direct_distill
|
||||
|
||||
# This is an experimental training feature designed to directly distill the model, enabling generation results with fewer steps to approximate those achieved with more steps.
|
||||
# The model (https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) is trained using this script.
|
||||
# The sample dataset is provided solely to demonstrate the dataset format. For actual usage, please construct a larger dataset using the base model.
|
||||
18
examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh
Normal file
18
examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh
Normal file
@@ -0,0 +1,18 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_edit.csv \
|
||||
--data_file_keys "image,edit_image" \
|
||||
--extra_inputs "edit_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image-Edit:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-Edit_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8 \
|
||||
--find_unused_parameters
|
||||
@@ -0,0 +1,20 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path "data/example_image_dataset" \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_qwenimage_context.csv \
|
||||
--data_file_keys "image,context_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-In-Context-Control-Union_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 64 \
|
||||
--lora_checkpoint "models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union/model.safetensors" \
|
||||
--extra_inputs "context_image" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
|
||||
# if you want to train from scratch, you can remove the --lora_checkpoint argument
|
||||
@@ -0,0 +1,26 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
--output_path "./models/train/Qwen-Image_lora_cache" \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8 \
|
||||
--task data_process
|
||||
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path models/train/Qwen-Image_lora_cache \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8 \
|
||||
--find_unused_parameters \
|
||||
--enable_fp8_training
|
||||
@@ -1,64 +1,47 @@
|
||||
import torch, os, json
|
||||
from diffsynth import load_state_dict
|
||||
import torch, os
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from diffsynth.pipelines.flux_image_new import ControlNetInput
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task
|
||||
from diffsynth.trainers.unified_dataset import UnifiedDataset
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
|
||||
class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
def __init__(
|
||||
self,
|
||||
model_paths=None, model_id_with_origin_paths=None,
|
||||
tokenizer_path=None,
|
||||
tokenizer_path=None, processor_path=None,
|
||||
trainable_models=None,
|
||||
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||
use_gradient_checkpointing=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
extra_inputs=None,
|
||||
enable_fp8_training=False,
|
||||
task="sft",
|
||||
beta_dpo=1000.,
|
||||
):
|
||||
super().__init__()
|
||||
# Load models
|
||||
model_configs = []
|
||||
if model_paths is not None:
|
||||
model_paths = json.loads(model_paths)
|
||||
model_configs += [ModelConfig(path=path) for path in model_paths]
|
||||
if model_id_with_origin_paths is not None:
|
||||
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
||||
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths]
|
||||
if tokenizer_path is not None:
|
||||
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=ModelConfig(tokenizer_path))
|
||||
else:
|
||||
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
|
||||
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=enable_fp8_training)
|
||||
tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
|
||||
processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path)
|
||||
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config)
|
||||
|
||||
# Reset training scheduler (do it in each training step)
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
# Training mode
|
||||
self.switch_pipe_to_training_mode(
|
||||
self.pipe, trainable_models,
|
||||
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
|
||||
enable_fp8_training=enable_fp8_training,
|
||||
)
|
||||
|
||||
# Freeze untrainable models
|
||||
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
|
||||
|
||||
# Add LoRA to the base models
|
||||
if lora_base_model is not None:
|
||||
model = self.add_lora_to_model(
|
||||
getattr(self.pipe, lora_base_model),
|
||||
target_modules=lora_target_modules.split(","),
|
||||
lora_rank=lora_rank
|
||||
)
|
||||
if lora_checkpoint is not None:
|
||||
state_dict = load_state_dict(lora_checkpoint)
|
||||
state_dict = self.mapping_lora_state_dict(state_dict)
|
||||
load_result = model.load_state_dict(state_dict, strict=False)
|
||||
if len(load_result[1]) > 0:
|
||||
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
||||
setattr(self.pipe, lora_base_model, model)
|
||||
|
||||
# Store other configs
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||
self.task = task
|
||||
self.lora_base_model = lora_base_model
|
||||
self.beta_dpo = beta_dpo
|
||||
|
||||
|
||||
def forward_preprocess(self, data):
|
||||
# CFG-sensitive parameters
|
||||
inputs_posi = {"prompt": data["prompt"]}
|
||||
@@ -77,6 +60,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
"rand_device": self.pipe.device,
|
||||
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||
"edit_image_auto_resize": True,
|
||||
}
|
||||
|
||||
# Extra inputs
|
||||
@@ -97,12 +81,61 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
for unit in self.pipe.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
||||
return {**inputs_shared, **inputs_posi}
|
||||
|
||||
|
||||
def forward(self, data, inputs=None):
|
||||
if inputs is None: inputs = self.forward_preprocess(data)
|
||||
|
||||
def forward_dpo(self, data, accelerator=None):
|
||||
# Loss DPO: -logσ(−β(diff_policy − diff_ref))
|
||||
# Prepare inputs
|
||||
win_data = {key: data[key] for key in ["prompt", "image"]}
|
||||
lose_data = {"prompt": None, "image": data["lose_image"]}
|
||||
inputs_win = self.forward_preprocess(win_data)
|
||||
inputs_lose = self.forward_preprocess(lose_data)
|
||||
inputs_lose.update({key: inputs_win[key] for key in ["prompt", "prompt_emb", "prompt_emb_mask"]})
|
||||
inputs_win.pop('noise')
|
||||
inputs_lose.pop('noise')
|
||||
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||
loss = self.pipe.training_loss(**models, **inputs)
|
||||
# sample timestep and noise
|
||||
timestep = self.pipe.sample_timestep()
|
||||
noise = torch.rand_like(inputs_win["latents"])
|
||||
# compute diff_policy = loss_win - loss_lose
|
||||
loss_win = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_win)
|
||||
loss_lose = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_lose)
|
||||
diff_policy = loss_win - loss_lose
|
||||
# compute diff_ref
|
||||
if self.lora_base_model is not None:
|
||||
self.disable_all_lora_layers(accelerator.unwrap_model(self).pipe.dit)
|
||||
# load the original model weights
|
||||
with torch.no_grad():
|
||||
loss_win_ref = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_win)
|
||||
loss_lose_ref = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_lose)
|
||||
diff_ref = loss_win_ref - loss_lose_ref
|
||||
self.enable_all_lora_layers(accelerator.unwrap_model(self).pipe.dit)
|
||||
else:
|
||||
# TODO: may support full model training
|
||||
raise NotImplementedError("DPO with full model training is not supported yet.")
|
||||
# compute loss
|
||||
loss = -1. * torch.nn.functional.logsigmoid(self.beta_dpo * (diff_ref - diff_policy)).mean()
|
||||
return loss
|
||||
|
||||
def forward(self, data, inputs=None, return_inputs=False, accelerator=None, **kwargs):
|
||||
if self.task == "dpo":
|
||||
return self.forward_dpo(data, accelerator=accelerator)
|
||||
# Inputs
|
||||
if inputs is None:
|
||||
inputs = self.forward_preprocess(data)
|
||||
else:
|
||||
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||
if return_inputs: return inputs
|
||||
|
||||
# Loss
|
||||
if self.task == "sft":
|
||||
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||
loss = self.pipe.training_loss(**models, **inputs)
|
||||
elif self.task == "data_process":
|
||||
loss = inputs
|
||||
elif self.task == "direct_distill":
|
||||
loss = self.pipe.direct_distill_loss(**inputs)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported task: {self.task}.")
|
||||
return loss
|
||||
|
||||
|
||||
@@ -110,11 +143,25 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
if __name__ == "__main__":
|
||||
parser = qwen_image_parser()
|
||||
args = parser.parse_args()
|
||||
dataset = ImageDataset(args=args)
|
||||
dataset = UnifiedDataset(
|
||||
base_path=args.dataset_base_path,
|
||||
metadata_path=args.dataset_metadata_path,
|
||||
repeat=args.dataset_repeat,
|
||||
data_file_keys=args.data_file_keys.split(","),
|
||||
main_data_operator=UnifiedDataset.default_image_operator(
|
||||
base_path=args.dataset_base_path,
|
||||
max_pixels=args.max_pixels,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
height_division_factor=16,
|
||||
width_division_factor=16,
|
||||
)
|
||||
)
|
||||
model = QwenImageTrainingModule(
|
||||
model_paths=args.model_paths,
|
||||
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||
tokenizer_path=args.tokenizer_path,
|
||||
processor_path=args.processor_path,
|
||||
trainable_models=args.trainable_models,
|
||||
lora_base_model=args.lora_base_model,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
@@ -123,15 +170,15 @@ if __name__ == "__main__":
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
extra_inputs=args.extra_inputs,
|
||||
enable_fp8_training=args.enable_fp8_training,
|
||||
task=args.task,
|
||||
beta_dpo=args.beta_dpo,
|
||||
)
|
||||
model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt)
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||
launch_training_task(
|
||||
dataset, model, model_logger, optimizer, scheduler,
|
||||
num_epochs=args.num_epochs,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
save_steps=args.save_steps,
|
||||
find_unused_parameters=args.find_unused_parameters,
|
||||
num_workers=args.dataset_num_workers,
|
||||
)
|
||||
launcher_map = {
|
||||
"sft": launch_training_task,
|
||||
"data_process": launch_data_process_task,
|
||||
"direct_distill": launch_training_task,
|
||||
"dpo": launch_training_task,
|
||||
}
|
||||
launcher_map[args.task](dataset, model, model_logger, args=args)
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from diffsynth import load_state_dict
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=None,
|
||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
||||
)
|
||||
state_dict = load_state_dict("models/train/Qwen-Image-Edit_full/epoch-1.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
|
||||
prompt = "将裙子改为粉色"
|
||||
image = Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024))
|
||||
image = pipe(prompt, edit_image=image, seed=0, num_inference_steps=40, height=1024, width=1024)
|
||||
image.save(f"image.jpg")
|
||||
@@ -0,0 +1,19 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Qwen-Image_DPO_lora/epoch-4.safetensors")
|
||||
prompt = "黑板上写着“群起效尤,心灵手巧”,字的颜色分别是 “群”: 橙色、“起”: 黑色、“效”: 蓝色、“尤”: 绿色、“心”: 紫色、“灵”: 粉色、“手”: 红色、“巧”: 白色"
|
||||
for seed in range(0, 5):
|
||||
image = pipe(prompt, seed=seed)
|
||||
image.save(f"image_dpo_{seed}.jpg")
|
||||
@@ -0,0 +1,23 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Distill-LoRA_lora/epoch-4.safetensors")
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(
|
||||
prompt,
|
||||
seed=0,
|
||||
num_inference_steps=4,
|
||||
cfg_scale=1,
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,21 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=None,
|
||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Edit_lora/epoch-4.safetensors")
|
||||
|
||||
prompt = "将裙子改为粉色"
|
||||
image = Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024))
|
||||
image = pipe(prompt, edit_image=image, seed=0, num_inference_steps=40, height=1024, width=1024)
|
||||
image.save(f"image.jpg")
|
||||
@@ -0,0 +1,19 @@
|
||||
from PIL import Image
|
||||
import torch
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-In-Context-Control-Union_lora/epoch-4.safetensors")
|
||||
image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1024, 1024))
|
||||
prompt = "Context_Control. a dog"
|
||||
image = pipe(prompt=prompt, seed=0, context_image=image, height=1024, width=1024)
|
||||
image.save("image_context.jpg")
|
||||
@@ -48,9 +48,13 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|
||||
| Model ID | Extra Parameters | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.2-Fun-A14B-InP.py)|[code](./model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](./model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](./model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](./model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](./model_inference/Wan2.2-Fun-A14B-Control.py)|[code](./model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](./model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](./model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](./model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./model_inference/Wan2.1-T2V-1.3B.py)|[code](./model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](./model_inference/Wan2.1-T2V-14B.py)|[code](./model_training/full/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./model_training/lora/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](./model_inference/Wan2.1-I2V-14B-480P.py)|[code](./model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||
|
||||
@@ -48,9 +48,13 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|
||||
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.2-Fun-A14B-InP.py)|[code](./model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](./model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](./model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](./model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](./model_inference/Wan2.2-Fun-A14B-Control.py)|[code](./model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](./model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](./model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](./model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./model_inference/Wan2.1-T2V-1.3B.py)|[code](./model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](./model_inference/Wan2.1-T2V-14B.py)|[code](./model_training/full/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./model_training/lora/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](./model_inference/Wan2.1-I2V-14B-480P.py)|[code](./model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
import torch
|
||||
from diffsynth import save_video,VideoData
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/input_image.jpg"
|
||||
)
|
||||
input_image = Image.open("data/examples/wan/input_image.jpg")
|
||||
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
input_image=input_image,
|
||||
camera_control_direction="Left", camera_control_speed=0.01,
|
||||
)
|
||||
save_video(video, "video_left.mp4", fps=15, quality=5)
|
||||
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
input_image=input_image,
|
||||
camera_control_direction="Up", camera_control_speed=0.01,
|
||||
)
|
||||
save_video(video, "video_up.mp4", fps=15, quality=5)
|
||||
35
examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py
Normal file
35
examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
from diffsynth import save_video,VideoData
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"]
|
||||
)
|
||||
|
||||
# Control video
|
||||
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
|
||||
reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832))
|
||||
video = pipe(
|
||||
prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
control_video=control_video, reference_image=reference_image,
|
||||
height=832, width=576, num_frames=49,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video.mp4", fps=15, quality=5)
|
||||
35
examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py
Normal file
35
examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
from diffsynth import save_video
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/input_image.jpg"
|
||||
)
|
||||
image = Image.open("data/examples/wan/input_image.jpg")
|
||||
|
||||
# First and last frame to video
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=image,
|
||||
seed=0, tiled=True,
|
||||
# You can input `end_image=xxx` to control the last frame of the video.
|
||||
# The model will automatically generate the dynamic content between `input_image` and `end_image`.
|
||||
)
|
||||
save_video(video, "video.mp4", fps=15, quality=5)
|
||||
69
examples/wanvideo/model_inference/Wan2.2-S2V-14B.py
Normal file
69
examples/wanvideo/model_inference/Wan2.2-S2V-14B.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
import librosa
|
||||
from diffsynth import VideoData, save_video_with_audio
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
audio_processor_config=ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/"),
|
||||
)
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_video_dataset",
|
||||
local_dir="./data/example_video_dataset",
|
||||
allow_file_pattern=f"wans2v/*"
|
||||
)
|
||||
|
||||
num_frames = 81 # 4n+1
|
||||
height = 448
|
||||
width = 832
|
||||
|
||||
prompt = "a person is singing"
|
||||
negative_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
|
||||
input_image = Image.open("data/example_video_dataset/wans2v/pose.png").convert("RGB").resize((width, height))
|
||||
# s2v audio input, recommend 16kHz sampling rate
|
||||
audio_path = 'data/example_video_dataset/wans2v/sing.MP3'
|
||||
input_audio, sample_rate = librosa.load(audio_path, sr=16000)
|
||||
|
||||
# Speech-to-video
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
input_image=input_image,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=0,
|
||||
num_frames=num_frames,
|
||||
height=height,
|
||||
width=width,
|
||||
audio_sample_rate=sample_rate,
|
||||
input_audio=input_audio,
|
||||
num_inference_steps=40,
|
||||
)
|
||||
save_video_with_audio(video[1:], "video_with_audio.mp4", audio_path, fps=16, quality=5)
|
||||
|
||||
# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps.
|
||||
pose_video_path = 'data/example_video_dataset/wans2v/pose.mp4'
|
||||
pose_video = VideoData(pose_video_path, height=height, width=width)
|
||||
|
||||
# Speech-to-video with pose
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
input_image=input_image,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=0,
|
||||
num_frames=num_frames,
|
||||
height=height,
|
||||
width=width,
|
||||
audio_sample_rate=sample_rate,
|
||||
input_audio=input_audio,
|
||||
s2v_pose_video=pose_video,
|
||||
num_inference_steps=40,
|
||||
)
|
||||
save_video_with_audio(video[1:], "video_pose_with_audio.mp4", audio_path, fps=16, quality=5)
|
||||
116
examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py
Normal file
116
examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
import librosa
|
||||
from diffsynth import VideoData, save_video_with_audio
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig, WanVideoUnit_S2V
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
def speech_to_video(
|
||||
prompt,
|
||||
input_image,
|
||||
audio_path,
|
||||
negative_prompt="",
|
||||
num_clip=None,
|
||||
audio_sample_rate=16000,
|
||||
pose_video_path=None,
|
||||
infer_frames=80,
|
||||
height=448,
|
||||
width=832,
|
||||
num_inference_steps=40,
|
||||
fps=16, # recommend fixing fps as 16 for s2v
|
||||
motion_frames=73, # hyperparameter of wan2.2-s2v
|
||||
save_path=None,
|
||||
):
|
||||
# s2v audio input, recommend 16kHz sampling rate
|
||||
input_audio, sample_rate = librosa.load(audio_path, sr=audio_sample_rate)
|
||||
# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps.
|
||||
pose_video = VideoData(pose_video_path, height=height, width=width) if pose_video_path is not None else None
|
||||
|
||||
audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose(
|
||||
pipe=pipe,
|
||||
input_audio=input_audio,
|
||||
audio_sample_rate=sample_rate,
|
||||
s2v_pose_video=pose_video,
|
||||
num_frames=infer_frames + 1,
|
||||
height=height,
|
||||
width=width,
|
||||
fps=fps,
|
||||
)
|
||||
num_repeat = min(num_repeat, num_clip) if num_clip is not None else num_repeat
|
||||
print(f"Generating {num_repeat} video clips...")
|
||||
motion_videos = []
|
||||
video = []
|
||||
for r in range(num_repeat):
|
||||
s2v_pose_latents = pose_latents[r] if pose_latents is not None else None
|
||||
current_clip = pipe(
|
||||
prompt=prompt,
|
||||
input_image=input_image,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=0,
|
||||
num_frames=infer_frames + 1,
|
||||
height=height,
|
||||
width=width,
|
||||
audio_embeds=audio_embeds[r],
|
||||
s2v_pose_latents=s2v_pose_latents,
|
||||
motion_video=motion_videos,
|
||||
num_inference_steps=num_inference_steps,
|
||||
)
|
||||
current_clip = current_clip[-infer_frames:]
|
||||
if r == 0:
|
||||
current_clip = current_clip[3:]
|
||||
overlap_frames_num = min(motion_frames, len(current_clip))
|
||||
motion_videos = motion_videos[overlap_frames_num:] + current_clip[-overlap_frames_num:]
|
||||
video.extend(current_clip)
|
||||
save_video_with_audio(video, save_path, audio_path, fps=16, quality=5)
|
||||
print(f"processed the {r+1}th clip of total {num_repeat} clips.")
|
||||
return video
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
audio_processor_config=ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/"),
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_video_dataset",
|
||||
local_dir="./data/example_video_dataset",
|
||||
allow_file_pattern=f"wans2v/*",
|
||||
)
|
||||
|
||||
infer_frames = 80 # 4n
|
||||
height = 448
|
||||
width = 832
|
||||
|
||||
prompt = "a person is singing"
|
||||
negative_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
|
||||
input_image = Image.open("data/example_video_dataset/wans2v/pose.png").convert("RGB").resize((width, height))
|
||||
|
||||
video_with_audio = speech_to_video(
|
||||
prompt=prompt,
|
||||
input_image=input_image,
|
||||
audio_path='data/example_video_dataset/wans2v/sing.MP3',
|
||||
negative_prompt=negative_prompt,
|
||||
pose_video_path='data/example_video_dataset/wans2v/pose.mp4',
|
||||
save_path="video_with_audio_full.mp4",
|
||||
infer_frames=infer_frames,
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
# num_clip means generating only the first n clips with n * infer_frames frames.
|
||||
video_with_audio_pose = speech_to_video(
|
||||
prompt=prompt,
|
||||
input_image=input_image,
|
||||
audio_path='data/example_video_dataset/wans2v/sing.MP3',
|
||||
negative_prompt=negative_prompt,
|
||||
pose_video_path='data/example_video_dataset/wans2v/pose.mp4',
|
||||
save_path="video_with_audio_pose_clip_2.mp4",
|
||||
num_clip=2
|
||||
)
|
||||
@@ -0,0 +1,35 @@
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata_camera_control.csv \
|
||||
--data_file_keys "video,control_video,reference_image" \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "PAI/Wan2.2-Fun-A14B-Control-Camera:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-Control-Camera:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-Fun-A14B-Control-Camera_high_niose_full" \
|
||||
--trainable_models "dit" \
|
||||
--extra_inputs "input_image,camera_control_direction,camera_control_speed" \
|
||||
--max_timestep_boundary 0.358 \
|
||||
--min_timestep_boundary 0
|
||||
# boundary corresponds to timesteps [900, 1000]
|
||||
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata_camera_control.csv \
|
||||
--data_file_keys "video,control_video,reference_image" \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "PAI/Wan2.2-Fun-A14B-Control-Camera:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-Control-Camera:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-Fun-A14B-Control-Camera_low_noise_full" \
|
||||
--trainable_models "dit" \
|
||||
--extra_inputs "input_image,camera_control_direction,camera_control_speed" \
|
||||
--max_timestep_boundary 1 \
|
||||
--min_timestep_boundary 0.358
|
||||
# boundary corresponds to timesteps [0, 900]
|
||||
@@ -0,0 +1,35 @@
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata_reference_control.csv \
|
||||
--data_file_keys "video,control_video,reference_image" \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "PAI/Wan2.2-Fun-A14B-Control:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-Control:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-Fun-A14B-Control_high_niose_full" \
|
||||
--trainable_models "dit" \
|
||||
--extra_inputs "control_video,reference_image" \
|
||||
--max_timestep_boundary 0.358 \
|
||||
--min_timestep_boundary 0
|
||||
# boundary corresponds to timesteps [900, 1000]
|
||||
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata_reference_control.csv \
|
||||
--data_file_keys "video,control_video,reference_image" \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "PAI/Wan2.2-Fun-A14B-Control:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-Control:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-Fun-A14B-Control_low_noise_full" \
|
||||
--trainable_models "dit" \
|
||||
--extra_inputs "control_video,reference_image" \
|
||||
--max_timestep_boundary 1 \
|
||||
--min_timestep_boundary 0.358
|
||||
# boundary corresponds to timesteps [0, 900]
|
||||
33
examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh
Normal file
33
examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh
Normal file
@@ -0,0 +1,33 @@
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "PAI/Wan2.2-Fun-A14B-InP:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-InP:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-Fun-A14B-InP_high_niose_full" \
|
||||
--trainable_models "dit" \
|
||||
--extra_inputs "input_image,end_image" \
|
||||
--max_timestep_boundary 0.358 \
|
||||
--min_timestep_boundary 0
|
||||
# boundary corresponds to timesteps [900, 1000]
|
||||
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "PAI/Wan2.2-Fun-A14B-InP:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-InP:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-Fun-A14B-InP_low_noise_full" \
|
||||
--trainable_models "dit" \
|
||||
--extra_inputs "input_image,end_image" \
|
||||
--max_timestep_boundary 1 \
|
||||
--min_timestep_boundary 0.358
|
||||
# boundary corresponds to timesteps [0, 900]
|
||||
@@ -0,0 +1,39 @@
|
||||
accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata_camera_control.csv \
|
||||
--data_file_keys "video,control_video,reference_image" \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "PAI/Wan2.2-Fun-A14B-Control-Camera:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-Control-Camera:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-Fun-A14B-Control-Camera_high_niose_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "input_image,camera_control_direction,camera_control_speed" \
|
||||
--max_timestep_boundary 0.358 \
|
||||
--min_timestep_boundary 0
|
||||
# boundary corresponds to timesteps [900, 1000]
|
||||
|
||||
accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata_camera_control.csv \
|
||||
--data_file_keys "video,control_video,reference_image" \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "PAI/Wan2.2-Fun-A14B-Control-Camera:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-Control-Camera:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-Fun-A14B-Control-Camera_low_noise_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "input_image,camera_control_direction,camera_control_speed" \
|
||||
--max_timestep_boundary 1 \
|
||||
--min_timestep_boundary 0.358
|
||||
# boundary corresponds to timesteps [0, 900]
|
||||
@@ -0,0 +1,39 @@
|
||||
accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata_reference_control.csv \
|
||||
--data_file_keys "video,control_video,reference_image" \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "PAI/Wan2.2-Fun-A14B-Control:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-Control:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-Fun-A14B-Control_high_niose_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "control_video,reference_image" \
|
||||
--max_timestep_boundary 0.358 \
|
||||
--min_timestep_boundary 0
|
||||
# boundary corresponds to timesteps [900, 1000]
|
||||
|
||||
accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata_reference_control.csv \
|
||||
--data_file_keys "video,control_video,reference_image" \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "PAI/Wan2.2-Fun-A14B-Control:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-Control:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-Fun-A14B-Control_low_noise_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "control_video,reference_image" \
|
||||
--max_timestep_boundary 1 \
|
||||
--min_timestep_boundary 0.358
|
||||
# boundary corresponds to timesteps [0, 900]
|
||||
37
examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh
Normal file
37
examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh
Normal file
@@ -0,0 +1,37 @@
|
||||
accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "PAI/Wan2.2-Fun-A14B-InP:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-InP:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-Fun-A14B-InP_high_niose_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "input_image,end_image" \
|
||||
--max_timestep_boundary 0.358 \
|
||||
--min_timestep_boundary 0
|
||||
# boundary corresponds to timesteps [900, 1000]
|
||||
|
||||
accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "PAI/Wan2.2-Fun-A14B-InP:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-InP:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-Fun-A14B-InP_low_noise_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "input_image,end_image" \
|
||||
--max_timestep_boundary 1 \
|
||||
--min_timestep_boundary 0.358
|
||||
# boundary corresponds to timesteps [0, 900]
|
||||
@@ -1,7 +1,8 @@
|
||||
import torch, os, json
|
||||
from diffsynth import load_state_dict
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, ModelLogger, launch_training_task, wan_parser
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, wan_parser
|
||||
from diffsynth.trainers.unified_dataset import UnifiedDataset
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
@@ -20,36 +21,16 @@ class WanTrainingModule(DiffusionTrainingModule):
|
||||
):
|
||||
super().__init__()
|
||||
# Load models
|
||||
model_configs = []
|
||||
if model_paths is not None:
|
||||
model_paths = json.loads(model_paths)
|
||||
model_configs += [ModelConfig(path=path) for path in model_paths]
|
||||
if model_id_with_origin_paths is not None:
|
||||
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
||||
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths]
|
||||
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=False)
|
||||
self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
|
||||
|
||||
# Reset training scheduler
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
# Training mode
|
||||
self.switch_pipe_to_training_mode(
|
||||
self.pipe, trainable_models,
|
||||
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
|
||||
enable_fp8_training=False,
|
||||
)
|
||||
|
||||
# Freeze untrainable models
|
||||
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
|
||||
|
||||
# Add LoRA to the base models
|
||||
if lora_base_model is not None:
|
||||
model = self.add_lora_to_model(
|
||||
getattr(self.pipe, lora_base_model),
|
||||
target_modules=lora_target_modules.split(","),
|
||||
lora_rank=lora_rank
|
||||
)
|
||||
if lora_checkpoint is not None:
|
||||
state_dict = load_state_dict(lora_checkpoint)
|
||||
state_dict = self.mapping_lora_state_dict(state_dict)
|
||||
load_result = model.load_state_dict(state_dict, strict=False)
|
||||
if len(load_result[1]) > 0:
|
||||
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
||||
setattr(self.pipe, lora_base_model, model)
|
||||
|
||||
# Store other configs
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
@@ -101,7 +82,7 @@ class WanTrainingModule(DiffusionTrainingModule):
|
||||
return {**inputs_shared, **inputs_posi}
|
||||
|
||||
|
||||
def forward(self, data, inputs=None):
|
||||
def forward(self, data, inputs=None, **kwargs):
|
||||
if inputs is None: inputs = self.forward_preprocess(data)
|
||||
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||
loss = self.pipe.training_loss(**models, **inputs)
|
||||
@@ -111,7 +92,23 @@ class WanTrainingModule(DiffusionTrainingModule):
|
||||
if __name__ == "__main__":
|
||||
parser = wan_parser()
|
||||
args = parser.parse_args()
|
||||
dataset = VideoDataset(args=args)
|
||||
dataset = UnifiedDataset(
|
||||
base_path=args.dataset_base_path,
|
||||
metadata_path=args.dataset_metadata_path,
|
||||
repeat=args.dataset_repeat,
|
||||
data_file_keys=args.data_file_keys.split(","),
|
||||
main_data_operator=UnifiedDataset.default_video_operator(
|
||||
base_path=args.dataset_base_path,
|
||||
max_pixels=args.max_pixels,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
height_division_factor=16,
|
||||
width_division_factor=16,
|
||||
num_frames=args.num_frames,
|
||||
time_division_factor=4,
|
||||
time_division_remainder=1,
|
||||
),
|
||||
)
|
||||
model = WanTrainingModule(
|
||||
model_paths=args.model_paths,
|
||||
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||
@@ -129,13 +126,4 @@ if __name__ == "__main__":
|
||||
args.output_path,
|
||||
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt
|
||||
)
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||
launch_training_task(
|
||||
dataset, model, model_logger, optimizer, scheduler,
|
||||
num_epochs=args.num_epochs,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
save_steps=args.save_steps,
|
||||
find_unused_parameters=args.find_unused_parameters,
|
||||
num_workers=args.dataset_num_workers,
|
||||
)
|
||||
launch_training_task(dataset, model, model_logger, args=args)
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData, load_state_dict
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
state_dict = load_state_dict("models/train/Wan2.2-Fun-A14B-Control-Camera_high_noise_full/epoch-1.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
state_dict = load_state_dict("models/train/Wan2.2-Fun-A14B-Control-Camera_low_noise_full/epoch-1.safetensors")
|
||||
pipe.dit2.load_state_dict(state_dict)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)
|
||||
|
||||
# First and last frame to video
|
||||
video = pipe(
|
||||
prompt="from sunset to night, a small town, light, house, river",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=video[0],
|
||||
camera_control_direction="Left", camera_control_speed=0.0,
|
||||
seed=0, tiled=True
|
||||
)
|
||||
save_video(video, "video_Wan2.2-Fun-A14B-Control-Camera.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData, load_state_dict
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
state_dict = load_state_dict("models/train/Wan2.2-Fun-A14B-Control_high_noise_full/epoch-1.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
state_dict = load_state_dict("models/train/Wan2.2-Fun-A14B-Control_low_noise_full/epoch-1.safetensors")
|
||||
pipe.dit2.load_state_dict(state_dict)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832)
|
||||
video = [video[i] for i in range(81)]
|
||||
reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
|
||||
|
||||
# Control video
|
||||
video = pipe(
|
||||
prompt="from sunset to night, a small town, light, house, river",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
control_video=video, reference_image=reference_image,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video_Wan2.2-Fun-A14B-Control.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData, load_state_dict
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
state_dict = load_state_dict("models/train/Wan2.2-Fun-A14B-InP_high_noise_full/epoch-1.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
state_dict = load_state_dict("models/train/Wan2.2-Fun-A14B-InP_low_noise_full/epoch-1.safetensors")
|
||||
pipe.dit2.load_state_dict(state_dict)
|
||||
pipe.enable_vram_management()
|
||||
video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)
|
||||
|
||||
# First and last frame to video
|
||||
video = pipe(
|
||||
prompt="from sunset to night, a small town, light, house, river",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=video[0], end_image=video[80],
|
||||
seed=0, tiled=True
|
||||
)
|
||||
save_video(video, "video_Wan2.2-Fun-A14B-InP.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData, load_state_dict
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Wan2.2-Fun-A14B-Control-Camera_high_noise_lora/epoch-4.safetensors", alpha=1)
|
||||
pipe.load_lora(pipe.dit2, "models/train/Wan2.2-Fun-A14B-Control-Camera_low_noise_lora/epoch-4.safetensors", alpha=1)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)
|
||||
|
||||
# First and last frame to video
|
||||
video = pipe(
|
||||
prompt="from sunset to night, a small town, light, house, river",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=video[0],
|
||||
camera_control_direction="Left", camera_control_speed=0.0,
|
||||
seed=0, tiled=True
|
||||
)
|
||||
save_video(video, "video_Wan2.2-Fun-A14B-Control-Camera.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Wan2.2-Fun-A14B-Control_high_noise_lora/epoch-4.safetensors", alpha=1)
|
||||
pipe.load_lora(pipe.dit2, "models/train/Wan2.2-Fun-A14B-Control_low_noise_lora/epoch-4.safetensors", alpha=1)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832)
|
||||
video = [video[i] for i in range(81)]
|
||||
reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
|
||||
|
||||
# Control video
|
||||
video = pipe(
|
||||
prompt="from sunset to night, a small town, light, house, river",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
control_video=video, reference_image=reference_image,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video_Wan2.2-Fun-A14B-Control.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Wan2.2-Fun-A14B-InP_high_noise_lora/epoch-4.safetensors", alpha=1)
|
||||
pipe.load_lora(pipe.dit2, "models/train/Wan2.2-Fun-A14B-InP_low_noise_lora/epoch-4.safetensors", alpha=1)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)
|
||||
|
||||
# First and last frame to video
|
||||
video = pipe(
|
||||
prompt="from sunset to night, a small town, light, house, river",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=video[0], end_image=video[80],
|
||||
seed=0, tiled=True
|
||||
)
|
||||
save_video(video, "video_Wan2.2-Fun-A14B-InP.mp4", fps=15, quality=5)
|
||||
@@ -1,8 +1,6 @@
|
||||
torch>=2.0.0
|
||||
torchvision
|
||||
cupy-cuda12x
|
||||
transformers
|
||||
controlnet-aux==0.0.7
|
||||
imageio
|
||||
imageio[ffmpeg]
|
||||
safetensors
|
||||
@@ -14,3 +12,4 @@ ftfy
|
||||
pynvml
|
||||
pandas
|
||||
accelerate
|
||||
peft
|
||||
|
||||
2
setup.py
2
setup.py
@@ -14,7 +14,7 @@ else:
|
||||
|
||||
setup(
|
||||
name="diffsynth",
|
||||
version="1.1.7",
|
||||
version="1.1.8",
|
||||
description="Enjoy the magic of Diffusion models!",
|
||||
author="Artiprocher",
|
||||
packages=find_packages(),
|
||||
|
||||
@@ -1,119 +0,0 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, QwenImageUnit_PromptEmbedder, load_state_dict
|
||||
import torch, os
|
||||
from tqdm import tqdm
|
||||
from diffsynth.models.svd_unet import TemporalTimesteps
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
|
||||
class ValueEncoder(torch.nn.Module):
|
||||
def __init__(self, dim_in=256, dim_out=3584, value_emb_length=32):
|
||||
super().__init__()
|
||||
self.value_emb = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.positional_emb = torch.nn.Parameter(torch.randn(1, value_emb_length, dim_out))
|
||||
self.proj_value = torch.nn.Linear(dim_in, dim_out)
|
||||
self.proj_out = torch.nn.Linear(dim_out, dim_out)
|
||||
self.value_emb_length = value_emb_length
|
||||
|
||||
def forward(self, value):
|
||||
value = value * 1
|
||||
emb = self.value_emb(value).to(value.dtype)
|
||||
emb = self.proj_value(emb)
|
||||
emb = repeat(emb, "b d -> b s d", s=self.value_emb_length)
|
||||
emb = emb + self.positional_emb.to(dtype=emb.dtype, device=emb.device)
|
||||
emb = torch.nn.functional.silu(emb)
|
||||
emb = self.proj_out(emb)
|
||||
return emb
|
||||
|
||||
|
||||
class TextInterpolationModel(torch.nn.Module):
|
||||
def __init__(self, dim_in=256, dim_out=3584, value_emb_length=32, num_heads=32):
|
||||
super().__init__()
|
||||
self.to_q = ValueEncoder(dim_in=dim_in, dim_out=dim_out, value_emb_length=value_emb_length)
|
||||
self.xk_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
||||
self.yk_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
||||
self.xv_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
||||
self.yv_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
||||
self.to_k = torch.nn.Linear(dim_out, dim_out, bias=False)
|
||||
self.to_v = torch.nn.Linear(dim_out, dim_out, bias=False)
|
||||
self.to_out = torch.nn.Linear(dim_out, dim_out)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, value, x, y):
|
||||
q = self.to_q(value)
|
||||
k = self.to_k(torch.concat([x + self.xk_emb, y + self.yk_emb], dim=1))
|
||||
v = self.to_v(torch.concat([x + self.xv_emb, y + self.yv_emb], dim=1))
|
||||
q = rearrange(q, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
k = rearrange(k, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
v = rearrange(v, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
out = rearrange(out, 'b h s d -> b s (h d)')
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
unit = QwenImageUnit_PromptEmbedder()
|
||||
|
||||
dataset_prompt = [
|
||||
(
|
||||
"超级黑暗的画面,整体在黑暗中,暗无天日,暗淡无光,阴森黑暗,几乎全黑",
|
||||
"超级明亮的画面,爆闪,相机过曝,整个画面都是白色的眩光,几乎全是白色",
|
||||
),
|
||||
]
|
||||
dataset_tensors = []
|
||||
for prompt_x, prompt_y in tqdm(dataset_prompt):
|
||||
with torch.no_grad():
|
||||
x = unit.process(pipe, prompt_x)["prompt_emb"]
|
||||
y = unit.process(pipe, prompt_y)["prompt_emb"]
|
||||
dataset_tensors.append((x, y))
|
||||
|
||||
model = TextInterpolationModel().to(dtype=torch.bfloat16, device="cuda")
|
||||
model.load_state_dict(load_state_dict("models/interpolate.pth"))
|
||||
|
||||
def sample_tokens(emb, p):
|
||||
perm = torch.randperm(emb.shape[1])[:max(0, int(emb.shape[1]*p))]
|
||||
return emb[:, perm]
|
||||
|
||||
|
||||
def loss_fn(x, y):
|
||||
s, l = x.shape[1], y.shape[1]
|
||||
x = repeat(x, "b s d -> b s l d", l=l)
|
||||
y = repeat(y, "b l d -> b s l d", s=s)
|
||||
d = torch.square(x - y).mean(dim=-1)
|
||||
loss_x = d.min(dim=1).values.mean()
|
||||
loss_y = d.min(dim=2).values.mean()
|
||||
return loss_x + loss_y
|
||||
|
||||
|
||||
def get_target(x, y, p):
|
||||
x = sample_tokens(x, 1-p)
|
||||
y = sample_tokens(y, p)
|
||||
return torch.concat([x, y], dim=1)
|
||||
|
||||
name = "brightness"
|
||||
for i in range(6):
|
||||
v = i/5
|
||||
with torch.no_grad():
|
||||
data_id = 0
|
||||
x, y = dataset_tensors[data_id]
|
||||
x, y = x.to("cuda"), y.to("cuda")
|
||||
value = torch.tensor([v], dtype=torch.bfloat16, device="cuda")
|
||||
value_emb = model(value, x, y)
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40, extra_prompt_emb=value_emb)
|
||||
os.makedirs(f"data/qwen_image_value/{name}", exist_ok=True)
|
||||
image.save(f"data/qwen_image_value/{name}/image_{v}.jpg")
|
||||
@@ -1,121 +0,0 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, QwenImageUnit_PromptEmbedder
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from diffsynth.models.svd_unet import TemporalTimesteps
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
|
||||
class ValueEncoder(torch.nn.Module):
|
||||
def __init__(self, dim_in=256, dim_out=3584, value_emb_length=32):
|
||||
super().__init__()
|
||||
self.value_emb = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.positional_emb = torch.nn.Parameter(torch.randn(1, value_emb_length, dim_out))
|
||||
self.proj_value = torch.nn.Linear(dim_in, dim_out)
|
||||
self.proj_out = torch.nn.Linear(dim_out, dim_out)
|
||||
self.value_emb_length = value_emb_length
|
||||
|
||||
def forward(self, value):
|
||||
value = value * 1
|
||||
emb = self.value_emb(value).to(value.dtype)
|
||||
emb = self.proj_value(emb)
|
||||
emb = repeat(emb, "b d -> b s d", s=self.value_emb_length)
|
||||
emb = emb + self.positional_emb.to(dtype=emb.dtype, device=emb.device)
|
||||
emb = torch.nn.functional.silu(emb)
|
||||
emb = self.proj_out(emb)
|
||||
return emb
|
||||
|
||||
|
||||
class TextInterpolationModel(torch.nn.Module):
|
||||
def __init__(self, dim_in=256, dim_out=3584, value_emb_length=32, num_heads=32):
|
||||
super().__init__()
|
||||
self.to_q = ValueEncoder(dim_in=dim_in, dim_out=dim_out, value_emb_length=value_emb_length)
|
||||
self.xk_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
||||
self.yk_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
||||
self.xv_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
||||
self.yv_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
||||
self.to_k = torch.nn.Linear(dim_out, dim_out, bias=False)
|
||||
self.to_v = torch.nn.Linear(dim_out, dim_out, bias=False)
|
||||
self.to_out = torch.nn.Linear(dim_out, dim_out)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, value, x, y):
|
||||
q = self.to_q(value)
|
||||
k = self.to_k(torch.concat([x + self.xk_emb, y + self.yk_emb], dim=1))
|
||||
v = self.to_v(torch.concat([x + self.xv_emb, y + self.yv_emb], dim=1))
|
||||
q = rearrange(q, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
k = rearrange(k, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
v = rearrange(v, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
out = rearrange(out, 'b h s d -> b s (h d)')
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
|
||||
def sample_tokens(emb, p):
|
||||
perm = torch.randperm(emb.shape[1])[:max(0, int(emb.shape[1]*p))]
|
||||
return emb[:, perm]
|
||||
|
||||
|
||||
def loss_fn(x, y):
|
||||
s, l = x.shape[1], y.shape[1]
|
||||
x = repeat(x, "b s d -> b s l d", l=l)
|
||||
y = repeat(y, "b l d -> b s l d", s=s)
|
||||
d = torch.square(x - y).mean(dim=-1)
|
||||
loss_x = d.min(dim=1).values.mean()
|
||||
loss_y = d.min(dim=2).values.mean()
|
||||
return loss_x + loss_y
|
||||
|
||||
|
||||
def get_target(x, y, p):
|
||||
x = sample_tokens(x, 1-p)
|
||||
y = sample_tokens(y, p)
|
||||
return torch.concat([x, y], dim=1)
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
unit = QwenImageUnit_PromptEmbedder()
|
||||
|
||||
|
||||
dataset_prompt = [
|
||||
(
|
||||
"超级黑暗的画面,整体在黑暗中,暗无天日,暗淡无光,阴森黑暗,几乎全黑",
|
||||
"超级明亮的画面,爆闪,相机过曝,整个画面都是白色的眩光,几乎全是白色",
|
||||
),
|
||||
]
|
||||
|
||||
dataset_tensors = []
|
||||
for prompt_x, prompt_y in tqdm(dataset_prompt):
|
||||
with torch.no_grad():
|
||||
x = unit.process(pipe, prompt_x)["prompt_emb"].to(dtype=torch.float32, device="cpu")
|
||||
y = unit.process(pipe, prompt_y)["prompt_emb"].to(dtype=torch.float32, device="cpu")
|
||||
dataset_tensors.append((x, y))
|
||||
|
||||
model = TextInterpolationModel().to(dtype=torch.float32, device="cuda")
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
|
||||
|
||||
for step_id, step in enumerate(tqdm(range(100000))):
|
||||
optimizer.zero_grad()
|
||||
|
||||
data_id = torch.randint(0, len(dataset_tensors), size=(1,)).item()
|
||||
x, y = dataset_tensors[data_id]
|
||||
x, y = x.to("cuda"), y.to("cuda")
|
||||
|
||||
value = torch.rand((1,), dtype=torch.float32, device="cuda")
|
||||
out = model(value, x, y)
|
||||
loss = loss_fn(out, x) * (1 - value) + loss_fn(out, y) * value
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if (step_id + 1) % 1000 == 0:
|
||||
print(loss)
|
||||
|
||||
torch.save(model.state_dict(), f"models/interpolate_{step+1}.pth")
|
||||
Reference in New Issue
Block a user