mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
Compare commits
35 Commits
qwen-image
...
qwen-image
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4abc2c1cd1 | ||
|
|
f11a91e610 | ||
|
|
7ed09bb78d | ||
|
|
ac931856d5 | ||
|
|
2d09318236 | ||
|
|
7dc49bd036 | ||
|
|
4d16bdf853 | ||
|
|
01a1f48f70 | ||
|
|
6a9d875d65 | ||
|
|
f1c96d31b4 | ||
|
|
aafcca8d77 | ||
|
|
bf369cad4d | ||
|
|
024fdad76d | ||
|
|
e1c2eda5f5 | ||
|
|
0b574cc0c2 | ||
|
|
3212c83398 | ||
|
|
49f9a11eb3 | ||
|
|
fa36739f01 | ||
|
|
42e9764b60 | ||
|
|
f7f5c07570 | ||
|
|
ec1a936624 | ||
|
|
6e6136586c | ||
|
|
34766863f8 | ||
|
|
1d76d5e828 | ||
|
|
250540a398 | ||
|
|
46f3c38c37 | ||
|
|
9a8982efb1 | ||
|
|
3c815cce4b | ||
|
|
39d199c8bb | ||
|
|
f5506d1e13 | ||
|
|
166a8734fe | ||
|
|
b2273ec568 | ||
|
|
89c4e3bdb6 | ||
|
|
051ebf3439 | ||
|
|
7cfadc2ca8 |
25
README.md
25
README.md
@@ -87,11 +87,16 @@ image.save("image.jpg")
|
||||
|
||||
<summary>Model Overview</summary>
|
||||
|
||||
|Model ID|Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|
|
||||
|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|Model ID|Inference|Low VRAM Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|-|-|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
||||
|
||||
</details>
|
||||
|
||||
### FLUX Series
|
||||
@@ -363,6 +368,16 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
|
||||
|
||||
|
||||
## Update History
|
||||
- **August 18, 2025** We trained and open-sourced the Inpaint ControlNet model for Qwen-Image, [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint), which adopts a lightweight architectural design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py).
|
||||
|
||||
- **August 15, 2025** We open-sourced the [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset). This is an image dataset generated using the Qwen-Image model, with a total of 160,000 `1024 x 1024` images. It includes the general, English text rendering, and Chinese text rendering subsets. We provide caption, entity and control images annotations for each image. Developers can use this dataset to train models such as ControlNet and EliGen for the Qwen-Image model. We aim to promote technological development through open-source contributions!
|
||||
|
||||
- **August 13, 2025** We trained and open-sourced the ControlNet model for Qwen-Image, [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth), which adopts a lightweight architectural design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py).
|
||||
|
||||
- **August 12, 2025** We trained and open-sourced the ControlNet model for Qwen-Image, [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny), which adopts a lightweight architectural design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py).
|
||||
|
||||
- **August 11, 2025** We released another distilled acceleration model for Qwen-Image, [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA). It uses the same training process as [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full), but the model structure is changed to LoRA. This makes it work better with other open-source models.
|
||||
|
||||
- **August 7, 2025** We open-sourced the entity control LoRA of Qwen-Image, [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen). Qwen-Image-EliGen is able to achieve entity-level controlled text-to-image generation. See the [paper](https://arxiv.org/abs/2501.01097) for technical details. Training dataset: [EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet).
|
||||
|
||||
- **August 5, 2025** We open-sourced the distilled acceleration model of Qwen-Image, [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full), achieving approximately 5x speedup.
|
||||
|
||||
24
README_zh.md
24
README_zh.md
@@ -89,11 +89,15 @@ image.save("image.jpg")
|
||||
|
||||
<summary>模型总览</summary>
|
||||
|
||||
|模型 ID|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|-|-|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
||||
|
||||
</details>
|
||||
|
||||
@@ -380,6 +384,16 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
|
||||
|
||||
|
||||
## 更新历史
|
||||
- **2025年8月18日** 我们训练并开源了 Qwen-Image 的图像重绘 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)。
|
||||
|
||||
- **2025年8月15日** 我们开源了 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset) 数据集。这是一个使用 Qwen-Image 模型生成的图像数据集,共包含 160,000 张`1024 x 1024`图像。它包括通用、英文文本渲染和中文文本渲染子集。我们为每张图像提供了图像描述、实体和结构控制图像的标注。开发者可以使用这个数据集来训练 Qwen-Image 模型的 ControlNet 和 EliGen 等模型,我们旨在通过开源推动技术发展!
|
||||
|
||||
- **2025年8月13日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)。
|
||||
|
||||
- **2025年8月12日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)。
|
||||
|
||||
- **2025年8月11日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA),沿用了与 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) 相同的训练流程,但模型结构修改为了 LoRA,因此能够更好地与其他开源生态模型兼容。
|
||||
|
||||
- **2025年8月7日** 我们开源了 Qwen-Image 的实体控制 LoRA 模型 [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)。Qwen-Image-EliGen 能够实现实体级可控的文生图。技术细节请参见[论文](https://arxiv.org/abs/2501.01097)。训练数据集:[EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)。
|
||||
|
||||
- **2025年8月5日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full),实现了约 5 倍加速。
|
||||
|
||||
@@ -72,10 +72,10 @@ from ..models.flux_lora_encoder import FluxLoRAEncoder
|
||||
from ..models.nexus_gen_projector import NexusGenAdapter, NexusGenImageEmbeddingMerger
|
||||
from ..models.nexus_gen import NexusGenAutoregressiveModel
|
||||
|
||||
from ..models.qwen_image_accelerate_adapter import QwenImageAccelerateAdapter
|
||||
from ..models.qwen_image_dit import QwenImageDiT
|
||||
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
|
||||
from ..models.qwen_image_vae import QwenImageVAE
|
||||
from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet
|
||||
|
||||
model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
@@ -166,9 +166,10 @@ model_loader_configs = [
|
||||
(None, "63c969fd37cce769a90aa781fbff5f81", ["flux_dit", "nexus_gen_editing_adapter"], [FluxDiT, NexusGenImageEmbeddingMerger], "civitai"),
|
||||
(None, "2bd19e845116e4f875a0a048e27fc219", ["nexus_gen_llm"], [NexusGenAutoregressiveModel], "civitai"),
|
||||
(None, "0319a1cb19835fb510907dd3367c95ff", ["qwen_image_dit"], [QwenImageDiT], "civitai"),
|
||||
(None, "ae9d13bfc578702baf6445d2cf3d1d46", ["qwen_image_accelerate_adapter"], [QwenImageAccelerateAdapter], "civitai"),
|
||||
(None, "8004730443f55db63092006dd9f7110e", ["qwen_image_text_encoder"], [QwenImageTextEncoder], "diffusers"),
|
||||
(None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"),
|
||||
(None, "073bce9cf969e317e5662cd570c3e79c", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
|
||||
(None, "a9e54e480a628f0b956a688a81c33bab", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
|
||||
]
|
||||
huggingface_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
from .qwen_image_dit import QwenImageTransformerBlock, AdaLayerNorm, TimestepEmbeddings
|
||||
from einops import rearrange
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
class QwenImageAccelerateAdapter(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.proj_latents_in = torch.nn.Linear(64, 3072)
|
||||
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True)
|
||||
self.transformer_blocks = torch.nn.ModuleList(
|
||||
[
|
||||
QwenImageTransformerBlock(
|
||||
dim=3072,
|
||||
num_attention_heads=24,
|
||||
attention_head_dim=128,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_out = AdaLayerNorm(3072, single=True)
|
||||
self.proj_out = torch.nn.Linear(3072, 64)
|
||||
self.proj_latents_out = torch.nn.Linear(64, 64)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents=None,
|
||||
image=None,
|
||||
text=None,
|
||||
image_rotary_emb=None,
|
||||
timestep=None,
|
||||
):
|
||||
latents = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
||||
image = image + self.proj_latents_in(latents)
|
||||
conditioning = self.time_text_embed(timestep, image.dtype)
|
||||
for block in self.transformer_blocks:
|
||||
text, image = block(
|
||||
image=image,
|
||||
text=text,
|
||||
temb=conditioning,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
image = self.norm_out(image, conditioning)
|
||||
image = self.proj_out(image)
|
||||
image = image + self.proj_latents_out(latents)
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return QwenImageAccelerateAdapterStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class QwenImageAccelerateAdapterStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
74
diffsynth/models/qwen_image_controlnet.py
Normal file
74
diffsynth/models/qwen_image_controlnet.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .sd3_dit import RMSNorm
|
||||
from .utils import hash_state_dict_keys
|
||||
|
||||
|
||||
class BlockWiseControlBlock(torch.nn.Module):
|
||||
# [linear, gelu, linear]
|
||||
def __init__(self, dim: int = 3072):
|
||||
super().__init__()
|
||||
self.x_rms = RMSNorm(dim, eps=1e-6)
|
||||
self.y_rms = RMSNorm(dim, eps=1e-6)
|
||||
self.input_proj = nn.Linear(dim, dim)
|
||||
self.act = nn.GELU()
|
||||
self.output_proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x, y):
|
||||
x, y = self.x_rms(x), self.y_rms(y)
|
||||
x = self.input_proj(x + y)
|
||||
x = self.act(x)
|
||||
x = self.output_proj(x)
|
||||
return x
|
||||
|
||||
def init_weights(self):
|
||||
# zero initialize output_proj
|
||||
nn.init.zeros_(self.output_proj.weight)
|
||||
nn.init.zeros_(self.output_proj.bias)
|
||||
|
||||
|
||||
class QwenImageBlockWiseControlNet(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 60,
|
||||
in_dim: int = 64,
|
||||
additional_in_dim: int = 0,
|
||||
dim: int = 3072,
|
||||
):
|
||||
super().__init__()
|
||||
self.img_in = nn.Linear(in_dim + additional_in_dim, dim)
|
||||
self.controlnet_blocks = nn.ModuleList(
|
||||
[
|
||||
BlockWiseControlBlock(dim)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def init_weight(self):
|
||||
nn.init.zeros_(self.img_in.weight)
|
||||
nn.init.zeros_(self.img_in.bias)
|
||||
for block in self.controlnet_blocks:
|
||||
block.init_weights()
|
||||
|
||||
def process_controlnet_conditioning(self, controlnet_conditioning):
|
||||
return self.img_in(controlnet_conditioning)
|
||||
|
||||
def blockwise_forward(self, img, controlnet_conditioning, block_id):
|
||||
return self.controlnet_blocks[block_id](img, controlnet_conditioning)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return QwenImageBlockWiseControlNetStateDictConverter()
|
||||
|
||||
|
||||
class QwenImageBlockWiseControlNetStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
hash_value = hash_state_dict_keys(state_dict)
|
||||
extra_kwargs = {}
|
||||
if hash_value == "a9e54e480a628f0b956a688a81c33bab":
|
||||
# inpaint controlnet
|
||||
extra_kwargs = {"additional_in_dim": 4}
|
||||
return state_dict, extra_kwargs
|
||||
@@ -90,8 +90,39 @@ class QwenEmbedRope(nn.Module):
|
||||
)
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
|
||||
|
||||
def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens):
|
||||
if isinstance(video_fhw, list):
|
||||
video_fhw = video_fhw[0]
|
||||
_, height, width = video_fhw
|
||||
if self.scale_rope:
|
||||
max_vid_index = max(height // 2, width // 2)
|
||||
else:
|
||||
max_vid_index = max(height, width)
|
||||
required_len = max_vid_index + max(txt_seq_lens)
|
||||
cur_max_len = self.pos_freqs.shape[0]
|
||||
if required_len <= cur_max_len:
|
||||
return
|
||||
|
||||
new_max_len = math.ceil(required_len / 512) * 512
|
||||
pos_index = torch.arange(new_max_len)
|
||||
neg_index = torch.arange(new_max_len).flip(0) * -1 - 1
|
||||
self.pos_freqs = torch.cat([
|
||||
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
||||
], dim=1)
|
||||
self.neg_freqs = torch.cat([
|
||||
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
||||
], dim=1)
|
||||
return
|
||||
|
||||
|
||||
def forward(self, video_fhw, txt_seq_lens, device):
|
||||
self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
@@ -422,7 +453,7 @@ class QwenImageDiT(torch.nn.Module):
|
||||
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
|
||||
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
||||
|
||||
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (P Q C)", H=height//16, W=width//16, P=2, Q=2)
|
||||
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
image = self.img_in(image)
|
||||
text = self.txt_in(self.txt_norm(prompt_emb))
|
||||
|
||||
@@ -441,7 +472,7 @@ class QwenImageDiT(torch.nn.Module):
|
||||
image = self.norm_out(image, conditioning)
|
||||
image = self.proj_out(image)
|
||||
|
||||
latents = rearrange(image, "B (H W) (P Q C) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -4,19 +4,46 @@ from typing import Union
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
|
||||
from ..models import ModelManager, load_state_dict
|
||||
from ..models.qwen_image_accelerate_adapter import QwenImageAccelerateAdapter
|
||||
from ..models.qwen_image_dit import QwenImageDiT
|
||||
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
|
||||
from ..models.qwen_image_vae import QwenImageVAE
|
||||
from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet
|
||||
from ..schedulers import FlowMatchScheduler
|
||||
from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||
from ..lora import GeneralLoRALoader
|
||||
from .flux_image_new import ControlNetInput
|
||||
|
||||
from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
|
||||
|
||||
class QwenImageBlockwiseMultiControlNet(torch.nn.Module):
|
||||
def __init__(self, models: list[QwenImageBlockWiseControlNet]):
|
||||
super().__init__()
|
||||
if not isinstance(models, list):
|
||||
models = [models]
|
||||
self.models = torch.nn.ModuleList(models)
|
||||
|
||||
def preprocess(self, controlnet_inputs: list[ControlNetInput], conditionings: list[torch.Tensor], **kwargs):
|
||||
processed_conditionings = []
|
||||
for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
|
||||
conditioning = rearrange(conditioning, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
||||
model_output = self.models[controlnet_input.controlnet_id].process_controlnet_conditioning(conditioning)
|
||||
processed_conditionings.append(model_output)
|
||||
return processed_conditionings
|
||||
|
||||
def blockwise_forward(self, image, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, block_id, **kwargs):
|
||||
res = 0
|
||||
for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
|
||||
progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1)
|
||||
if progress > controlnet_input.start + (1e-4) or progress < controlnet_input.end - (1e-4):
|
||||
continue
|
||||
model_output = self.models[controlnet_input.controlnet_id].blockwise_forward(image, conditioning, block_id)
|
||||
res = res + model_output * controlnet_input.scale
|
||||
return res
|
||||
|
||||
|
||||
class QwenImagePipeline(BasePipeline):
|
||||
|
||||
@@ -28,19 +55,21 @@ class QwenImagePipeline(BasePipeline):
|
||||
from transformers import Qwen2Tokenizer
|
||||
|
||||
self.scheduler = FlowMatchScheduler(sigma_min=0, sigma_max=1, extra_one_step=True, exponential_shift=True, exponential_shift_mu=0.8, shift_terminal=0.02)
|
||||
self.accelerate_adapter: QwenImageAccelerateAdapter = None
|
||||
self.text_encoder: QwenImageTextEncoder = None
|
||||
self.dit: QwenImageDiT = None
|
||||
self.vae: QwenImageVAE = None
|
||||
self.blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None
|
||||
self.tokenizer: Qwen2Tokenizer = None
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.in_iteration_models = ("accelerate_adapter", "dit",)
|
||||
self.in_iteration_models = ("dit", "blockwise_controlnet")
|
||||
self.units = [
|
||||
QwenImageUnit_ShapeChecker(),
|
||||
QwenImageUnit_NoiseInitializer(),
|
||||
QwenImageUnit_InputImageEmbedder(),
|
||||
QwenImageUnit_Inpaint(),
|
||||
QwenImageUnit_PromptEmbedder(),
|
||||
QwenImageUnit_EntityControl(),
|
||||
QwenImageUnit_BlockwiseControlNet(),
|
||||
]
|
||||
self.model_fn = model_fn_qwen_image
|
||||
|
||||
@@ -165,6 +194,23 @@ class QwenImagePipeline(BasePipeline):
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.blockwise_controlnet is not None:
|
||||
enable_vram_management(
|
||||
self.blockwise_controlnet,
|
||||
module_map = {
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@@ -189,7 +235,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
pipe.text_encoder = model_manager.fetch_model("qwen_image_text_encoder")
|
||||
pipe.dit = model_manager.fetch_model("qwen_image_dit")
|
||||
pipe.vae = model_manager.fetch_model("qwen_image_vae")
|
||||
pipe.accelerate_adapter = model_manager.fetch_model("qwen_image_accelerate_adapter")
|
||||
pipe.blockwise_controlnet = QwenImageBlockwiseMultiControlNet(model_manager.fetch_model("qwen_image_blockwise_controlnet", index="all"))
|
||||
if tokenizer_config is not None and pipe.text_encoder is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
from transformers import Qwen2Tokenizer
|
||||
@@ -207,6 +253,10 @@ class QwenImagePipeline(BasePipeline):
|
||||
# Image
|
||||
input_image: Image.Image = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Inpaint
|
||||
inpaint_mask: Image.Image = None,
|
||||
inpaint_blur_size: int = None,
|
||||
inpaint_blur_sigma: float = None,
|
||||
# Shape
|
||||
height: int = 1328,
|
||||
width: int = 1328,
|
||||
@@ -215,6 +265,8 @@ class QwenImagePipeline(BasePipeline):
|
||||
rand_device: str = "cpu",
|
||||
# Steps
|
||||
num_inference_steps: int = 30,
|
||||
# Blockwise ControlNet
|
||||
blockwise_controlnet_inputs: list[ControlNetInput] = None,
|
||||
# EliGen
|
||||
eligen_entity_prompts: list[str] = None,
|
||||
eligen_entity_masks: list[Image.Image] = None,
|
||||
@@ -227,6 +279,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
tile_stride: int = 64,
|
||||
# Progress bar
|
||||
progress_bar_cmd = tqdm,
|
||||
extra_prompt_emb = None,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16))
|
||||
@@ -241,14 +294,20 @@ class QwenImagePipeline(BasePipeline):
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||
"inpaint_mask": inpaint_mask, "inpaint_blur_size": inpaint_blur_size, "inpaint_blur_sigma": inpaint_blur_sigma,
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"enable_fp8_attention": enable_fp8_attention,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"blockwise_controlnet_inputs": blockwise_controlnet_inputs,
|
||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
if extra_prompt_emb is not None:
|
||||
inputs_posi["prompt_emb"] = torch.concat([inputs_posi["prompt_emb"], extra_prompt_emb], dim=1)
|
||||
inputs_posi["prompt_emb_mask"] = torch.ones((1, inputs_posi["prompt_emb"].shape[1]), dtype=inputs_posi["prompt_emb_mask"].dtype, device=inputs_posi["prompt_emb_mask"].device)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
@@ -265,7 +324,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# Scheduler
|
||||
inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"])
|
||||
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
@@ -314,7 +373,26 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit):
|
||||
return {"latents": noise, "input_latents": input_latents}
|
||||
else:
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents, "input_latents": None}
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
|
||||
|
||||
|
||||
class QwenImageUnit_Inpaint(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("inpaint_mask", "height", "width", "inpaint_blur_size", "inpaint_blur_sigma"),
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, inpaint_mask, height, width, inpaint_blur_size, inpaint_blur_sigma):
|
||||
if inpaint_mask is None:
|
||||
return {}
|
||||
inpaint_mask = pipe.preprocess_image(inpaint_mask.convert("RGB").resize((width // 8, height // 8)), min_value=0, max_value=1)
|
||||
inpaint_mask = inpaint_mask.mean(dim=1, keepdim=True)
|
||||
if inpaint_blur_size is not None and inpaint_blur_sigma is not None:
|
||||
from torchvision.transforms import GaussianBlur
|
||||
blur = GaussianBlur(kernel_size=inpaint_blur_size * 2 + 1, sigma=inpaint_blur_sigma)
|
||||
inpaint_mask = blur(inpaint_mask)
|
||||
return {"inpaint_mask": inpaint_mask}
|
||||
|
||||
|
||||
|
||||
@@ -340,7 +418,9 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit):
|
||||
template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
drop_idx = 34
|
||||
txt = [template.format(e) for e in prompt]
|
||||
txt_tokens = pipe.tokenizer(txt, max_length=1024+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device)
|
||||
txt_tokens = pipe.tokenizer(txt, max_length=4096+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device)
|
||||
if txt_tokens.input_ids.shape[1] >= 1024:
|
||||
print(f"Warning!!! QwenImage model was trained on prompts up to 512 tokens. Current prompt requires {txt_tokens['input_ids'].shape[1] - drop_idx} tokens, which may lead to unpredictable behavior.")
|
||||
hidden_states = pipe.text_encoder(input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True,)[-1]
|
||||
|
||||
split_hidden_states = self.extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
|
||||
@@ -359,7 +439,7 @@ class QwenImageUnit_EntityControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
onload_model_names=("text_encoder")
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
|
||||
def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
|
||||
@@ -434,15 +514,62 @@ class QwenImageUnit_EntityControl(PipelineUnit):
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
|
||||
class QwenImageUnit_BlockwiseControlNet(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("blockwise_controlnet_inputs", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def apply_controlnet_mask_on_latents(self, pipe, latents, mask):
|
||||
mask = (pipe.preprocess_image(mask) + 1) / 2
|
||||
mask = mask.mean(dim=1, keepdim=True)
|
||||
mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])
|
||||
latents = torch.concat([latents, mask], dim=1)
|
||||
return latents
|
||||
|
||||
def apply_controlnet_mask_on_image(self, pipe, image, mask):
|
||||
mask = mask.resize(image.size)
|
||||
mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu()
|
||||
image = np.array(image)
|
||||
image[mask > 0] = 0
|
||||
image = Image.fromarray(image)
|
||||
return image
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, blockwise_controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride):
|
||||
if blockwise_controlnet_inputs is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
conditionings = []
|
||||
for controlnet_input in blockwise_controlnet_inputs:
|
||||
image = controlnet_input.image
|
||||
if controlnet_input.inpaint_mask is not None:
|
||||
image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask)
|
||||
|
||||
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
image = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
if controlnet_input.inpaint_mask is not None:
|
||||
image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask)
|
||||
conditionings.append(image)
|
||||
|
||||
return {"blockwise_controlnet_conditioning": conditionings}
|
||||
|
||||
|
||||
def model_fn_qwen_image(
|
||||
dit: QwenImageDiT = None,
|
||||
accelerate_adapter: QwenImageAccelerateAdapter = None,
|
||||
blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_emb=None,
|
||||
prompt_emb_mask=None,
|
||||
height=None,
|
||||
width=None,
|
||||
blockwise_controlnet_conditioning=None,
|
||||
blockwise_controlnet_inputs=None,
|
||||
progress_id=0,
|
||||
num_inference_steps=1,
|
||||
entity_prompt_emb=None,
|
||||
entity_prompt_emb_mask=None,
|
||||
entity_masks=None,
|
||||
@@ -469,8 +596,12 @@ def model_fn_qwen_image(
|
||||
text = dit.txt_in(dit.txt_norm(prompt_emb))
|
||||
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||
attention_mask = None
|
||||
|
||||
if blockwise_controlnet_conditioning is not None:
|
||||
blockwise_controlnet_conditioning = blockwise_controlnet.preprocess(
|
||||
blockwise_controlnet_inputs, blockwise_controlnet_conditioning)
|
||||
|
||||
for block in dit.transformer_blocks:
|
||||
for block_id, block in enumerate(dit.transformer_blocks):
|
||||
text, image = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
@@ -482,12 +613,15 @@ def model_fn_qwen_image(
|
||||
attention_mask=attention_mask,
|
||||
enable_fp8_attention=enable_fp8_attention,
|
||||
)
|
||||
|
||||
if accelerate_adapter is not None:
|
||||
image = accelerate_adapter(latents, image, text, image_rotary_emb, timestep)
|
||||
else:
|
||||
image = dit.norm_out(image, conditioning)
|
||||
image = dit.proj_out(image)
|
||||
if blockwise_controlnet_conditioning is not None:
|
||||
image = image + blockwise_controlnet.blockwise_forward(
|
||||
image=image, conditionings=blockwise_controlnet_conditioning,
|
||||
controlnet_inputs=blockwise_controlnet_inputs, block_id=block_id,
|
||||
progress_id=progress_id, num_inference_steps=num_inference_steps,
|
||||
)
|
||||
|
||||
image = dit.norm_out(image, conditioning)
|
||||
image = dit.proj_out(image)
|
||||
|
||||
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
return latents
|
||||
|
||||
@@ -1021,6 +1021,10 @@ def model_fn_wan_video(
|
||||
torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep
|
||||
]).flatten()
|
||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0))
|
||||
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
||||
t_chunks = torch.chunk(t, get_sequence_parallel_world_size(), dim=1)
|
||||
t_chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, t_chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in t_chunks]
|
||||
t = t_chunks[get_sequence_parallel_rank()]
|
||||
t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim))
|
||||
else:
|
||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||
|
||||
@@ -31,13 +31,11 @@ class FlowMatchScheduler():
|
||||
self.set_timesteps(num_inference_steps)
|
||||
|
||||
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None, dynamic_shift_len=None, random_sigmas=False):
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None, dynamic_shift_len=None):
|
||||
if shift is not None:
|
||||
self.shift = shift
|
||||
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
||||
if random_sigmas:
|
||||
self.sigmas = torch.Tensor(sorted([torch.rand((1,)).item() * (sigma_start - self.sigma_min) for i in range(num_inference_steps - 1)] + [sigma_start], reverse=True))
|
||||
elif self.extra_one_step:
|
||||
if self.extra_one_step:
|
||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
|
||||
else:
|
||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
|
||||
|
||||
@@ -344,8 +344,17 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
|
||||
model = inject_adapter_in_model(lora_config, model)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
|
||||
def mapping_lora_state_dict(self, state_dict):
|
||||
new_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if "lora_A.weight" in key or "lora_B.weight" in key:
|
||||
new_key = key.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight")
|
||||
new_state_dict[new_key] = value
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def export_trainable_state_dict(self, state_dict, remove_prefix=None):
|
||||
trainable_param_names = self.trainable_param_names()
|
||||
state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names}
|
||||
@@ -467,6 +476,7 @@ def wan_parser():
|
||||
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
|
||||
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
|
||||
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
|
||||
parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.")
|
||||
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
||||
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
||||
@@ -475,6 +485,7 @@ def wan_parser():
|
||||
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
|
||||
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
|
||||
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
|
||||
return parser
|
||||
|
||||
|
||||
@@ -498,6 +509,7 @@ def flux_parser():
|
||||
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
|
||||
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
|
||||
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
|
||||
parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.")
|
||||
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
||||
parser.add_argument("--align_to_opensource_format", default=False, action="store_true", help="Whether to align the lora format to opensource format. Only for DiT's LoRA.")
|
||||
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
|
||||
@@ -506,6 +518,7 @@ def flux_parser():
|
||||
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
|
||||
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
|
||||
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
|
||||
return parser
|
||||
|
||||
|
||||
@@ -530,12 +543,13 @@ def qwen_image_parser():
|
||||
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
|
||||
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
|
||||
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
|
||||
parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.")
|
||||
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
||||
parser.add_argument("--align_to_opensource_format", default=False, action="store_true", help="Whether to align the lora format to opensource format. Only for DiT's LoRA.")
|
||||
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
|
||||
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
||||
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
|
||||
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
|
||||
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
|
||||
return parser
|
||||
|
||||
@@ -139,6 +139,20 @@ class BasePipeline(torch.nn.Module):
|
||||
else:
|
||||
model.eval()
|
||||
model.requires_grad_(False)
|
||||
|
||||
|
||||
def blend_with_mask(self, base, addition, mask):
|
||||
return base * (1 - mask) + addition * mask
|
||||
|
||||
|
||||
def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs):
|
||||
timestep = scheduler.timesteps[progress_id]
|
||||
if inpaint_mask is not None:
|
||||
noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents)
|
||||
noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask)
|
||||
latents_next = scheduler.step(noise_pred, timestep, latents)
|
||||
return latents_next
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -255,6 +255,7 @@ The script includes the following parameters:
|
||||
* `--model_id_with_origin_paths`: Model ID with original paths, e.g., black-forest-labs/FLUX.1-dev:flux1-dev.safetensors. Separate with commas.
|
||||
* Training
|
||||
* `--learning_rate`: Learning rate.
|
||||
* `--weight_decay`: Weight decay.
|
||||
* `--num_epochs`: Number of epochs.
|
||||
* `--output_path`: Save path.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefix in checkpoint.
|
||||
@@ -265,6 +266,7 @@ The script includes the following parameters:
|
||||
* `--lora_base_model`: Which model to add LoRA to.
|
||||
* `--lora_target_modules`: Which layers to add LoRA to.
|
||||
* `--lora_rank`: Rank of LoRA.
|
||||
* `--lora_checkpoint`: Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.
|
||||
* Extra Model Inputs
|
||||
* `--extra_inputs`: Extra model inputs, separated by commas.
|
||||
* VRAM Management
|
||||
|
||||
@@ -255,6 +255,7 @@ FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_tra
|
||||
* `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 black-forest-labs/FLUX.1-dev:flux1-dev.safetensors。用逗号分隔。
|
||||
* 训练
|
||||
* `--learning_rate`: 学习率。
|
||||
* `--weight_decay`:权重衰减大小。
|
||||
* `--num_epochs`: 轮数(Epoch)。
|
||||
* `--output_path`: 保存路径。
|
||||
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
|
||||
@@ -265,6 +266,7 @@ FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_tra
|
||||
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||
* `--lora_target_modules`: LoRA 添加到哪一层上。
|
||||
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||
* `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。
|
||||
* 额外模型输入
|
||||
* `--extra_inputs`: 额外的模型输入,以逗号分隔。
|
||||
* 显存管理
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch, os, json
|
||||
from diffsynth import load_state_dict
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, flux_parser
|
||||
from diffsynth.models.lora import FluxLoRAConverter
|
||||
@@ -11,7 +12,7 @@ class FluxTrainingModule(DiffusionTrainingModule):
|
||||
self,
|
||||
model_paths=None, model_id_with_origin_paths=None,
|
||||
trainable_models=None,
|
||||
lora_base_model=None, lora_target_modules="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp", lora_rank=32,
|
||||
lora_base_model=None, lora_target_modules="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp", lora_rank=32, lora_checkpoint=None,
|
||||
use_gradient_checkpointing=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
extra_inputs=None,
|
||||
@@ -40,6 +41,12 @@ class FluxTrainingModule(DiffusionTrainingModule):
|
||||
target_modules=lora_target_modules.split(","),
|
||||
lora_rank=lora_rank
|
||||
)
|
||||
if lora_checkpoint is not None:
|
||||
state_dict = load_state_dict(lora_checkpoint)
|
||||
state_dict = self.mapping_lora_state_dict(state_dict)
|
||||
load_result = model.load_state_dict(state_dict, strict=False)
|
||||
if len(load_result[1]) > 0:
|
||||
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
||||
setattr(self.pipe, lora_base_model, model)
|
||||
|
||||
# Store other configs
|
||||
@@ -106,6 +113,7 @@ if __name__ == "__main__":
|
||||
lora_base_model=args.lora_base_model,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
lora_rank=args.lora_rank,
|
||||
lora_checkpoint=args.lora_checkpoint,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
extra_inputs=args.extra_inputs,
|
||||
@@ -115,7 +123,7 @@ if __name__ == "__main__":
|
||||
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
|
||||
)
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||
launch_training_task(
|
||||
dataset, model, model_logger, optimizer, scheduler,
|
||||
|
||||
@@ -40,11 +40,15 @@ image.save("image.jpg")
|
||||
|
||||
## Model Overview
|
||||
|
||||
|Model ID|Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|
|
||||
|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image )|[code](./model_inference/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./model_inference/Qwen-Image-EliGen.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|Model ID|Inference|Low VRAM Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./model_inference/Qwen-Image.py)|[code](./model_inference_low_vram/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./model_inference/Qwen-Image-Distill-LoRA.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|-|-|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./model_inference/Qwen-Image-EliGen.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
||||
|
||||
## Model Inference
|
||||
|
||||
@@ -174,10 +178,13 @@ After enabling VRAM management, the framework will automatically choose a memory
|
||||
<summary>Inference Acceleration</summary>
|
||||
|
||||
* FP8 Quantization: Choose the appropriate quantization method based on your hardware and requirements.
|
||||
* GPUs that do not support FP8 computation (e.g., A100, 4090, etc.): FP8 quantization will only reduce VRAM usage without speeding up inference. Code: [./model_inference_lor_vram/Qwen-Image.py](./model_inference_lor_vram/Qwen-Image.py)
|
||||
* GPUs that do not support FP8 computation (e.g., A100, 4090, etc.): FP8 quantization will only reduce VRAM usage without speeding up inference. Code: [./model_inference_low_vram/Qwen-Image.py](./model_inference_low_vram/Qwen-Image.py)
|
||||
* GPUs that support FP8 operations (e.g., H200, etc.): Please install [Flash Attention 3](https://github.com/Dao-AILab/flash-attention). Otherwise, FP8 acceleration will only apply to Linear layers.
|
||||
* Faster inference but higher VRAM usage: Use [./accelerate/Qwen-Image-FP8.py](./accelerate/Qwen-Image-FP8.py)
|
||||
* Slightly slower inference but lower VRAM usage: Use [./accelerate/Qwen-Image-FP8-offload.py](./accelerate/Qwen-Image-FP8-offload.py)
|
||||
* Distillation acceleration: We trained two distillation models for fast inference at `cfg_scale=1` and `num_inference_steps=15`.
|
||||
* [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full): Full distillation version. Better image quality but lower LoRA compatibility. Use [./model_inference/Qwen-Image-Distill-Full.py](./model_inference/Qwen-Image-Distill-Full.py).
|
||||
* [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA): LoRA distillation version. Slightly lower image quality but better LoRA compatibility. Use [./model_inference/Qwen-Image-Distill-LoRA.py](./model_inference/Qwen-Image-Distill-LoRA.py).
|
||||
|
||||
</details>
|
||||
|
||||
@@ -231,6 +238,7 @@ The script includes the following parameters:
|
||||
* `--tokenizer_path`: Tokenizer path. Leave empty to auto-download.
|
||||
* Training
|
||||
* `--learning_rate`: Learning rate.
|
||||
* `--weight_decay`: Weight decay.
|
||||
* `--num_epochs`: Number of epochs.
|
||||
* `--output_path`: Save path.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefix in checkpoint.
|
||||
@@ -241,14 +249,13 @@ The script includes the following parameters:
|
||||
* `--lora_base_model`: Which model to add LoRA to.
|
||||
* `--lora_target_modules`: Which layers to add LoRA to.
|
||||
* `--lora_rank`: Rank of LoRA.
|
||||
* `--lora_checkpoint`: Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.
|
||||
* Extra Model Inputs
|
||||
* `--extra_inputs`: Extra model inputs, separated by commas.
|
||||
* VRAM Management
|
||||
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
|
||||
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
|
||||
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
|
||||
* Others
|
||||
* `--align_to_opensource_format`: Whether to align DiT LoRA format with open-source version. Only works for LoRA training.
|
||||
|
||||
In addition, the training framework is built on [`accelerate`](https://huggingface.co/docs/accelerate/index). Run `accelerate config` before training to set GPU-related settings. For some training tasks (e.g., full training of 20B model), we provide suggested `accelerate` config files. Check the corresponding training script for details.
|
||||
|
||||
|
||||
@@ -40,11 +40,15 @@ image.save("image.jpg")
|
||||
|
||||
## 模型总览
|
||||
|
||||
|模型 ID|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./model_inference/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./model_inference/Qwen-Image-EliGen.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./model_inference/Qwen-Image.py)|[code](./model_inference_low_vram/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./model_inference/Qwen-Image-Distill-LoRA.py)|[code](./model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|-|-|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./model_inference/Qwen-Image-EliGen.py)|[code](./model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./model_training/lora/Qwen-Image-EliGen.sh)|[code](./model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](./model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
||||
|
||||
## 模型推理
|
||||
|
||||
@@ -174,10 +178,13 @@ FP8 量化能够大幅度减少显存占用,但不会加速,部分模型在
|
||||
<summary>推理加速</summary>
|
||||
|
||||
* FP8 量化:根据您的硬件与需求,请选择合适的量化方式
|
||||
* GPU 不支持 FP8 计算(例如 A100、4090 等):FP8 量化仅能降低显存占用,无法加速,代码:[./model_inference_lor_vram/Qwen-Image.py](./model_inference_lor_vram/Qwen-Image.py)
|
||||
* GPU 不支持 FP8 计算(例如 A100、4090 等):FP8 量化仅能降低显存占用,无法加速,代码:[./model_inference_low_vram/Qwen-Image.py](./model_inference_low_vram/Qwen-Image.py)
|
||||
* GPU 支持 FP8 运算(例如 H200 等):请安装 [Flash Attention 3](https://github.com/Dao-AILab/flash-attention),否则 FP8 加速仅对 Linear 层生效
|
||||
* 更快的速度,但更大的显存:请使用 [./accelerate/Qwen-Image-FP8.py](./accelerate/Qwen-Image-FP8.py)
|
||||
* 稍慢的速度,但更小的显存:请使用 [./accelerate/Qwen-Image-FP8-offload.py](./accelerate/Qwen-Image-FP8-offload.py)
|
||||
* 蒸馏加速:我们训练了两个蒸馏加速模型,可以在 `cfg_scale=1` 和 `num_inference_steps=15` 设置下进行快速推理
|
||||
* [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full):全量蒸馏训练版本,更好的生成效果,稍差的 LoRA 兼容性,请使用 [./model_inference/Qwen-Image-Distill-Full.py](./model_inference/Qwen-Image-Distill-Full.py)
|
||||
* [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA):LoRA 蒸馏训练版本,稍差的生成效果,更好的 LoRA 兼容性,请使用 [./model_inference/Qwen-Image-Distill-LoRA.py](./model_inference/Qwen-Image-Distill-LoRA.py)
|
||||
|
||||
</details>
|
||||
|
||||
@@ -231,6 +238,7 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod
|
||||
* `--tokenizer_path`: tokenizer 路径,留空将会自动下载。
|
||||
* 训练
|
||||
* `--learning_rate`: 学习率。
|
||||
* `--weight_decay`:权重衰减大小。
|
||||
* `--num_epochs`: 轮数(Epoch)。
|
||||
* `--output_path`: 保存路径。
|
||||
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
|
||||
@@ -241,14 +249,13 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod
|
||||
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||
* `--lora_target_modules`: LoRA 添加到哪一层上。
|
||||
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||
* `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。
|
||||
* 额外模型输入
|
||||
* `--extra_inputs`: 额外的模型输入,以逗号分隔。
|
||||
* 显存管理
|
||||
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
|
||||
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
|
||||
* `--gradient_accumulation_steps`: 梯度累积步数。
|
||||
* 其他
|
||||
* `--align_to_opensource_format`: 是否将 DiT LoRA 的格式与开源版本对齐,仅对 LoRA 训练生效。
|
||||
|
||||
此外,训练框架基于 [`accelerate`](https://huggingface.co/docs/accelerate/index) 构建,在开始训练前运行 `accelerate config` 可配置 GPU 的相关参数。对于部分模型训练(例如 20B 模型的全量训练)脚本,我们提供了建议的 `accelerate` 配置文件,可在对应的训练脚本中查看。
|
||||
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
from PIL import Image
|
||||
import torch
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny", origin_file_pattern="model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||
local_dir="./data/example_image_dataset",
|
||||
allow_file_pattern="canny/image_1.jpg"
|
||||
)
|
||||
controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1328, 1328))
|
||||
|
||||
prompt = "一只小狗,毛发光洁柔顺,眼神灵动,背景是樱花纷飞的春日庭院,唯美温馨。"
|
||||
image = pipe(
|
||||
prompt, seed=0,
|
||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,32 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
from PIL import Image
|
||||
import torch
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||
local_dir="./data/example_image_dataset",
|
||||
allow_file_pattern="depth/image_1.jpg"
|
||||
)
|
||||
|
||||
controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1328, 1328))
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(
|
||||
prompt, seed=0,
|
||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||
local_dir="./data/example_image_dataset",
|
||||
allow_file_pattern="inpaint/*.jpg"
|
||||
)
|
||||
prompt = "a cat with sunglasses"
|
||||
controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1328, 1328))
|
||||
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1328, 1328))
|
||||
image = pipe(
|
||||
prompt, seed=0,
|
||||
input_image=controlnet_image, inpaint_mask=inpaint_mask,
|
||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)],
|
||||
num_inference_steps=40,
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -1,7 +1,8 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from modelscope import snapshot_download
|
||||
import torch
|
||||
|
||||
|
||||
snapshot_download("DiffSynth-Studio/Qwen-Image-Distill-LoRA", local_dir="models/DiffSynth-Studio/Qwen-Image-Distill-LoRA")
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
@@ -9,10 +10,11 @@ pipe = QwenImagePipeline.from_pretrained(
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig("models/adapter.safetensors")
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-Distill-LoRA/model.safetensors")
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=4, cfg_scale=1)
|
||||
image.save("image.jpg")
|
||||
image = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,32 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
from PIL import Image
|
||||
import torch
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny", origin_file_pattern="model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||
local_dir="./data/example_image_dataset",
|
||||
allow_file_pattern="canny/image_1.jpg"
|
||||
)
|
||||
controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1328, 1328))
|
||||
|
||||
prompt = "一只小狗,毛发光洁柔顺,眼神灵动,背景是樱花纷飞的春日庭院,唯美温馨。"
|
||||
image = pipe(
|
||||
prompt, seed=0,
|
||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,33 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
from PIL import Image
|
||||
import torch
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||
local_dir="./data/example_image_dataset",
|
||||
allow_file_pattern="depth/image_1.jpg"
|
||||
)
|
||||
|
||||
controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1328, 1328))
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(
|
||||
prompt, seed=0,
|
||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||
local_dir="./data/example_image_dataset",
|
||||
allow_file_pattern="inpaint/*.jpg"
|
||||
)
|
||||
prompt = "a cat with sunglasses"
|
||||
controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1328, 1328))
|
||||
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1328, 1328))
|
||||
image = pipe(
|
||||
prompt, seed=0,
|
||||
input_image=controlnet_image, inpaint_mask=inpaint_mask,
|
||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)],
|
||||
num_inference_steps=40,
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,22 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from modelscope import snapshot_download
|
||||
import torch
|
||||
|
||||
# Please do not use float8 on this model
|
||||
snapshot_download("DiffSynth-Studio/Qwen-Image-Distill-LoRA", local_dir="models/DiffSynth-Studio/Qwen-Image-Distill-LoRA")
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-Distill-LoRA/model.safetensors")
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,129 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from modelscope import dataset_snapshot_download, snapshot_download
|
||||
import random
|
||||
|
||||
|
||||
def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):
|
||||
# Create a blank image for overlays
|
||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||
|
||||
colors = [
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
]
|
||||
# Generate random colors for each mask
|
||||
if use_random_colors:
|
||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
||||
|
||||
# Font settings
|
||||
try:
|
||||
font = ImageFont.truetype("wqy-zenhei.ttc", font_size) # Adjust as needed
|
||||
except IOError:
|
||||
font = ImageFont.load_default(font_size)
|
||||
|
||||
# Overlay each mask onto the overlay image
|
||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
||||
# Convert mask to RGBA mode
|
||||
mask_rgba = mask.convert('RGBA')
|
||||
mask_data = mask_rgba.getdata()
|
||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
||||
mask_rgba.putdata(new_data)
|
||||
|
||||
# Draw the mask prompt text on the mask
|
||||
draw = ImageDraw.Draw(mask_rgba)
|
||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
||||
|
||||
# Alpha composite the overlay with this mask
|
||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
||||
|
||||
# Composite the overlay onto the original image
|
||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
||||
|
||||
# Save or display the resulting image
|
||||
result.save(output_path)
|
||||
|
||||
return result
|
||||
|
||||
def example(pipe, seeds, example_id, global_prompt, entity_prompts):
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/qwen-image/example_{example_id}/*.png")
|
||||
masks = [Image.open(f"./data/examples/eligen/qwen-image/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||
negative_prompt = ""
|
||||
for seed in seeds:
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt=global_prompt,
|
||||
cfg_scale=4.0,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=30,
|
||||
seed=seed,
|
||||
height=1024,
|
||||
width=1024,
|
||||
eligen_entity_prompts=entity_prompts,
|
||||
eligen_entity_masks=masks,
|
||||
)
|
||||
image.save(f"eligen_example_{example_id}_{seed}.png")
|
||||
visualize_masks(image, masks, entity_prompts, f"eligen_example_{example_id}_mask_{seed}.png")
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
snapshot_download("DiffSynth-Studio/Qwen-Image-EliGen", local_dir="models/DiffSynth-Studio/Qwen-Image-EliGen", allow_file_pattern="model.safetensors")
|
||||
pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-EliGen/model.safetensors")
|
||||
|
||||
# example 1
|
||||
global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\n"
|
||||
entity_prompts = ["cliff", "sea", "moon", "sailing boat", "a seated beautiful woman", "pale blue long dress with soft glow"]
|
||||
example(pipe, [0], 1, global_prompt, entity_prompts)
|
||||
|
||||
# example 2
|
||||
global_prompt = "samurai girl wearing a kimono, she's holding a sword glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render."
|
||||
entity_prompts = ["flowing hair", "sword glowing with red flame", "A cute bird", "yellow belt"]
|
||||
example(pipe, [0], 2, global_prompt, entity_prompts)
|
||||
|
||||
# example 3
|
||||
global_prompt = "Image of a neverending staircase up to a mysterious palace in the sky, The ancient palace stood majestically atop a mist-shrouded mountain, sunrise, two traditional monk walk in the stair looking at the sunrise, fog,see-through, best quality, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, photorealistic, concept art, harmony, serenity, tranquility, ambient occlusion, halation, cozy ambient lighting, dynamic lighting,masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning,"
|
||||
entity_prompts = ["ancient palace", "stone staircase with railings", "a traditional monk", "a traditional monk"]
|
||||
example(pipe, [27], 3, global_prompt, entity_prompts)
|
||||
|
||||
# example 4
|
||||
global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'"
|
||||
entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"]
|
||||
example(pipe, [21], 4, global_prompt, entity_prompts)
|
||||
|
||||
# example 5
|
||||
global_prompt = "A captivating, dramatic scene in a painting that exudes mystery and foreboding. A white sky, swirling blue clouds, and a crescent yellow moon illuminate a solitary woman standing near the water's edge. Her long dress flows in the wind, silhouetted against the eerie glow. The water mirrors the fiery sky and moonlight, amplifying the uneasy atmosphere."
|
||||
entity_prompts = ["crescent yellow moon", "a solitary woman", "water", "swirling blue clouds"]
|
||||
example(pipe, [0], 5, global_prompt, entity_prompts)
|
||||
|
||||
# example 7, same prompt with different seeds
|
||||
seeds = range(5, 9)
|
||||
global_prompt = "A beautiful asia woman wearing white dress, holding a mirror, with a forest background."
|
||||
entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"]
|
||||
example(pipe, seeds, 7, global_prompt, entity_prompts)
|
||||
@@ -0,0 +1,38 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_canny.csv \
|
||||
--data_file_keys "image,blockwise_controlnet_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny:model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \
|
||||
--output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Canny_full" \
|
||||
--trainable_models "blockwise_controlnet" \
|
||||
--extra_inputs "blockwise_controlnet_image" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
|
||||
# If you want to pre-train a Blockwise ControlNet from scratch,
|
||||
# please run the following script to first generate the initialized model weights file,
|
||||
# and then start training with a high learning rate (1e-3).
|
||||
|
||||
# python examples/qwen_image/model_training/scripts/Qwen-Image-Blockwise-ControlNet-Initialize.py
|
||||
|
||||
# accelerate launch examples/qwen_image/model_training/train.py \
|
||||
# --dataset_base_path data/example_image_dataset \
|
||||
# --dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_canny.csv \
|
||||
# --data_file_keys "image,blockwise_controlnet_image" \
|
||||
# --max_pixels 1048576 \
|
||||
# --dataset_repeat 50 \
|
||||
# --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
# --model_paths '["models/blockwise_controlnet.safetensors"]' \
|
||||
# --learning_rate 1e-3 \
|
||||
# --num_epochs 2 \
|
||||
# --remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \
|
||||
# --output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Canny_full" \
|
||||
# --trainable_models "blockwise_controlnet" \
|
||||
# --extra_inputs "blockwise_controlnet_image" \
|
||||
# --use_gradient_checkpointing \
|
||||
# --find_unused_parameters
|
||||
@@ -0,0 +1,38 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_depth.csv \
|
||||
--data_file_keys "image,blockwise_controlnet_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth:model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \
|
||||
--output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Depth_full" \
|
||||
--trainable_models "blockwise_controlnet" \
|
||||
--extra_inputs "blockwise_controlnet_image" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
|
||||
# If you want to pre-train a Blockwise ControlNet from scratch,
|
||||
# please run the following script to first generate the initialized model weights file,
|
||||
# and then start training with a high learning rate (1e-3).
|
||||
|
||||
# python examples/qwen_image/model_training/scripts/Qwen-Image-Blockwise-ControlNet-Initialize.py
|
||||
|
||||
# accelerate launch examples/qwen_image/model_training/train.py \
|
||||
# --dataset_base_path data/example_image_dataset \
|
||||
# --dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_depth.csv \
|
||||
# --data_file_keys "image,blockwise_controlnet_image" \
|
||||
# --max_pixels 1048576 \
|
||||
# --dataset_repeat 50 \
|
||||
# --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
# --model_paths '["models/blockwise_controlnet.safetensors"]' \
|
||||
# --learning_rate 1e-3 \
|
||||
# --num_epochs 2 \
|
||||
# --remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \
|
||||
# --output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Depth_full" \
|
||||
# --trainable_models "blockwise_controlnet" \
|
||||
# --extra_inputs "blockwise_controlnet_image" \
|
||||
# --use_gradient_checkpointing \
|
||||
# --find_unused_parameters
|
||||
@@ -0,0 +1,38 @@
|
||||
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config.yaml examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_inpaint.csv \
|
||||
--data_file_keys "image,blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint:model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \
|
||||
--output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_full" \
|
||||
--trainable_models "blockwise_controlnet" \
|
||||
--extra_inputs "blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
|
||||
# If you want to pre-train a Inpaint Blockwise ControlNet from scratch,
|
||||
# please run the following script to first generate the initialized model weights file,
|
||||
# and then start training with a high learning rate (1e-3).
|
||||
|
||||
# python examples/qwen_image/model_training/scripts/Qwen-Image-Blockwise-ControlNet-Inpaint-Initialize.py
|
||||
|
||||
# accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config.yaml examples/qwen_image/model_training/train.py \
|
||||
# --dataset_base_path data/example_image_dataset \
|
||||
# --dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_inpaint.csv \
|
||||
# --data_file_keys "image,blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \
|
||||
# --max_pixels 1048576 \
|
||||
# --dataset_repeat 50 \
|
||||
# --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
# --model_paths '["models/blockwise_controlnet_inpaint.safetensors"]' \
|
||||
# --learning_rate 1e-3 \
|
||||
# --num_epochs 2 \
|
||||
# --remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \
|
||||
# --output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_full" \
|
||||
# --trainable_models "blockwise_controlnet" \
|
||||
# --extra_inputs "blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \
|
||||
# --use_gradient_checkpointing \
|
||||
# --find_unused_parameters
|
||||
@@ -9,4 +9,5 @@ accelerate launch --config_file examples/qwen_image/model_training/full/accelera
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-Distill-Full_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
|
||||
@@ -9,4 +9,5 @@ accelerate launch --config_file examples/qwen_image/model_training/full/accelera
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 1
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -1,32 +0,0 @@
|
||||
# This script is for initializing a Qwen-Image-Accelerate-Adapter
|
||||
from diffsynth import load_state_dict, hash_state_dict_keys
|
||||
from diffsynth.pipelines.qwen_image import QwenImageAccelerateAdapter
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
state_dict_dit = {}
|
||||
for i in range(1, 10):
|
||||
state_dict_dit.update(load_state_dict(f"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-0000{i}-of-00009.safetensors", torch_dtype=torch.bfloat16, device="cuda"))
|
||||
|
||||
adapter = QwenImageAccelerateAdapter().to(dtype=torch.bfloat16, device="cuda")
|
||||
state_dict_adapter = adapter.state_dict()
|
||||
|
||||
state_dict_init = {}
|
||||
for k in state_dict_adapter:
|
||||
if k.startswith("transformer_blocks"):
|
||||
name = k.replace("transformer_blocks.0.", "transformer_blocks.59.")
|
||||
param = state_dict_dit[name]
|
||||
if "_mod." in k:
|
||||
param[2*3072: 3*3072] = 0
|
||||
param[5*3072: 6*3072] = 0
|
||||
state_dict_init[k] = param
|
||||
elif k in state_dict_dit:
|
||||
state_dict_init[k] = state_dict_dit[k]
|
||||
else:
|
||||
state_dict_init[k] = torch.zeros_like(state_dict_adapter[k])
|
||||
print("Zero initialized:", k)
|
||||
adapter.load_state_dict(state_dict_init)
|
||||
|
||||
print(hash_state_dict_keys(state_dict_init))
|
||||
save_file(state_dict_init, "models/adapter.safetensors")
|
||||
@@ -0,0 +1,17 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_canny.csv \
|
||||
--data_file_keys "image,blockwise_controlnet_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny:model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Canny_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "blockwise_controlnet_image" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
@@ -0,0 +1,17 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_depth.csv \
|
||||
--data_file_keys "image,blockwise_controlnet_image" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth:model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Depth_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "blockwise_controlnet_image" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
@@ -0,0 +1,17 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_inpaint.csv \
|
||||
--data_file_keys "image,blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint:model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
@@ -11,5 +11,5 @@ accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--align_to_opensource_format \
|
||||
--use_gradient_checkpointing
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
|
||||
@@ -12,7 +12,6 @@ accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--align_to_opensource_format \
|
||||
--extra_inputs "eligen_entity_masks,eligen_entity_prompts" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
|
||||
@@ -11,7 +11,6 @@ accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--align_to_opensource_format \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8 \
|
||||
--find_unused_parameters
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
# This script is for initializing a Qwen-Image-Blockwise-ControlNet
|
||||
from diffsynth import hash_state_dict_keys
|
||||
from diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
controlnet = QwenImageBlockWiseControlNet().to(dtype=torch.bfloat16, device="cuda")
|
||||
controlnet.init_weight()
|
||||
state_dict_controlnet = controlnet.state_dict()
|
||||
|
||||
print(hash_state_dict_keys(state_dict_controlnet))
|
||||
save_file(state_dict_controlnet, "models/blockwise_controlnet.safetensors")
|
||||
@@ -0,0 +1,12 @@
|
||||
# This script is for initializing a Inpaint Qwen-Image-ControlNet
|
||||
import torch
|
||||
from diffsynth import hash_state_dict_keys
|
||||
from diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet
|
||||
from safetensors.torch import save_file
|
||||
|
||||
controlnet = QwenImageBlockWiseControlNet(additional_in_dim=4).to(dtype=torch.bfloat16, device="cuda")
|
||||
controlnet.init_weight()
|
||||
state_dict_controlnet = controlnet.state_dict()
|
||||
|
||||
print(hash_state_dict_keys(state_dict_controlnet))
|
||||
save_file(state_dict_controlnet, "models/blockwise_controlnet_inpaint.safetensors")
|
||||
@@ -1,7 +1,8 @@
|
||||
import torch, os, json
|
||||
from diffsynth import load_state_dict
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from diffsynth.pipelines.flux_image_new import ControlNetInput
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser
|
||||
from diffsynth.models.lora import QwenImageLoRAConverter
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
@@ -12,7 +13,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
model_paths=None, model_id_with_origin_paths=None,
|
||||
tokenizer_path=None,
|
||||
trainable_models=None,
|
||||
lora_base_model=None, lora_target_modules="", lora_rank=32,
|
||||
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||
use_gradient_checkpointing=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
extra_inputs=None,
|
||||
@@ -44,6 +45,12 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
target_modules=lora_target_modules.split(","),
|
||||
lora_rank=lora_rank
|
||||
)
|
||||
if lora_checkpoint is not None:
|
||||
state_dict = load_state_dict(lora_checkpoint)
|
||||
state_dict = self.mapping_lora_state_dict(state_dict)
|
||||
load_result = model.load_state_dict(state_dict, strict=False)
|
||||
if len(load_result[1]) > 0:
|
||||
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
||||
setattr(self.pipe, lora_base_model, model)
|
||||
|
||||
# Store other configs
|
||||
@@ -73,8 +80,18 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
}
|
||||
|
||||
# Extra inputs
|
||||
controlnet_input, blockwise_controlnet_input = {}, {}
|
||||
for extra_input in self.extra_inputs:
|
||||
inputs_shared[extra_input] = data[extra_input]
|
||||
if extra_input.startswith("blockwise_controlnet_"):
|
||||
blockwise_controlnet_input[extra_input.replace("blockwise_controlnet_", "")] = data[extra_input]
|
||||
elif extra_input.startswith("controlnet_"):
|
||||
controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input]
|
||||
else:
|
||||
inputs_shared[extra_input] = data[extra_input]
|
||||
if len(controlnet_input) > 0:
|
||||
inputs_shared["controlnet_inputs"] = [ControlNetInput(**controlnet_input)]
|
||||
if len(blockwise_controlnet_input) > 0:
|
||||
inputs_shared["blockwise_controlnet_inputs"] = [ControlNetInput(**blockwise_controlnet_input)]
|
||||
|
||||
# Pipeline units will automatically process the input parameters.
|
||||
for unit in self.pipe.units:
|
||||
@@ -102,16 +119,13 @@ if __name__ == "__main__":
|
||||
lora_base_model=args.lora_base_model,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
lora_rank=args.lora_rank,
|
||||
lora_checkpoint=args.lora_checkpoint,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
extra_inputs=args.extra_inputs,
|
||||
)
|
||||
model_logger = ModelLogger(
|
||||
args.output_path,
|
||||
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||
state_dict_converter=QwenImageLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
|
||||
)
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
|
||||
model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt)
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||
launch_training_task(
|
||||
dataset, model, model_logger, optimizer, scheduler,
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
from PIL import Image
|
||||
import torch
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig(path="models/train/Qwen-Image-Blockwise-ControlNet-Canny_full/epoch-1.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||
local_dir="./data/example_image_dataset",
|
||||
allow_file_pattern="canny/image_1.jpg"
|
||||
)
|
||||
controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1328, 1328))
|
||||
|
||||
prompt = "一只小狗,毛发光洁柔顺,眼神灵动,背景是樱花纷飞的春日庭院,唯美温馨。"
|
||||
image = pipe(
|
||||
prompt, seed=0,
|
||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,31 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
from PIL import Image
|
||||
import torch
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig(path="models/train/Qwen-Image-Blockwise-ControlNet-Depth_full/epoch-1.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||
local_dir="./data/example_image_dataset",
|
||||
allow_file_pattern="depth/image_1.jpg"
|
||||
)
|
||||
controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1328, 1328))
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(
|
||||
prompt, seed=0,
|
||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig(path="models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_full/epoch-1.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||
local_dir="./data/example_image_dataset",
|
||||
allow_file_pattern="inpaint/*.jpg"
|
||||
)
|
||||
prompt = "a cat with sunglasses"
|
||||
controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024))
|
||||
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024))
|
||||
image = pipe(
|
||||
prompt, seed=0,
|
||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)],
|
||||
height=1024, width=1024,
|
||||
num_inference_steps=40,
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,32 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
from PIL import Image
|
||||
import torch
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny", origin_file_pattern="model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Blockwise-ControlNet-Canny_lora/epoch-4.safetensors")
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||
local_dir="./data/example_image_dataset",
|
||||
allow_file_pattern="canny/image_1.jpg"
|
||||
)
|
||||
controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1328, 1328))
|
||||
|
||||
prompt = "一只小狗,毛发光洁柔顺,眼神灵动,背景是樱花纷飞的春日庭院,唯美温馨。"
|
||||
image = pipe(
|
||||
prompt, seed=0,
|
||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,33 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
from PIL import Image
|
||||
import torch
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Blockwise-ControlNet-Depth_lora/epoch-4.safetensors")
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||
local_dir="./data/example_image_dataset",
|
||||
allow_file_pattern="depth/image_1.jpg"
|
||||
)
|
||||
|
||||
controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1328, 1328))
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(
|
||||
prompt, seed=0,
|
||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)]
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_lora/epoch-4.safetensors")
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||
local_dir="./data/example_image_dataset",
|
||||
allow_file_pattern="inpaint/*.jpg"
|
||||
)
|
||||
prompt = "a cat with sunglasses"
|
||||
controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024))
|
||||
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024))
|
||||
image = pipe(
|
||||
prompt, seed=0,
|
||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)],
|
||||
height=1024, width=1024,
|
||||
num_inference_steps=40,
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -288,6 +288,7 @@ The script includes the following parameters:
|
||||
* `--min_timestep_boundary`: Minimum value of the timestep interval, ranging from 0 to 1. Default is 1. This needs to be manually set only when training mixed models with multiple DiTs, for example, [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B).
|
||||
* Training
|
||||
* `--learning_rate`: Learning rate.
|
||||
* `--weight_decay`: Weight decay.
|
||||
* `--num_epochs`: Number of epochs.
|
||||
* `--output_path`: Output save path.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefix in ckpt.
|
||||
@@ -298,6 +299,7 @@ The script includes the following parameters:
|
||||
* `--lora_base_model`: Which model LoRA is added to.
|
||||
* `--lora_target_modules`: Which layers LoRA is added to.
|
||||
* `--lora_rank`: Rank of LoRA.
|
||||
* `--lora_checkpoint`: Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.
|
||||
* Extra Inputs
|
||||
* `--extra_inputs`: Additional model inputs, comma-separated.
|
||||
* VRAM Management
|
||||
|
||||
@@ -290,6 +290,7 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai
|
||||
* `--min_timestep_boundary`: Timestep 区间最小值,范围为 0~1,默认为 1,仅在多 DiT 的混合模型训练中需要手动设置,例如 [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)。
|
||||
* 训练
|
||||
* `--learning_rate`: 学习率。
|
||||
* `--weight_decay`:权重衰减大小。
|
||||
* `--num_epochs`: 轮数(Epoch)。
|
||||
* `--output_path`: 保存路径。
|
||||
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
|
||||
@@ -300,6 +301,7 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai
|
||||
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||
* `--lora_target_modules`: LoRA 添加到哪一层上。
|
||||
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||
* `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。
|
||||
* 额外模型输入
|
||||
* `--extra_inputs`: 额外的模型输入,以逗号分隔。
|
||||
* 显存管理
|
||||
|
||||
@@ -28,5 +28,6 @@ video = pipe(
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
input_image=input_image,
|
||||
switch_DiT_boundary=0.9,
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|
||||
@@ -13,8 +13,9 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
|
||||
--trainable_models "dit" \
|
||||
--extra_inputs "input_image" \
|
||||
--use_gradient_checkpointing_offload \
|
||||
--max_timestep_boundary 1 \
|
||||
--min_timestep_boundary 0.875
|
||||
--max_timestep_boundary 0.358 \
|
||||
--min_timestep_boundary 0
|
||||
# boundary corresponds to timesteps [900, 1000]
|
||||
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
@@ -31,5 +32,6 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
|
||||
--trainable_models "dit" \
|
||||
--extra_inputs "input_image" \
|
||||
--use_gradient_checkpointing_offload \
|
||||
--max_timestep_boundary 0.875 \
|
||||
--min_timestep_boundary 0
|
||||
--max_timestep_boundary 1 \
|
||||
--min_timestep_boundary 0.358
|
||||
# boundary corresponds to timesteps [0, 900)
|
||||
@@ -11,8 +11,9 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-T2V-A14B_high_noise_full" \
|
||||
--trainable_models "dit" \
|
||||
--max_timestep_boundary 1 \
|
||||
--min_timestep_boundary 0.875
|
||||
--max_timestep_boundary 0.417 \
|
||||
--min_timestep_boundary 0
|
||||
# boundary corresponds to timesteps [875, 1000]
|
||||
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
@@ -27,5 +28,6 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.2-T2V-A14B_low_noise_full" \
|
||||
--trainable_models "dit" \
|
||||
--max_timestep_boundary 0.875 \
|
||||
--min_timestep_boundary 0
|
||||
--max_timestep_boundary 1 \
|
||||
--min_timestep_boundary 0.417
|
||||
# boundary corresponds to timesteps [0, 875)
|
||||
@@ -14,8 +14,9 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "input_image" \
|
||||
--max_timestep_boundary 1 \
|
||||
--min_timestep_boundary 0.875
|
||||
--max_timestep_boundary 0.358 \
|
||||
--min_timestep_boundary 0
|
||||
# boundary corresponds to timesteps [900, 1000]
|
||||
|
||||
accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
@@ -33,5 +34,6 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "input_image" \
|
||||
--max_timestep_boundary 0.875 \
|
||||
--min_timestep_boundary 0
|
||||
--max_timestep_boundary 1 \
|
||||
--min_timestep_boundary 0.358
|
||||
# boundary corresponds to timesteps [0, 900)
|
||||
@@ -13,8 +13,9 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--max_timestep_boundary 1 \
|
||||
--min_timestep_boundary 0.875
|
||||
--max_timestep_boundary 0.417 \
|
||||
--min_timestep_boundary 0
|
||||
# boundary corresponds to timesteps [875, 1000]
|
||||
|
||||
|
||||
accelerate launch examples/wanvideo/model_training/train.py \
|
||||
@@ -32,5 +33,6 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--max_timestep_boundary 0.875 \
|
||||
--min_timestep_boundary 0
|
||||
--max_timestep_boundary 1 \
|
||||
--min_timestep_boundary 0.417
|
||||
# boundary corresponds to timesteps [0, 875)
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch, os, json
|
||||
from diffsynth import load_state_dict
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, ModelLogger, launch_training_task, wan_parser
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
@@ -10,7 +11,7 @@ class WanTrainingModule(DiffusionTrainingModule):
|
||||
self,
|
||||
model_paths=None, model_id_with_origin_paths=None,
|
||||
trainable_models=None,
|
||||
lora_base_model=None, lora_target_modules="q,k,v,o,ffn.0,ffn.2", lora_rank=32,
|
||||
lora_base_model=None, lora_target_modules="q,k,v,o,ffn.0,ffn.2", lora_rank=32, lora_checkpoint=None,
|
||||
use_gradient_checkpointing=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
extra_inputs=None,
|
||||
@@ -41,6 +42,12 @@ class WanTrainingModule(DiffusionTrainingModule):
|
||||
target_modules=lora_target_modules.split(","),
|
||||
lora_rank=lora_rank
|
||||
)
|
||||
if lora_checkpoint is not None:
|
||||
state_dict = load_state_dict(lora_checkpoint)
|
||||
state_dict = self.mapping_lora_state_dict(state_dict)
|
||||
load_result = model.load_state_dict(state_dict, strict=False)
|
||||
if len(load_result[1]) > 0:
|
||||
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
||||
setattr(self.pipe, lora_base_model, model)
|
||||
|
||||
# Store other configs
|
||||
@@ -112,6 +119,7 @@ if __name__ == "__main__":
|
||||
lora_base_model=args.lora_base_model,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
lora_rank=args.lora_rank,
|
||||
lora_checkpoint=args.lora_checkpoint,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
extra_inputs=args.extra_inputs,
|
||||
max_timestep_boundary=args.max_timestep_boundary,
|
||||
@@ -121,7 +129,7 @@ if __name__ == "__main__":
|
||||
args.output_path,
|
||||
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt
|
||||
)
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||
launch_training_task(
|
||||
dataset, model, model_logger, optimizer, scheduler,
|
||||
|
||||
119
test_interpolate.py
Normal file
119
test_interpolate.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, QwenImageUnit_PromptEmbedder, load_state_dict
|
||||
import torch, os
|
||||
from tqdm import tqdm
|
||||
from diffsynth.models.svd_unet import TemporalTimesteps
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
|
||||
class ValueEncoder(torch.nn.Module):
|
||||
def __init__(self, dim_in=256, dim_out=3584, value_emb_length=32):
|
||||
super().__init__()
|
||||
self.value_emb = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.positional_emb = torch.nn.Parameter(torch.randn(1, value_emb_length, dim_out))
|
||||
self.proj_value = torch.nn.Linear(dim_in, dim_out)
|
||||
self.proj_out = torch.nn.Linear(dim_out, dim_out)
|
||||
self.value_emb_length = value_emb_length
|
||||
|
||||
def forward(self, value):
|
||||
value = value * 1
|
||||
emb = self.value_emb(value).to(value.dtype)
|
||||
emb = self.proj_value(emb)
|
||||
emb = repeat(emb, "b d -> b s d", s=self.value_emb_length)
|
||||
emb = emb + self.positional_emb.to(dtype=emb.dtype, device=emb.device)
|
||||
emb = torch.nn.functional.silu(emb)
|
||||
emb = self.proj_out(emb)
|
||||
return emb
|
||||
|
||||
|
||||
class TextInterpolationModel(torch.nn.Module):
|
||||
def __init__(self, dim_in=256, dim_out=3584, value_emb_length=32, num_heads=32):
|
||||
super().__init__()
|
||||
self.to_q = ValueEncoder(dim_in=dim_in, dim_out=dim_out, value_emb_length=value_emb_length)
|
||||
self.xk_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
||||
self.yk_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
||||
self.xv_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
||||
self.yv_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
||||
self.to_k = torch.nn.Linear(dim_out, dim_out, bias=False)
|
||||
self.to_v = torch.nn.Linear(dim_out, dim_out, bias=False)
|
||||
self.to_out = torch.nn.Linear(dim_out, dim_out)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, value, x, y):
|
||||
q = self.to_q(value)
|
||||
k = self.to_k(torch.concat([x + self.xk_emb, y + self.yk_emb], dim=1))
|
||||
v = self.to_v(torch.concat([x + self.xv_emb, y + self.yv_emb], dim=1))
|
||||
q = rearrange(q, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
k = rearrange(k, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
v = rearrange(v, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
out = rearrange(out, 'b h s d -> b s (h d)')
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
unit = QwenImageUnit_PromptEmbedder()
|
||||
|
||||
dataset_prompt = [
|
||||
(
|
||||
"超级黑暗的画面,整体在黑暗中,暗无天日,暗淡无光,阴森黑暗,几乎全黑",
|
||||
"超级明亮的画面,爆闪,相机过曝,整个画面都是白色的眩光,几乎全是白色",
|
||||
),
|
||||
]
|
||||
dataset_tensors = []
|
||||
for prompt_x, prompt_y in tqdm(dataset_prompt):
|
||||
with torch.no_grad():
|
||||
x = unit.process(pipe, prompt_x)["prompt_emb"]
|
||||
y = unit.process(pipe, prompt_y)["prompt_emb"]
|
||||
dataset_tensors.append((x, y))
|
||||
|
||||
model = TextInterpolationModel().to(dtype=torch.bfloat16, device="cuda")
|
||||
model.load_state_dict(load_state_dict("models/interpolate.pth"))
|
||||
|
||||
def sample_tokens(emb, p):
|
||||
perm = torch.randperm(emb.shape[1])[:max(0, int(emb.shape[1]*p))]
|
||||
return emb[:, perm]
|
||||
|
||||
|
||||
def loss_fn(x, y):
|
||||
s, l = x.shape[1], y.shape[1]
|
||||
x = repeat(x, "b s d -> b s l d", l=l)
|
||||
y = repeat(y, "b l d -> b s l d", s=s)
|
||||
d = torch.square(x - y).mean(dim=-1)
|
||||
loss_x = d.min(dim=1).values.mean()
|
||||
loss_y = d.min(dim=2).values.mean()
|
||||
return loss_x + loss_y
|
||||
|
||||
|
||||
def get_target(x, y, p):
|
||||
x = sample_tokens(x, 1-p)
|
||||
y = sample_tokens(y, p)
|
||||
return torch.concat([x, y], dim=1)
|
||||
|
||||
name = "brightness"
|
||||
for i in range(6):
|
||||
v = i/5
|
||||
with torch.no_grad():
|
||||
data_id = 0
|
||||
x, y = dataset_tensors[data_id]
|
||||
x, y = x.to("cuda"), y.to("cuda")
|
||||
value = torch.tensor([v], dtype=torch.bfloat16, device="cuda")
|
||||
value_emb = model(value, x, y)
|
||||
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40, extra_prompt_emb=value_emb)
|
||||
os.makedirs(f"data/qwen_image_value/{name}", exist_ok=True)
|
||||
image.save(f"data/qwen_image_value/{name}/image_{v}.jpg")
|
||||
121
train_interpolate.py
Normal file
121
train_interpolate.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, QwenImageUnit_PromptEmbedder
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from diffsynth.models.svd_unet import TemporalTimesteps
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
|
||||
class ValueEncoder(torch.nn.Module):
|
||||
def __init__(self, dim_in=256, dim_out=3584, value_emb_length=32):
|
||||
super().__init__()
|
||||
self.value_emb = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.positional_emb = torch.nn.Parameter(torch.randn(1, value_emb_length, dim_out))
|
||||
self.proj_value = torch.nn.Linear(dim_in, dim_out)
|
||||
self.proj_out = torch.nn.Linear(dim_out, dim_out)
|
||||
self.value_emb_length = value_emb_length
|
||||
|
||||
def forward(self, value):
|
||||
value = value * 1
|
||||
emb = self.value_emb(value).to(value.dtype)
|
||||
emb = self.proj_value(emb)
|
||||
emb = repeat(emb, "b d -> b s d", s=self.value_emb_length)
|
||||
emb = emb + self.positional_emb.to(dtype=emb.dtype, device=emb.device)
|
||||
emb = torch.nn.functional.silu(emb)
|
||||
emb = self.proj_out(emb)
|
||||
return emb
|
||||
|
||||
|
||||
class TextInterpolationModel(torch.nn.Module):
|
||||
def __init__(self, dim_in=256, dim_out=3584, value_emb_length=32, num_heads=32):
|
||||
super().__init__()
|
||||
self.to_q = ValueEncoder(dim_in=dim_in, dim_out=dim_out, value_emb_length=value_emb_length)
|
||||
self.xk_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
||||
self.yk_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
||||
self.xv_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
||||
self.yv_emb = torch.nn.Parameter(torch.randn(1, 1, dim_out))
|
||||
self.to_k = torch.nn.Linear(dim_out, dim_out, bias=False)
|
||||
self.to_v = torch.nn.Linear(dim_out, dim_out, bias=False)
|
||||
self.to_out = torch.nn.Linear(dim_out, dim_out)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, value, x, y):
|
||||
q = self.to_q(value)
|
||||
k = self.to_k(torch.concat([x + self.xk_emb, y + self.yk_emb], dim=1))
|
||||
v = self.to_v(torch.concat([x + self.xv_emb, y + self.yv_emb], dim=1))
|
||||
q = rearrange(q, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
k = rearrange(k, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
v = rearrange(v, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
out = rearrange(out, 'b h s d -> b s (h d)')
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
|
||||
def sample_tokens(emb, p):
|
||||
perm = torch.randperm(emb.shape[1])[:max(0, int(emb.shape[1]*p))]
|
||||
return emb[:, perm]
|
||||
|
||||
|
||||
def loss_fn(x, y):
|
||||
s, l = x.shape[1], y.shape[1]
|
||||
x = repeat(x, "b s d -> b s l d", l=l)
|
||||
y = repeat(y, "b l d -> b s l d", s=s)
|
||||
d = torch.square(x - y).mean(dim=-1)
|
||||
loss_x = d.min(dim=1).values.mean()
|
||||
loss_y = d.min(dim=2).values.mean()
|
||||
return loss_x + loss_y
|
||||
|
||||
|
||||
def get_target(x, y, p):
|
||||
x = sample_tokens(x, 1-p)
|
||||
y = sample_tokens(y, p)
|
||||
return torch.concat([x, y], dim=1)
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
unit = QwenImageUnit_PromptEmbedder()
|
||||
|
||||
|
||||
dataset_prompt = [
|
||||
(
|
||||
"超级黑暗的画面,整体在黑暗中,暗无天日,暗淡无光,阴森黑暗,几乎全黑",
|
||||
"超级明亮的画面,爆闪,相机过曝,整个画面都是白色的眩光,几乎全是白色",
|
||||
),
|
||||
]
|
||||
|
||||
dataset_tensors = []
|
||||
for prompt_x, prompt_y in tqdm(dataset_prompt):
|
||||
with torch.no_grad():
|
||||
x = unit.process(pipe, prompt_x)["prompt_emb"].to(dtype=torch.float32, device="cpu")
|
||||
y = unit.process(pipe, prompt_y)["prompt_emb"].to(dtype=torch.float32, device="cpu")
|
||||
dataset_tensors.append((x, y))
|
||||
|
||||
model = TextInterpolationModel().to(dtype=torch.float32, device="cuda")
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
|
||||
|
||||
for step_id, step in enumerate(tqdm(range(100000))):
|
||||
optimizer.zero_grad()
|
||||
|
||||
data_id = torch.randint(0, len(dataset_tensors), size=(1,)).item()
|
||||
x, y = dataset_tensors[data_id]
|
||||
x, y = x.to("cuda"), y.to("cuda")
|
||||
|
||||
value = torch.rand((1,), dtype=torch.float32, device="cuda")
|
||||
out = model(value, x, y)
|
||||
loss = loss_fn(out, x) * (1 - value) + loss_fn(out, y) * value
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if (step_id + 1) % 1000 == 0:
|
||||
print(loss)
|
||||
|
||||
torch.save(model.state_dict(), f"models/interpolate_{step+1}.pth")
|
||||
Reference in New Issue
Block a user