mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 07:18:14 +00:00
Compare commits
60 Commits
z-image-om
...
diffsynth-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e443e2032f | ||
|
|
ec93f55855 | ||
|
|
0e03797fd1 | ||
|
|
41e2b47e1d | ||
|
|
17600eda0f | ||
|
|
6fbd9e94ec | ||
|
|
b323873bf0 | ||
|
|
7747f38561 | ||
|
|
4a15618080 | ||
|
|
9ecb9d8fe7 | ||
|
|
5c37fdcd8f | ||
|
|
d5a0aab2b2 | ||
|
|
3d4c92ef35 | ||
|
|
f7c2d54ebd | ||
|
|
92a742e0df | ||
|
|
81bcb39e82 | ||
|
|
4a80e9c179 | ||
|
|
5065c9ef6a | ||
|
|
ea1980ec4f | ||
|
|
2379387df2 | ||
|
|
62c94a9927 | ||
|
|
9048d2e9d4 | ||
|
|
20cf2317e0 | ||
|
|
b106458eac | ||
|
|
675ae5e91f | ||
|
|
1a6fd69e6b | ||
|
|
0b72c2b3ba | ||
|
|
fb892bd860 | ||
|
|
a112fb2e10 | ||
|
|
0b527c460f | ||
|
|
0eead33ed7 | ||
|
|
0336551544 | ||
|
|
0b7dd55ff3 | ||
|
|
96daa30bcc | ||
|
|
eeb55a0ce6 | ||
|
|
6ad8d73717 | ||
|
|
453ca89046 | ||
|
|
c119ce7e64 | ||
|
|
ff35fa56c2 | ||
|
|
cc85388d79 | ||
|
|
82378a2815 | ||
|
|
f85af085df | ||
|
|
2d23c897c2 | ||
|
|
3f9e9cad9d | ||
|
|
7b756a518e | ||
|
|
416ff5df74 | ||
|
|
47246060d6 | ||
|
|
ea0a5c5908 | ||
|
|
e3356556ee | ||
|
|
5be5c32fe4 | ||
|
|
cb70126c88 | ||
|
|
5e95a85281 | ||
|
|
eacec13309 | ||
|
|
ceb473efc0 | ||
|
|
bdedd46d4c | ||
|
|
74f8181f93 | ||
|
|
6a6eca7baf | ||
|
|
3afecc65fc | ||
|
|
d27917ad41 | ||
|
|
288fb7604c |
22
README.md
22
README.md
@@ -33,8 +33,6 @@ We believe that a well-developed open-source code framework can lower the thresh
|
|||||||
|
|
||||||
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
|
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
|
||||||
|
|
||||||
- **December 9, 2025** We release a wild model based on DiffSynth-Studio 2.0: [Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L) (Image-to-LoRA). This model takes an image as input and outputs a LoRA. Although this version still has significant room for improvement in terms of generalization, detail preservation, and other aspects, we are open-sourcing these models to inspire more innovative research. For more details, please refer to our [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l).
|
|
||||||
|
|
||||||
- **December 4, 2025** DiffSynth-Studio 2.0 released! Many new features online
|
- **December 4, 2025** DiffSynth-Studio 2.0 released! Many new features online
|
||||||
- [Documentation](/docs/en/README.md) online: Our documentation is still continuously being optimized and updated
|
- [Documentation](/docs/en/README.md) online: Our documentation is still continuously being optimized and updated
|
||||||
- [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md) module upgraded, supporting layer-level disk offload, releasing both memory and VRAM simultaneously
|
- [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md) module upgraded, supporting layer-level disk offload, releasing both memory and VRAM simultaneously
|
||||||
@@ -189,7 +187,21 @@ cd DiffSynth-Studio
|
|||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
For more installation methods and instructions for non-NVIDIA GPUs, please refer to the [Installation Guide](/docs/en/Pipeline_Usage/Setup.md).
|
<details>
|
||||||
|
<summary>Other installation methods</summary>
|
||||||
|
|
||||||
|
Install from PyPI (version updates may be delayed; for latest features, install from source)
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install diffsynth
|
||||||
|
```
|
||||||
|
|
||||||
|
If you meet problems during installation, they might be caused by upstream dependencies. Please check the docs of these packages:
|
||||||
|
|
||||||
|
* [torch](https://pytorch.org/get-started/locally/)
|
||||||
|
* [sentencepiece](https://github.com/google/sentencepiece)
|
||||||
|
* [cmake](https://cmake.org)
|
||||||
|
* [cupy](https://docs.cupy.dev/en/stable/install.html)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@@ -396,11 +408,8 @@ Example code for Qwen-Image is available at: [/examples/qwen_image/](/examples/q
|
|||||||
| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
||||||
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
|
|
||||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||||
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
||||||
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
|
||||||
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
||||||
@@ -411,7 +420,6 @@ Example code for Qwen-Image is available at: [/examples/qwen_image/](/examples/q
|
|||||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|
||||||
|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
|
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|||||||
22
README_zh.md
22
README_zh.md
@@ -33,8 +33,6 @@ DiffSynth 目前包括两个开源项目:
|
|||||||
|
|
||||||
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
|
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
|
||||||
|
|
||||||
- **2025年12月9日** 我们基于 DiffSynth-Studio 2.0 训练了一个疯狂的模型:[Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)(Image to LoRA)。这一模型以图像为输入,以 LoRA 为输出。尽管这个版本的模型在泛化能力、细节保持能力等方面还有很大改进空间,我们将这些模型开源,以启发更多创新性的研究工作。更多细节,请参考我们的 [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l)。
|
|
||||||
|
|
||||||
- **2025年12月4日** DiffSynth-Studio 2.0 发布!众多新功能上线
|
- **2025年12月4日** DiffSynth-Studio 2.0 发布!众多新功能上线
|
||||||
- [文档](/docs/zh/README.md)上线:我们的文档还在持续优化更新中
|
- [文档](/docs/zh/README.md)上线:我们的文档还在持续优化更新中
|
||||||
- [显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)模块升级,支持 Layer 级别的 Disk Offload,同时释放内存与显存
|
- [显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)模块升级,支持 Layer 级别的 Disk Offload,同时释放内存与显存
|
||||||
@@ -189,7 +187,21 @@ cd DiffSynth-Studio
|
|||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
更多安装方式,以及非 NVIDIA GPU 的安装,请参考[安装文档](/docs/zh/Pipeline_Usage/Setup.md)。
|
<details>
|
||||||
|
<summary>其他安装方式</summary>
|
||||||
|
|
||||||
|
从 pypi 安装(存在版本更新延迟,如需使用最新功能,请从源码安装)
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install diffsynth
|
||||||
|
```
|
||||||
|
|
||||||
|
如果在安装过程中遇到问题,可能是由上游依赖包导致的,请参考这些包的文档:
|
||||||
|
|
||||||
|
* [torch](https://pytorch.org/get-started/locally/)
|
||||||
|
* [sentencepiece](https://github.com/google/sentencepiece)
|
||||||
|
* [cmake](https://cmake.org)
|
||||||
|
* [cupy](https://docs.cupy.dev/en/stable/install.html)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@@ -396,11 +408,8 @@ Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/
|
|||||||
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
||||||
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
|
|
||||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||||
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
||||||
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
|
||||||
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
||||||
@@ -411,7 +420,6 @@ Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/
|
|||||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|
||||||
|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
|
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|||||||
@@ -31,52 +31,6 @@ qwen_image_series = [
|
|||||||
"model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
|
"model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
|
||||||
"extra_kwargs": {"additional_in_dim": 4},
|
"extra_kwargs": {"additional_in_dim": 4},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
# Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors")
|
|
||||||
"model_hash": "469c78b61e3e31bc9eec0d0af3d3f2f8",
|
|
||||||
"model_name": "siglip2_image_encoder",
|
|
||||||
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
# Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors")
|
|
||||||
"model_hash": "5722b5c873720009de96422993b15682",
|
|
||||||
"model_name": "dinov3_image_encoder",
|
|
||||||
"model_class": "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
# Example:
|
|
||||||
"model_hash": "a166c33455cdbd89c0888a3645ca5c0f",
|
|
||||||
"model_name": "qwen_image_image2lora_coarse",
|
|
||||||
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
# Example:
|
|
||||||
"model_hash": "a5476e691767a4da6d3a6634a10f7408",
|
|
||||||
"model_name": "qwen_image_image2lora_fine",
|
|
||||||
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
|
|
||||||
"extra_kwargs": {"residual_length": 37*37+7, "residual_mid_dim": 64}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
# Example:
|
|
||||||
"model_hash": "0aad514690602ecaff932c701cb4b0bb",
|
|
||||||
"model_name": "qwen_image_image2lora_style",
|
|
||||||
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
|
|
||||||
"extra_kwargs": {"compress_dim": 64, "use_residual": False}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
|
||||||
"model_hash": "8dc8cda05de16c73afa755e2c1ce2839",
|
|
||||||
"model_name": "qwen_image_dit",
|
|
||||||
"model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
|
|
||||||
"extra_kwargs": {"use_layer3d_rope": True, "use_additional_t_cond": True}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
|
||||||
"model_hash": "44b39ddc499e027cfb24f7878d7416b9",
|
|
||||||
"model_name": "qwen_image_vae",
|
|
||||||
"model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
|
|
||||||
"extra_kwargs": {"image_channels": 4}
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
wan_series = [
|
wan_series = [
|
||||||
@@ -527,32 +481,6 @@ z_image_series = [
|
|||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers",
|
||||||
"extra_kwargs": {"use_conv_attention": False},
|
"extra_kwargs": {"use_conv_attention": False},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors")
|
|
||||||
"model_hash": "aa3563718e5c3ecde3dfbb020ca61180",
|
|
||||||
"model_name": "z_image_dit",
|
|
||||||
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
|
|
||||||
"extra_kwargs": {"siglip_feat_dim": 1152},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors")
|
|
||||||
"model_hash": "89d48e420f45cff95115a9f3e698d44a",
|
|
||||||
"model_name": "siglip_vision_model_428m",
|
|
||||||
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
# Example: ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors")
|
|
||||||
"model_hash": "1677708d40029ab380a95f6c731a57d7",
|
|
||||||
"model_name": "z_image_controlnet",
|
|
||||||
"model_class": "diffsynth.models.z_image_controlnet.ZImageControlNet",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
# Example: ???
|
|
||||||
"model_hash": "9510cb8cd1dd34ee0e4f111c24905510",
|
|
||||||
"model_name": "z_image_image2lora_style",
|
|
||||||
"model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
|
|
||||||
"extra_kwargs": {"compress_dim": 128},
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series
|
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
|||||||
"diffsynth.models.qwen_image_dit.QwenImageDiT": {
|
"diffsynth.models.qwen_image_dit.QwenImageDiT": {
|
||||||
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
||||||
},
|
},
|
||||||
"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": {
|
"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": {
|
||||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
@@ -33,25 +32,6 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
|||||||
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
},
|
},
|
||||||
"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder": {
|
|
||||||
"transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
||||||
"transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
||||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
||||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
||||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
||||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
|
||||||
},
|
|
||||||
"diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder": {
|
|
||||||
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTLayerScale": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
||||||
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTRopePositionEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
||||||
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
||||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
||||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
||||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
|
||||||
},
|
|
||||||
"diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel": {
|
|
||||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
|
||||||
},
|
|
||||||
"diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter": {
|
"diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter": {
|
||||||
"diffsynth.models.wan_video_animate_adapter.FaceEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
|
"diffsynth.models.wan_video_animate_adapter.FaceEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
"diffsynth.models.wan_video_animate_adapter.EqualLinear": "diffsynth.core.vram.layers.AutoWrappedModule",
|
"diffsynth.models.wan_video_animate_adapter.EqualLinear": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
@@ -195,19 +175,4 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
|||||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
},
|
},
|
||||||
"diffsynth.models.z_image_controlnet.ZImageControlNet": {
|
|
||||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
|
||||||
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
||||||
},
|
|
||||||
"diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel": {
|
|
||||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
|
||||||
},
|
|
||||||
"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M": {
|
|
||||||
"transformers.models.siglip2.modeling_siglip2.Siglip2VisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
||||||
"transformers.models.siglip2.modeling_siglip2.Siglip2MultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
||||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
||||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
||||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
|
||||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,4 +3,3 @@ from .data import *
|
|||||||
from .gradient import *
|
from .gradient import *
|
||||||
from .loader import *
|
from .loader import *
|
||||||
from .vram import *
|
from .vram import *
|
||||||
from .device import *
|
|
||||||
|
|||||||
@@ -53,14 +53,12 @@ class ToStr(DataProcessingOperator):
|
|||||||
|
|
||||||
|
|
||||||
class LoadImage(DataProcessingOperator):
|
class LoadImage(DataProcessingOperator):
|
||||||
def __init__(self, convert_RGB=True, convert_RGBA=False):
|
def __init__(self, convert_RGB=True):
|
||||||
self.convert_RGB = convert_RGB
|
self.convert_RGB = convert_RGB
|
||||||
self.convert_RGBA = convert_RGBA
|
|
||||||
|
|
||||||
def __call__(self, data: str):
|
def __call__(self, data: str):
|
||||||
image = Image.open(data)
|
image = Image.open(data)
|
||||||
if self.convert_RGB: image = image.convert("RGB")
|
if self.convert_RGB: image = image.convert("RGB")
|
||||||
if self.convert_RGBA: image = image.convert("RGBA")
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type
|
|
||||||
@@ -1,107 +0,0 @@
|
|||||||
import importlib
|
|
||||||
import torch
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
def is_torch_npu_available():
|
|
||||||
return importlib.util.find_spec("torch_npu") is not None
|
|
||||||
|
|
||||||
|
|
||||||
IS_CUDA_AVAILABLE = torch.cuda.is_available()
|
|
||||||
IS_NPU_AVAILABLE = is_torch_npu_available() and torch.npu.is_available()
|
|
||||||
|
|
||||||
if IS_NPU_AVAILABLE:
|
|
||||||
import torch_npu
|
|
||||||
|
|
||||||
torch.npu.config.allow_internal_format = False
|
|
||||||
|
|
||||||
|
|
||||||
def get_device_type() -> str:
|
|
||||||
"""Get device type based on current machine, currently only support CPU, CUDA, NPU."""
|
|
||||||
if IS_CUDA_AVAILABLE:
|
|
||||||
device = "cuda"
|
|
||||||
elif IS_NPU_AVAILABLE:
|
|
||||||
device = "npu"
|
|
||||||
else:
|
|
||||||
device = "cpu"
|
|
||||||
|
|
||||||
return device
|
|
||||||
|
|
||||||
|
|
||||||
def get_torch_device() -> Any:
|
|
||||||
"""Get torch attribute based on device type, e.g. torch.cuda or torch.npu"""
|
|
||||||
device_name = get_device_type()
|
|
||||||
|
|
||||||
try:
|
|
||||||
return getattr(torch, device_name)
|
|
||||||
except AttributeError:
|
|
||||||
print(f"Device namespace '{device_name}' not found in torch, try to load 'torch.cuda'.")
|
|
||||||
return torch.cuda
|
|
||||||
|
|
||||||
|
|
||||||
def get_device_id() -> int:
|
|
||||||
"""Get current device id based on device type."""
|
|
||||||
return get_torch_device().current_device()
|
|
||||||
|
|
||||||
|
|
||||||
def get_device_name() -> str:
|
|
||||||
"""Get current device name based on device type."""
|
|
||||||
return f"{get_device_type()}:{get_device_id()}"
|
|
||||||
|
|
||||||
|
|
||||||
def synchronize() -> None:
|
|
||||||
"""Execute torch synchronize operation."""
|
|
||||||
get_torch_device().synchronize()
|
|
||||||
|
|
||||||
|
|
||||||
def empty_cache() -> None:
|
|
||||||
"""Execute torch empty cache operation."""
|
|
||||||
get_torch_device().empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
def get_nccl_backend() -> str:
|
|
||||||
"""Return distributed communication backend type based on device type."""
|
|
||||||
if IS_CUDA_AVAILABLE:
|
|
||||||
return "nccl"
|
|
||||||
elif IS_NPU_AVAILABLE:
|
|
||||||
return "hccl"
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"No available distributed communication backend found on device type {get_device_type()}.")
|
|
||||||
|
|
||||||
|
|
||||||
def enable_high_precision_for_bf16():
|
|
||||||
"""
|
|
||||||
Set high accumulation dtype for matmul and reduction.
|
|
||||||
"""
|
|
||||||
if IS_CUDA_AVAILABLE:
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = False
|
|
||||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
|
|
||||||
|
|
||||||
if IS_NPU_AVAILABLE:
|
|
||||||
torch.npu.matmul.allow_tf32 = False
|
|
||||||
torch.npu.matmul.allow_bf16_reduced_precision_reduction = False
|
|
||||||
|
|
||||||
|
|
||||||
def parse_device_type(device):
|
|
||||||
if isinstance(device, str):
|
|
||||||
if device.startswith("cuda"):
|
|
||||||
return "cuda"
|
|
||||||
elif device.startswith("npu"):
|
|
||||||
return "npu"
|
|
||||||
else:
|
|
||||||
return "cpu"
|
|
||||||
elif isinstance(device, torch.device):
|
|
||||||
return device.type
|
|
||||||
|
|
||||||
|
|
||||||
def parse_nccl_backend(device_type):
|
|
||||||
if device_type == "cuda":
|
|
||||||
return "nccl"
|
|
||||||
elif device_type == "npu":
|
|
||||||
return "hccl"
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"No available distributed communication backend found on device type {device_type}.")
|
|
||||||
|
|
||||||
|
|
||||||
def get_available_device_type():
|
|
||||||
return get_device_type()
|
|
||||||
@@ -97,7 +97,6 @@ class ModelConfig:
|
|||||||
self.reset_local_model_path()
|
self.reset_local_model_path()
|
||||||
if self.require_downloading():
|
if self.require_downloading():
|
||||||
self.download()
|
self.download()
|
||||||
if self.path is None:
|
|
||||||
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
||||||
self.path = os.path.join(self.local_model_path, self.model_id)
|
self.path = os.path.join(self.local_model_path, self.model_id)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import torch, copy
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
from .initialization import skip_model_initialization
|
from .initialization import skip_model_initialization
|
||||||
from .disk_map import DiskMap
|
from .disk_map import DiskMap
|
||||||
from ..device import parse_device_type
|
|
||||||
|
|
||||||
|
|
||||||
class AutoTorchModule(torch.nn.Module):
|
class AutoTorchModule(torch.nn.Module):
|
||||||
@@ -33,7 +32,6 @@ class AutoTorchModule(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
self.state = 0
|
self.state = 0
|
||||||
self.name = ""
|
self.name = ""
|
||||||
self.computation_device_type = parse_device_type(self.computation_device)
|
|
||||||
|
|
||||||
def set_dtype_and_device(
|
def set_dtype_and_device(
|
||||||
self,
|
self,
|
||||||
@@ -63,8 +61,7 @@ class AutoTorchModule(torch.nn.Module):
|
|||||||
return r
|
return r
|
||||||
|
|
||||||
def check_free_vram(self):
|
def check_free_vram(self):
|
||||||
device = self.computation_device if self.computation_device != "npu" else "npu:0"
|
gpu_mem_state = torch.cuda.mem_get_info(self.computation_device)
|
||||||
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)
|
|
||||||
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
|
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
|
||||||
return used_memory < self.vram_limit
|
return used_memory < self.vram_limit
|
||||||
|
|
||||||
@@ -310,7 +307,6 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
|||||||
self.lora_B_weights = []
|
self.lora_B_weights = []
|
||||||
self.lora_merger = None
|
self.lora_merger = None
|
||||||
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
|
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
|
||||||
self.computation_device_type = parse_device_type(self.computation_device)
|
|
||||||
|
|
||||||
if offload_dtype == "disk":
|
if offload_dtype == "disk":
|
||||||
self.disk_map = disk_map
|
self.disk_map = disk_map
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from einops import repeat, reduce
|
from einops import repeat, reduce
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
|
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig
|
||||||
from ..utils.lora import GeneralLoRALoader
|
from ..utils.lora import GeneralLoRALoader
|
||||||
from ..models.model_loader import ModelPool
|
from ..models.model_loader import ModelPool
|
||||||
from ..utils.controlnet import ControlNetInput
|
from ..utils.controlnet import ControlNetInput
|
||||||
@@ -68,7 +68,6 @@ class BasePipeline(torch.nn.Module):
|
|||||||
# The device and torch_dtype is used for the storage of intermediate variables, not models.
|
# The device and torch_dtype is used for the storage of intermediate variables, not models.
|
||||||
self.device = device
|
self.device = device
|
||||||
self.torch_dtype = torch_dtype
|
self.torch_dtype = torch_dtype
|
||||||
self.device_type = parse_device_type(device)
|
|
||||||
# The following parameters are used for shape check.
|
# The following parameters are used for shape check.
|
||||||
self.height_division_factor = height_division_factor
|
self.height_division_factor = height_division_factor
|
||||||
self.width_division_factor = width_division_factor
|
self.width_division_factor = width_division_factor
|
||||||
@@ -155,7 +154,7 @@ class BasePipeline(torch.nn.Module):
|
|||||||
for module in model.modules():
|
for module in model.modules():
|
||||||
if hasattr(module, "offload"):
|
if hasattr(module, "offload"):
|
||||||
module.offload()
|
module.offload()
|
||||||
getattr(torch, self.device_type).empty_cache()
|
torch.cuda.empty_cache()
|
||||||
# onload models
|
# onload models
|
||||||
for name, model in self.named_children():
|
for name, model in self.named_children():
|
||||||
if name in model_names:
|
if name in model_names:
|
||||||
@@ -177,8 +176,7 @@ class BasePipeline(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def get_vram(self):
|
def get_vram(self):
|
||||||
device = self.device if self.device != "npu" else "npu:0"
|
return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
|
||||||
return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)
|
|
||||||
|
|
||||||
def get_module(self, model, name):
|
def get_module(self, model, name):
|
||||||
if "." in name:
|
if "." in name:
|
||||||
@@ -235,7 +233,6 @@ class BasePipeline(torch.nn.Module):
|
|||||||
alpha=1,
|
alpha=1,
|
||||||
hotload=None,
|
hotload=None,
|
||||||
state_dict=None,
|
state_dict=None,
|
||||||
verbose=1,
|
|
||||||
):
|
):
|
||||||
if state_dict is None:
|
if state_dict is None:
|
||||||
if isinstance(lora_config, str):
|
if isinstance(lora_config, str):
|
||||||
@@ -262,13 +259,12 @@ class BasePipeline(torch.nn.Module):
|
|||||||
updated_num += 1
|
updated_num += 1
|
||||||
module.lora_A_weights.append(lora[lora_a_name] * alpha)
|
module.lora_A_weights.append(lora[lora_a_name] * alpha)
|
||||||
module.lora_B_weights.append(lora[lora_b_name])
|
module.lora_B_weights.append(lora[lora_b_name])
|
||||||
if verbose >= 1:
|
print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.")
|
||||||
print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.")
|
|
||||||
else:
|
else:
|
||||||
lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha)
|
lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha)
|
||||||
|
|
||||||
|
|
||||||
def clear_lora(self, verbose=1):
|
def clear_lora(self):
|
||||||
cleared_num = 0
|
cleared_num = 0
|
||||||
for name, module in self.named_modules():
|
for name, module in self.named_modules():
|
||||||
if isinstance(module, AutoWrappedLinear):
|
if isinstance(module, AutoWrappedLinear):
|
||||||
@@ -278,8 +274,7 @@ class BasePipeline(torch.nn.Module):
|
|||||||
module.lora_A_weights.clear()
|
module.lora_A_weights.clear()
|
||||||
if hasattr(module, "lora_B_weights"):
|
if hasattr(module, "lora_B_weights"):
|
||||||
module.lora_B_weights.clear()
|
module.lora_B_weights.clear()
|
||||||
if verbose >= 1:
|
print(f"{cleared_num} LoRA layers are cleared.")
|
||||||
print(f"{cleared_num} LoRA layers are cleared.")
|
|
||||||
|
|
||||||
|
|
||||||
def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None):
|
def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None):
|
||||||
@@ -307,13 +302,8 @@ class BasePipeline(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
|
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
|
||||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
|
||||||
self.clear_lora(verbose=0)
|
|
||||||
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
|
|
||||||
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
|
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
|
||||||
self.clear_lora(verbose=0)
|
|
||||||
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,94 +0,0 @@
|
|||||||
from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
|
|
||||||
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class DINOv3ImageEncoder(DINOv3ViTModel):
|
|
||||||
def __init__(self):
|
|
||||||
config = DINOv3ViTConfig(
|
|
||||||
architectures = [
|
|
||||||
"DINOv3ViTModel"
|
|
||||||
],
|
|
||||||
attention_dropout = 0.0,
|
|
||||||
drop_path_rate = 0.0,
|
|
||||||
dtype = "float32",
|
|
||||||
hidden_act = "silu",
|
|
||||||
hidden_size = 4096,
|
|
||||||
image_size = 224,
|
|
||||||
initializer_range = 0.02,
|
|
||||||
intermediate_size = 8192,
|
|
||||||
key_bias = False,
|
|
||||||
layer_norm_eps = 1e-05,
|
|
||||||
layerscale_value = 1.0,
|
|
||||||
mlp_bias = True,
|
|
||||||
model_type = "dinov3_vit",
|
|
||||||
num_attention_heads = 32,
|
|
||||||
num_channels = 3,
|
|
||||||
num_hidden_layers = 40,
|
|
||||||
num_register_tokens = 4,
|
|
||||||
patch_size = 16,
|
|
||||||
pos_embed_jitter = None,
|
|
||||||
pos_embed_rescale = 2.0,
|
|
||||||
pos_embed_shift = None,
|
|
||||||
proj_bias = True,
|
|
||||||
query_bias = False,
|
|
||||||
rope_theta = 100.0,
|
|
||||||
transformers_version = "4.56.1",
|
|
||||||
use_gated_mlp = True,
|
|
||||||
value_bias = False
|
|
||||||
)
|
|
||||||
super().__init__(config)
|
|
||||||
self.processor = DINOv3ViTImageProcessorFast(
|
|
||||||
crop_size = None,
|
|
||||||
data_format = "channels_first",
|
|
||||||
default_to_square = True,
|
|
||||||
device = None,
|
|
||||||
disable_grouping = None,
|
|
||||||
do_center_crop = None,
|
|
||||||
do_convert_rgb = None,
|
|
||||||
do_normalize = True,
|
|
||||||
do_rescale = True,
|
|
||||||
do_resize = True,
|
|
||||||
image_mean = [
|
|
||||||
0.485,
|
|
||||||
0.456,
|
|
||||||
0.406
|
|
||||||
],
|
|
||||||
image_processor_type = "DINOv3ViTImageProcessorFast",
|
|
||||||
image_std = [
|
|
||||||
0.229,
|
|
||||||
0.224,
|
|
||||||
0.225
|
|
||||||
],
|
|
||||||
input_data_format = None,
|
|
||||||
resample = 2,
|
|
||||||
rescale_factor = 0.00392156862745098,
|
|
||||||
return_tensors = None,
|
|
||||||
size = {
|
|
||||||
"height": 224,
|
|
||||||
"width": 224
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
|
|
||||||
inputs = self.processor(images=image, return_tensors="pt")
|
|
||||||
pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device)
|
|
||||||
bool_masked_pos = None
|
|
||||||
head_mask = None
|
|
||||||
|
|
||||||
pixel_values = pixel_values.to(torch_dtype)
|
|
||||||
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
|
||||||
position_embeddings = self.rope_embeddings(pixel_values)
|
|
||||||
|
|
||||||
for i, layer_module in enumerate(self.layer):
|
|
||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
|
||||||
hidden_states = layer_module(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=layer_head_mask,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
)
|
|
||||||
|
|
||||||
sequence_output = self.norm(hidden_states)
|
|
||||||
pooled_output = sequence_output[:, 0, :]
|
|
||||||
|
|
||||||
return pooled_output
|
|
||||||
@@ -19,7 +19,7 @@ def get_timestep_embedding(
|
|||||||
)
|
)
|
||||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||||
|
|
||||||
emb = torch.exp(exponent)
|
emb = torch.exp(exponent).to(timesteps.device)
|
||||||
if align_dtype_to_timestep:
|
if align_dtype_to_timestep:
|
||||||
emb = emb.to(timesteps.dtype)
|
emb = emb.to(timesteps.dtype)
|
||||||
emb = timesteps[:, None].float() * emb[None, :]
|
emb = timesteps[:, None].float() * emb[None, :]
|
||||||
@@ -78,7 +78,7 @@ class DiffusersCompatibleTimestepProj(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TimestepEmbeddings(torch.nn.Module):
|
class TimestepEmbeddings(torch.nn.Module):
|
||||||
def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False, use_additional_t_cond=False):
|
def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep)
|
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep)
|
||||||
if diffusers_compatible_format:
|
if diffusers_compatible_format:
|
||||||
@@ -87,17 +87,10 @@ class TimestepEmbeddings(torch.nn.Module):
|
|||||||
self.timestep_embedder = torch.nn.Sequential(
|
self.timestep_embedder = torch.nn.Sequential(
|
||||||
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||||
)
|
)
|
||||||
self.use_additional_t_cond = use_additional_t_cond
|
|
||||||
if use_additional_t_cond:
|
|
||||||
self.addition_t_embedding = torch.nn.Embedding(2, dim_out)
|
|
||||||
|
|
||||||
def forward(self, timestep, dtype, addition_t_cond=None):
|
def forward(self, timestep, dtype):
|
||||||
time_emb = self.time_proj(timestep).to(dtype)
|
time_emb = self.time_proj(timestep).to(dtype)
|
||||||
time_emb = self.timestep_embedder(time_emb)
|
time_emb = self.timestep_embedder(time_emb)
|
||||||
if addition_t_cond is not None:
|
|
||||||
addition_t_emb = self.addition_t_embedding(addition_t_cond)
|
|
||||||
addition_t_emb = addition_t_emb.to(dtype=dtype)
|
|
||||||
time_emb = time_emb + addition_t_emb
|
|
||||||
return time_emb
|
return time_emb
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import torch, math, functools
|
import torch, math
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import Tuple, Optional, Union, List
|
from typing import Tuple, Optional, Union, List
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
@@ -225,121 +225,6 @@ class QwenEmbedRope(nn.Module):
|
|||||||
return vid_freqs, txt_freqs
|
return vid_freqs, txt_freqs
|
||||||
|
|
||||||
|
|
||||||
class QwenEmbedLayer3DRope(nn.Module):
|
|
||||||
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
|
|
||||||
super().__init__()
|
|
||||||
self.theta = theta
|
|
||||||
self.axes_dim = axes_dim
|
|
||||||
pos_index = torch.arange(4096)
|
|
||||||
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
|
||||||
self.pos_freqs = torch.cat(
|
|
||||||
[
|
|
||||||
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
|
||||||
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
|
||||||
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
|
||||||
],
|
|
||||||
dim=1,
|
|
||||||
)
|
|
||||||
self.neg_freqs = torch.cat(
|
|
||||||
[
|
|
||||||
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
|
||||||
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
|
||||||
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
|
||||||
],
|
|
||||||
dim=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.scale_rope = scale_rope
|
|
||||||
|
|
||||||
def rope_params(self, index, dim, theta=10000):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
|
|
||||||
"""
|
|
||||||
assert dim % 2 == 0
|
|
||||||
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
|
|
||||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
|
||||||
return freqs
|
|
||||||
|
|
||||||
def forward(self, video_fhw, txt_seq_lens, device):
|
|
||||||
"""
|
|
||||||
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
|
|
||||||
txt_length: [bs] a list of 1 integers representing the length of the text
|
|
||||||
"""
|
|
||||||
if self.pos_freqs.device != device:
|
|
||||||
self.pos_freqs = self.pos_freqs.to(device)
|
|
||||||
self.neg_freqs = self.neg_freqs.to(device)
|
|
||||||
|
|
||||||
video_fhw = [video_fhw]
|
|
||||||
if isinstance(video_fhw, list):
|
|
||||||
video_fhw = video_fhw[0]
|
|
||||||
if not isinstance(video_fhw, list):
|
|
||||||
video_fhw = [video_fhw]
|
|
||||||
|
|
||||||
vid_freqs = []
|
|
||||||
max_vid_index = 0
|
|
||||||
layer_num = len(video_fhw) - 1
|
|
||||||
for idx, fhw in enumerate(video_fhw):
|
|
||||||
frame, height, width = fhw
|
|
||||||
if idx != layer_num:
|
|
||||||
video_freq = self._compute_video_freqs(frame, height, width, idx)
|
|
||||||
else:
|
|
||||||
### For the condition image, we set the layer index to -1
|
|
||||||
video_freq = self._compute_condition_freqs(frame, height, width)
|
|
||||||
video_freq = video_freq.to(device)
|
|
||||||
vid_freqs.append(video_freq)
|
|
||||||
|
|
||||||
if self.scale_rope:
|
|
||||||
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
|
||||||
else:
|
|
||||||
max_vid_index = max(height, width, max_vid_index)
|
|
||||||
|
|
||||||
max_vid_index = max(max_vid_index, layer_num)
|
|
||||||
max_len = max(txt_seq_lens)
|
|
||||||
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
|
||||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
|
||||||
|
|
||||||
return vid_freqs, txt_freqs
|
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=None)
|
|
||||||
def _compute_video_freqs(self, frame, height, width, idx=0):
|
|
||||||
seq_lens = frame * height * width
|
|
||||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
|
||||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
|
||||||
|
|
||||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
|
||||||
if self.scale_rope:
|
|
||||||
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
|
|
||||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
|
||||||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
|
||||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
|
||||||
else:
|
|
||||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
|
||||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
|
||||||
|
|
||||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
|
||||||
return freqs.clone().contiguous()
|
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=None)
|
|
||||||
def _compute_condition_freqs(self, frame, height, width):
|
|
||||||
seq_lens = frame * height * width
|
|
||||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
|
||||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
|
||||||
|
|
||||||
freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
|
||||||
if self.scale_rope:
|
|
||||||
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
|
|
||||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
|
||||||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
|
||||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
|
||||||
else:
|
|
||||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
|
||||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
|
||||||
|
|
||||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
|
||||||
return freqs.clone().contiguous()
|
|
||||||
|
|
||||||
|
|
||||||
class QwenFeedForward(nn.Module):
|
class QwenFeedForward(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -467,38 +352,9 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||||
self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim)
|
self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim)
|
||||||
|
|
||||||
def _modulate(self, x, mod_params, index=None):
|
def _modulate(self, x, mod_params):
|
||||||
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
||||||
if index is not None:
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
||||||
# Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts)
|
|
||||||
# So shift, scale, gate have shape [2*actual_batch, d]
|
|
||||||
actual_batch = shift.size(0) // 2
|
|
||||||
shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d]
|
|
||||||
scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]
|
|
||||||
gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]
|
|
||||||
|
|
||||||
# index: [b, l] where b is actual batch size
|
|
||||||
# Expand to [b, l, 1] to match feature dimension
|
|
||||||
index_expanded = index.unsqueeze(-1) # [b, l, 1]
|
|
||||||
|
|
||||||
# Expand chunks to [b, 1, d] then broadcast to [b, l, d]
|
|
||||||
shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d]
|
|
||||||
shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d]
|
|
||||||
scale_0_exp = scale_0.unsqueeze(1)
|
|
||||||
scale_1_exp = scale_1.unsqueeze(1)
|
|
||||||
gate_0_exp = gate_0.unsqueeze(1)
|
|
||||||
gate_1_exp = gate_1.unsqueeze(1)
|
|
||||||
|
|
||||||
# Use torch.where to select based on index
|
|
||||||
shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp)
|
|
||||||
scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp)
|
|
||||||
gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp)
|
|
||||||
else:
|
|
||||||
shift_result = shift.unsqueeze(1)
|
|
||||||
scale_result = scale.unsqueeze(1)
|
|
||||||
gate_result = gate.unsqueeze(1)
|
|
||||||
|
|
||||||
return x * (1 + scale_result) + shift_result, gate_result
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -508,16 +364,13 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
enable_fp8_attention = False,
|
enable_fp8_attention = False,
|
||||||
modulate_index: Optional[List[int]] = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
||||||
if modulate_index is not None:
|
|
||||||
temb = torch.chunk(temb, 2, dim=0)[0]
|
|
||||||
txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
||||||
|
|
||||||
img_normed = self.img_norm1(image)
|
img_normed = self.img_norm1(image)
|
||||||
img_modulated, img_gate = self._modulate(img_normed, img_mod_attn, index=modulate_index)
|
img_modulated, img_gate = self._modulate(img_normed, img_mod_attn)
|
||||||
|
|
||||||
txt_normed = self.txt_norm1(text)
|
txt_normed = self.txt_norm1(text)
|
||||||
txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn)
|
txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn)
|
||||||
@@ -534,7 +387,7 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
text = text + txt_gate * txt_attn_out
|
text = text + txt_gate * txt_attn_out
|
||||||
|
|
||||||
img_normed_2 = self.img_norm2(image)
|
img_normed_2 = self.img_norm2(image)
|
||||||
img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp, index=modulate_index)
|
img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp)
|
||||||
|
|
||||||
txt_normed_2 = self.txt_norm2(text)
|
txt_normed_2 = self.txt_norm2(text)
|
||||||
txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp)
|
txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp)
|
||||||
@@ -552,17 +405,12 @@ class QwenImageDiT(torch.nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_layers: int = 60,
|
num_layers: int = 60,
|
||||||
use_layer3d_rope: bool = False,
|
|
||||||
use_additional_t_cond: bool = False,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if not use_layer3d_rope:
|
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
|
||||||
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
|
|
||||||
else:
|
|
||||||
self.pos_embed = QwenEmbedLayer3DRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
|
|
||||||
|
|
||||||
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=False, use_additional_t_cond=use_additional_t_cond)
|
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True)
|
||||||
self.txt_norm = RMSNorm(3584, eps=1e-6)
|
self.txt_norm = RMSNorm(3584, eps=1e-6)
|
||||||
|
|
||||||
self.img_in = nn.Linear(64, 3072)
|
self.img_in = nn.Linear(64, 3072)
|
||||||
|
|||||||
@@ -1,128 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class CompressedMLP(torch.nn.Module):
|
|
||||||
def __init__(self, in_dim, mid_dim, out_dim, bias=False):
|
|
||||||
super().__init__()
|
|
||||||
self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias)
|
|
||||||
self.proj_out = torch.nn.Linear(mid_dim, out_dim, bias=bias)
|
|
||||||
|
|
||||||
def forward(self, x, residual=None):
|
|
||||||
x = self.proj_in(x)
|
|
||||||
if residual is not None: x = x + residual
|
|
||||||
x = self.proj_out(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ImageEmbeddingToLoraMatrix(torch.nn.Module):
|
|
||||||
def __init__(self, in_dim, compress_dim, lora_a_dim, lora_b_dim, rank):
|
|
||||||
super().__init__()
|
|
||||||
self.proj_a = CompressedMLP(in_dim, compress_dim, lora_a_dim * rank)
|
|
||||||
self.proj_b = CompressedMLP(in_dim, compress_dim, lora_b_dim * rank)
|
|
||||||
self.lora_a_dim = lora_a_dim
|
|
||||||
self.lora_b_dim = lora_b_dim
|
|
||||||
self.rank = rank
|
|
||||||
|
|
||||||
def forward(self, x, residual=None):
|
|
||||||
lora_a = self.proj_a(x, residual).view(self.rank, self.lora_a_dim)
|
|
||||||
lora_b = self.proj_b(x, residual).view(self.lora_b_dim, self.rank)
|
|
||||||
return lora_a, lora_b
|
|
||||||
|
|
||||||
|
|
||||||
class SequencialMLP(torch.nn.Module):
|
|
||||||
def __init__(self, length, in_dim, mid_dim, out_dim, bias=False):
|
|
||||||
super().__init__()
|
|
||||||
self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias)
|
|
||||||
self.proj_out = torch.nn.Linear(length * mid_dim, out_dim, bias=bias)
|
|
||||||
self.length = length
|
|
||||||
self.in_dim = in_dim
|
|
||||||
self.mid_dim = mid_dim
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = x.view(self.length, self.in_dim)
|
|
||||||
x = self.proj_in(x)
|
|
||||||
x = x.view(1, self.length * self.mid_dim)
|
|
||||||
x = self.proj_out(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class LoRATrainerBlock(torch.nn.Module):
|
|
||||||
def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024):
|
|
||||||
super().__init__()
|
|
||||||
self.lora_patterns = lora_patterns
|
|
||||||
self.block_id = block_id
|
|
||||||
self.layers = []
|
|
||||||
for name, lora_a_dim, lora_b_dim in self.lora_patterns:
|
|
||||||
self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank))
|
|
||||||
self.layers = torch.nn.ModuleList(self.layers)
|
|
||||||
if use_residual:
|
|
||||||
self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim)
|
|
||||||
else:
|
|
||||||
self.proj_residual = None
|
|
||||||
|
|
||||||
def forward(self, x, residual=None):
|
|
||||||
lora = {}
|
|
||||||
if self.proj_residual is not None: residual = self.proj_residual(residual)
|
|
||||||
for lora_pattern, layer in zip(self.lora_patterns, self.layers):
|
|
||||||
name = lora_pattern[0]
|
|
||||||
lora_a, lora_b = layer(x, residual=residual)
|
|
||||||
lora[f"transformer_blocks.{self.block_id}.{name}.lora_A.default.weight"] = lora_a
|
|
||||||
lora[f"transformer_blocks.{self.block_id}.{name}.lora_B.default.weight"] = lora_b
|
|
||||||
return lora
|
|
||||||
|
|
||||||
|
|
||||||
class QwenImageImage2LoRAModel(torch.nn.Module):
|
|
||||||
def __init__(self, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024):
|
|
||||||
super().__init__()
|
|
||||||
self.lora_patterns = [
|
|
||||||
[
|
|
||||||
("attn.to_q", 3072, 3072),
|
|
||||||
("attn.to_k", 3072, 3072),
|
|
||||||
("attn.to_v", 3072, 3072),
|
|
||||||
("attn.to_out.0", 3072, 3072),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
("img_mlp.net.2", 3072*4, 3072),
|
|
||||||
("img_mod.1", 3072, 3072*6),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
("attn.add_q_proj", 3072, 3072),
|
|
||||||
("attn.add_k_proj", 3072, 3072),
|
|
||||||
("attn.add_v_proj", 3072, 3072),
|
|
||||||
("attn.to_add_out", 3072, 3072),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
("txt_mlp.net.2", 3072*4, 3072),
|
|
||||||
("txt_mod.1", 3072, 3072*6),
|
|
||||||
],
|
|
||||||
]
|
|
||||||
self.num_blocks = num_blocks
|
|
||||||
self.blocks = []
|
|
||||||
for lora_patterns in self.lora_patterns:
|
|
||||||
for block_id in range(self.num_blocks):
|
|
||||||
self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim))
|
|
||||||
self.blocks = torch.nn.ModuleList(self.blocks)
|
|
||||||
self.residual_scale = 0.05
|
|
||||||
self.use_residual = use_residual
|
|
||||||
|
|
||||||
def forward(self, x, residual=None):
|
|
||||||
if residual is not None:
|
|
||||||
if self.use_residual:
|
|
||||||
residual = residual * self.residual_scale
|
|
||||||
else:
|
|
||||||
residual = None
|
|
||||||
lora = {}
|
|
||||||
for block in self.blocks:
|
|
||||||
lora.update(block(x, residual))
|
|
||||||
return lora
|
|
||||||
|
|
||||||
def initialize_weights(self):
|
|
||||||
state_dict = self.state_dict()
|
|
||||||
for name in state_dict:
|
|
||||||
if ".proj_a." in name:
|
|
||||||
state_dict[name] = state_dict[name] * 0.3
|
|
||||||
elif ".proj_b.proj_out." in name:
|
|
||||||
state_dict[name] = state_dict[name] * 0
|
|
||||||
elif ".proj_residual.proj_out." in name:
|
|
||||||
state_dict[name] = state_dict[name] * 0.3
|
|
||||||
self.load_state_dict(state_dict)
|
|
||||||
@@ -366,7 +366,6 @@ class QwenImageEncoder3d(nn.Module):
|
|||||||
temperal_downsample=[True, True, False],
|
temperal_downsample=[True, True, False],
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
non_linearity: str = "silu",
|
non_linearity: str = "silu",
|
||||||
image_channels=3
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@@ -382,7 +381,7 @@ class QwenImageEncoder3d(nn.Module):
|
|||||||
scale = 1.0
|
scale = 1.0
|
||||||
|
|
||||||
# init block
|
# init block
|
||||||
self.conv_in = QwenImageCausalConv3d(image_channels, dims[0], 3, padding=1)
|
self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1)
|
||||||
|
|
||||||
# downsample blocks
|
# downsample blocks
|
||||||
self.down_blocks = torch.nn.ModuleList([])
|
self.down_blocks = torch.nn.ModuleList([])
|
||||||
@@ -545,7 +544,6 @@ class QwenImageDecoder3d(nn.Module):
|
|||||||
temperal_upsample=[False, True, True],
|
temperal_upsample=[False, True, True],
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
non_linearity: str = "silu",
|
non_linearity: str = "silu",
|
||||||
image_channels=3,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@@ -596,7 +594,7 @@ class QwenImageDecoder3d(nn.Module):
|
|||||||
|
|
||||||
# output blocks
|
# output blocks
|
||||||
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
|
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
|
||||||
self.conv_out = QwenImageCausalConv3d(out_dim, image_channels, 3, padding=1)
|
self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
@@ -649,7 +647,6 @@ class QwenImageVAE(torch.nn.Module):
|
|||||||
attn_scales: List[float] = [],
|
attn_scales: List[float] = [],
|
||||||
temperal_downsample: List[bool] = [False, True, True],
|
temperal_downsample: List[bool] = [False, True, True],
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
image_channels: int = 3,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -658,13 +655,13 @@ class QwenImageVAE(torch.nn.Module):
|
|||||||
self.temperal_upsample = temperal_downsample[::-1]
|
self.temperal_upsample = temperal_downsample[::-1]
|
||||||
|
|
||||||
self.encoder = QwenImageEncoder3d(
|
self.encoder = QwenImageEncoder3d(
|
||||||
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, image_channels=image_channels,
|
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
|
||||||
)
|
)
|
||||||
self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
|
self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||||
self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
|
self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
|
||||||
|
|
||||||
self.decoder = QwenImageDecoder3d(
|
self.decoder = QwenImageDecoder3d(
|
||||||
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, image_channels=image_channels,
|
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
|
||||||
)
|
)
|
||||||
|
|
||||||
mean = [
|
mean = [
|
||||||
|
|||||||
@@ -1,132 +0,0 @@
|
|||||||
from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig
|
|
||||||
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class Siglip2ImageEncoder(SiglipVisionTransformer):
|
|
||||||
def __init__(self):
|
|
||||||
config = SiglipVisionConfig(
|
|
||||||
attention_dropout = 0.0,
|
|
||||||
dtype = "float32",
|
|
||||||
hidden_act = "gelu_pytorch_tanh",
|
|
||||||
hidden_size = 1536,
|
|
||||||
image_size = 384,
|
|
||||||
intermediate_size = 6144,
|
|
||||||
layer_norm_eps = 1e-06,
|
|
||||||
model_type = "siglip_vision_model",
|
|
||||||
num_attention_heads = 16,
|
|
||||||
num_channels = 3,
|
|
||||||
num_hidden_layers = 40,
|
|
||||||
patch_size = 16,
|
|
||||||
transformers_version = "4.56.1",
|
|
||||||
_attn_implementation = "sdpa"
|
|
||||||
)
|
|
||||||
super().__init__(config)
|
|
||||||
self.processor = SiglipImageProcessor(
|
|
||||||
do_convert_rgb = None,
|
|
||||||
do_normalize = True,
|
|
||||||
do_rescale = True,
|
|
||||||
do_resize = True,
|
|
||||||
image_mean = [
|
|
||||||
0.5,
|
|
||||||
0.5,
|
|
||||||
0.5
|
|
||||||
],
|
|
||||||
image_processor_type = "SiglipImageProcessor",
|
|
||||||
image_std = [
|
|
||||||
0.5,
|
|
||||||
0.5,
|
|
||||||
0.5
|
|
||||||
],
|
|
||||||
processor_class = "SiglipProcessor",
|
|
||||||
resample = 2,
|
|
||||||
rescale_factor = 0.00392156862745098,
|
|
||||||
size = {
|
|
||||||
"height": 384,
|
|
||||||
"width": 384
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
|
|
||||||
pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"]
|
|
||||||
pixel_values = pixel_values.to(device=device, dtype=torch_dtype)
|
|
||||||
output_attentions = False
|
|
||||||
output_hidden_states = False
|
|
||||||
interpolate_pos_encoding = False
|
|
||||||
|
|
||||||
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
|
||||||
inputs_embeds=hidden_states,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
)
|
|
||||||
|
|
||||||
last_hidden_state = encoder_outputs.last_hidden_state
|
|
||||||
last_hidden_state = self.post_layernorm(last_hidden_state)
|
|
||||||
|
|
||||||
pooler_output = self.head(last_hidden_state) if self.use_head else None
|
|
||||||
|
|
||||||
return pooler_output
|
|
||||||
|
|
||||||
|
|
||||||
class Siglip2ImageEncoder428M(Siglip2VisionModel):
|
|
||||||
def __init__(self):
|
|
||||||
config = Siglip2VisionConfig(
|
|
||||||
attention_dropout = 0.0,
|
|
||||||
dtype = "bfloat16",
|
|
||||||
hidden_act = "gelu_pytorch_tanh",
|
|
||||||
hidden_size = 1152,
|
|
||||||
intermediate_size = 4304,
|
|
||||||
layer_norm_eps = 1e-06,
|
|
||||||
model_type = "siglip2_vision_model",
|
|
||||||
num_attention_heads = 16,
|
|
||||||
num_channels = 3,
|
|
||||||
num_hidden_layers = 27,
|
|
||||||
num_patches = 256,
|
|
||||||
patch_size = 16,
|
|
||||||
transformers_version = "4.57.1"
|
|
||||||
)
|
|
||||||
super().__init__(config)
|
|
||||||
self.processor = Siglip2ImageProcessorFast(
|
|
||||||
**{
|
|
||||||
"data_format": "channels_first",
|
|
||||||
"default_to_square": True,
|
|
||||||
"device": None,
|
|
||||||
"disable_grouping": None,
|
|
||||||
"do_convert_rgb": None,
|
|
||||||
"do_normalize": True,
|
|
||||||
"do_pad": None,
|
|
||||||
"do_rescale": True,
|
|
||||||
"do_resize": True,
|
|
||||||
"image_mean": [
|
|
||||||
0.5,
|
|
||||||
0.5,
|
|
||||||
0.5
|
|
||||||
],
|
|
||||||
"image_processor_type": "Siglip2ImageProcessorFast",
|
|
||||||
"image_std": [
|
|
||||||
0.5,
|
|
||||||
0.5,
|
|
||||||
0.5
|
|
||||||
],
|
|
||||||
"input_data_format": None,
|
|
||||||
"max_num_patches": 256,
|
|
||||||
"pad_size": None,
|
|
||||||
"patch_size": 16,
|
|
||||||
"processor_class": "Siglip2Processor",
|
|
||||||
"resample": 2,
|
|
||||||
"rescale_factor": 0.00392156862745098,
|
|
||||||
"return_tensors": None,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
|
|
||||||
siglip_inputs = self.processor(images=[image], return_tensors="pt").to(device)
|
|
||||||
shape = siglip_inputs.spatial_shapes[0]
|
|
||||||
hidden_state = super().forward(**siglip_inputs).last_hidden_state
|
|
||||||
B, N, C = hidden_state.shape
|
|
||||||
hidden_state = hidden_state[:, : shape[0] * shape[1]]
|
|
||||||
hidden_state = hidden_state.view(shape[0], shape[1], C)
|
|
||||||
hidden_state = hidden_state.to(torch_dtype)
|
|
||||||
return hidden_state
|
|
||||||
@@ -1,154 +0,0 @@
|
|||||||
from .z_image_dit import ZImageTransformerBlock
|
|
||||||
from ..core.gradient import gradient_checkpoint_forward
|
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
|
|
||||||
class ZImageControlTransformerBlock(ZImageTransformerBlock):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_id: int = 1000,
|
|
||||||
dim: int = 3840,
|
|
||||||
n_heads: int = 30,
|
|
||||||
n_kv_heads: int = 30,
|
|
||||||
norm_eps: float = 1e-5,
|
|
||||||
qk_norm: bool = True,
|
|
||||||
modulation = True,
|
|
||||||
block_id = 0
|
|
||||||
):
|
|
||||||
super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation)
|
|
||||||
self.block_id = block_id
|
|
||||||
if block_id == 0:
|
|
||||||
self.before_proj = nn.Linear(self.dim, self.dim)
|
|
||||||
self.after_proj = nn.Linear(self.dim, self.dim)
|
|
||||||
|
|
||||||
def forward(self, c, x, **kwargs):
|
|
||||||
if self.block_id == 0:
|
|
||||||
c = self.before_proj(c) + x
|
|
||||||
all_c = []
|
|
||||||
else:
|
|
||||||
all_c = list(torch.unbind(c))
|
|
||||||
c = all_c.pop(-1)
|
|
||||||
|
|
||||||
c = super().forward(c, **kwargs)
|
|
||||||
c_skip = self.after_proj(c)
|
|
||||||
all_c += [c_skip, c]
|
|
||||||
c = torch.stack(all_c)
|
|
||||||
return c
|
|
||||||
|
|
||||||
|
|
||||||
class ZImageControlNet(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
control_layers_places=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
|
|
||||||
control_in_dim=33,
|
|
||||||
dim=3840,
|
|
||||||
n_refiner_layers=2,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.control_layers = nn.ModuleList([ZImageControlTransformerBlock(layer_id=i, block_id=i) for i in control_layers_places])
|
|
||||||
self.control_all_x_embedder = nn.ModuleDict({"2-1": nn.Linear(1 * 2 * 2 * control_in_dim, dim, bias=True)})
|
|
||||||
self.control_noise_refiner = nn.ModuleList([ZImageControlTransformerBlock(block_id=layer_id) for layer_id in range(n_refiner_layers)])
|
|
||||||
self.control_layers_mapping = {0: 0, 2: 1, 4: 2, 6: 3, 8: 4, 10: 5, 12: 6, 14: 7, 16: 8, 18: 9, 20: 10, 22: 11, 24: 12, 26: 13, 28: 14}
|
|
||||||
|
|
||||||
def forward_layers(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
cap_feats,
|
|
||||||
control_context,
|
|
||||||
control_context_item_seqlens,
|
|
||||||
kwargs,
|
|
||||||
use_gradient_checkpointing=False,
|
|
||||||
use_gradient_checkpointing_offload=False,
|
|
||||||
):
|
|
||||||
bsz = len(control_context)
|
|
||||||
# unified
|
|
||||||
cap_item_seqlens = [len(_) for _ in cap_feats]
|
|
||||||
control_context_unified = []
|
|
||||||
for i in range(bsz):
|
|
||||||
control_context_len = control_context_item_seqlens[i]
|
|
||||||
cap_len = cap_item_seqlens[i]
|
|
||||||
control_context_unified.append(torch.cat([control_context[i][:control_context_len], cap_feats[i][:cap_len]]))
|
|
||||||
c = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0)
|
|
||||||
|
|
||||||
# arguments
|
|
||||||
new_kwargs = dict(x=x)
|
|
||||||
new_kwargs.update(kwargs)
|
|
||||||
|
|
||||||
for layer in self.control_layers:
|
|
||||||
c = gradient_checkpoint_forward(
|
|
||||||
layer,
|
|
||||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
|
||||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
|
||||||
c=c, **new_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
hints = torch.unbind(c)[:-1]
|
|
||||||
return hints
|
|
||||||
|
|
||||||
def forward_refiner(
|
|
||||||
self,
|
|
||||||
dit,
|
|
||||||
x,
|
|
||||||
cap_feats,
|
|
||||||
control_context,
|
|
||||||
kwargs,
|
|
||||||
t=None,
|
|
||||||
patch_size=2,
|
|
||||||
f_patch_size=1,
|
|
||||||
use_gradient_checkpointing=False,
|
|
||||||
use_gradient_checkpointing_offload=False,
|
|
||||||
):
|
|
||||||
# embeddings
|
|
||||||
bsz = len(control_context)
|
|
||||||
device = control_context[0].device
|
|
||||||
(
|
|
||||||
control_context,
|
|
||||||
control_context_size,
|
|
||||||
control_context_pos_ids,
|
|
||||||
control_context_inner_pad_mask,
|
|
||||||
) = dit.patchify_controlnet(control_context, patch_size, f_patch_size, cap_feats[0].size(0))
|
|
||||||
|
|
||||||
# control_context embed & refine
|
|
||||||
control_context_item_seqlens = [len(_) for _ in control_context]
|
|
||||||
assert all(_ % 2 == 0 for _ in control_context_item_seqlens)
|
|
||||||
control_context_max_item_seqlen = max(control_context_item_seqlens)
|
|
||||||
|
|
||||||
control_context = torch.cat(control_context, dim=0)
|
|
||||||
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context)
|
|
||||||
|
|
||||||
# Match t_embedder output dtype to control_context for layerwise casting compatibility
|
|
||||||
adaln_input = t.type_as(control_context)
|
|
||||||
control_context[torch.cat(control_context_inner_pad_mask)] = dit.x_pad_token.to(dtype=control_context.dtype, device=control_context.device)
|
|
||||||
control_context = list(control_context.split(control_context_item_seqlens, dim=0))
|
|
||||||
control_context_freqs_cis = list(dit.rope_embedder(torch.cat(control_context_pos_ids, dim=0)).split(control_context_item_seqlens, dim=0))
|
|
||||||
|
|
||||||
control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0)
|
|
||||||
control_context_freqs_cis = pad_sequence(control_context_freqs_cis, batch_first=True, padding_value=0.0)
|
|
||||||
control_context_attn_mask = torch.zeros((bsz, control_context_max_item_seqlen), dtype=torch.bool, device=device)
|
|
||||||
for i, seq_len in enumerate(control_context_item_seqlens):
|
|
||||||
control_context_attn_mask[i, :seq_len] = 1
|
|
||||||
c = control_context
|
|
||||||
|
|
||||||
# arguments
|
|
||||||
new_kwargs = dict(
|
|
||||||
x=x,
|
|
||||||
attn_mask=control_context_attn_mask,
|
|
||||||
freqs_cis=control_context_freqs_cis,
|
|
||||||
adaln_input=adaln_input,
|
|
||||||
)
|
|
||||||
new_kwargs.update(kwargs)
|
|
||||||
|
|
||||||
for layer in self.control_noise_refiner:
|
|
||||||
c = gradient_checkpoint_forward(
|
|
||||||
layer,
|
|
||||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
|
||||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
|
||||||
c=c, **new_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
hints = torch.unbind(c)[:-1]
|
|
||||||
control_context = torch.unbind(c)[-1]
|
|
||||||
|
|
||||||
return hints, control_context, control_context_item_seqlens
|
|
||||||
@@ -13,7 +13,6 @@ from ..core.gradient import gradient_checkpoint_forward
|
|||||||
|
|
||||||
ADALN_EMBED_DIM = 256
|
ADALN_EMBED_DIM = 256
|
||||||
SEQ_MULTI_OF = 32
|
SEQ_MULTI_OF = 32
|
||||||
X_PAD_DIM = 64
|
|
||||||
|
|
||||||
|
|
||||||
class TimestepEmbedder(nn.Module):
|
class TimestepEmbedder(nn.Module):
|
||||||
@@ -87,7 +86,7 @@ class Attention(torch.nn.Module):
|
|||||||
self.norm_q = RMSNorm(head_dim, eps=1e-5)
|
self.norm_q = RMSNorm(head_dim, eps=1e-5)
|
||||||
self.norm_k = RMSNorm(head_dim, eps=1e-5)
|
self.norm_k = RMSNorm(head_dim, eps=1e-5)
|
||||||
|
|
||||||
def forward(self, hidden_states, freqs_cis, attention_mask):
|
def forward(self, hidden_states, freqs_cis):
|
||||||
query = self.to_q(hidden_states)
|
query = self.to_q(hidden_states)
|
||||||
key = self.to_k(hidden_states)
|
key = self.to_k(hidden_states)
|
||||||
value = self.to_v(hidden_states)
|
value = self.to_v(hidden_states)
|
||||||
@@ -124,7 +123,6 @@ class Attention(torch.nn.Module):
|
|||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
|
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
|
||||||
attn_mask=attention_mask,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reshape back
|
# Reshape back
|
||||||
@@ -138,20 +136,6 @@ class Attention(torch.nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def select_per_token(
|
|
||||||
value_noisy: torch.Tensor,
|
|
||||||
value_clean: torch.Tensor,
|
|
||||||
noise_mask: torch.Tensor,
|
|
||||||
seq_len: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1)
|
|
||||||
return torch.where(
|
|
||||||
noise_mask_expanded == 1,
|
|
||||||
value_noisy.unsqueeze(1).expand(-1, seq_len, -1),
|
|
||||||
value_clean.unsqueeze(1).expand(-1, seq_len, -1),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ZImageTransformerBlock(nn.Module):
|
class ZImageTransformerBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -196,53 +180,40 @@ class ZImageTransformerBlock(nn.Module):
|
|||||||
attn_mask: torch.Tensor,
|
attn_mask: torch.Tensor,
|
||||||
freqs_cis: torch.Tensor,
|
freqs_cis: torch.Tensor,
|
||||||
adaln_input: Optional[torch.Tensor] = None,
|
adaln_input: Optional[torch.Tensor] = None,
|
||||||
noise_mask: Optional[torch.Tensor] = None,
|
|
||||||
adaln_noisy: Optional[torch.Tensor] = None,
|
|
||||||
adaln_clean: Optional[torch.Tensor] = None,
|
|
||||||
):
|
):
|
||||||
if self.modulation:
|
if self.modulation:
|
||||||
seq_len = x.shape[1]
|
assert adaln_input is not None
|
||||||
|
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
|
||||||
if noise_mask is not None:
|
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||||
# Per-token modulation: different modulation for noisy/clean tokens
|
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||||
mod_noisy = self.adaLN_modulation(adaln_noisy)
|
|
||||||
mod_clean = self.adaLN_modulation(adaln_clean)
|
|
||||||
|
|
||||||
scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1)
|
|
||||||
scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1)
|
|
||||||
|
|
||||||
gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh()
|
|
||||||
gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh()
|
|
||||||
|
|
||||||
scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy
|
|
||||||
scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean
|
|
||||||
|
|
||||||
scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len)
|
|
||||||
scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len)
|
|
||||||
gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len)
|
|
||||||
gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len)
|
|
||||||
else:
|
|
||||||
# Global modulation: same modulation for all tokens (avoid double select)
|
|
||||||
mod = self.adaLN_modulation(adaln_input)
|
|
||||||
scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2)
|
|
||||||
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
|
||||||
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
|
||||||
|
|
||||||
# Attention block
|
# Attention block
|
||||||
attn_out = self.attention(
|
attn_out = self.attention(
|
||||||
self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis
|
self.attention_norm1(x) * scale_msa,
|
||||||
|
freqs_cis=freqs_cis,
|
||||||
)
|
)
|
||||||
x = x + gate_msa * self.attention_norm2(attn_out)
|
x = x + gate_msa * self.attention_norm2(attn_out)
|
||||||
|
|
||||||
# FFN block
|
# FFN block
|
||||||
x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
|
x = x + gate_mlp * self.ffn_norm2(
|
||||||
|
self.feed_forward(
|
||||||
|
self.ffn_norm1(x) * scale_mlp,
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Attention block
|
# Attention block
|
||||||
attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis)
|
attn_out = self.attention(
|
||||||
|
self.attention_norm1(x),
|
||||||
|
freqs_cis=freqs_cis,
|
||||||
|
)
|
||||||
x = x + self.attention_norm2(attn_out)
|
x = x + self.attention_norm2(attn_out)
|
||||||
|
|
||||||
# FFN block
|
# FFN block
|
||||||
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
|
x = x + self.ffn_norm2(
|
||||||
|
self.feed_forward(
|
||||||
|
self.ffn_norm1(x),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@@ -258,21 +229,9 @@ class FinalLayer(nn.Module):
|
|||||||
nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
|
nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None):
|
def forward(self, x, c):
|
||||||
seq_len = x.shape[1]
|
scale = 1.0 + self.adaLN_modulation(c)
|
||||||
|
x = self.norm_final(x) * scale.unsqueeze(1)
|
||||||
if noise_mask is not None:
|
|
||||||
# Per-token modulation
|
|
||||||
scale_noisy = 1.0 + self.adaLN_modulation(c_noisy)
|
|
||||||
scale_clean = 1.0 + self.adaLN_modulation(c_clean)
|
|
||||||
scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len)
|
|
||||||
else:
|
|
||||||
# Original global modulation
|
|
||||||
assert c is not None, "Either c or (c_noisy, c_clean) must be provided"
|
|
||||||
scale = 1.0 + self.adaLN_modulation(c)
|
|
||||||
scale = scale.unsqueeze(1)
|
|
||||||
|
|
||||||
x = self.norm_final(x) * scale
|
|
||||||
x = self.linear(x)
|
x = self.linear(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@@ -340,7 +299,6 @@ class ZImageDiT(nn.Module):
|
|||||||
t_scale=1000.0,
|
t_scale=1000.0,
|
||||||
axes_dims=[32, 48, 48],
|
axes_dims=[32, 48, 48],
|
||||||
axes_lens=[1024, 512, 512],
|
axes_lens=[1024, 512, 512],
|
||||||
siglip_feat_dim=None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
@@ -401,32 +359,6 @@ class ZImageDiT(nn.Module):
|
|||||||
nn.Linear(cap_feat_dim, dim, bias=True),
|
nn.Linear(cap_feat_dim, dim, bias=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optional SigLIP components (for Omni variant)
|
|
||||||
self.siglip_feat_dim = siglip_feat_dim
|
|
||||||
if siglip_feat_dim is not None:
|
|
||||||
self.siglip_embedder = nn.Sequential(
|
|
||||||
RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True)
|
|
||||||
)
|
|
||||||
self.siglip_refiner = nn.ModuleList(
|
|
||||||
[
|
|
||||||
ZImageTransformerBlock(
|
|
||||||
2000 + layer_id,
|
|
||||||
dim,
|
|
||||||
n_heads,
|
|
||||||
n_kv_heads,
|
|
||||||
norm_eps,
|
|
||||||
qk_norm,
|
|
||||||
modulation=False,
|
|
||||||
)
|
|
||||||
for layer_id in range(n_refiner_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.siglip_pad_token = nn.Parameter(torch.empty((1, dim)))
|
|
||||||
else:
|
|
||||||
self.siglip_embedder = None
|
|
||||||
self.siglip_refiner = None
|
|
||||||
self.siglip_pad_token = None
|
|
||||||
|
|
||||||
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
|
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||||
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
|
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||||
|
|
||||||
@@ -443,57 +375,22 @@ class ZImageDiT(nn.Module):
|
|||||||
|
|
||||||
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
|
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
|
||||||
|
|
||||||
def unpatchify(
|
def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]:
|
||||||
self,
|
|
||||||
x: List[torch.Tensor],
|
|
||||||
size: List[Tuple],
|
|
||||||
patch_size = 2,
|
|
||||||
f_patch_size = 1,
|
|
||||||
x_pos_offsets: Optional[List[Tuple[int, int]]] = None,
|
|
||||||
) -> List[torch.Tensor]:
|
|
||||||
pH = pW = patch_size
|
pH = pW = patch_size
|
||||||
pF = f_patch_size
|
pF = f_patch_size
|
||||||
bsz = len(x)
|
bsz = len(x)
|
||||||
assert len(size) == bsz
|
assert len(size) == bsz
|
||||||
|
for i in range(bsz):
|
||||||
if x_pos_offsets is not None:
|
F, H, W = size[i]
|
||||||
# Omni: extract target image from unified sequence (cond_images + target)
|
ori_len = (F // pF) * (H // pH) * (W // pW)
|
||||||
result = []
|
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
|
||||||
for i in range(bsz):
|
x[i] = (
|
||||||
unified_x = x[i][x_pos_offsets[i][0] : x_pos_offsets[i][1]]
|
x[i][:ori_len]
|
||||||
cu_len = 0
|
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
|
||||||
x_item = None
|
.permute(6, 0, 3, 1, 4, 2, 5)
|
||||||
for j in range(len(size[i])):
|
.reshape(self.out_channels, F, H, W)
|
||||||
if size[i][j] is None:
|
)
|
||||||
ori_len = 0
|
return x
|
||||||
pad_len = SEQ_MULTI_OF
|
|
||||||
cu_len += pad_len + ori_len
|
|
||||||
else:
|
|
||||||
F, H, W = size[i][j]
|
|
||||||
ori_len = (F // pF) * (H // pH) * (W // pW)
|
|
||||||
pad_len = (-ori_len) % SEQ_MULTI_OF
|
|
||||||
x_item = (
|
|
||||||
unified_x[cu_len : cu_len + ori_len]
|
|
||||||
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
|
|
||||||
.permute(6, 0, 3, 1, 4, 2, 5)
|
|
||||||
.reshape(self.out_channels, F, H, W)
|
|
||||||
)
|
|
||||||
cu_len += ori_len + pad_len
|
|
||||||
result.append(x_item) # Return only the last (target) image
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
# Original mode: simple unpatchify
|
|
||||||
for i in range(bsz):
|
|
||||||
F, H, W = size[i]
|
|
||||||
ori_len = (F // pF) * (H // pH) * (W // pW)
|
|
||||||
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
|
|
||||||
x[i] = (
|
|
||||||
x[i][:ori_len]
|
|
||||||
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
|
|
||||||
.permute(6, 0, 3, 1, 4, 2, 5)
|
|
||||||
.reshape(self.out_channels, F, H, W)
|
|
||||||
)
|
|
||||||
return x
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_coordinate_grid(size, start=None, device=None):
|
def create_coordinate_grid(size, start=None, device=None):
|
||||||
@@ -508,8 +405,8 @@ class ZImageDiT(nn.Module):
|
|||||||
self,
|
self,
|
||||||
all_image: List[torch.Tensor],
|
all_image: List[torch.Tensor],
|
||||||
all_cap_feats: List[torch.Tensor],
|
all_cap_feats: List[torch.Tensor],
|
||||||
patch_size: int = 2,
|
patch_size: int,
|
||||||
f_patch_size: int = 1,
|
f_patch_size: int,
|
||||||
):
|
):
|
||||||
pH = pW = patch_size
|
pH = pW = patch_size
|
||||||
pF = f_patch_size
|
pF = f_patch_size
|
||||||
@@ -593,487 +490,90 @@ class ZImageDiT(nn.Module):
|
|||||||
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
|
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
|
||||||
all_image_out.append(image_padded_feat)
|
all_image_out.append(image_padded_feat)
|
||||||
|
|
||||||
return all_image_out, all_cap_feats_out, {
|
|
||||||
"x_size": all_image_size,
|
|
||||||
"x_pos_ids": all_image_pos_ids,
|
|
||||||
"cap_pos_ids": all_cap_pos_ids,
|
|
||||||
"x_pad_mask": all_image_pad_mask,
|
|
||||||
"cap_pad_mask": all_cap_pad_mask
|
|
||||||
}
|
|
||||||
# (
|
|
||||||
# all_img_out,
|
|
||||||
# all_cap_out,
|
|
||||||
# all_img_size,
|
|
||||||
# all_img_pos_ids,
|
|
||||||
# all_cap_pos_ids,
|
|
||||||
# all_img_pad_mask,
|
|
||||||
# all_cap_pad_mask,
|
|
||||||
# )
|
|
||||||
|
|
||||||
def patchify_controlnet(
|
|
||||||
self,
|
|
||||||
all_image: List[torch.Tensor],
|
|
||||||
patch_size: int = 2,
|
|
||||||
f_patch_size: int = 1,
|
|
||||||
cap_padding_len: int = None,
|
|
||||||
):
|
|
||||||
pH = pW = patch_size
|
|
||||||
pF = f_patch_size
|
|
||||||
device = all_image[0].device
|
|
||||||
|
|
||||||
all_image_out = []
|
|
||||||
all_image_size = []
|
|
||||||
all_image_pos_ids = []
|
|
||||||
all_image_pad_mask = []
|
|
||||||
|
|
||||||
for i, image in enumerate(all_image):
|
|
||||||
### Process Image
|
|
||||||
C, F, H, W = image.size()
|
|
||||||
all_image_size.append((F, H, W))
|
|
||||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
|
||||||
|
|
||||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
|
||||||
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
|
|
||||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
|
||||||
|
|
||||||
image_ori_len = len(image)
|
|
||||||
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
|
|
||||||
|
|
||||||
image_ori_pos_ids = self.create_coordinate_grid(
|
|
||||||
size=(F_tokens, H_tokens, W_tokens),
|
|
||||||
start=(cap_padding_len + 1, 0, 0),
|
|
||||||
device=device,
|
|
||||||
).flatten(0, 2)
|
|
||||||
image_padding_pos_ids = (
|
|
||||||
self.create_coordinate_grid(
|
|
||||||
size=(1, 1, 1),
|
|
||||||
start=(0, 0, 0),
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
.flatten(0, 2)
|
|
||||||
.repeat(image_padding_len, 1)
|
|
||||||
)
|
|
||||||
image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
|
|
||||||
all_image_pos_ids.append(image_padded_pos_ids)
|
|
||||||
# pad mask
|
|
||||||
all_image_pad_mask.append(
|
|
||||||
torch.cat(
|
|
||||||
[
|
|
||||||
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
|
|
||||||
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# padded feature
|
|
||||||
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
|
|
||||||
all_image_out.append(image_padded_feat)
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
all_image_out,
|
all_image_out,
|
||||||
|
all_cap_feats_out,
|
||||||
all_image_size,
|
all_image_size,
|
||||||
all_image_pos_ids,
|
all_image_pos_ids,
|
||||||
all_image_pad_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _prepare_sequence(
|
|
||||||
self,
|
|
||||||
feats: List[torch.Tensor],
|
|
||||||
pos_ids: List[torch.Tensor],
|
|
||||||
inner_pad_mask: List[torch.Tensor],
|
|
||||||
pad_token: torch.nn.Parameter,
|
|
||||||
noise_mask: Optional[List[List[int]]] = None,
|
|
||||||
device: torch.device = None,
|
|
||||||
):
|
|
||||||
"""Prepare sequence: apply pad token, RoPE embed, pad to batch, create attention mask."""
|
|
||||||
item_seqlens = [len(f) for f in feats]
|
|
||||||
max_seqlen = max(item_seqlens)
|
|
||||||
bsz = len(feats)
|
|
||||||
|
|
||||||
# Pad token
|
|
||||||
feats_cat = torch.cat(feats, dim=0)
|
|
||||||
feats_cat[torch.cat(inner_pad_mask)] = pad_token.to(dtype=feats_cat.dtype, device=feats_cat.device)
|
|
||||||
feats = list(feats_cat.split(item_seqlens, dim=0))
|
|
||||||
|
|
||||||
# RoPE
|
|
||||||
freqs_cis = list(self.rope_embedder(torch.cat(pos_ids, dim=0)).split([len(p) for p in pos_ids], dim=0))
|
|
||||||
|
|
||||||
# Pad to batch
|
|
||||||
feats = pad_sequence(feats, batch_first=True, padding_value=0.0)
|
|
||||||
freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]]
|
|
||||||
|
|
||||||
# Attention mask
|
|
||||||
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
|
|
||||||
for i, seq_len in enumerate(item_seqlens):
|
|
||||||
attn_mask[i, :seq_len] = 1
|
|
||||||
|
|
||||||
# Noise mask
|
|
||||||
noise_mask_tensor = None
|
|
||||||
if noise_mask is not None:
|
|
||||||
noise_mask_tensor = pad_sequence(
|
|
||||||
[torch.tensor(m, dtype=torch.long, device=device) for m in noise_mask],
|
|
||||||
batch_first=True,
|
|
||||||
padding_value=0,
|
|
||||||
)[:, : feats.shape[1]]
|
|
||||||
|
|
||||||
return feats, freqs_cis, attn_mask, item_seqlens, noise_mask_tensor
|
|
||||||
|
|
||||||
def _build_unified_sequence(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
x_freqs: torch.Tensor,
|
|
||||||
x_seqlens: List[int],
|
|
||||||
x_noise_mask: Optional[List[List[int]]],
|
|
||||||
cap: torch.Tensor,
|
|
||||||
cap_freqs: torch.Tensor,
|
|
||||||
cap_seqlens: List[int],
|
|
||||||
cap_noise_mask: Optional[List[List[int]]],
|
|
||||||
siglip: Optional[torch.Tensor],
|
|
||||||
siglip_freqs: Optional[torch.Tensor],
|
|
||||||
siglip_seqlens: Optional[List[int]],
|
|
||||||
siglip_noise_mask: Optional[List[List[int]]],
|
|
||||||
omni_mode: bool,
|
|
||||||
device: torch.device,
|
|
||||||
):
|
|
||||||
"""Build unified sequence: x, cap, and optionally siglip.
|
|
||||||
Basic mode order: [x, cap]; Omni mode order: [cap, x, siglip]
|
|
||||||
"""
|
|
||||||
bsz = len(x_seqlens)
|
|
||||||
unified = []
|
|
||||||
unified_freqs = []
|
|
||||||
unified_noise_mask = []
|
|
||||||
|
|
||||||
for i in range(bsz):
|
|
||||||
x_len, cap_len = x_seqlens[i], cap_seqlens[i]
|
|
||||||
|
|
||||||
if omni_mode:
|
|
||||||
# Omni: [cap, x, siglip]
|
|
||||||
if siglip is not None and siglip_seqlens is not None:
|
|
||||||
sig_len = siglip_seqlens[i]
|
|
||||||
unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len], siglip[i][:sig_len]]))
|
|
||||||
unified_freqs.append(
|
|
||||||
torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len], siglip_freqs[i][:sig_len]])
|
|
||||||
)
|
|
||||||
unified_noise_mask.append(
|
|
||||||
torch.tensor(
|
|
||||||
cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], dtype=torch.long, device=device
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len]]))
|
|
||||||
unified_freqs.append(torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len]]))
|
|
||||||
unified_noise_mask.append(
|
|
||||||
torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Basic: [x, cap]
|
|
||||||
unified.append(torch.cat([x[i][:x_len], cap[i][:cap_len]]))
|
|
||||||
unified_freqs.append(torch.cat([x_freqs[i][:x_len], cap_freqs[i][:cap_len]]))
|
|
||||||
|
|
||||||
# Compute unified seqlens
|
|
||||||
if omni_mode:
|
|
||||||
if siglip is not None and siglip_seqlens is not None:
|
|
||||||
unified_seqlens = [a + b + c for a, b, c in zip(cap_seqlens, x_seqlens, siglip_seqlens)]
|
|
||||||
else:
|
|
||||||
unified_seqlens = [a + b for a, b in zip(cap_seqlens, x_seqlens)]
|
|
||||||
else:
|
|
||||||
unified_seqlens = [a + b for a, b in zip(x_seqlens, cap_seqlens)]
|
|
||||||
|
|
||||||
max_seqlen = max(unified_seqlens)
|
|
||||||
|
|
||||||
# Pad to batch
|
|
||||||
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
|
||||||
unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0)
|
|
||||||
|
|
||||||
# Attention mask
|
|
||||||
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
|
|
||||||
for i, seq_len in enumerate(unified_seqlens):
|
|
||||||
attn_mask[i, :seq_len] = 1
|
|
||||||
|
|
||||||
# Noise mask
|
|
||||||
noise_mask_tensor = None
|
|
||||||
if omni_mode:
|
|
||||||
noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0)[
|
|
||||||
:, : unified.shape[1]
|
|
||||||
]
|
|
||||||
|
|
||||||
return unified, unified_freqs, attn_mask, noise_mask_tensor
|
|
||||||
|
|
||||||
def _pad_with_ids(
|
|
||||||
self,
|
|
||||||
feat: torch.Tensor,
|
|
||||||
pos_grid_size: Tuple,
|
|
||||||
pos_start: Tuple,
|
|
||||||
device: torch.device,
|
|
||||||
noise_mask_val: Optional[int] = None,
|
|
||||||
):
|
|
||||||
"""Pad feature to SEQ_MULTI_OF, create position IDs and pad mask."""
|
|
||||||
ori_len = len(feat)
|
|
||||||
pad_len = (-ori_len) % SEQ_MULTI_OF
|
|
||||||
total_len = ori_len + pad_len
|
|
||||||
|
|
||||||
# Pos IDs
|
|
||||||
ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2)
|
|
||||||
if pad_len > 0:
|
|
||||||
pad_pos_ids = (
|
|
||||||
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
|
|
||||||
.flatten(0, 2)
|
|
||||||
.repeat(pad_len, 1)
|
|
||||||
)
|
|
||||||
pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0)
|
|
||||||
padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0)
|
|
||||||
pad_mask = torch.cat(
|
|
||||||
[
|
|
||||||
torch.zeros(ori_len, dtype=torch.bool, device=device),
|
|
||||||
torch.ones(pad_len, dtype=torch.bool, device=device),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
pos_ids = ori_pos_ids
|
|
||||||
padded_feat = feat
|
|
||||||
pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device)
|
|
||||||
|
|
||||||
noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level
|
|
||||||
return padded_feat, pos_ids, pad_mask, total_len, noise_mask
|
|
||||||
|
|
||||||
def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int):
|
|
||||||
"""Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim)."""
|
|
||||||
pH, pW, pF = patch_size, patch_size, f_patch_size
|
|
||||||
C, F, H, W = image.size()
|
|
||||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
|
||||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
|
||||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
|
||||||
return image, (F, H, W), (F_tokens, H_tokens, W_tokens)
|
|
||||||
|
|
||||||
def patchify_and_embed_omni(
|
|
||||||
self,
|
|
||||||
all_x: List[List[torch.Tensor]],
|
|
||||||
all_cap_feats: List[List[torch.Tensor]],
|
|
||||||
all_siglip_feats: List[List[torch.Tensor]],
|
|
||||||
patch_size: int = 2,
|
|
||||||
f_patch_size: int = 1,
|
|
||||||
images_noise_mask: List[List[int]] = None,
|
|
||||||
):
|
|
||||||
"""Patchify for omni mode: multiple images per batch item with noise masks."""
|
|
||||||
bsz = len(all_x)
|
|
||||||
device = all_x[0][-1].device
|
|
||||||
dtype = all_x[0][-1].dtype
|
|
||||||
|
|
||||||
all_x_out, all_x_size, all_x_pos_ids, all_x_pad_mask, all_x_len, all_x_noise_mask = [], [], [], [], [], []
|
|
||||||
all_cap_out, all_cap_pos_ids, all_cap_pad_mask, all_cap_len, all_cap_noise_mask = [], [], [], [], []
|
|
||||||
all_sig_out, all_sig_pos_ids, all_sig_pad_mask, all_sig_len, all_sig_noise_mask = [], [], [], [], []
|
|
||||||
|
|
||||||
for i in range(bsz):
|
|
||||||
num_images = len(all_x[i])
|
|
||||||
cap_feats_list, cap_pos_list, cap_mask_list, cap_lens, cap_noise = [], [], [], [], []
|
|
||||||
cap_end_pos = []
|
|
||||||
cap_cu_len = 1
|
|
||||||
|
|
||||||
# Process captions
|
|
||||||
for j, cap_item in enumerate(all_cap_feats[i]):
|
|
||||||
noise_val = images_noise_mask[i][j] if j < len(images_noise_mask[i]) else 1
|
|
||||||
cap_out, cap_pos, cap_mask, cap_len, cap_nm = self._pad_with_ids(
|
|
||||||
cap_item,
|
|
||||||
(len(cap_item) + (-len(cap_item)) % SEQ_MULTI_OF, 1, 1),
|
|
||||||
(cap_cu_len, 0, 0),
|
|
||||||
device,
|
|
||||||
noise_val,
|
|
||||||
)
|
|
||||||
cap_feats_list.append(cap_out)
|
|
||||||
cap_pos_list.append(cap_pos)
|
|
||||||
cap_mask_list.append(cap_mask)
|
|
||||||
cap_lens.append(cap_len)
|
|
||||||
cap_noise.extend(cap_nm)
|
|
||||||
cap_cu_len += len(cap_item)
|
|
||||||
cap_end_pos.append(cap_cu_len)
|
|
||||||
cap_cu_len += 2 # for image vae and siglip tokens
|
|
||||||
|
|
||||||
all_cap_out.append(torch.cat(cap_feats_list, dim=0))
|
|
||||||
all_cap_pos_ids.append(torch.cat(cap_pos_list, dim=0))
|
|
||||||
all_cap_pad_mask.append(torch.cat(cap_mask_list, dim=0))
|
|
||||||
all_cap_len.append(cap_lens)
|
|
||||||
all_cap_noise_mask.append(cap_noise)
|
|
||||||
|
|
||||||
# Process images
|
|
||||||
x_feats_list, x_pos_list, x_mask_list, x_lens, x_size, x_noise = [], [], [], [], [], []
|
|
||||||
for j, x_item in enumerate(all_x[i]):
|
|
||||||
noise_val = images_noise_mask[i][j]
|
|
||||||
if x_item is not None:
|
|
||||||
x_patches, size, (F_t, H_t, W_t) = self._patchify_image(x_item, patch_size, f_patch_size)
|
|
||||||
x_out, x_pos, x_mask, x_len, x_nm = self._pad_with_ids(
|
|
||||||
x_patches, (F_t, H_t, W_t), (cap_end_pos[j], 0, 0), device, noise_val
|
|
||||||
)
|
|
||||||
x_size.append(size)
|
|
||||||
else:
|
|
||||||
x_len = SEQ_MULTI_OF
|
|
||||||
x_out = torch.zeros((x_len, X_PAD_DIM), dtype=dtype, device=device)
|
|
||||||
x_pos = self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(x_len, 1)
|
|
||||||
x_mask = torch.ones(x_len, dtype=torch.bool, device=device)
|
|
||||||
x_nm = [noise_val] * x_len
|
|
||||||
x_size.append(None)
|
|
||||||
x_feats_list.append(x_out)
|
|
||||||
x_pos_list.append(x_pos)
|
|
||||||
x_mask_list.append(x_mask)
|
|
||||||
x_lens.append(x_len)
|
|
||||||
x_noise.extend(x_nm)
|
|
||||||
|
|
||||||
all_x_out.append(torch.cat(x_feats_list, dim=0))
|
|
||||||
all_x_pos_ids.append(torch.cat(x_pos_list, dim=0))
|
|
||||||
all_x_pad_mask.append(torch.cat(x_mask_list, dim=0))
|
|
||||||
all_x_size.append(x_size)
|
|
||||||
all_x_len.append(x_lens)
|
|
||||||
all_x_noise_mask.append(x_noise)
|
|
||||||
|
|
||||||
# Process siglip
|
|
||||||
if all_siglip_feats[i] is None:
|
|
||||||
all_sig_len.append([0] * num_images)
|
|
||||||
all_sig_out.append(None)
|
|
||||||
else:
|
|
||||||
sig_feats_list, sig_pos_list, sig_mask_list, sig_lens, sig_noise = [], [], [], [], []
|
|
||||||
for j, sig_item in enumerate(all_siglip_feats[i]):
|
|
||||||
noise_val = images_noise_mask[i][j]
|
|
||||||
if sig_item is not None:
|
|
||||||
sig_H, sig_W, sig_C = sig_item.size()
|
|
||||||
sig_flat = sig_item.permute(2, 0, 1).reshape(sig_H * sig_W, sig_C)
|
|
||||||
sig_out, sig_pos, sig_mask, sig_len, sig_nm = self._pad_with_ids(
|
|
||||||
sig_flat, (1, sig_H, sig_W), (cap_end_pos[j] + 1, 0, 0), device, noise_val
|
|
||||||
)
|
|
||||||
# Scale position IDs to match x resolution
|
|
||||||
if x_size[j] is not None:
|
|
||||||
sig_pos = sig_pos.float()
|
|
||||||
sig_pos[..., 1] = sig_pos[..., 1] / max(sig_H - 1, 1) * (x_size[j][1] - 1)
|
|
||||||
sig_pos[..., 2] = sig_pos[..., 2] / max(sig_W - 1, 1) * (x_size[j][2] - 1)
|
|
||||||
sig_pos = sig_pos.to(torch.int32)
|
|
||||||
else:
|
|
||||||
sig_len = SEQ_MULTI_OF
|
|
||||||
sig_out = torch.zeros((sig_len, self.siglip_feat_dim), dtype=dtype, device=device)
|
|
||||||
sig_pos = (
|
|
||||||
self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(sig_len, 1)
|
|
||||||
)
|
|
||||||
sig_mask = torch.ones(sig_len, dtype=torch.bool, device=device)
|
|
||||||
sig_nm = [noise_val] * sig_len
|
|
||||||
sig_feats_list.append(sig_out)
|
|
||||||
sig_pos_list.append(sig_pos)
|
|
||||||
sig_mask_list.append(sig_mask)
|
|
||||||
sig_lens.append(sig_len)
|
|
||||||
sig_noise.extend(sig_nm)
|
|
||||||
|
|
||||||
all_sig_out.append(torch.cat(sig_feats_list, dim=0))
|
|
||||||
all_sig_pos_ids.append(torch.cat(sig_pos_list, dim=0))
|
|
||||||
all_sig_pad_mask.append(torch.cat(sig_mask_list, dim=0))
|
|
||||||
all_sig_len.append(sig_lens)
|
|
||||||
all_sig_noise_mask.append(sig_noise)
|
|
||||||
|
|
||||||
# Compute x position offsets
|
|
||||||
all_x_pos_offsets = [(sum(all_cap_len[i]), sum(all_cap_len[i]) + sum(all_x_len[i])) for i in range(bsz)]
|
|
||||||
|
|
||||||
return (
|
|
||||||
all_x_out,
|
|
||||||
all_cap_out,
|
|
||||||
all_sig_out,
|
|
||||||
all_x_size,
|
|
||||||
all_x_pos_ids,
|
|
||||||
all_cap_pos_ids,
|
all_cap_pos_ids,
|
||||||
all_sig_pos_ids,
|
all_image_pad_mask,
|
||||||
all_x_pad_mask,
|
|
||||||
all_cap_pad_mask,
|
all_cap_pad_mask,
|
||||||
all_sig_pad_mask,
|
|
||||||
all_x_pos_offsets,
|
|
||||||
all_x_noise_mask,
|
|
||||||
all_cap_noise_mask,
|
|
||||||
all_sig_noise_mask,
|
|
||||||
)
|
)
|
||||||
return all_x_out, all_cap_out, all_sig_out, {
|
|
||||||
"x_size": x_size,
|
|
||||||
"x_pos_ids": all_x_pos_ids,
|
|
||||||
"cap_pos_ids": all_cap_pos_ids,
|
|
||||||
"sig_pos_ids": all_sig_pos_ids,
|
|
||||||
"x_pad_mask": all_x_pad_mask,
|
|
||||||
"cap_pad_mask": all_cap_pad_mask,
|
|
||||||
"sig_pad_mask": all_sig_pad_mask,
|
|
||||||
"x_pos_offsets": all_x_pos_offsets,
|
|
||||||
"x_noise_mask": all_x_noise_mask,
|
|
||||||
"cap_noise_mask": all_cap_noise_mask,
|
|
||||||
"sig_noise_mask": all_sig_noise_mask,
|
|
||||||
}
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: List[torch.Tensor],
|
x: List[torch.Tensor],
|
||||||
t,
|
t,
|
||||||
cap_feats: List[torch.Tensor],
|
cap_feats: List[torch.Tensor],
|
||||||
siglip_feats = None,
|
|
||||||
image_noise_mask = None,
|
|
||||||
patch_size=2,
|
patch_size=2,
|
||||||
f_patch_size=1,
|
f_patch_size=1,
|
||||||
use_gradient_checkpointing=False,
|
use_gradient_checkpointing=False,
|
||||||
use_gradient_checkpointing_offload=False,
|
use_gradient_checkpointing_offload=False,
|
||||||
):
|
):
|
||||||
assert patch_size in self.all_patch_size and f_patch_size in self.all_f_patch_size
|
assert patch_size in self.all_patch_size
|
||||||
omni_mode = isinstance(x[0], list)
|
assert f_patch_size in self.all_f_patch_size
|
||||||
device = x[0][-1].device if omni_mode else x[0].device
|
|
||||||
|
|
||||||
if omni_mode:
|
bsz = len(x)
|
||||||
# Dual embeddings: noisy (t) and clean (t=1)
|
device = x[0].device
|
||||||
t_noisy = self.t_embedder(t * self.t_scale).type_as(x[0][-1])
|
t = t * self.t_scale
|
||||||
t_clean = self.t_embedder(torch.ones_like(t) * self.t_scale).type_as(x[0][-1])
|
t = self.t_embedder(t)
|
||||||
adaln_input = None
|
|
||||||
else:
|
|
||||||
# Single embedding for all tokens
|
|
||||||
adaln_input = self.t_embedder(t * self.t_scale).type_as(x[0])
|
|
||||||
t_noisy = t_clean = None
|
|
||||||
|
|
||||||
# Patchify
|
adaln_input = t
|
||||||
if omni_mode:
|
|
||||||
(
|
(
|
||||||
x,
|
x,
|
||||||
cap_feats,
|
cap_feats,
|
||||||
siglip_feats,
|
x_size,
|
||||||
x_size,
|
x_pos_ids,
|
||||||
x_pos_ids,
|
cap_pos_ids,
|
||||||
cap_pos_ids,
|
x_inner_pad_mask,
|
||||||
siglip_pos_ids,
|
cap_inner_pad_mask,
|
||||||
x_pad_mask,
|
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
||||||
cap_pad_mask,
|
|
||||||
siglip_pad_mask,
|
|
||||||
x_pos_offsets,
|
|
||||||
x_noise_mask,
|
|
||||||
cap_noise_mask,
|
|
||||||
siglip_noise_mask,
|
|
||||||
) = self.patchify_and_embed_omni(x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask)
|
|
||||||
else:
|
|
||||||
(
|
|
||||||
x,
|
|
||||||
cap_feats,
|
|
||||||
x_size,
|
|
||||||
x_pos_ids,
|
|
||||||
cap_pos_ids,
|
|
||||||
x_pad_mask,
|
|
||||||
cap_pad_mask,
|
|
||||||
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
|
||||||
x_pos_offsets = x_noise_mask = cap_noise_mask = siglip_noise_mask = None
|
|
||||||
|
|
||||||
# x embed & refine
|
# x embed & refine
|
||||||
x_seqlens = [len(xi) for xi in x]
|
x_item_seqlens = [len(_) for _ in x]
|
||||||
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](torch.cat(x, dim=0)) # embed
|
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
|
||||||
x, x_freqs, x_mask, _, x_noise_tensor = self._prepare_sequence(
|
x_max_item_seqlen = max(x_item_seqlens)
|
||||||
list(x.split(x_seqlens, dim=0)), x_pos_ids, x_pad_mask, self.x_pad_token, x_noise_mask, device
|
|
||||||
)
|
x = torch.cat(x, dim=0)
|
||||||
|
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
|
||||||
|
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token.to(dtype=x.dtype, device=x.device)
|
||||||
|
x = list(x.split(x_item_seqlens, dim=0))
|
||||||
|
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
|
||||||
|
|
||||||
|
x = pad_sequence(x, batch_first=True, padding_value=0.0)
|
||||||
|
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
|
||||||
|
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
||||||
|
for i, seq_len in enumerate(x_item_seqlens):
|
||||||
|
x_attn_mask[i, :seq_len] = 1
|
||||||
|
|
||||||
for layer in self.noise_refiner:
|
for layer in self.noise_refiner:
|
||||||
x = gradient_checkpoint_forward(
|
x = gradient_checkpoint_forward(
|
||||||
layer,
|
layer,
|
||||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
x=x, attn_mask=x_mask, freqs_cis=x_freqs, adaln_input=adaln_input, noise_mask=x_noise_tensor, adaln_noisy=t_noisy, adaln_clean=t_clean,
|
x=x,
|
||||||
|
attn_mask=x_attn_mask,
|
||||||
|
freqs_cis=x_freqs_cis,
|
||||||
|
adaln_input=adaln_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Cap embed & refine
|
# cap embed & refine
|
||||||
cap_seqlens = [len(ci) for ci in cap_feats]
|
cap_item_seqlens = [len(_) for _ in cap_feats]
|
||||||
cap_feats = self.cap_embedder(torch.cat(cap_feats, dim=0)) # embed
|
assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
|
||||||
cap_feats, cap_freqs, cap_mask, _, _ = self._prepare_sequence(
|
cap_max_item_seqlen = max(cap_item_seqlens)
|
||||||
list(cap_feats.split(cap_seqlens, dim=0)), cap_pos_ids, cap_pad_mask, self.cap_pad_token, None, device
|
|
||||||
)
|
cap_feats = torch.cat(cap_feats, dim=0)
|
||||||
|
cap_feats = self.cap_embedder(cap_feats)
|
||||||
|
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token.to(dtype=x.dtype, device=x.device)
|
||||||
|
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
|
||||||
|
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
|
||||||
|
|
||||||
|
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
|
||||||
|
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
|
||||||
|
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
|
||||||
|
for i, seq_len in enumerate(cap_item_seqlens):
|
||||||
|
cap_attn_mask[i, :seq_len] = 1
|
||||||
|
|
||||||
for layer in self.context_refiner:
|
for layer in self.context_refiner:
|
||||||
cap_feats = gradient_checkpoint_forward(
|
cap_feats = gradient_checkpoint_forward(
|
||||||
@@ -1081,68 +581,41 @@ class ZImageDiT(nn.Module):
|
|||||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
x=cap_feats,
|
x=cap_feats,
|
||||||
attn_mask=cap_mask,
|
attn_mask=cap_attn_mask,
|
||||||
freqs_cis=cap_freqs,
|
freqs_cis=cap_freqs_cis,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Siglip embed & refine
|
# unified
|
||||||
siglip_seqlens = siglip_freqs = None
|
unified = []
|
||||||
if omni_mode and siglip_feats[0] is not None and self.siglip_embedder is not None:
|
unified_freqs_cis = []
|
||||||
siglip_seqlens = [len(si) for si in siglip_feats]
|
for i in range(bsz):
|
||||||
siglip_feats = self.siglip_embedder(torch.cat(siglip_feats, dim=0)) # embed
|
x_len = x_item_seqlens[i]
|
||||||
siglip_feats, siglip_freqs, siglip_mask, _, _ = self._prepare_sequence(
|
cap_len = cap_item_seqlens[i]
|
||||||
list(siglip_feats.split(siglip_seqlens, dim=0)),
|
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
|
||||||
siglip_pos_ids,
|
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
|
||||||
siglip_pad_mask,
|
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
|
||||||
self.siglip_pad_token,
|
assert unified_item_seqlens == [len(_) for _ in unified]
|
||||||
None,
|
unified_max_item_seqlen = max(unified_item_seqlens)
|
||||||
device,
|
|
||||||
)
|
|
||||||
|
|
||||||
for layer in self.siglip_refiner:
|
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
||||||
siglip_feats = gradient_checkpoint_forward(
|
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
|
||||||
layer,
|
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
||||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
for i, seq_len in enumerate(unified_item_seqlens):
|
||||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
unified_attn_mask[i, :seq_len] = 1
|
||||||
x=siglip_feats, attn_mask=siglip_mask, freqs_cis=siglip_freqs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Unified sequence
|
for layer in self.layers:
|
||||||
unified, unified_freqs, unified_mask, unified_noise_tensor = self._build_unified_sequence(
|
|
||||||
x,
|
|
||||||
x_freqs,
|
|
||||||
x_seqlens,
|
|
||||||
x_noise_mask,
|
|
||||||
cap_feats,
|
|
||||||
cap_freqs,
|
|
||||||
cap_seqlens,
|
|
||||||
cap_noise_mask,
|
|
||||||
siglip_feats,
|
|
||||||
siglip_freqs,
|
|
||||||
siglip_seqlens,
|
|
||||||
siglip_noise_mask,
|
|
||||||
omni_mode,
|
|
||||||
device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Main transformer layers
|
|
||||||
for layer_idx, layer in enumerate(self.layers):
|
|
||||||
unified = gradient_checkpoint_forward(
|
unified = gradient_checkpoint_forward(
|
||||||
layer,
|
layer,
|
||||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
x=unified, attn_mask=unified_mask, freqs_cis=unified_freqs, adaln_input=adaln_input, noise_mask=unified_noise_tensor, adaln_noisy=t_noisy, adaln_clean=t_clean
|
x=unified,
|
||||||
|
attn_mask=unified_attn_mask,
|
||||||
|
freqs_cis=unified_freqs_cis,
|
||||||
|
adaln_input=adaln_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
unified = (
|
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
|
||||||
self.all_final_layer[f"{patch_size}-{f_patch_size}"](
|
unified = list(unified.unbind(dim=0))
|
||||||
unified, noise_mask=unified_noise_tensor, c_noisy=t_noisy, c_clean=t_clean
|
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
|
||||||
)
|
|
||||||
if omni_mode
|
|
||||||
else self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, c=adaln_input)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Unpatchify
|
return x, {}
|
||||||
x = self.unpatchify(list(unified.unbind(dim=0)), x_size, patch_size, f_patch_size, x_pos_offsets)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|||||||
@@ -1,189 +0,0 @@
|
|||||||
import torch
|
|
||||||
from .qwen_image_image2lora import ImageEmbeddingToLoraMatrix, SequencialMLP
|
|
||||||
|
|
||||||
|
|
||||||
class LoRATrainerBlock(torch.nn.Module):
|
|
||||||
def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024, prefix="transformer_blocks"):
|
|
||||||
super().__init__()
|
|
||||||
self.prefix = prefix
|
|
||||||
self.lora_patterns = lora_patterns
|
|
||||||
self.block_id = block_id
|
|
||||||
self.layers = []
|
|
||||||
for name, lora_a_dim, lora_b_dim in self.lora_patterns:
|
|
||||||
self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank))
|
|
||||||
self.layers = torch.nn.ModuleList(self.layers)
|
|
||||||
if use_residual:
|
|
||||||
self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim)
|
|
||||||
else:
|
|
||||||
self.proj_residual = None
|
|
||||||
|
|
||||||
def forward(self, x, residual=None):
|
|
||||||
lora = {}
|
|
||||||
if self.proj_residual is not None: residual = self.proj_residual(residual)
|
|
||||||
for lora_pattern, layer in zip(self.lora_patterns, self.layers):
|
|
||||||
name = lora_pattern[0]
|
|
||||||
lora_a, lora_b = layer(x, residual=residual)
|
|
||||||
lora[f"{self.prefix}.{self.block_id}.{name}.lora_A.default.weight"] = lora_a
|
|
||||||
lora[f"{self.prefix}.{self.block_id}.{name}.lora_B.default.weight"] = lora_b
|
|
||||||
return lora
|
|
||||||
|
|
||||||
|
|
||||||
class ZImageImage2LoRAComponent(torch.nn.Module):
|
|
||||||
def __init__(self, lora_patterns, prefix, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024):
|
|
||||||
super().__init__()
|
|
||||||
self.lora_patterns = lora_patterns
|
|
||||||
self.num_blocks = num_blocks
|
|
||||||
self.blocks = []
|
|
||||||
for lora_patterns in self.lora_patterns:
|
|
||||||
for block_id in range(self.num_blocks):
|
|
||||||
self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim, prefix=prefix))
|
|
||||||
self.blocks = torch.nn.ModuleList(self.blocks)
|
|
||||||
self.residual_scale = 0.05
|
|
||||||
self.use_residual = use_residual
|
|
||||||
|
|
||||||
def forward(self, x, residual=None):
|
|
||||||
if residual is not None:
|
|
||||||
if self.use_residual:
|
|
||||||
residual = residual * self.residual_scale
|
|
||||||
else:
|
|
||||||
residual = None
|
|
||||||
lora = {}
|
|
||||||
for block in self.blocks:
|
|
||||||
lora.update(block(x, residual))
|
|
||||||
return lora
|
|
||||||
|
|
||||||
|
|
||||||
class ZImageImage2LoRAModel(torch.nn.Module):
|
|
||||||
def __init__(self, use_residual=False, compress_dim=64, rank=4, residual_length=64+7, residual_mid_dim=1024):
|
|
||||||
super().__init__()
|
|
||||||
lora_patterns = [
|
|
||||||
[
|
|
||||||
("attention.to_q", 3840, 3840),
|
|
||||||
("attention.to_k", 3840, 3840),
|
|
||||||
("attention.to_v", 3840, 3840),
|
|
||||||
("attention.to_out.0", 3840, 3840),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
("feed_forward.w1", 3840, 10240),
|
|
||||||
("feed_forward.w2", 10240, 3840),
|
|
||||||
("feed_forward.w3", 3840, 10240),
|
|
||||||
],
|
|
||||||
]
|
|
||||||
config = {
|
|
||||||
"lora_patterns": lora_patterns,
|
|
||||||
"use_residual": use_residual,
|
|
||||||
"compress_dim": compress_dim,
|
|
||||||
"rank": rank,
|
|
||||||
"residual_length": residual_length,
|
|
||||||
"residual_mid_dim": residual_mid_dim,
|
|
||||||
}
|
|
||||||
self.layers_lora = ZImageImage2LoRAComponent(
|
|
||||||
prefix="layers",
|
|
||||||
num_blocks=30,
|
|
||||||
**config,
|
|
||||||
)
|
|
||||||
self.context_refiner_lora = ZImageImage2LoRAComponent(
|
|
||||||
prefix="context_refiner",
|
|
||||||
num_blocks=2,
|
|
||||||
**config,
|
|
||||||
)
|
|
||||||
self.noise_refiner_lora = ZImageImage2LoRAComponent(
|
|
||||||
prefix="noise_refiner",
|
|
||||||
num_blocks=2,
|
|
||||||
**config,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, residual=None):
|
|
||||||
lora = {}
|
|
||||||
lora.update(self.layers_lora(x, residual=residual))
|
|
||||||
lora.update(self.context_refiner_lora(x, residual=residual))
|
|
||||||
lora.update(self.noise_refiner_lora(x, residual=residual))
|
|
||||||
return lora
|
|
||||||
|
|
||||||
def initialize_weights(self):
|
|
||||||
state_dict = self.state_dict()
|
|
||||||
for name in state_dict:
|
|
||||||
if ".proj_a." in name:
|
|
||||||
state_dict[name] = state_dict[name] * 0.3
|
|
||||||
elif ".proj_b.proj_out." in name:
|
|
||||||
state_dict[name] = state_dict[name] * 0
|
|
||||||
elif ".proj_residual.proj_out." in name:
|
|
||||||
state_dict[name] = state_dict[name] * 0.3
|
|
||||||
self.load_state_dict(state_dict)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageEmb2LoRAWeightCompressed(torch.nn.Module):
|
|
||||||
def __init__(self, in_dim, out_dim, emb_dim, rank):
|
|
||||||
super().__init__()
|
|
||||||
self.lora_a = torch.nn.Parameter(torch.randn((rank, in_dim)))
|
|
||||||
self.lora_b = torch.nn.Parameter(torch.randn((out_dim, rank)))
|
|
||||||
self.proj = torch.nn.Linear(emb_dim, rank * rank, bias=True)
|
|
||||||
self.rank = rank
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.proj(x).view(self.rank, self.rank)
|
|
||||||
lora_a = x @ self.lora_a
|
|
||||||
lora_b = self.lora_b
|
|
||||||
return lora_a, lora_b
|
|
||||||
|
|
||||||
|
|
||||||
class ZImageImage2LoRAModelCompressed(torch.nn.Module):
|
|
||||||
def __init__(self, emb_dim=1536+4096, rank=32):
|
|
||||||
super().__init__()
|
|
||||||
target_layers = [
|
|
||||||
("attention.to_q", 3840, 3840),
|
|
||||||
("attention.to_k", 3840, 3840),
|
|
||||||
("attention.to_v", 3840, 3840),
|
|
||||||
("attention.to_out.0", 3840, 3840),
|
|
||||||
("feed_forward.w1", 3840, 10240),
|
|
||||||
("feed_forward.w2", 10240, 3840),
|
|
||||||
("feed_forward.w3", 3840, 10240),
|
|
||||||
]
|
|
||||||
self.lora_patterns = [
|
|
||||||
{
|
|
||||||
"prefix": "layers",
|
|
||||||
"num_layers": 30,
|
|
||||||
"target_layers": target_layers,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"prefix": "context_refiner",
|
|
||||||
"num_layers": 2,
|
|
||||||
"target_layers": target_layers,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"prefix": "noise_refiner",
|
|
||||||
"num_layers": 2,
|
|
||||||
"target_layers": target_layers,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
module_dict = {}
|
|
||||||
for lora_pattern in self.lora_patterns:
|
|
||||||
prefix, num_layers, target_layers = lora_pattern["prefix"], lora_pattern["num_layers"], lora_pattern["target_layers"]
|
|
||||||
for layer_id in range(num_layers):
|
|
||||||
for layer_name, in_dim, out_dim in target_layers:
|
|
||||||
name = f"{prefix}.{layer_id}.{layer_name}".replace(".", "___")
|
|
||||||
model = ImageEmb2LoRAWeightCompressed(in_dim, out_dim, emb_dim, rank)
|
|
||||||
module_dict[name] = model
|
|
||||||
self.module_dict = torch.nn.ModuleDict(module_dict)
|
|
||||||
|
|
||||||
def forward(self, x, residual=None):
|
|
||||||
lora = {}
|
|
||||||
for name, module in self.module_dict.items():
|
|
||||||
name = name.replace("___", ".")
|
|
||||||
name_a, name_b = f"{name}.lora_A.default.weight", f"{name}.lora_B.default.weight"
|
|
||||||
lora_a, lora_b = module(x)
|
|
||||||
lora[name_a] = lora_a
|
|
||||||
lora[name_b] = lora_b
|
|
||||||
return lora
|
|
||||||
|
|
||||||
def initialize_weights(self):
|
|
||||||
state_dict = self.state_dict()
|
|
||||||
for name in state_dict:
|
|
||||||
if "lora_b" in name:
|
|
||||||
state_dict[name] = state_dict[name] * 0
|
|
||||||
elif "lora_a" in name:
|
|
||||||
state_dict[name] = state_dict[name] * 0.2
|
|
||||||
elif "proj.weight" in name:
|
|
||||||
print(name)
|
|
||||||
state_dict[name] = state_dict[name] * 0.2
|
|
||||||
self.load_state_dict(state_dict)
|
|
||||||
@@ -4,20 +4,15 @@ from typing import Union
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from math import prod
|
|
||||||
|
|
||||||
from ..diffusion import FlowMatchScheduler
|
from ..diffusion import FlowMatchScheduler
|
||||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
||||||
from ..utils.lora.merge import merge_lora
|
|
||||||
|
|
||||||
from ..models.qwen_image_dit import QwenImageDiT
|
from ..models.qwen_image_dit import QwenImageDiT
|
||||||
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
|
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
|
||||||
from ..models.qwen_image_vae import QwenImageVAE
|
from ..models.qwen_image_vae import QwenImageVAE
|
||||||
from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet
|
from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet
|
||||||
from ..models.siglip2_image_encoder import Siglip2ImageEncoder
|
|
||||||
from ..models.dinov3_image_encoder import DINOv3ImageEncoder
|
|
||||||
from ..models.qwen_image_image2lora import QwenImageImage2LoRAModel
|
|
||||||
|
|
||||||
|
|
||||||
class QwenImagePipeline(BasePipeline):
|
class QwenImagePipeline(BasePipeline):
|
||||||
@@ -35,11 +30,6 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
self.vae: QwenImageVAE = None
|
self.vae: QwenImageVAE = None
|
||||||
self.blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None
|
self.blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None
|
||||||
self.tokenizer: Qwen2Tokenizer = None
|
self.tokenizer: Qwen2Tokenizer = None
|
||||||
self.siglip2_image_encoder: Siglip2ImageEncoder = None
|
|
||||||
self.dinov3_image_encoder: DINOv3ImageEncoder = None
|
|
||||||
self.image2lora_style: QwenImageImage2LoRAModel = None
|
|
||||||
self.image2lora_coarse: QwenImageImage2LoRAModel = None
|
|
||||||
self.image2lora_fine: QwenImageImage2LoRAModel = None
|
|
||||||
self.processor: Qwen2VLProcessor = None
|
self.processor: Qwen2VLProcessor = None
|
||||||
self.in_iteration_models = ("dit", "blockwise_controlnet")
|
self.in_iteration_models = ("dit", "blockwise_controlnet")
|
||||||
self.units = [
|
self.units = [
|
||||||
@@ -48,7 +38,6 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
QwenImageUnit_InputImageEmbedder(),
|
QwenImageUnit_InputImageEmbedder(),
|
||||||
QwenImageUnit_Inpaint(),
|
QwenImageUnit_Inpaint(),
|
||||||
QwenImageUnit_EditImageEmbedder(),
|
QwenImageUnit_EditImageEmbedder(),
|
||||||
QwenImageUnit_LayerInputImageEmbedder(),
|
|
||||||
QwenImageUnit_ContextImageEmbedder(),
|
QwenImageUnit_ContextImageEmbedder(),
|
||||||
QwenImageUnit_PromptEmbedder(),
|
QwenImageUnit_PromptEmbedder(),
|
||||||
QwenImageUnit_EntityControl(),
|
QwenImageUnit_EntityControl(),
|
||||||
@@ -83,11 +72,6 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
processor_config.download_if_necessary()
|
processor_config.download_if_necessary()
|
||||||
from transformers import Qwen2VLProcessor
|
from transformers import Qwen2VLProcessor
|
||||||
pipe.processor = Qwen2VLProcessor.from_pretrained(processor_config.path)
|
pipe.processor = Qwen2VLProcessor.from_pretrained(processor_config.path)
|
||||||
pipe.siglip2_image_encoder = model_pool.fetch_model("siglip2_image_encoder")
|
|
||||||
pipe.dinov3_image_encoder = model_pool.fetch_model("dinov3_image_encoder")
|
|
||||||
pipe.image2lora_style = model_pool.fetch_model("qwen_image_image2lora_style")
|
|
||||||
pipe.image2lora_coarse = model_pool.fetch_model("qwen_image_image2lora_coarse")
|
|
||||||
pipe.image2lora_fine = model_pool.fetch_model("qwen_image_image2lora_fine")
|
|
||||||
|
|
||||||
# VRAM Management
|
# VRAM Management
|
||||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||||
@@ -127,11 +111,6 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
edit_image: Image.Image = None,
|
edit_image: Image.Image = None,
|
||||||
edit_image_auto_resize: bool = True,
|
edit_image_auto_resize: bool = True,
|
||||||
edit_rope_interpolation: bool = False,
|
edit_rope_interpolation: bool = False,
|
||||||
# Qwen-Image-Edit-2511
|
|
||||||
zero_cond_t: bool = False,
|
|
||||||
# Qwen-Image-Layered
|
|
||||||
layer_input_image: Image.Image = None,
|
|
||||||
layer_num: int = None,
|
|
||||||
# In-context control
|
# In-context control
|
||||||
context_image: Image.Image = None,
|
context_image: Image.Image = None,
|
||||||
# Tile
|
# Tile
|
||||||
@@ -163,9 +142,6 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
|
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
|
||||||
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "edit_rope_interpolation": edit_rope_interpolation,
|
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "edit_rope_interpolation": edit_rope_interpolation,
|
||||||
"context_image": context_image,
|
"context_image": context_image,
|
||||||
"zero_cond_t": zero_cond_t,
|
|
||||||
"layer_input_image": layer_input_image,
|
|
||||||
"layer_num": layer_num,
|
|
||||||
}
|
}
|
||||||
for unit in self.units:
|
for unit in self.units:
|
||||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||||
@@ -185,10 +161,7 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
# Decode
|
# Decode
|
||||||
self.load_models_to_device(['vae'])
|
self.load_models_to_device(['vae'])
|
||||||
image = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
image = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
if layer_num is None:
|
image = self.vae_output_to_image(image)
|
||||||
image = self.vae_output_to_image(image)
|
|
||||||
else:
|
|
||||||
image = [self.vae_output_to_image(i, pattern="C H W") for i in image]
|
|
||||||
self.load_models_to_device([])
|
self.load_models_to_device([])
|
||||||
|
|
||||||
return image
|
return image
|
||||||
@@ -239,15 +212,12 @@ class QwenImageUnit_ShapeChecker(PipelineUnit):
|
|||||||
class QwenImageUnit_NoiseInitializer(PipelineUnit):
|
class QwenImageUnit_NoiseInitializer(PipelineUnit):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_params=("height", "width", "seed", "rand_device", "layer_num"),
|
input_params=("height", "width", "seed", "rand_device"),
|
||||||
output_params=("noise",),
|
output_params=("noise",),
|
||||||
)
|
)
|
||||||
|
|
||||||
def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device, layer_num):
|
def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device):
|
||||||
if layer_num is None:
|
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||||
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
|
||||||
else:
|
|
||||||
noise = pipe.generate_noise((layer_num + 1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
|
||||||
return {"noise": noise}
|
return {"noise": noise}
|
||||||
|
|
||||||
|
|
||||||
@@ -264,15 +234,8 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit):
|
|||||||
if input_image is None:
|
if input_image is None:
|
||||||
return {"latents": noise, "input_latents": None}
|
return {"latents": noise, "input_latents": None}
|
||||||
pipe.load_models_to_device(['vae'])
|
pipe.load_models_to_device(['vae'])
|
||||||
if isinstance(input_image, list):
|
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||||
input_latents = []
|
input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
for image in input_image:
|
|
||||||
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
|
||||||
input_latents.append(pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride))
|
|
||||||
input_latents = torch.concat(input_latents, dim=0)
|
|
||||||
else:
|
|
||||||
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
|
||||||
input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
if pipe.scheduler.training:
|
if pipe.scheduler.training:
|
||||||
return {"latents": noise, "input_latents": input_latents}
|
return {"latents": noise, "input_latents": input_latents}
|
||||||
else:
|
else:
|
||||||
@@ -280,22 +243,6 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit):
|
|||||||
return {"latents": latents, "input_latents": input_latents}
|
return {"latents": latents, "input_latents": input_latents}
|
||||||
|
|
||||||
|
|
||||||
class QwenImageUnit_LayerInputImageEmbedder(PipelineUnit):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
input_params=("layer_input_image", "tiled", "tile_size", "tile_stride"),
|
|
||||||
output_params=("layer_input_latents",),
|
|
||||||
onload_model_names=("vae",)
|
|
||||||
)
|
|
||||||
|
|
||||||
def process(self, pipe: QwenImagePipeline, layer_input_image, tiled, tile_size, tile_stride):
|
|
||||||
if layer_input_image is None:
|
|
||||||
return {}
|
|
||||||
pipe.load_models_to_device(['vae'])
|
|
||||||
image = pipe.preprocess_image(layer_input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
|
||||||
latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
return {"layer_input_latents": latents}
|
|
||||||
|
|
||||||
|
|
||||||
class QwenImageUnit_Inpaint(PipelineUnit):
|
class QwenImageUnit_Inpaint(PipelineUnit):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -568,116 +515,6 @@ class QwenImageUnit_EditImageEmbedder(PipelineUnit):
|
|||||||
return {"edit_latents": edit_latents, "edit_image": resized_edit_image}
|
return {"edit_latents": edit_latents, "edit_image": resized_edit_image}
|
||||||
|
|
||||||
|
|
||||||
class QwenImageUnit_Image2LoRAEncode(PipelineUnit):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
input_params=("image2lora_images",),
|
|
||||||
output_params=("image2lora_x", "image2lora_residual", "image2lora_residual_highres"),
|
|
||||||
onload_model_names=("siglip2_image_encoder", "dinov3_image_encoder", "text_encoder"),
|
|
||||||
)
|
|
||||||
from ..core.data.operators import ImageCropAndResize
|
|
||||||
self.processor_lowres = ImageCropAndResize(height=28*8, width=28*8)
|
|
||||||
self.processor_highres = ImageCropAndResize(height=1024, width=1024)
|
|
||||||
|
|
||||||
def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
|
|
||||||
bool_mask = mask.bool()
|
|
||||||
valid_lengths = bool_mask.sum(dim=1)
|
|
||||||
selected = hidden_states[bool_mask]
|
|
||||||
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
|
||||||
return split_result
|
|
||||||
|
|
||||||
def encode_prompt_edit(self, pipe: QwenImagePipeline, prompt, edit_image):
|
|
||||||
prompt = [prompt]
|
|
||||||
template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
|
||||||
drop_idx = 64
|
|
||||||
txt = [template.format(e) for e in prompt]
|
|
||||||
model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device)
|
|
||||||
hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1]
|
|
||||||
split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)
|
|
||||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
|
||||||
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
|
||||||
prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
|
|
||||||
prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
||||||
return prompt_embeds.view(1, -1)
|
|
||||||
|
|
||||||
def encode_images_using_siglip2(self, pipe: QwenImagePipeline, images: list[Image.Image]):
|
|
||||||
pipe.load_models_to_device(["siglip2_image_encoder"])
|
|
||||||
embs = []
|
|
||||||
for image in images:
|
|
||||||
image = self.processor_highres(image)
|
|
||||||
embs.append(pipe.siglip2_image_encoder(image).to(pipe.torch_dtype))
|
|
||||||
embs = torch.stack(embs)
|
|
||||||
return embs
|
|
||||||
|
|
||||||
def encode_images_using_dinov3(self, pipe: QwenImagePipeline, images: list[Image.Image]):
|
|
||||||
pipe.load_models_to_device(["dinov3_image_encoder"])
|
|
||||||
embs = []
|
|
||||||
for image in images:
|
|
||||||
image = self.processor_highres(image)
|
|
||||||
embs.append(pipe.dinov3_image_encoder(image).to(pipe.torch_dtype))
|
|
||||||
embs = torch.stack(embs)
|
|
||||||
return embs
|
|
||||||
|
|
||||||
def encode_images_using_qwenvl(self, pipe: QwenImagePipeline, images: list[Image.Image], highres=False):
|
|
||||||
pipe.load_models_to_device(["text_encoder"])
|
|
||||||
embs = []
|
|
||||||
for image in images:
|
|
||||||
image = self.processor_highres(image) if highres else self.processor_lowres(image)
|
|
||||||
embs.append(self.encode_prompt_edit(pipe, prompt="", edit_image=image))
|
|
||||||
embs = torch.stack(embs)
|
|
||||||
return embs
|
|
||||||
|
|
||||||
def encode_images(self, pipe: QwenImagePipeline, images: list[Image.Image]):
|
|
||||||
if images is None:
|
|
||||||
return {}
|
|
||||||
if not isinstance(images, list):
|
|
||||||
images = [images]
|
|
||||||
embs_siglip2 = self.encode_images_using_siglip2(pipe, images)
|
|
||||||
embs_dinov3 = self.encode_images_using_dinov3(pipe, images)
|
|
||||||
x = torch.concat([embs_siglip2, embs_dinov3], dim=-1)
|
|
||||||
residual = None
|
|
||||||
residual_highres = None
|
|
||||||
if pipe.image2lora_coarse is not None:
|
|
||||||
residual = self.encode_images_using_qwenvl(pipe, images, highres=False)
|
|
||||||
if pipe.image2lora_fine is not None:
|
|
||||||
residual_highres = self.encode_images_using_qwenvl(pipe, images, highres=True)
|
|
||||||
return x, residual, residual_highres
|
|
||||||
|
|
||||||
def process(self, pipe: QwenImagePipeline, image2lora_images):
|
|
||||||
if image2lora_images is None:
|
|
||||||
return {}
|
|
||||||
x, residual, residual_highres = self.encode_images(pipe, image2lora_images)
|
|
||||||
return {"image2lora_x": x, "image2lora_residual": residual, "image2lora_residual_highres": residual_highres}
|
|
||||||
|
|
||||||
|
|
||||||
class QwenImageUnit_Image2LoRADecode(PipelineUnit):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
input_params=("image2lora_x", "image2lora_residual", "image2lora_residual_highres"),
|
|
||||||
output_params=("lora",),
|
|
||||||
onload_model_names=("image2lora_coarse", "image2lora_fine", "image2lora_style"),
|
|
||||||
)
|
|
||||||
|
|
||||||
def process(self, pipe: QwenImagePipeline, image2lora_x, image2lora_residual, image2lora_residual_highres):
|
|
||||||
if image2lora_x is None:
|
|
||||||
return {}
|
|
||||||
loras = []
|
|
||||||
if pipe.image2lora_style is not None:
|
|
||||||
pipe.load_models_to_device(["image2lora_style"])
|
|
||||||
for x in image2lora_x:
|
|
||||||
loras.append(pipe.image2lora_style(x=x, residual=None))
|
|
||||||
if pipe.image2lora_coarse is not None:
|
|
||||||
pipe.load_models_to_device(["image2lora_coarse"])
|
|
||||||
for x, residual in zip(image2lora_x, image2lora_residual):
|
|
||||||
loras.append(pipe.image2lora_coarse(x=x, residual=residual))
|
|
||||||
if pipe.image2lora_fine is not None:
|
|
||||||
pipe.load_models_to_device(["image2lora_fine"])
|
|
||||||
for x, residual in zip(image2lora_x, image2lora_residual_highres):
|
|
||||||
loras.append(pipe.image2lora_fine(x=x, residual=residual))
|
|
||||||
lora = merge_lora(loras, alpha=1 / len(image2lora_x))
|
|
||||||
return {"lora": lora}
|
|
||||||
|
|
||||||
|
|
||||||
class QwenImageUnit_ContextImageEmbedder(PipelineUnit):
|
class QwenImageUnit_ContextImageEmbedder(PipelineUnit):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -712,26 +549,18 @@ def model_fn_qwen_image(
|
|||||||
entity_prompt_emb_mask=None,
|
entity_prompt_emb_mask=None,
|
||||||
entity_masks=None,
|
entity_masks=None,
|
||||||
edit_latents=None,
|
edit_latents=None,
|
||||||
layer_input_latents=None,
|
|
||||||
layer_num=None,
|
|
||||||
context_latents=None,
|
context_latents=None,
|
||||||
enable_fp8_attention=False,
|
enable_fp8_attention=False,
|
||||||
use_gradient_checkpointing=False,
|
use_gradient_checkpointing=False,
|
||||||
use_gradient_checkpointing_offload=False,
|
use_gradient_checkpointing_offload=False,
|
||||||
edit_rope_interpolation=False,
|
edit_rope_interpolation=False,
|
||||||
zero_cond_t=False,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if layer_num is None:
|
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
|
||||||
layer_num = 1
|
|
||||||
img_shapes = [(1, latents.shape[2]//2, latents.shape[3]//2)]
|
|
||||||
else:
|
|
||||||
layer_num = layer_num + 1
|
|
||||||
img_shapes = [(1, latents.shape[2]//2, latents.shape[3]//2)] * layer_num
|
|
||||||
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
||||||
timestep = timestep / 1000
|
timestep = timestep / 1000
|
||||||
|
|
||||||
image = rearrange(latents, "(B N) C (H P) (W Q) -> B (N H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2, N=layer_num)
|
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||||
image_seq_len = image.shape[1]
|
image_seq_len = image.shape[1]
|
||||||
|
|
||||||
if context_latents is not None:
|
if context_latents is not None:
|
||||||
@@ -743,27 +572,9 @@ def model_fn_qwen_image(
|
|||||||
img_shapes += [(e.shape[0], e.shape[2]//2, e.shape[3]//2) for e in edit_latents_list]
|
img_shapes += [(e.shape[0], e.shape[2]//2, e.shape[3]//2) for e in edit_latents_list]
|
||||||
edit_image = [rearrange(e, "B C (H P) (W Q) -> B (H W) (C P Q)", H=e.shape[2]//2, W=e.shape[3]//2, P=2, Q=2) for e in edit_latents_list]
|
edit_image = [rearrange(e, "B C (H P) (W Q) -> B (H W) (C P Q)", H=e.shape[2]//2, W=e.shape[3]//2, P=2, Q=2) for e in edit_latents_list]
|
||||||
image = torch.cat([image] + edit_image, dim=1)
|
image = torch.cat([image] + edit_image, dim=1)
|
||||||
if layer_input_latents is not None:
|
|
||||||
layer_num = layer_num + 1
|
|
||||||
img_shapes += [(layer_input_latents.shape[0], layer_input_latents.shape[2]//2, layer_input_latents.shape[3]//2)]
|
|
||||||
layer_input_latents = rearrange(layer_input_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
|
||||||
image = torch.cat([image, layer_input_latents], dim=1)
|
|
||||||
|
|
||||||
image = dit.img_in(image)
|
image = dit.img_in(image)
|
||||||
if zero_cond_t:
|
conditioning = dit.time_text_embed(timestep, image.dtype)
|
||||||
timestep = torch.cat([timestep, timestep * 0], dim=0)
|
|
||||||
modulate_index = torch.tensor(
|
|
||||||
[[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in [img_shapes]],
|
|
||||||
device=timestep.device,
|
|
||||||
dtype=torch.int,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
modulate_index = None
|
|
||||||
conditioning = dit.time_text_embed(
|
|
||||||
timestep,
|
|
||||||
image.dtype,
|
|
||||||
addition_t_cond=None if not dit.time_text_embed.use_additional_t_cond else torch.tensor([0]).to(device=image.device, dtype=torch.long)
|
|
||||||
)
|
|
||||||
|
|
||||||
if entity_prompt_emb is not None:
|
if entity_prompt_emb is not None:
|
||||||
text, image_rotary_emb, attention_mask = dit.process_entity_masks(
|
text, image_rotary_emb, attention_mask = dit.process_entity_masks(
|
||||||
@@ -793,7 +604,6 @@ def model_fn_qwen_image(
|
|||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
enable_fp8_attention=enable_fp8_attention,
|
enable_fp8_attention=enable_fp8_attention,
|
||||||
modulate_index=modulate_index,
|
|
||||||
)
|
)
|
||||||
if blockwise_controlnet_conditioning is not None:
|
if blockwise_controlnet_conditioning is not None:
|
||||||
image_slice = image[:, :image_seq_len].clone()
|
image_slice = image[:, :image_seq_len].clone()
|
||||||
@@ -804,11 +614,9 @@ def model_fn_qwen_image(
|
|||||||
)
|
)
|
||||||
image[:, :image_seq_len] = image_slice + controlnet_output
|
image[:, :image_seq_len] = image_slice + controlnet_output
|
||||||
|
|
||||||
if zero_cond_t:
|
|
||||||
conditioning = conditioning.chunk(2, dim=0)[0]
|
|
||||||
image = dit.norm_out(image, conditioning)
|
image = dit.norm_out(image, conditioning)
|
||||||
image = dit.proj_out(image)
|
image = dit.proj_out(image)
|
||||||
image = image[:, :image_seq_len]
|
image = image[:, :image_seq_len]
|
||||||
|
|
||||||
latents = rearrange(image, "B (N H W) (C P Q) -> (B N) C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2, B=1)
|
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||||
return latents
|
return latents
|
||||||
|
|||||||
@@ -126,7 +126,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||||
if use_usp:
|
if use_usp:
|
||||||
from ..utils.xfuser import initialize_usp
|
from ..utils.xfuser import initialize_usp
|
||||||
initialize_usp(device)
|
initialize_usp()
|
||||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||||
|
|
||||||
# Fetch models
|
# Fetch models
|
||||||
@@ -241,7 +241,6 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
tea_cache_model_id: Optional[str] = "",
|
tea_cache_model_id: Optional[str] = "",
|
||||||
# progress_bar
|
# progress_bar
|
||||||
progress_bar_cmd=tqdm,
|
progress_bar_cmd=tqdm,
|
||||||
output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized",
|
|
||||||
):
|
):
|
||||||
# Scheduler
|
# Scheduler
|
||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
||||||
@@ -321,11 +320,9 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
# Decode
|
# Decode
|
||||||
self.load_models_to_device(['vae'])
|
self.load_models_to_device(['vae'])
|
||||||
video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
if output_type == "quantized":
|
video = self.vae_output_to_video(video)
|
||||||
video = self.vae_output_to_video(video)
|
|
||||||
elif output_type == "floatpoint":
|
|
||||||
pass
|
|
||||||
self.load_models_to_device([])
|
self.load_models_to_device([])
|
||||||
|
|
||||||
return video
|
return video
|
||||||
|
|
||||||
|
|
||||||
@@ -826,9 +823,9 @@ class WanVideoUnit_S2V(PipelineUnit):
|
|||||||
pipe.load_models_to_device(["vae"])
|
pipe.load_models_to_device(["vae"])
|
||||||
motion_frames = 73
|
motion_frames = 73
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if motion_video is not None:
|
if motion_video is not None and len(motion_video) > 0:
|
||||||
assert motion_video.shape[2] == motion_frames, f"motion video must have {motion_frames} frames, but got {motion_video.shape[2]}"
|
assert len(motion_video) == motion_frames, f"motion video must have {motion_frames} frames, but got {len(motion_video)}"
|
||||||
motion_latents = motion_video
|
motion_latents = pipe.preprocess_video(motion_video)
|
||||||
kwargs["drop_motion_frames"] = False
|
kwargs["drop_motion_frames"] = False
|
||||||
else:
|
else:
|
||||||
motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device)
|
motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
|||||||
@@ -4,23 +4,16 @@ from typing import Union
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Union, List, Optional, Tuple, Iterable, Dict
|
from typing import Union, List, Optional, Tuple
|
||||||
|
|
||||||
from ..diffusion import FlowMatchScheduler
|
from ..diffusion import FlowMatchScheduler
|
||||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||||
from ..core.data.operators import ImageCropAndResize
|
|
||||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
||||||
from ..utils.lora import merge_lora
|
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from ..models.z_image_text_encoder import ZImageTextEncoder
|
from ..models.z_image_text_encoder import ZImageTextEncoder
|
||||||
from ..models.z_image_dit import ZImageDiT
|
from ..models.z_image_dit import ZImageDiT
|
||||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||||
from ..models.siglip2_image_encoder import Siglip2ImageEncoder428M
|
|
||||||
from ..models.z_image_controlnet import ZImageControlNet
|
|
||||||
from ..models.siglip2_image_encoder import Siglip2ImageEncoder
|
|
||||||
from ..models.dinov3_image_encoder import DINOv3ImageEncoder
|
|
||||||
from ..models.z_image_image2lora import ZImageImage2LoRAModel
|
|
||||||
|
|
||||||
|
|
||||||
class ZImagePipeline(BasePipeline):
|
class ZImagePipeline(BasePipeline):
|
||||||
@@ -35,22 +28,13 @@ class ZImagePipeline(BasePipeline):
|
|||||||
self.dit: ZImageDiT = None
|
self.dit: ZImageDiT = None
|
||||||
self.vae_encoder: FluxVAEEncoder = None
|
self.vae_encoder: FluxVAEEncoder = None
|
||||||
self.vae_decoder: FluxVAEDecoder = None
|
self.vae_decoder: FluxVAEDecoder = None
|
||||||
self.image_encoder: Siglip2ImageEncoder428M = None
|
|
||||||
self.controlnet: ZImageControlNet = None
|
|
||||||
self.siglip2_image_encoder: Siglip2ImageEncoder = None
|
|
||||||
self.dinov3_image_encoder: DINOv3ImageEncoder = None
|
|
||||||
self.image2lora_style: ZImageImage2LoRAModel = None
|
|
||||||
self.tokenizer: AutoTokenizer = None
|
self.tokenizer: AutoTokenizer = None
|
||||||
self.in_iteration_models = ("dit", "controlnet")
|
self.in_iteration_models = ("dit",)
|
||||||
self.units = [
|
self.units = [
|
||||||
ZImageUnit_ShapeChecker(),
|
ZImageUnit_ShapeChecker(),
|
||||||
ZImageUnit_PromptEmbedder(),
|
ZImageUnit_PromptEmbedder(),
|
||||||
ZImageUnit_NoiseInitializer(),
|
ZImageUnit_NoiseInitializer(),
|
||||||
ZImageUnit_InputImageEmbedder(),
|
ZImageUnit_InputImageEmbedder(),
|
||||||
ZImageUnit_EditImageAutoResize(),
|
|
||||||
ZImageUnit_EditImageEmbedderVAE(),
|
|
||||||
ZImageUnit_EditImageEmbedderSiglip(),
|
|
||||||
ZImageUnit_PAIControlNet(),
|
|
||||||
]
|
]
|
||||||
self.model_fn = model_fn_z_image
|
self.model_fn = model_fn_z_image
|
||||||
|
|
||||||
@@ -72,11 +56,6 @@ class ZImagePipeline(BasePipeline):
|
|||||||
pipe.dit = model_pool.fetch_model("z_image_dit")
|
pipe.dit = model_pool.fetch_model("z_image_dit")
|
||||||
pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder")
|
pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder")
|
||||||
pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder")
|
pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder")
|
||||||
pipe.image_encoder = model_pool.fetch_model("siglip_vision_model_428m")
|
|
||||||
pipe.controlnet = model_pool.fetch_model("z_image_controlnet")
|
|
||||||
pipe.siglip2_image_encoder = model_pool.fetch_model("siglip2_image_encoder")
|
|
||||||
pipe.dinov3_image_encoder = model_pool.fetch_model("dinov3_image_encoder")
|
|
||||||
pipe.image2lora_style = model_pool.fetch_model("z_image_image2lora_style")
|
|
||||||
if tokenizer_config is not None:
|
if tokenizer_config is not None:
|
||||||
tokenizer_config.download_if_necessary()
|
tokenizer_config.download_if_necessary()
|
||||||
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||||
@@ -96,9 +75,6 @@ class ZImagePipeline(BasePipeline):
|
|||||||
# Image
|
# Image
|
||||||
input_image: Image.Image = None,
|
input_image: Image.Image = None,
|
||||||
denoising_strength: float = 1.0,
|
denoising_strength: float = 1.0,
|
||||||
# Edit
|
|
||||||
edit_image: Image.Image = None,
|
|
||||||
edit_image_auto_resize: bool = True,
|
|
||||||
# Shape
|
# Shape
|
||||||
height: int = 1024,
|
height: int = 1024,
|
||||||
width: int = 1024,
|
width: int = 1024,
|
||||||
@@ -107,17 +83,11 @@ class ZImagePipeline(BasePipeline):
|
|||||||
rand_device: str = "cpu",
|
rand_device: str = "cpu",
|
||||||
# Steps
|
# Steps
|
||||||
num_inference_steps: int = 8,
|
num_inference_steps: int = 8,
|
||||||
sigma_shift: float = None,
|
|
||||||
# ControlNet
|
|
||||||
controlnet_inputs: List[ControlNetInput] = None,
|
|
||||||
# Image to LoRA
|
|
||||||
image2lora_images: List[Image.Image] = None,
|
|
||||||
positive_only_lora: Dict[str, torch.Tensor] = None,
|
|
||||||
# Progress bar
|
# Progress bar
|
||||||
progress_bar_cmd = tqdm,
|
progress_bar_cmd = tqdm,
|
||||||
):
|
):
|
||||||
# Scheduler
|
# Scheduler
|
||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
|
||||||
|
|
||||||
# Parameters
|
# Parameters
|
||||||
inputs_posi = {
|
inputs_posi = {
|
||||||
@@ -132,9 +102,6 @@ class ZImagePipeline(BasePipeline):
|
|||||||
"height": height, "width": width,
|
"height": height, "width": width,
|
||||||
"seed": seed, "rand_device": rand_device,
|
"seed": seed, "rand_device": rand_device,
|
||||||
"num_inference_steps": num_inference_steps,
|
"num_inference_steps": num_inference_steps,
|
||||||
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
|
|
||||||
"controlnet_inputs": controlnet_inputs,
|
|
||||||
"image2lora_images": image2lora_images, "positive_only_lora": positive_only_lora,
|
|
||||||
}
|
}
|
||||||
for unit in self.units:
|
for unit in self.units:
|
||||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||||
@@ -176,7 +143,6 @@ class ZImageUnit_PromptEmbedder(PipelineUnit):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
seperate_cfg=True,
|
seperate_cfg=True,
|
||||||
input_params=("edit_image",),
|
|
||||||
input_params_posi={"prompt": "prompt"},
|
input_params_posi={"prompt": "prompt"},
|
||||||
input_params_nega={"prompt": "negative_prompt"},
|
input_params_nega={"prompt": "negative_prompt"},
|
||||||
output_params=("prompt_embeds",),
|
output_params=("prompt_embeds",),
|
||||||
@@ -229,80 +195,9 @@ class ZImageUnit_PromptEmbedder(PipelineUnit):
|
|||||||
|
|
||||||
return embeddings_list
|
return embeddings_list
|
||||||
|
|
||||||
def encode_prompt_omni(
|
def process(self, pipe: ZImagePipeline, prompt):
|
||||||
self,
|
|
||||||
pipe,
|
|
||||||
prompt: Union[str, List[str]],
|
|
||||||
edit_image=None,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
max_sequence_length: int = 512,
|
|
||||||
) -> List[torch.FloatTensor]:
|
|
||||||
if isinstance(prompt, str):
|
|
||||||
prompt = [prompt]
|
|
||||||
|
|
||||||
if edit_image is None:
|
|
||||||
num_condition_images = 0
|
|
||||||
elif isinstance(edit_image, list):
|
|
||||||
num_condition_images = len(edit_image)
|
|
||||||
else:
|
|
||||||
num_condition_images = 1
|
|
||||||
|
|
||||||
for i, prompt_item in enumerate(prompt):
|
|
||||||
if num_condition_images == 0:
|
|
||||||
prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"]
|
|
||||||
elif num_condition_images > 0:
|
|
||||||
prompt_list = ["<|im_start|>user\n<|vision_start|>"]
|
|
||||||
prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1)
|
|
||||||
prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"]
|
|
||||||
prompt_list += ["<|vision_end|><|im_end|>"]
|
|
||||||
prompt[i] = prompt_list
|
|
||||||
|
|
||||||
flattened_prompt = []
|
|
||||||
prompt_list_lengths = []
|
|
||||||
|
|
||||||
for i in range(len(prompt)):
|
|
||||||
prompt_list_lengths.append(len(prompt[i]))
|
|
||||||
flattened_prompt.extend(prompt[i])
|
|
||||||
|
|
||||||
text_inputs = pipe.tokenizer(
|
|
||||||
flattened_prompt,
|
|
||||||
padding="max_length",
|
|
||||||
max_length=max_sequence_length,
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
|
|
||||||
text_input_ids = text_inputs.input_ids.to(device)
|
|
||||||
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
|
||||||
|
|
||||||
prompt_embeds = pipe.text_encoder(
|
|
||||||
input_ids=text_input_ids,
|
|
||||||
attention_mask=prompt_masks,
|
|
||||||
output_hidden_states=True,
|
|
||||||
).hidden_states[-2]
|
|
||||||
|
|
||||||
embeddings_list = []
|
|
||||||
start_idx = 0
|
|
||||||
for i in range(len(prompt_list_lengths)):
|
|
||||||
batch_embeddings = []
|
|
||||||
end_idx = start_idx + prompt_list_lengths[i]
|
|
||||||
for j in range(start_idx, end_idx):
|
|
||||||
batch_embeddings.append(prompt_embeds[j][prompt_masks[j]])
|
|
||||||
embeddings_list.append(batch_embeddings)
|
|
||||||
start_idx = end_idx
|
|
||||||
|
|
||||||
return embeddings_list
|
|
||||||
|
|
||||||
def process(self, pipe: ZImagePipeline, prompt, edit_image):
|
|
||||||
pipe.load_models_to_device(self.onload_model_names)
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
if hasattr(pipe, "dit") and pipe.dit.siglip_embedder is not None:
|
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
|
||||||
# Z-Image-Turbo and Z-Image-Omni-Base use different prompt encoding methods.
|
|
||||||
# We determine which encoding method to use based on the model architecture.
|
|
||||||
# If you are using two-stage split training,
|
|
||||||
# please use `--offload_models` instead of skipping the DiT model loading.
|
|
||||||
prompt_embeds = self.encode_prompt_omni(pipe, prompt, edit_image, pipe.device)
|
|
||||||
else:
|
|
||||||
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
|
|
||||||
return {"prompt_embeds": prompt_embeds}
|
return {"prompt_embeds": prompt_embeds}
|
||||||
|
|
||||||
|
|
||||||
@@ -339,330 +234,24 @@ class ZImageUnit_InputImageEmbedder(PipelineUnit):
|
|||||||
return {"latents": latents, "input_latents": input_latents}
|
return {"latents": latents, "input_latents": input_latents}
|
||||||
|
|
||||||
|
|
||||||
class ZImageUnit_EditImageAutoResize(PipelineUnit):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
input_params=("edit_image", "edit_image_auto_resize"),
|
|
||||||
output_params=("edit_image",),
|
|
||||||
)
|
|
||||||
|
|
||||||
def process(self, pipe: ZImagePipeline, edit_image, edit_image_auto_resize):
|
|
||||||
if edit_image is None:
|
|
||||||
return {}
|
|
||||||
if edit_image_auto_resize is None or not edit_image_auto_resize:
|
|
||||||
return {}
|
|
||||||
operator = ImageCropAndResize(max_pixels=1024*1024, height_division_factor=16, width_division_factor=16)
|
|
||||||
if not isinstance(edit_image, list):
|
|
||||||
edit_image = [edit_image]
|
|
||||||
edit_image = [operator(i) for i in edit_image]
|
|
||||||
return {"edit_image": edit_image}
|
|
||||||
|
|
||||||
|
|
||||||
class ZImageUnit_EditImageEmbedderSiglip(PipelineUnit):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
input_params=("edit_image",),
|
|
||||||
output_params=("image_embeds",),
|
|
||||||
onload_model_names=("image_encoder",)
|
|
||||||
)
|
|
||||||
|
|
||||||
def process(self, pipe: ZImagePipeline, edit_image):
|
|
||||||
if edit_image is None:
|
|
||||||
return {}
|
|
||||||
pipe.load_models_to_device(self.onload_model_names)
|
|
||||||
if not isinstance(edit_image, list):
|
|
||||||
edit_image = [edit_image]
|
|
||||||
image_emb = []
|
|
||||||
for image_ in edit_image:
|
|
||||||
image_emb.append(pipe.image_encoder(image_, device=pipe.device))
|
|
||||||
return {"image_embeds": image_emb}
|
|
||||||
|
|
||||||
|
|
||||||
class ZImageUnit_EditImageEmbedderVAE(PipelineUnit):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
input_params=("edit_image",),
|
|
||||||
output_params=("image_latents",),
|
|
||||||
onload_model_names=("vae_encoder",)
|
|
||||||
)
|
|
||||||
|
|
||||||
def process(self, pipe: ZImagePipeline, edit_image):
|
|
||||||
if edit_image is None:
|
|
||||||
return {}
|
|
||||||
pipe.load_models_to_device(self.onload_model_names)
|
|
||||||
if not isinstance(edit_image, list):
|
|
||||||
edit_image = [edit_image]
|
|
||||||
image_latents = []
|
|
||||||
for image_ in edit_image:
|
|
||||||
image_ = pipe.preprocess_image(image_)
|
|
||||||
image_latents.append(pipe.vae_encoder(image_))
|
|
||||||
return {"image_latents": image_latents}
|
|
||||||
|
|
||||||
|
|
||||||
class ZImageUnit_PAIControlNet(PipelineUnit):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
input_params=("controlnet_inputs", "height", "width"),
|
|
||||||
output_params=("control_context", "control_scale"),
|
|
||||||
onload_model_names=("vae_encoder",)
|
|
||||||
)
|
|
||||||
|
|
||||||
def process(self, pipe: ZImagePipeline, controlnet_inputs: List[ControlNetInput], height, width):
|
|
||||||
if controlnet_inputs is None:
|
|
||||||
return {}
|
|
||||||
if len(controlnet_inputs) != 1:
|
|
||||||
print("Z-Image ControlNet doesn't support multi-ControlNet. Only one image will be used.")
|
|
||||||
controlnet_input = controlnet_inputs[0]
|
|
||||||
pipe.load_models_to_device(self.onload_model_names)
|
|
||||||
|
|
||||||
control_image = controlnet_input.image
|
|
||||||
if control_image is not None:
|
|
||||||
control_image = pipe.preprocess_image(control_image)
|
|
||||||
control_latents = pipe.vae_encoder(control_image)
|
|
||||||
else:
|
|
||||||
control_latents = torch.ones((1, 16, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device) * -1
|
|
||||||
|
|
||||||
inpaint_mask = controlnet_input.inpaint_mask
|
|
||||||
if inpaint_mask is not None:
|
|
||||||
inpaint_mask = pipe.preprocess_image(inpaint_mask, min_value=0, max_value=1)
|
|
||||||
inpaint_image = controlnet_input.inpaint_image
|
|
||||||
inpaint_image = pipe.preprocess_image(inpaint_image)
|
|
||||||
inpaint_image = inpaint_image * (inpaint_mask < 0.5)
|
|
||||||
inpaint_mask = torch.nn.functional.interpolate(1 - inpaint_mask, (height // 8, width // 8), mode='nearest')[:, :1]
|
|
||||||
else:
|
|
||||||
inpaint_mask = torch.zeros((1, 1, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device)
|
|
||||||
inpaint_image = torch.zeros((1, 3, height, width), dtype=pipe.torch_dtype, device=pipe.device)
|
|
||||||
inpaint_latent = pipe.vae_encoder(inpaint_image)
|
|
||||||
|
|
||||||
control_context = torch.concat([control_latents, inpaint_mask, inpaint_latent], dim=1)
|
|
||||||
control_context = rearrange(control_context, "B C H W -> B C 1 H W")
|
|
||||||
return {"control_context": control_context, "control_scale": controlnet_input.scale}
|
|
||||||
|
|
||||||
|
|
||||||
def model_fn_z_image(
|
def model_fn_z_image(
|
||||||
dit: ZImageDiT,
|
dit: ZImageDiT,
|
||||||
controlnet: ZImageControlNet = None,
|
|
||||||
latents=None,
|
latents=None,
|
||||||
timestep=None,
|
timestep=None,
|
||||||
prompt_embeds=None,
|
prompt_embeds=None,
|
||||||
image_embeds=None,
|
|
||||||
image_latents=None,
|
|
||||||
use_gradient_checkpointing=False,
|
use_gradient_checkpointing=False,
|
||||||
use_gradient_checkpointing_offload=False,
|
use_gradient_checkpointing_offload=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# Due to the complex and verbose codebase of Z-Image,
|
|
||||||
# we are temporarily using this inelegant structure.
|
|
||||||
# We will refactor this part in the future (if time permits).
|
|
||||||
if dit.siglip_embedder is None:
|
|
||||||
return model_fn_z_image_turbo(
|
|
||||||
dit,
|
|
||||||
controlnet=controlnet,
|
|
||||||
latents=latents,
|
|
||||||
timestep=timestep,
|
|
||||||
prompt_embeds=prompt_embeds,
|
|
||||||
image_embeds=image_embeds,
|
|
||||||
image_latents=image_latents,
|
|
||||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
|
||||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
latents = [rearrange(latents, "B C H W -> C B H W")]
|
latents = [rearrange(latents, "B C H W -> C B H W")]
|
||||||
if dit.siglip_embedder is not None:
|
|
||||||
if image_latents is not None:
|
|
||||||
image_latents = [rearrange(image_latent, "B C H W -> C B H W") for image_latent in image_latents]
|
|
||||||
latents = [image_latents + latents]
|
|
||||||
image_noise_mask = [[0] * len(image_latents) + [1]]
|
|
||||||
else:
|
|
||||||
latents = [latents]
|
|
||||||
image_noise_mask = [[1]]
|
|
||||||
image_embeds = [image_embeds]
|
|
||||||
else:
|
|
||||||
image_noise_mask = None
|
|
||||||
timestep = (1000 - timestep) / 1000
|
timestep = (1000 - timestep) / 1000
|
||||||
model_output = dit(
|
model_output = dit(
|
||||||
latents,
|
latents,
|
||||||
timestep,
|
timestep,
|
||||||
prompt_embeds,
|
prompt_embeds,
|
||||||
siglip_feats=image_embeds,
|
|
||||||
image_noise_mask=image_noise_mask,
|
|
||||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
)[0]
|
)[0][0]
|
||||||
model_output = -model_output
|
model_output = -model_output
|
||||||
model_output = rearrange(model_output, "C B H W -> B C H W")
|
model_output = rearrange(model_output, "C B H W -> B C H W")
|
||||||
return model_output
|
return model_output
|
||||||
|
|
||||||
|
|
||||||
class ZImageUnit_Image2LoRAEncode(PipelineUnit):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
input_params=("image2lora_images",),
|
|
||||||
output_params=("image2lora_x",),
|
|
||||||
onload_model_names=("siglip2_image_encoder", "dinov3_image_encoder",),
|
|
||||||
)
|
|
||||||
from ..core.data.operators import ImageCropAndResize
|
|
||||||
self.processor_highres = ImageCropAndResize(height=1024, width=1024)
|
|
||||||
|
|
||||||
def encode_images_using_siglip2(self, pipe: ZImagePipeline, images: list[Image.Image]):
|
|
||||||
pipe.load_models_to_device(["siglip2_image_encoder"])
|
|
||||||
embs = []
|
|
||||||
for image in images:
|
|
||||||
image = self.processor_highres(image)
|
|
||||||
embs.append(pipe.siglip2_image_encoder(image).to(pipe.torch_dtype))
|
|
||||||
embs = torch.stack(embs)
|
|
||||||
return embs
|
|
||||||
|
|
||||||
def encode_images_using_dinov3(self, pipe: ZImagePipeline, images: list[Image.Image]):
|
|
||||||
pipe.load_models_to_device(["dinov3_image_encoder"])
|
|
||||||
embs = []
|
|
||||||
for image in images:
|
|
||||||
image = self.processor_highres(image)
|
|
||||||
embs.append(pipe.dinov3_image_encoder(image).to(pipe.torch_dtype))
|
|
||||||
embs = torch.stack(embs)
|
|
||||||
return embs
|
|
||||||
|
|
||||||
def encode_images(self, pipe: ZImagePipeline, images: list[Image.Image]):
|
|
||||||
if images is None:
|
|
||||||
return {}
|
|
||||||
if not isinstance(images, list):
|
|
||||||
images = [images]
|
|
||||||
embs_siglip2 = self.encode_images_using_siglip2(pipe, images)
|
|
||||||
embs_dinov3 = self.encode_images_using_dinov3(pipe, images)
|
|
||||||
x = torch.concat([embs_siglip2, embs_dinov3], dim=-1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def process(self, pipe: ZImagePipeline, image2lora_images):
|
|
||||||
if image2lora_images is None:
|
|
||||||
return {}
|
|
||||||
x = self.encode_images(pipe, image2lora_images)
|
|
||||||
return {"image2lora_x": x}
|
|
||||||
|
|
||||||
|
|
||||||
class ZImageUnit_Image2LoRADecode(PipelineUnit):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
input_params=("image2lora_x",),
|
|
||||||
output_params=("lora",),
|
|
||||||
onload_model_names=("image2lora_style",),
|
|
||||||
)
|
|
||||||
|
|
||||||
def process(self, pipe: ZImagePipeline, image2lora_x):
|
|
||||||
if image2lora_x is None:
|
|
||||||
return {}
|
|
||||||
loras = []
|
|
||||||
if pipe.image2lora_style is not None:
|
|
||||||
pipe.load_models_to_device(["image2lora_style"])
|
|
||||||
for x in image2lora_x:
|
|
||||||
loras.append(pipe.image2lora_style(x=x, residual=None))
|
|
||||||
lora = merge_lora(loras, alpha=1 / len(image2lora_x))
|
|
||||||
return {"lora": lora}
|
|
||||||
|
|
||||||
|
|
||||||
def model_fn_z_image_turbo(
|
|
||||||
dit: ZImageDiT,
|
|
||||||
controlnet: ZImageControlNet = None,
|
|
||||||
latents=None,
|
|
||||||
timestep=None,
|
|
||||||
prompt_embeds=None,
|
|
||||||
image_embeds=None,
|
|
||||||
image_latents=None,
|
|
||||||
control_context=None,
|
|
||||||
control_scale=None,
|
|
||||||
use_gradient_checkpointing=False,
|
|
||||||
use_gradient_checkpointing_offload=False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
while isinstance(prompt_embeds, list):
|
|
||||||
prompt_embeds = prompt_embeds[0]
|
|
||||||
while isinstance(latents, list):
|
|
||||||
latents = latents[0]
|
|
||||||
while isinstance(image_embeds, list):
|
|
||||||
image_embeds = image_embeds[0]
|
|
||||||
|
|
||||||
# Timestep
|
|
||||||
timestep = 1000 - timestep
|
|
||||||
t_noisy = dit.t_embedder(timestep)
|
|
||||||
t_clean = dit.t_embedder(torch.ones_like(timestep) * 1000)
|
|
||||||
|
|
||||||
# Patchify
|
|
||||||
latents = rearrange(latents, "B C H W -> C B H W")
|
|
||||||
x, cap_feats, patch_metadata = dit.patchify_and_embed([latents], [prompt_embeds])
|
|
||||||
x = x[0]
|
|
||||||
cap_feats = cap_feats[0]
|
|
||||||
|
|
||||||
# Noise refine
|
|
||||||
x = dit.all_x_embedder["2-1"](x)
|
|
||||||
x[torch.cat(patch_metadata.get("x_pad_mask"))] = dit.x_pad_token.to(dtype=x.dtype, device=x.device)
|
|
||||||
x_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("x_pos_ids"), dim=0))
|
|
||||||
x = rearrange(x, "L C -> 1 L C")
|
|
||||||
x_freqs_cis = rearrange(x_freqs_cis, "L C -> 1 L C")
|
|
||||||
|
|
||||||
if control_context is not None:
|
|
||||||
kwargs = dict(attn_mask=None, freqs_cis=x_freqs_cis, adaln_input=t_noisy)
|
|
||||||
refiner_hints, control_context, control_context_item_seqlens = controlnet.forward_refiner(
|
|
||||||
dit, x, [cap_feats], control_context, kwargs, t=t_noisy, patch_size=2, f_patch_size=1,
|
|
||||||
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
|
||||||
)
|
|
||||||
|
|
||||||
for layer_id, layer in enumerate(dit.noise_refiner):
|
|
||||||
x = gradient_checkpoint_forward(
|
|
||||||
layer,
|
|
||||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
|
||||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
|
||||||
x=x,
|
|
||||||
attn_mask=None,
|
|
||||||
freqs_cis=x_freqs_cis,
|
|
||||||
adaln_input=t_noisy,
|
|
||||||
)
|
|
||||||
if control_context is not None:
|
|
||||||
x = x + refiner_hints[layer_id] * control_scale
|
|
||||||
|
|
||||||
# Prompt refine
|
|
||||||
cap_feats = dit.cap_embedder(cap_feats)
|
|
||||||
cap_feats[torch.cat(patch_metadata.get("cap_pad_mask"))] = dit.cap_pad_token.to(dtype=x.dtype, device=x.device)
|
|
||||||
cap_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("cap_pos_ids"), dim=0))
|
|
||||||
cap_feats = rearrange(cap_feats, "L C -> 1 L C")
|
|
||||||
cap_freqs_cis = rearrange(cap_freqs_cis, "L C -> 1 L C")
|
|
||||||
|
|
||||||
for layer in dit.context_refiner:
|
|
||||||
cap_feats = gradient_checkpoint_forward(
|
|
||||||
layer,
|
|
||||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
|
||||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
|
||||||
x=cap_feats,
|
|
||||||
attn_mask=None,
|
|
||||||
freqs_cis=cap_freqs_cis,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Unified
|
|
||||||
unified = torch.cat([x, cap_feats], dim=1)
|
|
||||||
unified_freqs_cis = torch.cat([x_freqs_cis, cap_freqs_cis], dim=1)
|
|
||||||
|
|
||||||
if control_context is not None:
|
|
||||||
kwargs = dict(attn_mask=None, freqs_cis=unified_freqs_cis, adaln_input=t_noisy)
|
|
||||||
hints = controlnet.forward_layers(
|
|
||||||
unified, cap_feats, control_context, control_context_item_seqlens, kwargs,
|
|
||||||
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
|
||||||
)
|
|
||||||
|
|
||||||
for layer_id, layer in enumerate(dit.layers):
|
|
||||||
unified = gradient_checkpoint_forward(
|
|
||||||
layer,
|
|
||||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
|
||||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
|
||||||
x=unified,
|
|
||||||
attn_mask=None,
|
|
||||||
freqs_cis=unified_freqs_cis,
|
|
||||||
adaln_input=t_noisy,
|
|
||||||
)
|
|
||||||
if control_context is not None:
|
|
||||||
if layer_id in controlnet.control_layers_mapping:
|
|
||||||
unified = unified + hints[controlnet.control_layers_mapping[layer_id]] * control_scale
|
|
||||||
|
|
||||||
# Output
|
|
||||||
unified = dit.all_final_layer["2-1"](unified, t_noisy)
|
|
||||||
x = dit.unpatchify([unified[0]], patch_metadata.get("x_size"))[0]
|
|
||||||
x = rearrange(x, "C B H W -> B C H W")
|
|
||||||
x = -x
|
|
||||||
return x
|
|
||||||
|
|||||||
@@ -9,6 +9,5 @@ class ControlNetInput:
|
|||||||
start: float = 1.0
|
start: float = 1.0
|
||||||
end: float = 0.0
|
end: float = 0.0
|
||||||
image: Image.Image = None
|
image: Image.Image = None
|
||||||
inpaint_image: Image.Image = None
|
|
||||||
inpaint_mask: Image.Image = None
|
inpaint_mask: Image.Image = None
|
||||||
processor_id: str = None
|
processor_id: str = None
|
||||||
|
|||||||
@@ -1,3 +1 @@
|
|||||||
from .general import GeneralLoRALoader
|
from .general import GeneralLoRALoader
|
||||||
from .merge import merge_lora
|
|
||||||
from .reset_rank import reset_lora_rank
|
|
||||||
@@ -202,99 +202,3 @@ class FluxLoRALoader(GeneralLoRALoader):
|
|||||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
||||||
return state_dict_
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
class FluxLoRAConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def align_to_opensource_format(state_dict, alpha=None):
|
|
||||||
prefix_rename_dict = {
|
|
||||||
"single_blocks": "lora_unet_single_blocks",
|
|
||||||
"blocks": "lora_unet_double_blocks",
|
|
||||||
}
|
|
||||||
middle_rename_dict = {
|
|
||||||
"norm.linear": "modulation_lin",
|
|
||||||
"to_qkv_mlp": "linear1",
|
|
||||||
"proj_out": "linear2",
|
|
||||||
|
|
||||||
"norm1_a.linear": "img_mod_lin",
|
|
||||||
"norm1_b.linear": "txt_mod_lin",
|
|
||||||
"attn.a_to_qkv": "img_attn_qkv",
|
|
||||||
"attn.b_to_qkv": "txt_attn_qkv",
|
|
||||||
"attn.a_to_out": "img_attn_proj",
|
|
||||||
"attn.b_to_out": "txt_attn_proj",
|
|
||||||
"ff_a.0": "img_mlp_0",
|
|
||||||
"ff_a.2": "img_mlp_2",
|
|
||||||
"ff_b.0": "txt_mlp_0",
|
|
||||||
"ff_b.2": "txt_mlp_2",
|
|
||||||
}
|
|
||||||
suffix_rename_dict = {
|
|
||||||
"lora_B.weight": "lora_up.weight",
|
|
||||||
"lora_A.weight": "lora_down.weight",
|
|
||||||
}
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
names = name.split(".")
|
|
||||||
if names[-2] != "lora_A" and names[-2] != "lora_B":
|
|
||||||
names.pop(-2)
|
|
||||||
prefix = names[0]
|
|
||||||
middle = ".".join(names[2:-2])
|
|
||||||
suffix = ".".join(names[-2:])
|
|
||||||
block_id = names[1]
|
|
||||||
if middle not in middle_rename_dict:
|
|
||||||
continue
|
|
||||||
rename = prefix_rename_dict[prefix] + "_" + block_id + "_" + middle_rename_dict[middle] + "." + suffix_rename_dict[suffix]
|
|
||||||
state_dict_[rename] = param
|
|
||||||
if rename.endswith("lora_up.weight"):
|
|
||||||
lora_alpha = alpha if alpha is not None else param.shape[-1]
|
|
||||||
state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((lora_alpha,))[0]
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def align_to_diffsynth_format(state_dict):
|
|
||||||
rename_dict = {
|
|
||||||
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
|
|
||||||
"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
|
|
||||||
"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
|
|
||||||
"lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
|
|
||||||
"lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
|
|
||||||
"lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
|
|
||||||
"lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
|
|
||||||
}
|
|
||||||
def guess_block_id(name):
|
|
||||||
names = name.split("_")
|
|
||||||
for i in names:
|
|
||||||
if i.isdigit():
|
|
||||||
return i, name.replace(f"_{i}_", "_blockid_")
|
|
||||||
return None, None
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
block_id, source_name = guess_block_id(name)
|
|
||||||
if source_name in rename_dict:
|
|
||||||
target_name = rename_dict[source_name]
|
|
||||||
target_name = target_name.replace(".blockid.", f".{block_id}.")
|
|
||||||
state_dict_[target_name] = param
|
|
||||||
else:
|
|
||||||
state_dict_[name] = param
|
|
||||||
return state_dict_
|
|
||||||
|
|||||||
@@ -1,20 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
def decomposite(tensor_A, tensor_B, rank):
|
|
||||||
dtype, device = tensor_A.dtype, tensor_A.device
|
|
||||||
weight = tensor_B @ tensor_A
|
|
||||||
U, S, V = torch.pca_lowrank(weight.float(), q=rank)
|
|
||||||
tensor_A = (V.T).to(dtype=dtype, device=device).contiguous()
|
|
||||||
tensor_B = (U @ torch.diag(S)).to(dtype=dtype, device=device).contiguous()
|
|
||||||
return tensor_A, tensor_B
|
|
||||||
|
|
||||||
def reset_lora_rank(lora, rank):
|
|
||||||
lora_merged = {}
|
|
||||||
keys = [i for i in lora.keys() if ".lora_A." in i]
|
|
||||||
for key in keys:
|
|
||||||
tensor_A = lora[key]
|
|
||||||
tensor_B = lora[key.replace(".lora_A.", ".lora_B.")]
|
|
||||||
tensor_A, tensor_B = decomposite(tensor_A, tensor_B, rank)
|
|
||||||
lora_merged[key] = tensor_A
|
|
||||||
lora_merged[key.replace(".lora_A.", ".lora_B.")] = tensor_B
|
|
||||||
return lora_merged
|
|
||||||
@@ -5,20 +5,19 @@ from xfuser.core.distributed import (get_sequence_parallel_rank,
|
|||||||
get_sequence_parallel_world_size,
|
get_sequence_parallel_world_size,
|
||||||
get_sp_group)
|
get_sp_group)
|
||||||
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
||||||
from ...core.device import parse_nccl_backend, parse_device_type
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_usp(device_type):
|
def initialize_usp():
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
|
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
|
||||||
dist.init_process_group(backend=parse_nccl_backend(device_type), init_method="env://")
|
dist.init_process_group(backend="nccl", init_method="env://")
|
||||||
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
|
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
|
||||||
initialize_model_parallel(
|
initialize_model_parallel(
|
||||||
sequence_parallel_degree=dist.get_world_size(),
|
sequence_parallel_degree=dist.get_world_size(),
|
||||||
ring_degree=1,
|
ring_degree=1,
|
||||||
ulysses_degree=dist.get_world_size(),
|
ulysses_degree=dist.get_world_size(),
|
||||||
)
|
)
|
||||||
getattr(torch, device_type).set_device(dist.get_rank())
|
torch.cuda.set_device(dist.get_rank())
|
||||||
|
|
||||||
|
|
||||||
def sinusoidal_embedding_1d(dim, position):
|
def sinusoidal_embedding_1d(dim, position):
|
||||||
@@ -142,5 +141,5 @@ def usp_attn_forward(self, x, freqs):
|
|||||||
x = x.flatten(2)
|
x = x.flatten(2)
|
||||||
|
|
||||||
del q, k, v
|
del q, k, v
|
||||||
getattr(torch, parse_device_type(x.device)).empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return self.o(x)
|
return self.o(x)
|
||||||
@@ -81,11 +81,8 @@ graph LR;
|
|||||||
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
||||||
| - | - | - | - | - | - | - |
|
| - | - | - | - | - | - | - |
|
||||||
| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](/examples/qwen_image/model_inference/Qwen-Image.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) |
|
| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](/examples/qwen_image/model_inference/Qwen-Image.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) |
|
||||||
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
|
|
||||||
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
|
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
|
||||||
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
|
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
|
||||||
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
|
||||||
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
|
||||||
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
|
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
|
||||||
| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
|
| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
|
||||||
| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) |
|
| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) |
|
||||||
@@ -96,7 +93,6 @@ graph LR;
|
|||||||
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) |
|
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) |
|
||||||
| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) |
|
| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) |
|
||||||
| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - |
|
| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - |
|
||||||
|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
|
|
||||||
|
|
||||||
Special Training Scripts:
|
Special Training Scripts:
|
||||||
|
|
||||||
|
|||||||
@@ -138,4 +138,4 @@ Training Tips:
|
|||||||
* Differential LoRA Training ([code](/examples/z_image/model_training/special/differential_training/)) + Acceleration Configuration Inference
|
* Differential LoRA Training ([code](/examples/z_image/model_training/special/differential_training/)) + Acceleration Configuration Inference
|
||||||
* An additional LoRA needs to be loaded in differential LoRA training, e.g., [ostris/zimage_turbo_training_adapter](https://www.modelscope.cn/models/ostris/zimage_turbo_training_adapter)
|
* An additional LoRA needs to be loaded in differential LoRA training, e.g., [ostris/zimage_turbo_training_adapter](https://www.modelscope.cn/models/ostris/zimage_turbo_training_adapter)
|
||||||
* Standard SFT Training ([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Trajectory Imitation Distillation Training ([code](/examples/z_image/model_training/special/trajectory_imitation/)) + Acceleration Configuration Inference
|
* Standard SFT Training ([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Trajectory Imitation Distillation Training ([code](/examples/z_image/model_training/special/trajectory_imitation/)) + Acceleration Configuration Inference
|
||||||
* Standard SFT Training ([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Load Distillation Acceleration LoRA During Inference ([model](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-Turbo-DistillPatch)) + Acceleration Configuration Inference
|
* Standard SFT Training ([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Load Distillation Acceleration LoRA During Inference ([model](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-Turbo-DistillFix)) + Acceleration Configuration Inference
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
# GPU/NPU Support
|
|
||||||
|
|
||||||
`DiffSynth-Studio` supports various GPUs and NPUs. This document explains how to run model inference and training on these devices.
|
|
||||||
|
|
||||||
Before you begin, please follow the [Installation Guide](/docs/en/Pipeline_Usage/Setup.md) to install the required GPU/NPU dependencies.
|
|
||||||
|
|
||||||
## NVIDIA GPU
|
|
||||||
|
|
||||||
All sample code provided by this project supports NVIDIA GPUs by default, requiring no additional modifications.
|
|
||||||
|
|
||||||
## AMD GPU
|
|
||||||
|
|
||||||
AMD provides PyTorch packages based on ROCm, so most models can run without code changes. A small number of models may not be compatible due to their reliance on CUDA-specific instructions.
|
|
||||||
|
|
||||||
## Ascend NPU
|
|
||||||
|
|
||||||
When using Ascend NPU, you need to replace `"cuda"` with `"npu"` in your code.
|
|
||||||
|
|
||||||
For example, here is the inference code for **Wan2.1-T2V-1.3B**, modified for Ascend NPU:
|
|
||||||
|
|
||||||
```diff
|
|
||||||
import torch
|
|
||||||
from diffsynth.utils.data import save_video, VideoData
|
|
||||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
|
||||||
|
|
||||||
vram_config = {
|
|
||||||
"offload_dtype": "disk",
|
|
||||||
"offload_device": "disk",
|
|
||||||
"onload_dtype": torch.bfloat16,
|
|
||||||
"onload_device": "cpu",
|
|
||||||
"preparing_dtype": torch.bfloat16,
|
|
||||||
- "preparing_device": "cuda",
|
|
||||||
+ "preparing_device": "npu",
|
|
||||||
"computation_dtype": torch.bfloat16,
|
|
||||||
- "computation_device": "cuda",
|
|
||||||
+ "computation_device": "npu",
|
|
||||||
}
|
|
||||||
pipe = WanVideoPipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
- device="cuda",
|
|
||||||
+ device="npu",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config),
|
|
||||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
|
||||||
- vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
|
|
||||||
+ vram_limit=torch.npu.mem_get_info("npu:0")[1] / (1024 ** 3) - 2,
|
|
||||||
)
|
|
||||||
|
|
||||||
video = pipe(
|
|
||||||
prompt="Documentary-style photography: a lively puppy running swiftly across lush green grass. The puppy has brownish-yellow fur, upright ears, and an alert, joyful expression. Sunlight bathes its body, making the fur appear exceptionally soft and shiny. The background is an open field with occasional wildflowers, and faint blue sky with scattered white clouds in the distance. Strong perspective captures the motion of the running puppy and the vitality of the surrounding grass. Mid-shot, side-moving viewpoint.",
|
|
||||||
negative_prompt="Overly vibrant colors, overexposed, static, blurry details, subtitles, artistic style, painting, still image, overall grayish tone, worst quality, low quality, JPEG artifacts, ugly, distorted, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, fused fingers, motionless scene, cluttered background, three legs, many people in background, walking backward",
|
|
||||||
seed=0, tiled=True,
|
|
||||||
)
|
|
||||||
save_video(video, "video.mp4", fps=15, quality=5)
|
|
||||||
```
|
|
||||||
@@ -14,35 +14,8 @@ Install from PyPI (there may be delays in version updates; for latest features,
|
|||||||
pip install diffsynth
|
pip install diffsynth
|
||||||
```
|
```
|
||||||
|
|
||||||
## GPU/NPU Support
|
If you encounter issues during installation, they may be caused by upstream dependency packages. Please refer to the documentation for these packages:
|
||||||
|
|
||||||
* **NVIDIA GPU**
|
|
||||||
|
|
||||||
Install as described above.
|
|
||||||
|
|
||||||
* **AMD GPU**
|
|
||||||
|
|
||||||
You need to install the `torch` package with ROCm support. Taking ROCm 6.4 (as of the article update date: December 15, 2025) on Linux as an example, run the following command:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm6.4
|
|
||||||
```
|
|
||||||
|
|
||||||
* **Ascend NPU**
|
|
||||||
|
|
||||||
Ascend NPU support is provided via the `torch-npu` package. Taking version `2.1.0.post17` (as of the article update date: December 15, 2025) as an example, run the following command:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
pip install torch-npu==2.1.0.post17
|
|
||||||
```
|
|
||||||
|
|
||||||
When using Ascend NPU, please replace `"cuda"` with `"npu"` in your Python code. For details, see [NPU Support](/docs/en/Pipeline_Usage/GPU_support.md#ascend-npu).
|
|
||||||
|
|
||||||
## Other Installation Issues
|
|
||||||
|
|
||||||
If you encounter issues during installation, they may be caused by upstream dependencies. Please refer to the documentation for these packages:
|
|
||||||
|
|
||||||
* [torch](https://pytorch.org/get-started/locally/)
|
* [torch](https://pytorch.org/get-started/locally/)
|
||||||
* [Ascend/pytorch](https://github.com/Ascend/pytorch)
|
|
||||||
* [sentencepiece](https://github.com/google/sentencepiece)
|
* [sentencepiece](https://github.com/google/sentencepiece)
|
||||||
* [cmake](https://cmake.org)
|
* [cmake](https://cmake.org)
|
||||||
@@ -31,7 +31,6 @@ This section introduces the basic usage of `DiffSynth-Studio`, including how to
|
|||||||
* [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md)
|
* [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md)
|
||||||
* [Model Training](/docs/en/Pipeline_Usage/Model_Training.md)
|
* [Model Training](/docs/en/Pipeline_Usage/Model_Training.md)
|
||||||
* [Environment Variables](/docs/en/Pipeline_Usage/Environment_Variables.md)
|
* [Environment Variables](/docs/en/Pipeline_Usage/Environment_Variables.md)
|
||||||
* [GPU/NPU Support](/docs/en/Pipeline_Usage/GPU_support.md)
|
|
||||||
|
|
||||||
## Section 2: Model Details
|
## Section 2: Model Details
|
||||||
|
|
||||||
|
|||||||
@@ -81,11 +81,8 @@ graph LR;
|
|||||||
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
||||||
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
|
|
||||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||||
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
||||||
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
|
||||||
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
||||||
@@ -96,7 +93,6 @@ graph LR;
|
|||||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|
||||||
|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
|
|
||||||
|
|
||||||
特殊训练脚本:
|
特殊训练脚本:
|
||||||
|
|
||||||
|
|||||||
@@ -138,4 +138,4 @@ modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir
|
|||||||
* 差分 LoRA 训练([code](/examples/z_image/model_training/special/differential_training/)) + 加速配置推理
|
* 差分 LoRA 训练([code](/examples/z_image/model_training/special/differential_training/)) + 加速配置推理
|
||||||
* 差分 LoRA 训练中需加载一个额外的 LoRA,例如 [ostris/zimage_turbo_training_adapter](https://www.modelscope.cn/models/ostris/zimage_turbo_training_adapter)
|
* 差分 LoRA 训练中需加载一个额外的 LoRA,例如 [ostris/zimage_turbo_training_adapter](https://www.modelscope.cn/models/ostris/zimage_turbo_training_adapter)
|
||||||
* 标准 SFT 训练([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh))+ 轨迹模仿蒸馏训练([code](/examples/z_image/model_training/special/trajectory_imitation/))+ 加速配置推理
|
* 标准 SFT 训练([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh))+ 轨迹模仿蒸馏训练([code](/examples/z_image/model_training/special/trajectory_imitation/))+ 加速配置推理
|
||||||
* 标准 SFT 训练([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh))+ 推理时加载蒸馏加速 LoRA([model](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-Turbo-DistillPatch)) + 加速配置推理
|
* 标准 SFT 训练([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh))+ 推理时加载蒸馏加速 LoRA([model](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-Turbo-DistillFix)) + 加速配置推理
|
||||||
|
|||||||
@@ -1,58 +0,0 @@
|
|||||||
# GPU/NPU 支持
|
|
||||||
|
|
||||||
`DiffSynth-Studio` 支持多种 GPU/NPU,本文介绍如何在这些设备上运行模型推理和训练。
|
|
||||||
|
|
||||||
在开始前,请参考[安装依赖](/docs/zh/Pipeline_Usage/Setup.md)安装好 GPU/NPU 相关的依赖包。
|
|
||||||
|
|
||||||
## NVIDIA GPU
|
|
||||||
|
|
||||||
本项目提供的所有样例代码默认支持 NVIDIA GPU,无需额外修改。
|
|
||||||
|
|
||||||
## AMD GPU
|
|
||||||
|
|
||||||
AMD 提供了基于 ROCm 的 torch 包,所以大多数模型无需修改代码即可运行,少数模型由于依赖特定的 cuda 指令无法运行。
|
|
||||||
|
|
||||||
## Ascend NPU
|
|
||||||
|
|
||||||
使用 Ascend NPU 时,需把代码中的 `"cuda"` 改为 `"npu"`。
|
|
||||||
|
|
||||||
例如,Wan2.1-T2V-1.3B 的推理代码:
|
|
||||||
|
|
||||||
```diff
|
|
||||||
import torch
|
|
||||||
from diffsynth.utils.data import save_video, VideoData
|
|
||||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
|
||||||
|
|
||||||
vram_config = {
|
|
||||||
"offload_dtype": "disk",
|
|
||||||
"offload_device": "disk",
|
|
||||||
"onload_dtype": torch.bfloat16,
|
|
||||||
"onload_device": "cpu",
|
|
||||||
"preparing_dtype": torch.bfloat16,
|
|
||||||
- "preparing_device": "cuda",
|
|
||||||
+ "preparing_device": "npu",
|
|
||||||
"computation_dtype": torch.bfloat16,
|
|
||||||
- "computation_device": "cuda",
|
|
||||||
+ "preparing_device": "npu",
|
|
||||||
}
|
|
||||||
pipe = WanVideoPipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
- device="cuda",
|
|
||||||
+ device="npu",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config),
|
|
||||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
|
||||||
- vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
|
|
||||||
+ vram_limit=torch.npu.mem_get_info("npu:0")[1] / (1024 ** 3) - 2,
|
|
||||||
)
|
|
||||||
|
|
||||||
video = pipe(
|
|
||||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
|
||||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
|
||||||
seed=0, tiled=True,
|
|
||||||
)
|
|
||||||
save_video(video, "video.mp4", fps=15, quality=5)
|
|
||||||
```
|
|
||||||
@@ -14,35 +14,8 @@ pip install -e .
|
|||||||
pip install diffsynth
|
pip install diffsynth
|
||||||
```
|
```
|
||||||
|
|
||||||
## GPU/NPU 支持
|
|
||||||
|
|
||||||
* NVIDIA GPU
|
|
||||||
|
|
||||||
按照以上方式安装即可。
|
|
||||||
|
|
||||||
* AMD GPU
|
|
||||||
|
|
||||||
需安装支持 ROCm 的 `torch` 包,以 ROCm 6.4(本文更新于 2025 年 12 月 15 日)、Linux 系统为例,请运行以下命令
|
|
||||||
|
|
||||||
```shell
|
|
||||||
pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm6.4
|
|
||||||
```
|
|
||||||
|
|
||||||
* Ascend NPU
|
|
||||||
|
|
||||||
Ascend NPU 通过 `torch-npu` 包提供支持,以 `2.1.0.post17` 版本(本文更新于 2025 年 12 月 15 日)为例,请运行以下命令
|
|
||||||
|
|
||||||
```shell
|
|
||||||
pip install torch-npu==2.1.0.post17
|
|
||||||
```
|
|
||||||
|
|
||||||
使用 Ascend NPU 时,请将 Python 代码中的 `"cuda"` 改为 `"npu"`,详见[NPU 支持](/docs/zh/Pipeline_Usage/GPU_support.md#ascend-npu)。
|
|
||||||
|
|
||||||
## 其他安装问题
|
|
||||||
|
|
||||||
如果在安装过程中遇到问题,可能是由上游依赖包导致的,请参考这些包的文档:
|
如果在安装过程中遇到问题,可能是由上游依赖包导致的,请参考这些包的文档:
|
||||||
|
|
||||||
* [torch](https://pytorch.org/get-started/locally/)
|
* [torch](https://pytorch.org/get-started/locally/)
|
||||||
* [Ascend/pytorch](https://github.com/Ascend/pytorch)
|
|
||||||
* [sentencepiece](https://github.com/google/sentencepiece)
|
* [sentencepiece](https://github.com/google/sentencepiece)
|
||||||
* [cmake](https://cmake.org)
|
* [cmake](https://cmake.org)
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ graph LR;
|
|||||||
* [显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)
|
* [显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)
|
||||||
* [模型训练](/docs/zh/Pipeline_Usage/Model_Training.md)
|
* [模型训练](/docs/zh/Pipeline_Usage/Model_Training.md)
|
||||||
* [环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md)
|
* [环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md)
|
||||||
* [GPU/NPU 支持](/docs/zh/Pipeline_Usage/GPU_support.md)
|
|
||||||
|
|
||||||
## Section 2: 模型详解
|
## Section 2: 模型详解
|
||||||
|
|
||||||
|
|||||||
@@ -108,14 +108,7 @@ def test_flux():
|
|||||||
run_inference("examples/flux/model_training/validate_lora")
|
run_inference("examples/flux/model_training/validate_lora")
|
||||||
|
|
||||||
|
|
||||||
def test_z_image():
|
|
||||||
run_inference("examples/z_image/model_inference")
|
|
||||||
run_inference("examples/z_image/model_inference_low_vram")
|
|
||||||
run_train_multi_GPU("examples/z_image/model_training/full")
|
|
||||||
run_inference("examples/z_image/model_training/validate_full")
|
|
||||||
run_train_single_GPU("examples/z_image/model_training/lora")
|
|
||||||
run_inference("examples/z_image/model_training/validate_lora")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_z_image()
|
test_qwen_image()
|
||||||
|
test_flux()
|
||||||
|
test_wan()
|
||||||
|
|||||||
@@ -1,17 +0,0 @@
|
|||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image-2512", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
|
||||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
|
||||||
image.save("image.jpg")
|
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
|
||||||
from modelscope import dataset_snapshot_download
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2511", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_snapshot_download(
|
|
||||||
"DiffSynth-Studio/example_image_dataset",
|
|
||||||
allow_file_pattern="qwen_image_edit/*",
|
|
||||||
local_dir="data/example_image_dataset",
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt = "生成这两个人的合影"
|
|
||||||
edit_image = [
|
|
||||||
Image.open("data/example_image_dataset/qwen_image_edit/image1.jpg"),
|
|
||||||
Image.open("data/example_image_dataset/qwen_image_edit/image2.jpg"),
|
|
||||||
]
|
|
||||||
image = pipe(
|
|
||||||
prompt,
|
|
||||||
edit_image=edit_image,
|
|
||||||
seed=1,
|
|
||||||
num_inference_steps=40,
|
|
||||||
height=1152,
|
|
||||||
width=896,
|
|
||||||
edit_image_auto_resize=True,
|
|
||||||
zero_cond_t=True, # This is a special parameter introduced by Qwen-Image-Edit-2511
|
|
||||||
)
|
|
||||||
image.save("image.jpg")
|
|
||||||
|
|
||||||
# Qwen-Image-Edit-2511 is a multi-image editing model.
|
|
||||||
# Please use a list to input `edit_image`, even if the input contains only one image.
|
|
||||||
# edit_image = [Image.open("image.jpg")]
|
|
||||||
# Please do not input the image directly.
|
|
||||||
# edit_image = Image.open("image.jpg")
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
|
||||||
from modelscope import dataset_snapshot_download
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_snapshot_download(
|
|
||||||
"DiffSynth-Studio/example_image_dataset",
|
|
||||||
allow_patterns="layer/image.png",
|
|
||||||
local_dir="data/example_image_dataset"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prompt should be provided to the pipeline. Our pipeline will not generate the prompt.
|
|
||||||
prompt = 'A cheerful child with brown hair is waving enthusiastically under a bright blue sky filled with colorful confetti and balloons. The word "HELLO!" is prominently displayed in bold red letters above the child, while "Have a Great Day!" appears in elegant cursive at the bottom right corner. The scene is vibrant and festive, with a mix of pastel colors and dynamic shapes creating a joyful atmosphere.'
|
|
||||||
# Height and width should be consistent with input_image and be divided evenly by 16
|
|
||||||
input_image = Image.open("data/example_image_dataset/layer/image.png").convert("RGBA").resize((864, 480))
|
|
||||||
images = pipe(
|
|
||||||
prompt,
|
|
||||||
seed=1, num_inference_steps=50,
|
|
||||||
height=480, width=864,
|
|
||||||
layer_input_image=input_image, layer_num=3,
|
|
||||||
)
|
|
||||||
for i, image in enumerate(images):
|
|
||||||
if i == 0: continue # The first image is the input image.
|
|
||||||
image.save(f"image_{i}.png")
|
|
||||||
@@ -1,110 +0,0 @@
|
|||||||
from diffsynth.pipelines.qwen_image import (
|
|
||||||
QwenImagePipeline, ModelConfig,
|
|
||||||
QwenImageUnit_Image2LoRAEncode, QwenImageUnit_Image2LoRADecode
|
|
||||||
)
|
|
||||||
from diffsynth.utils.lora import merge_lora
|
|
||||||
from diffsynth import load_state_dict
|
|
||||||
from modelscope import snapshot_download
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
def demo_style():
|
|
||||||
# Load models
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors"),
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors"),
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Style.safetensors"),
|
|
||||||
],
|
|
||||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load images
|
|
||||||
snapshot_download(
|
|
||||||
model_id="DiffSynth-Studio/Qwen-Image-i2L",
|
|
||||||
allow_file_pattern="assets/style/1/*",
|
|
||||||
local_dir="data/examples"
|
|
||||||
)
|
|
||||||
images = [
|
|
||||||
Image.open("data/examples/assets/style/1/0.jpg"),
|
|
||||||
Image.open("data/examples/assets/style/1/1.jpg"),
|
|
||||||
Image.open("data/examples/assets/style/1/2.jpg"),
|
|
||||||
Image.open("data/examples/assets/style/1/3.jpg"),
|
|
||||||
Image.open("data/examples/assets/style/1/4.jpg"),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Model inference
|
|
||||||
with torch.no_grad():
|
|
||||||
embs = QwenImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images)
|
|
||||||
lora = QwenImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"]
|
|
||||||
save_file(lora, "model_style.safetensors")
|
|
||||||
|
|
||||||
|
|
||||||
def demo_coarse_fine_bias():
|
|
||||||
# Load models
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors"),
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors"),
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Coarse.safetensors"),
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Fine.safetensors"),
|
|
||||||
],
|
|
||||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load images
|
|
||||||
snapshot_download(
|
|
||||||
model_id="DiffSynth-Studio/Qwen-Image-i2L",
|
|
||||||
allow_file_pattern="assets/lora/3/*",
|
|
||||||
local_dir="data/examples"
|
|
||||||
)
|
|
||||||
images = [
|
|
||||||
Image.open("data/examples/assets/lora/3/0.jpg"),
|
|
||||||
Image.open("data/examples/assets/lora/3/1.jpg"),
|
|
||||||
Image.open("data/examples/assets/lora/3/2.jpg"),
|
|
||||||
Image.open("data/examples/assets/lora/3/3.jpg"),
|
|
||||||
Image.open("data/examples/assets/lora/3/4.jpg"),
|
|
||||||
Image.open("data/examples/assets/lora/3/5.jpg"),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Model inference
|
|
||||||
with torch.no_grad():
|
|
||||||
embs = QwenImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images)
|
|
||||||
lora = QwenImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"]
|
|
||||||
lora_bias = ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Bias.safetensors")
|
|
||||||
lora_bias.download_if_necessary()
|
|
||||||
lora_bias = load_state_dict(lora_bias.path, torch_dtype=torch.bfloat16, device="cuda")
|
|
||||||
lora = merge_lora([lora, lora_bias])
|
|
||||||
save_file(lora, "model_coarse_fine_bias.safetensors")
|
|
||||||
|
|
||||||
|
|
||||||
def generate_image(lora_path, prompt, seed):
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
pipe.load_lora(pipe.dit, lora_path)
|
|
||||||
image = pipe(prompt, seed=seed, height=1024, width=1024, num_inference_steps=50)
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
demo_style()
|
|
||||||
image = generate_image("model_style.safetensors", "a cat", 0)
|
|
||||||
image.save("image_1.jpg")
|
|
||||||
|
|
||||||
demo_coarse_fine_bias()
|
|
||||||
image = generate_image("model_coarse_fine_bias.safetensors", "bowl", 1)
|
|
||||||
image.save("image_2.jpg")
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
vram_config = {
|
|
||||||
"offload_dtype": "disk",
|
|
||||||
"offload_device": "disk",
|
|
||||||
"onload_dtype": torch.float8_e4m3fn,
|
|
||||||
"onload_device": "cpu",
|
|
||||||
"preparing_dtype": torch.float8_e4m3fn,
|
|
||||||
"preparing_device": "cuda",
|
|
||||||
"computation_dtype": torch.bfloat16,
|
|
||||||
"computation_device": "cuda",
|
|
||||||
}
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image-2512", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
|
||||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
|
||||||
)
|
|
||||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
|
||||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
|
||||||
image.save("image.jpg")
|
|
||||||
@@ -1,54 +0,0 @@
|
|||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
|
||||||
from modelscope import dataset_snapshot_download
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
vram_config = {
|
|
||||||
"offload_dtype": "disk",
|
|
||||||
"offload_device": "disk",
|
|
||||||
"onload_dtype": torch.float8_e4m3fn,
|
|
||||||
"onload_device": "cpu",
|
|
||||||
"preparing_dtype": torch.float8_e4m3fn,
|
|
||||||
"preparing_device": "cuda",
|
|
||||||
"computation_dtype": torch.bfloat16,
|
|
||||||
"computation_device": "cuda",
|
|
||||||
}
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2511", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
|
||||||
],
|
|
||||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_snapshot_download(
|
|
||||||
"DiffSynth-Studio/example_image_dataset",
|
|
||||||
allow_file_pattern="qwen_image_edit/*",
|
|
||||||
local_dir="data/example_image_dataset",
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt = "生成这两个人的合影"
|
|
||||||
edit_image = [
|
|
||||||
Image.open("data/example_image_dataset/qwen_image_edit/image1.jpg"),
|
|
||||||
Image.open("data/example_image_dataset/qwen_image_edit/image2.jpg"),
|
|
||||||
]
|
|
||||||
image = pipe(
|
|
||||||
prompt,
|
|
||||||
edit_image=edit_image,
|
|
||||||
seed=1,
|
|
||||||
num_inference_steps=40,
|
|
||||||
height=1152,
|
|
||||||
width=896,
|
|
||||||
edit_image_auto_resize=True,
|
|
||||||
zero_cond_t=True, # This is a special parameter introduced by Qwen-Image-Edit-2511
|
|
||||||
)
|
|
||||||
image.save("image.jpg")
|
|
||||||
|
|
||||||
# Qwen-Image-Edit-2511 is a multi-image editing model.
|
|
||||||
# Please use a list to input `edit_image`, even if the input contains only one image.
|
|
||||||
# edit_image = [Image.open("image.jpg")]
|
|
||||||
# Please do not input the image directly.
|
|
||||||
# edit_image = Image.open("image.jpg")
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
|
||||||
from modelscope import dataset_snapshot_download
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
vram_config = {
|
|
||||||
"offload_dtype": "disk",
|
|
||||||
"offload_device": "disk",
|
|
||||||
"onload_dtype": torch.float8_e4m3fn,
|
|
||||||
"onload_device": "cpu",
|
|
||||||
"preparing_dtype": torch.float8_e4m3fn,
|
|
||||||
"preparing_device": "cuda",
|
|
||||||
"computation_dtype": torch.bfloat16,
|
|
||||||
"computation_device": "cuda",
|
|
||||||
}
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
|
||||||
],
|
|
||||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_snapshot_download(
|
|
||||||
"DiffSynth-Studio/example_image_dataset",
|
|
||||||
allow_patterns="layer/image.png",
|
|
||||||
local_dir="data/example_image_dataset"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prompt should be provided to the pipeline. Our pipeline will not generate the prompt.
|
|
||||||
prompt = 'A cheerful child with brown hair is waving enthusiastically under a bright blue sky filled with colorful confetti and balloons. The word "HELLO!" is prominently displayed in bold red letters above the child, while "Have a Great Day!" appears in elegant cursive at the bottom right corner. The scene is vibrant and festive, with a mix of pastel colors and dynamic shapes creating a joyful atmosphere.'
|
|
||||||
# Height and width should be consistent with input_image and be divided evenly by 16
|
|
||||||
input_image = Image.open("data/example_image_dataset/layer/image.png").convert("RGBA").resize((864, 480))
|
|
||||||
images = pipe(
|
|
||||||
prompt,
|
|
||||||
seed=1, num_inference_steps=50,
|
|
||||||
height=480, width=864,
|
|
||||||
layer_input_image=input_image, layer_num=3,
|
|
||||||
)
|
|
||||||
for i, image in enumerate(images):
|
|
||||||
if i == 0: continue # The first image is the input image.
|
|
||||||
image.save(f"image_{i}.png")
|
|
||||||
@@ -1,134 +0,0 @@
|
|||||||
from diffsynth.pipelines.qwen_image import (
|
|
||||||
QwenImagePipeline, ModelConfig,
|
|
||||||
QwenImageUnit_Image2LoRAEncode, QwenImageUnit_Image2LoRADecode
|
|
||||||
)
|
|
||||||
from diffsynth.utils.lora import merge_lora
|
|
||||||
from diffsynth import load_state_dict
|
|
||||||
from modelscope import snapshot_download
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
vram_config = {
|
|
||||||
"offload_dtype": "disk",
|
|
||||||
"offload_device": "disk",
|
|
||||||
"onload_dtype": torch.bfloat16,
|
|
||||||
"onload_device": "cpu",
|
|
||||||
"preparing_dtype": torch.bfloat16,
|
|
||||||
"preparing_device": "cuda",
|
|
||||||
"computation_dtype": torch.bfloat16,
|
|
||||||
"computation_device": "cuda",
|
|
||||||
}
|
|
||||||
vram_config_disk_offload = {
|
|
||||||
"offload_dtype": "disk",
|
|
||||||
"offload_device": "disk",
|
|
||||||
"onload_dtype": "disk",
|
|
||||||
"onload_device": "disk",
|
|
||||||
"preparing_dtype": torch.bfloat16,
|
|
||||||
"preparing_device": "cuda",
|
|
||||||
"computation_dtype": torch.bfloat16,
|
|
||||||
"computation_device": "cuda",
|
|
||||||
}
|
|
||||||
|
|
||||||
def demo_style():
|
|
||||||
# Load models
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors", **vram_config_disk_offload),
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors", **vram_config_disk_offload),
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Style.safetensors", **vram_config_disk_offload),
|
|
||||||
],
|
|
||||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
|
||||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load images
|
|
||||||
snapshot_download(
|
|
||||||
model_id="DiffSynth-Studio/Qwen-Image-i2L",
|
|
||||||
allow_file_pattern="assets/style/1/*",
|
|
||||||
local_dir="data/examples"
|
|
||||||
)
|
|
||||||
images = [
|
|
||||||
Image.open("data/examples/assets/style/1/0.jpg"),
|
|
||||||
Image.open("data/examples/assets/style/1/1.jpg"),
|
|
||||||
Image.open("data/examples/assets/style/1/2.jpg"),
|
|
||||||
Image.open("data/examples/assets/style/1/3.jpg"),
|
|
||||||
Image.open("data/examples/assets/style/1/4.jpg"),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Model inference
|
|
||||||
with torch.no_grad():
|
|
||||||
embs = QwenImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images)
|
|
||||||
lora = QwenImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"]
|
|
||||||
save_file(lora, "model_style.safetensors")
|
|
||||||
|
|
||||||
|
|
||||||
def demo_coarse_fine_bias():
|
|
||||||
# Load models
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config_disk_offload),
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors", **vram_config_disk_offload),
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors", **vram_config_disk_offload),
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Coarse.safetensors", **vram_config_disk_offload),
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Fine.safetensors", **vram_config_disk_offload),
|
|
||||||
],
|
|
||||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
|
||||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load images
|
|
||||||
snapshot_download(
|
|
||||||
model_id="DiffSynth-Studio/Qwen-Image-i2L",
|
|
||||||
allow_file_pattern="assets/lora/3/*",
|
|
||||||
local_dir="data/examples"
|
|
||||||
)
|
|
||||||
images = [
|
|
||||||
Image.open("data/examples/assets/lora/3/0.jpg"),
|
|
||||||
Image.open("data/examples/assets/lora/3/1.jpg"),
|
|
||||||
Image.open("data/examples/assets/lora/3/2.jpg"),
|
|
||||||
Image.open("data/examples/assets/lora/3/3.jpg"),
|
|
||||||
Image.open("data/examples/assets/lora/3/4.jpg"),
|
|
||||||
Image.open("data/examples/assets/lora/3/5.jpg"),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Model inference
|
|
||||||
with torch.no_grad():
|
|
||||||
embs = QwenImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images)
|
|
||||||
lora = QwenImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"]
|
|
||||||
lora_bias = ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Bias.safetensors")
|
|
||||||
lora_bias.download_if_necessary()
|
|
||||||
lora_bias = load_state_dict(lora_bias.path, torch_dtype=torch.bfloat16, device="cuda")
|
|
||||||
lora = merge_lora([lora, lora_bias])
|
|
||||||
save_file(lora, "model_coarse_fine_bias.safetensors")
|
|
||||||
|
|
||||||
|
|
||||||
def generate_image(lora_path, prompt, seed):
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
|
||||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
|
||||||
)
|
|
||||||
pipe.load_lora(pipe.dit, lora_path)
|
|
||||||
image = pipe(prompt, seed=seed, height=1024, width=1024, num_inference_steps=50)
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
demo_style()
|
|
||||||
image = generate_image("model_style.safetensors", "a cat", 0)
|
|
||||||
image.save("image_1.jpg")
|
|
||||||
|
|
||||||
demo_coarse_fine_bias()
|
|
||||||
image = generate_image("model_coarse_fine_bias.safetensors", "bowl", 1)
|
|
||||||
image.save("image_2.jpg")
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \
|
|
||||||
--dataset_base_path data/example_image_dataset \
|
|
||||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
|
||||||
--max_pixels 1048576 \
|
|
||||||
--dataset_repeat 50 \
|
|
||||||
--model_id_with_origin_paths "Qwen/Qwen-Image-2512:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
|
||||||
--learning_rate 1e-5 \
|
|
||||||
--num_epochs 2 \
|
|
||||||
--remove_prefix_in_ckpt "pipe.dit." \
|
|
||||||
--output_path "./models/train/Qwen-Image-2512_full" \
|
|
||||||
--trainable_models "dit" \
|
|
||||||
--use_gradient_checkpointing \
|
|
||||||
--find_unused_parameters
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \
|
|
||||||
--dataset_base_path data/example_image_dataset \
|
|
||||||
--dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
|
||||||
--data_file_keys "image,edit_image" \
|
|
||||||
--extra_inputs "edit_image" \
|
|
||||||
--max_pixels 1048576 \
|
|
||||||
--dataset_repeat 50 \
|
|
||||||
--model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2511:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
|
||||||
--learning_rate 1e-5 \
|
|
||||||
--num_epochs 2 \
|
|
||||||
--remove_prefix_in_ckpt "pipe.dit." \
|
|
||||||
--output_path "./models/train/Qwen-Image-Edit-2511_full" \
|
|
||||||
--trainable_models "dit" \
|
|
||||||
--use_gradient_checkpointing \
|
|
||||||
--find_unused_parameters \
|
|
||||||
--zero_cond_t # This is a special parameter introduced by Qwen-Image-Edit-2511. Please enable it for this model.
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
# Example Dataset: https://modelscope.cn/datasets/DiffSynth-Studio/example_image_dataset/tree/master/layer
|
|
||||||
|
|
||||||
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \
|
|
||||||
--dataset_base_path data/example_image_dataset/layer \
|
|
||||||
--dataset_metadata_path data/example_image_dataset/layer/metadata_layered.json \
|
|
||||||
--data_file_keys "image,layer_input_image" \
|
|
||||||
--max_pixels 1048576 \
|
|
||||||
--dataset_repeat 50 \
|
|
||||||
--model_id_with_origin_paths "Qwen/Qwen-Image-Layered:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image-Layered:vae/diffusion_pytorch_model.safetensors" \
|
|
||||||
--learning_rate 1e-5 \
|
|
||||||
--num_epochs 2 \
|
|
||||||
--remove_prefix_in_ckpt "pipe.dit." \
|
|
||||||
--output_path "./models/train/Qwen-Image-Layered_full" \
|
|
||||||
--trainable_models "dit" \
|
|
||||||
--extra_inputs "layer_num,layer_input_image" \
|
|
||||||
--use_gradient_checkpointing \
|
|
||||||
--dataset_num_workers 8 \
|
|
||||||
--find_unused_parameters
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
accelerate launch examples/qwen_image/model_training/train.py \
|
|
||||||
--dataset_base_path data/example_image_dataset \
|
|
||||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
|
||||||
--max_pixels 1048576 \
|
|
||||||
--dataset_repeat 50 \
|
|
||||||
--model_id_with_origin_paths "Qwen/Qwen-Image-2512:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
|
||||||
--learning_rate 1e-4 \
|
|
||||||
--num_epochs 5 \
|
|
||||||
--remove_prefix_in_ckpt "pipe.dit." \
|
|
||||||
--output_path "./models/train/Qwen-Image-2512_lora" \
|
|
||||||
--lora_base_model "dit" \
|
|
||||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
|
||||||
--lora_rank 32 \
|
|
||||||
--use_gradient_checkpointing \
|
|
||||||
--dataset_num_workers 8 \
|
|
||||||
--find_unused_parameters
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
accelerate launch examples/qwen_image/model_training/train.py \
|
|
||||||
--dataset_base_path data/example_image_dataset \
|
|
||||||
--dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
|
||||||
--data_file_keys "image,edit_image" \
|
|
||||||
--extra_inputs "edit_image" \
|
|
||||||
--max_pixels 1048576 \
|
|
||||||
--dataset_repeat 50 \
|
|
||||||
--model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2511:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
|
||||||
--learning_rate 1e-4 \
|
|
||||||
--num_epochs 5 \
|
|
||||||
--remove_prefix_in_ckpt "pipe.dit." \
|
|
||||||
--output_path "./models/train/Qwen-Image-Edit-2511_lora" \
|
|
||||||
--lora_base_model "dit" \
|
|
||||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
|
||||||
--lora_rank 32 \
|
|
||||||
--use_gradient_checkpointing \
|
|
||||||
--dataset_num_workers 8 \
|
|
||||||
--find_unused_parameters \
|
|
||||||
--zero_cond_t # This is a special parameter introduced by Qwen-Image-Edit-2511. Please enable it for this model.
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
# Example Dataset: https://modelscope.cn/datasets/DiffSynth-Studio/example_image_dataset/tree/master/layer
|
|
||||||
|
|
||||||
accelerate launch examples/qwen_image/model_training/train.py \
|
|
||||||
--dataset_base_path data/example_image_dataset/layer \
|
|
||||||
--dataset_metadata_path data/example_image_dataset/layer/metadata_layered.json \
|
|
||||||
--data_file_keys "image,layer_input_image" \
|
|
||||||
--max_pixels 1048576 \
|
|
||||||
--dataset_repeat 50 \
|
|
||||||
--model_id_with_origin_paths "Qwen/Qwen-Image-Layered:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image-Layered:vae/diffusion_pytorch_model.safetensors" \
|
|
||||||
--learning_rate 1e-4 \
|
|
||||||
--num_epochs 5 \
|
|
||||||
--remove_prefix_in_ckpt "pipe.dit." \
|
|
||||||
--output_path "./models/train/Qwen-Image-Layered_lora" \
|
|
||||||
--lora_base_model "dit" \
|
|
||||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
|
||||||
--lora_rank 32 \
|
|
||||||
--extra_inputs "layer_num,layer_input_image" \
|
|
||||||
--use_gradient_checkpointing \
|
|
||||||
--dataset_num_workers 8 \
|
|
||||||
--find_unused_parameters
|
|
||||||
@@ -2,7 +2,6 @@ import torch, os, argparse, accelerate
|
|||||||
from diffsynth.core import UnifiedDataset
|
from diffsynth.core import UnifiedDataset
|
||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||||
from diffsynth.diffusion import *
|
from diffsynth.diffusion import *
|
||||||
from diffsynth.core.data.operators import *
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
@@ -21,7 +20,6 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
|||||||
offload_models=None,
|
offload_models=None,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
task="sft",
|
task="sft",
|
||||||
zero_cond_t=False,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Load models
|
# Load models
|
||||||
@@ -45,7 +43,6 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
|||||||
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||||
self.fp8_models = fp8_models
|
self.fp8_models = fp8_models
|
||||||
self.task = task
|
self.task = task
|
||||||
self.zero_cond_t = zero_cond_t
|
|
||||||
self.task_to_loss = {
|
self.task_to_loss = {
|
||||||
"sft:data_process": lambda pipe, *args: args,
|
"sft:data_process": lambda pipe, *args: args,
|
||||||
"direct_distill:data_process": lambda pipe, *args: args,
|
"direct_distill:data_process": lambda pipe, *args: args,
|
||||||
@@ -59,6 +56,11 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
|||||||
inputs_posi = {"prompt": data["prompt"]}
|
inputs_posi = {"prompt": data["prompt"]}
|
||||||
inputs_nega = {"negative_prompt": ""}
|
inputs_nega = {"negative_prompt": ""}
|
||||||
inputs_shared = {
|
inputs_shared = {
|
||||||
|
# Assume you are using this pipeline for inference,
|
||||||
|
# please fill in the input parameters.
|
||||||
|
"input_image": data["image"],
|
||||||
|
"height": data["image"].size[1],
|
||||||
|
"width": data["image"].size[0],
|
||||||
# Please do not modify the following parameters
|
# Please do not modify the following parameters
|
||||||
# unless you clearly know what this will cause.
|
# unless you clearly know what this will cause.
|
||||||
"cfg_scale": 1,
|
"cfg_scale": 1,
|
||||||
@@ -66,22 +68,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
|||||||
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||||
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||||
"edit_image_auto_resize": True,
|
"edit_image_auto_resize": True,
|
||||||
"zero_cond_t": self.zero_cond_t,
|
|
||||||
}
|
}
|
||||||
# Assume you are using this pipeline for inference,
|
|
||||||
# please fill in the input parameters.
|
|
||||||
if isinstance(data["image"], list):
|
|
||||||
inputs_shared.update({
|
|
||||||
"input_image": data["image"],
|
|
||||||
"height": data["image"][0].size[1],
|
|
||||||
"width": data["image"][0].size[0],
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
inputs_shared.update({
|
|
||||||
"input_image": data["image"],
|
|
||||||
"height": data["image"].size[1],
|
|
||||||
"width": data["image"].size[0],
|
|
||||||
})
|
|
||||||
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
|
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
|
||||||
return inputs_shared, inputs_posi, inputs_nega
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
@@ -100,7 +87,6 @@ def qwen_image_parser():
|
|||||||
parser = add_image_size_config(parser)
|
parser = add_image_size_config(parser)
|
||||||
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
|
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
|
||||||
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
|
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
|
||||||
parser.add_argument("--zero_cond_t", default=False, action="store_true", help="A special parameter introduced by Qwen-Image-Edit-2511. Please enable it for this model.")
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@@ -123,15 +109,7 @@ if __name__ == "__main__":
|
|||||||
width=args.width,
|
width=args.width,
|
||||||
height_division_factor=16,
|
height_division_factor=16,
|
||||||
width_division_factor=16,
|
width_division_factor=16,
|
||||||
),
|
)
|
||||||
special_operator_map={
|
|
||||||
# Qwen-Image-Layered
|
|
||||||
"layer_input_image": ToAbsolutePath(args.dataset_base_path) >> LoadImage(convert_RGB=False, convert_RGBA=True) >> ImageCropAndResize(args.height, args.width, args.max_pixels, 16, 16),
|
|
||||||
"image": RouteByType(operator_map=[
|
|
||||||
(str, ToAbsolutePath(args.dataset_base_path) >> LoadImage() >> ImageCropAndResize(args.height, args.width, args.max_pixels, 16, 16)),
|
|
||||||
(list, SequencialProcess(ToAbsolutePath(args.dataset_base_path) >> LoadImage(convert_RGB=False, convert_RGBA=True) >> ImageCropAndResize(args.height, args.width, args.max_pixels, 16, 16))),
|
|
||||||
])
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
model = QwenImageTrainingModule(
|
model = QwenImageTrainingModule(
|
||||||
model_paths=args.model_paths,
|
model_paths=args.model_paths,
|
||||||
@@ -152,7 +130,6 @@ if __name__ == "__main__":
|
|||||||
offload_models=args.offload_models,
|
offload_models=args.offload_models,
|
||||||
task=args.task,
|
task=args.task,
|
||||||
device=accelerator.device,
|
device=accelerator.device,
|
||||||
zero_cond_t=args.zero_cond_t,
|
|
||||||
)
|
)
|
||||||
model_logger = ModelLogger(
|
model_logger = ModelLogger(
|
||||||
args.output_path,
|
args.output_path,
|
||||||
|
|||||||
@@ -1,20 +0,0 @@
|
|||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
|
||||||
from diffsynth import load_state_dict
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image-2512", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
state_dict = load_state_dict("models/train/Qwen-Image-2512_full/epoch-1.safetensors")
|
|
||||||
pipe.dit.load_state_dict(state_dict)
|
|
||||||
prompt = "a dog"
|
|
||||||
image = pipe(prompt, seed=0)
|
|
||||||
image.save("image.jpg")
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
|
||||||
from diffsynth import load_state_dict
|
|
||||||
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2511", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=None,
|
|
||||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
|
||||||
)
|
|
||||||
state_dict = load_state_dict("models/train/Qwen-Image-Edit-2511_full/epoch-1.safetensors")
|
|
||||||
pipe.dit.load_state_dict(state_dict)
|
|
||||||
|
|
||||||
prompt = "Change the color of the dress in Figure 1 to the color shown in Figure 2."
|
|
||||||
images = [
|
|
||||||
Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)),
|
|
||||||
Image.open("data/example_image_dataset/edit/image_color.jpg").resize((1024, 1024)),
|
|
||||||
]
|
|
||||||
image = pipe(prompt, edit_image=images, seed=123, num_inference_steps=40, height=1024, width=1024, zero_cond_t=True)
|
|
||||||
image.save("image.jpg")
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
|
||||||
from diffsynth import load_state_dict
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
state_dict = load_state_dict("models/train/Qwen-Image-Layered_full/epoch-1.safetensors")
|
|
||||||
pipe.dit.load_state_dict(state_dict)
|
|
||||||
prompt = "a poster"
|
|
||||||
input_image = Image.open("data/example_image_dataset/layer/image.png").convert("RGBA").resize((864, 480))
|
|
||||||
images = pipe(
|
|
||||||
prompt, seed=0,
|
|
||||||
height=480, width=864,
|
|
||||||
layer_input_image=input_image, layer_num=3,
|
|
||||||
)
|
|
||||||
for i, image in enumerate(images):
|
|
||||||
if i == 0: continue # The first image is the input image.
|
|
||||||
image.save(f"image_{i}.png")
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image-2512", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-2512_lora/epoch-4.safetensors")
|
|
||||||
prompt = "a dog"
|
|
||||||
image = pipe(prompt, seed=0)
|
|
||||||
image.save("image.jpg")
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
|
||||||
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2511", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=None,
|
|
||||||
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
|
||||||
)
|
|
||||||
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Edit-2511_lora/epoch-4.safetensors")
|
|
||||||
|
|
||||||
prompt = "Change the color of the dress in Figure 1 to the color shown in Figure 2."
|
|
||||||
images = [
|
|
||||||
Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)),
|
|
||||||
Image.open("data/example_image_dataset/edit/image_color.jpg").resize((1024, 1024)),
|
|
||||||
]
|
|
||||||
image = pipe(prompt, edit_image=images, seed=123, num_inference_steps=40, height=1024, width=1024, zero_cond_t=True)
|
|
||||||
image.save("image.jpg")
|
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
|
||||||
from diffsynth import load_state_dict
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
pipe = QwenImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Layered_lora/epoch-4.safetensors")
|
|
||||||
prompt = "a poster"
|
|
||||||
input_image = Image.open("data/example_image_dataset/layer/image.png").convert("RGBA").resize((864, 480))
|
|
||||||
images = pipe(
|
|
||||||
prompt, seed=0,
|
|
||||||
height=480, width=864,
|
|
||||||
layer_input_image=input_image, layer_num=3,
|
|
||||||
)
|
|
||||||
for i, image in enumerate(images):
|
|
||||||
if i == 0: continue # The first image is the input image.
|
|
||||||
image.save(f"image_{i}.png")
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from diffsynth.utils.data import save_video, VideoData
|
|
||||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
pipe = WanVideoPipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
use_usp=True,
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
|
||||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
|
||||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Text-to-video
|
|
||||||
video = pipe(
|
|
||||||
prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
|
|
||||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
|
||||||
seed=0, tiled=True,
|
|
||||||
)
|
|
||||||
if dist.get_rank() == 0:
|
|
||||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
|
||||||
@@ -27,24 +27,23 @@ def speech_to_video(
|
|||||||
# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps.
|
# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps.
|
||||||
pose_video = VideoData(pose_video_path, height=height, width=width) if pose_video_path is not None else None
|
pose_video = VideoData(pose_video_path, height=height, width=width) if pose_video_path is not None else None
|
||||||
|
|
||||||
with torch.no_grad():
|
audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose(
|
||||||
audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose(
|
pipe=pipe,
|
||||||
pipe=pipe,
|
input_audio=input_audio,
|
||||||
input_audio=input_audio,
|
audio_sample_rate=sample_rate,
|
||||||
audio_sample_rate=sample_rate,
|
s2v_pose_video=pose_video,
|
||||||
s2v_pose_video=pose_video,
|
num_frames=infer_frames + 1,
|
||||||
num_frames=infer_frames + 1,
|
height=height,
|
||||||
height=height,
|
width=width,
|
||||||
width=width,
|
fps=fps,
|
||||||
fps=fps,
|
)
|
||||||
)
|
|
||||||
num_repeat = min(num_repeat, num_clip) if num_clip is not None else num_repeat
|
num_repeat = min(num_repeat, num_clip) if num_clip is not None else num_repeat
|
||||||
print(f"Generating {num_repeat} video clips...")
|
print(f"Generating {num_repeat} video clips...")
|
||||||
motion_video = None
|
motion_videos = []
|
||||||
video = []
|
video = []
|
||||||
for r in range(num_repeat):
|
for r in range(num_repeat):
|
||||||
s2v_pose_latents = pose_latents[r] if pose_latents is not None else None
|
s2v_pose_latents = pose_latents[r] if pose_latents is not None else None
|
||||||
current_clip_tensor = pipe(
|
current_clip = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
input_image=input_image,
|
input_image=input_image,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
@@ -54,21 +53,15 @@ def speech_to_video(
|
|||||||
width=width,
|
width=width,
|
||||||
audio_embeds=audio_embeds[r],
|
audio_embeds=audio_embeds[r],
|
||||||
s2v_pose_latents=s2v_pose_latents,
|
s2v_pose_latents=s2v_pose_latents,
|
||||||
motion_video=motion_video,
|
motion_video=motion_videos,
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
output_type="floatpoint",
|
|
||||||
)
|
)
|
||||||
# (B, C, T, H, W)
|
current_clip = current_clip[-infer_frames:]
|
||||||
current_clip_tensor = current_clip_tensor[:,:,-infer_frames:,:,:]
|
|
||||||
if r == 0:
|
if r == 0:
|
||||||
current_clip_tensor = current_clip_tensor[:,:,3:,:,:]
|
current_clip = current_clip[3:]
|
||||||
overlap_frames_num = min(motion_frames, current_clip_tensor.shape[2])
|
overlap_frames_num = min(motion_frames, len(current_clip))
|
||||||
motion_video = current_clip_tensor[:,:,-overlap_frames_num:,:,:].clone()
|
motion_videos = motion_videos[overlap_frames_num:] + current_clip[-overlap_frames_num:]
|
||||||
else:
|
video.extend(current_clip)
|
||||||
overlap_frames_num = min(motion_frames, current_clip_tensor.shape[2])
|
|
||||||
motion_video = torch.cat((motion_video[:,:,overlap_frames_num:,:,:], current_clip_tensor[:,:,-overlap_frames_num:,:,:]), dim=2)
|
|
||||||
current_clip_quantized = pipe.vae_output_to_video(current_clip_tensor)
|
|
||||||
video.extend(current_clip_quantized)
|
|
||||||
save_video_with_audio(video, save_path, audio_path, fps=16, quality=5)
|
save_video_with_audio(video, save_path, audio_path, fps=16, quality=5)
|
||||||
print(f"processed the {r+1}th clip of total {num_repeat} clips.")
|
print(f"processed the {r+1}th clip of total {num_repeat} clips.")
|
||||||
return video
|
return video
|
||||||
|
|||||||
@@ -27,24 +27,23 @@ def speech_to_video(
|
|||||||
# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps.
|
# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps.
|
||||||
pose_video = VideoData(pose_video_path, height=height, width=width) if pose_video_path is not None else None
|
pose_video = VideoData(pose_video_path, height=height, width=width) if pose_video_path is not None else None
|
||||||
|
|
||||||
with torch.no_grad():
|
audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose(
|
||||||
audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose(
|
pipe=pipe,
|
||||||
pipe=pipe,
|
input_audio=input_audio,
|
||||||
input_audio=input_audio,
|
audio_sample_rate=sample_rate,
|
||||||
audio_sample_rate=sample_rate,
|
s2v_pose_video=pose_video,
|
||||||
s2v_pose_video=pose_video,
|
num_frames=infer_frames + 1,
|
||||||
num_frames=infer_frames + 1,
|
height=height,
|
||||||
height=height,
|
width=width,
|
||||||
width=width,
|
fps=fps,
|
||||||
fps=fps,
|
)
|
||||||
)
|
|
||||||
num_repeat = min(num_repeat, num_clip) if num_clip is not None else num_repeat
|
num_repeat = min(num_repeat, num_clip) if num_clip is not None else num_repeat
|
||||||
print(f"Generating {num_repeat} video clips...")
|
print(f"Generating {num_repeat} video clips...")
|
||||||
motion_video = None
|
motion_videos = []
|
||||||
video = []
|
video = []
|
||||||
for r in range(num_repeat):
|
for r in range(num_repeat):
|
||||||
s2v_pose_latents = pose_latents[r] if pose_latents is not None else None
|
s2v_pose_latents = pose_latents[r] if pose_latents is not None else None
|
||||||
current_clip_tensor = pipe(
|
current_clip = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
input_image=input_image,
|
input_image=input_image,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
@@ -54,24 +53,20 @@ def speech_to_video(
|
|||||||
width=width,
|
width=width,
|
||||||
audio_embeds=audio_embeds[r],
|
audio_embeds=audio_embeds[r],
|
||||||
s2v_pose_latents=s2v_pose_latents,
|
s2v_pose_latents=s2v_pose_latents,
|
||||||
motion_video=motion_video,
|
motion_video=motion_videos,
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
output_type="floatpoint",
|
|
||||||
)
|
)
|
||||||
current_clip_tensor = current_clip_tensor[:,:,-infer_frames:,:,:]
|
current_clip = current_clip[-infer_frames:]
|
||||||
if r == 0:
|
if r == 0:
|
||||||
current_clip_tensor = current_clip_tensor[:,:,3:,:,:]
|
current_clip = current_clip[3:]
|
||||||
overlap_frames_num = min(motion_frames, current_clip_tensor.shape[2])
|
overlap_frames_num = min(motion_frames, len(current_clip))
|
||||||
motion_video = current_clip_tensor[:,:,-overlap_frames_num:,:,:].clone()
|
motion_videos = motion_videos[overlap_frames_num:] + current_clip[-overlap_frames_num:]
|
||||||
else:
|
video.extend(current_clip)
|
||||||
overlap_frames_num = min(motion_frames, current_clip_tensor.shape[2])
|
|
||||||
motion_video = torch.cat((motion_video[:,:,overlap_frames_num:,:,:], current_clip_tensor[:,:,-overlap_frames_num:,:,:]), dim=2)
|
|
||||||
current_clip_quantized = pipe.vae_output_to_video(current_clip_tensor)
|
|
||||||
video.extend(current_clip_quantized)
|
|
||||||
save_video_with_audio(video, save_path, audio_path, fps=16, quality=5)
|
save_video_with_audio(video, save_path, audio_path, fps=16, quality=5)
|
||||||
print(f"processed the {r+1}th clip of total {num_repeat} clips.")
|
print(f"processed the {r+1}th clip of total {num_repeat} clips.")
|
||||||
return video
|
return video
|
||||||
|
|
||||||
|
|
||||||
vram_config = {
|
vram_config = {
|
||||||
"offload_dtype": "disk",
|
"offload_dtype": "disk",
|
||||||
"offload_device": "disk",
|
"offload_device": "disk",
|
||||||
|
|||||||
@@ -1,62 +0,0 @@
|
|||||||
from diffsynth.pipelines.z_image import (
|
|
||||||
ZImagePipeline, ModelConfig,
|
|
||||||
ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode
|
|
||||||
)
|
|
||||||
from modelscope import snapshot_download
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
# Use `vram_config` to enable LoRA hot-loading
|
|
||||||
vram_config = {
|
|
||||||
"offload_dtype": torch.bfloat16,
|
|
||||||
"offload_device": "cuda",
|
|
||||||
"onload_dtype": torch.bfloat16,
|
|
||||||
"onload_device": "cuda",
|
|
||||||
"preparing_dtype": torch.bfloat16,
|
|
||||||
"preparing_device": "cuda",
|
|
||||||
"computation_dtype": torch.bfloat16,
|
|
||||||
"computation_device": "cuda",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Load models
|
|
||||||
pipe = ZImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors"),
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors"),
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/Z-Image-Omni-Base-i2L", origin_file_pattern="model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load images
|
|
||||||
snapshot_download(
|
|
||||||
model_id="DiffSynth-Studio/Z-Image-Omni-Base-i2L",
|
|
||||||
allow_file_pattern="assets/style/*",
|
|
||||||
local_dir="data/style_input"
|
|
||||||
)
|
|
||||||
images = [Image.open(f"data/style_input/assets/style/1/{i}.jpg") for i in range(6)]
|
|
||||||
|
|
||||||
# Image to LoRA
|
|
||||||
with torch.no_grad():
|
|
||||||
embs = ZImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images)
|
|
||||||
lora = ZImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"]
|
|
||||||
save_file(lora, "lora.safetensors")
|
|
||||||
|
|
||||||
# Generate images
|
|
||||||
prompt = "a cat"
|
|
||||||
negative_prompt = "泛黄,发绿,模糊,低分辨率,低质量图像,扭曲的肢体,诡异的外观,丑陋,AI感,噪点,网格感,JPEG压缩条纹,异常的肢体,水印,乱码,意义不明的字符"
|
|
||||||
image = pipe(
|
|
||||||
prompt=prompt,
|
|
||||||
negative_prompt=negative_prompt,
|
|
||||||
seed=0, cfg_scale=7, num_inference_steps=50,
|
|
||||||
positive_only_lora=lora,
|
|
||||||
sigma_shift=8
|
|
||||||
)
|
|
||||||
image.save("image.jpg")
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
pipe = ZImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
|
|
||||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4)
|
|
||||||
image.save("image_Z-Image-Omni-Base.jpg")
|
|
||||||
|
|
||||||
image = Image.open("image_Z-Image-Omni-Base.jpg")
|
|
||||||
prompt = "Change the women's clothes to white cheongsam, keep other content unchanged"
|
|
||||||
image = pipe(prompt=prompt, edit_image=image, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4)
|
|
||||||
image.save("image_edit_Z-Image-Omni-Base.jpg")
|
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput
|
|
||||||
from modelscope import dataset_snapshot_download
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
pipe = ZImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_snapshot_download(
|
|
||||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
|
||||||
local_dir="./",
|
|
||||||
allow_file_pattern="data/examples/upscale/low_res.png"
|
|
||||||
)
|
|
||||||
controlnet_image = Image.open("data/examples/upscale/low_res.png").resize((1024, 1024))
|
|
||||||
prompt = "这是一张充满都市气息的户外人物肖像照片。画面中是一位年轻男性,他展现出时尚而自信的形象。人物拥有精心打理的短发发型,两侧修剪得较短,顶部保留一定长度,呈现出流行的Undercut造型。他佩戴着一副时尚的浅色墨镜或透明镜框眼镜,为整体造型增添了潮流感。脸上洋溢着温和友善的笑容,神情放松自然,给人以阳光开朗的印象。他身穿一件经典的牛仔外套,这件单品永不过时,展现出休闲又有型的穿衣风格。牛仔外套的蓝色调与整体氛围十分协调,领口处隐约可见内搭的衣物。照片的背景是典型的城市街景,可以看到模糊的建筑物、街道和行人,营造出繁华都市的氛围。背景经过了恰当的虚化处理,使人物主体更加突出。光线明亮而柔和,可能是白天的自然光,为照片带来清新通透的视觉效果。整张照片构图专业,景深控制得当,完美捕捉了一个现代都市年轻人充满活力和自信的瞬间,展现出积极向上的生活态度。"
|
|
||||||
image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)])
|
|
||||||
image.save("image_tile.jpg")
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput
|
|
||||||
from modelscope import dataset_snapshot_download
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
pipe = ZImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Control
|
|
||||||
dataset_snapshot_download(
|
|
||||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
|
||||||
local_dir="./data/example_image_dataset",
|
|
||||||
allow_file_pattern="depth/image_1.jpg"
|
|
||||||
)
|
|
||||||
controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1024, 1024))
|
|
||||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
|
||||||
image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)])
|
|
||||||
image.save("image_control.jpg")
|
|
||||||
|
|
||||||
# Inpaint
|
|
||||||
dataset_snapshot_download(
|
|
||||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
|
||||||
local_dir="./data/example_image_dataset",
|
|
||||||
allow_file_pattern="inpaint/*.jpg"
|
|
||||||
)
|
|
||||||
inpaint_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024))
|
|
||||||
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024))
|
|
||||||
prompt = "一只戴着墨镜的猫"
|
|
||||||
image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(inpaint_image=inpaint_image, inpaint_mask=inpaint_mask, scale=0.7)])
|
|
||||||
image.save("image_inpaint.jpg")
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput
|
|
||||||
from modelscope import dataset_snapshot_download
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
pipe = ZImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Control
|
|
||||||
dataset_snapshot_download(
|
|
||||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
|
||||||
local_dir="./data/example_image_dataset",
|
|
||||||
allow_file_pattern="depth/image_1.jpg"
|
|
||||||
)
|
|
||||||
controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1024, 1024))
|
|
||||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
|
||||||
image = pipe(
|
|
||||||
prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)],
|
|
||||||
num_inference_steps=30,
|
|
||||||
)
|
|
||||||
image.save("image_control.jpg")
|
|
||||||
|
|
||||||
# Inpaint
|
|
||||||
dataset_snapshot_download(
|
|
||||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
|
||||||
local_dir="./data/example_image_dataset",
|
|
||||||
allow_file_pattern="inpaint/*.jpg"
|
|
||||||
)
|
|
||||||
inpaint_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024))
|
|
||||||
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024))
|
|
||||||
prompt = "一只戴着墨镜的猫"
|
|
||||||
image = pipe(
|
|
||||||
prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(inpaint_image=inpaint_image, inpaint_mask=inpaint_mask, scale=0.7)],
|
|
||||||
num_inference_steps=30,
|
|
||||||
)
|
|
||||||
image.save("image_inpaint.jpg")
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
from diffsynth.pipelines.z_image import (
|
|
||||||
ZImagePipeline, ModelConfig,
|
|
||||||
ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode
|
|
||||||
)
|
|
||||||
from modelscope import snapshot_download
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
# Use `vram_config` to enable LoRA hot-loading
|
|
||||||
vram_config = {
|
|
||||||
"offload_dtype": torch.bfloat16,
|
|
||||||
"offload_device": "cpu",
|
|
||||||
"onload_dtype": torch.bfloat16,
|
|
||||||
"onload_device": "cpu",
|
|
||||||
"preparing_dtype": torch.bfloat16,
|
|
||||||
"preparing_device": "cuda",
|
|
||||||
"computation_dtype": torch.bfloat16,
|
|
||||||
"computation_device": "cuda",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Load models
|
|
||||||
pipe = ZImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/Z-Image-Omni-Base-i2L", origin_file_pattern="model.safetensors", **vram_config),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load images
|
|
||||||
snapshot_download(
|
|
||||||
model_id="DiffSynth-Studio/Z-Image-Omni-Base-i2L",
|
|
||||||
allow_file_pattern="assets/style/*",
|
|
||||||
local_dir="data/style_input"
|
|
||||||
)
|
|
||||||
images = [Image.open(f"data/style_input/assets/style/1/{i}.jpg") for i in range(6)]
|
|
||||||
|
|
||||||
# Image to LoRA
|
|
||||||
with torch.no_grad():
|
|
||||||
embs = ZImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images)
|
|
||||||
lora = ZImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"]
|
|
||||||
save_file(lora, "lora.safetensors")
|
|
||||||
|
|
||||||
# Generate images
|
|
||||||
prompt = "a cat"
|
|
||||||
negative_prompt = "泛黄,发绿,模糊,低分辨率,低质量图像,扭曲的肢体,诡异的外观,丑陋,AI感,噪点,网格感,JPEG压缩条纹,异常的肢体,水印,乱码,意义不明的字符"
|
|
||||||
image = pipe(
|
|
||||||
prompt=prompt,
|
|
||||||
negative_prompt=negative_prompt,
|
|
||||||
seed=0, cfg_scale=7, num_inference_steps=50,
|
|
||||||
positive_only_lora=lora,
|
|
||||||
sigma_shift=8
|
|
||||||
)
|
|
||||||
image.save("image.jpg")
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
vram_config = {
|
|
||||||
"offload_dtype": torch.bfloat16,
|
|
||||||
"offload_device": "cpu",
|
|
||||||
"onload_dtype": torch.bfloat16,
|
|
||||||
"onload_device": "cpu",
|
|
||||||
"preparing_dtype": torch.bfloat16,
|
|
||||||
"preparing_device": "cuda",
|
|
||||||
"computation_dtype": torch.bfloat16,
|
|
||||||
"computation_device": "cuda",
|
|
||||||
}
|
|
||||||
pipe = ZImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
|
|
||||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4)
|
|
||||||
image.save("image_Z-Image-Omni-Base.jpg")
|
|
||||||
|
|
||||||
image = Image.open("image_Z-Image-Omni-Base.jpg")
|
|
||||||
prompt = "Change the women's clothes to white cheongsam, keep other content unchanged"
|
|
||||||
image = pipe(prompt=prompt, edit_image=image, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4)
|
|
||||||
image.save("image_edit_Z-Image-Omni-Base.jpg")
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput
|
|
||||||
from modelscope import dataset_snapshot_download
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
vram_config = {
|
|
||||||
"offload_dtype": torch.bfloat16,
|
|
||||||
"offload_device": "cpu",
|
|
||||||
"onload_dtype": torch.bfloat16,
|
|
||||||
"onload_device": "cpu",
|
|
||||||
"preparing_dtype": torch.bfloat16,
|
|
||||||
"preparing_device": "cuda",
|
|
||||||
"computation_dtype": torch.bfloat16,
|
|
||||||
"computation_device": "cuda",
|
|
||||||
}
|
|
||||||
pipe = ZImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_snapshot_download(
|
|
||||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
|
||||||
local_dir="./",
|
|
||||||
allow_file_pattern="data/examples/upscale/low_res.png"
|
|
||||||
)
|
|
||||||
controlnet_image = Image.open("data/examples/upscale/low_res.png").resize((1024, 1024))
|
|
||||||
prompt = "这是一张充满都市气息的户外人物肖像照片。画面中是一位年轻男性,他展现出时尚而自信的形象。人物拥有精心打理的短发发型,两侧修剪得较短,顶部保留一定长度,呈现出流行的Undercut造型。他佩戴着一副时尚的浅色墨镜或透明镜框眼镜,为整体造型增添了潮流感。脸上洋溢着温和友善的笑容,神情放松自然,给人以阳光开朗的印象。他身穿一件经典的牛仔外套,这件单品永不过时,展现出休闲又有型的穿衣风格。牛仔外套的蓝色调与整体氛围十分协调,领口处隐约可见内搭的衣物。照片的背景是典型的城市街景,可以看到模糊的建筑物、街道和行人,营造出繁华都市的氛围。背景经过了恰当的虚化处理,使人物主体更加突出。光线明亮而柔和,可能是白天的自然光,为照片带来清新通透的视觉效果。整张照片构图专业,景深控制得当,完美捕捉了一个现代都市年轻人充满活力和自信的瞬间,展现出积极向上的生活态度。"
|
|
||||||
image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)])
|
|
||||||
image.save("image_tile.jpg")
|
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput
|
|
||||||
from modelscope import dataset_snapshot_download
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
vram_config = {
|
|
||||||
"offload_dtype": torch.bfloat16,
|
|
||||||
"offload_device": "cpu",
|
|
||||||
"onload_dtype": torch.bfloat16,
|
|
||||||
"onload_device": "cpu",
|
|
||||||
"preparing_dtype": torch.bfloat16,
|
|
||||||
"preparing_device": "cuda",
|
|
||||||
"computation_dtype": torch.bfloat16,
|
|
||||||
"computation_device": "cuda",
|
|
||||||
}
|
|
||||||
pipe = ZImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Control
|
|
||||||
dataset_snapshot_download(
|
|
||||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
|
||||||
local_dir="./data/example_image_dataset",
|
|
||||||
allow_file_pattern="depth/image_1.jpg"
|
|
||||||
)
|
|
||||||
controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1024, 1024))
|
|
||||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
|
||||||
image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)])
|
|
||||||
image.save("image_control.jpg")
|
|
||||||
|
|
||||||
# Inpaint
|
|
||||||
dataset_snapshot_download(
|
|
||||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
|
||||||
local_dir="./data/example_image_dataset",
|
|
||||||
allow_file_pattern="inpaint/*.jpg"
|
|
||||||
)
|
|
||||||
inpaint_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024))
|
|
||||||
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024))
|
|
||||||
prompt = "一只戴着墨镜的猫"
|
|
||||||
image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(inpaint_image=inpaint_image, inpaint_mask=inpaint_mask, scale=0.7)])
|
|
||||||
image.save("image_inpaint.jpg")
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput
|
|
||||||
from modelscope import dataset_snapshot_download
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
vram_config = {
|
|
||||||
"offload_dtype": torch.bfloat16,
|
|
||||||
"offload_device": "cpu",
|
|
||||||
"onload_dtype": torch.bfloat16,
|
|
||||||
"onload_device": "cpu",
|
|
||||||
"preparing_dtype": torch.bfloat16,
|
|
||||||
"preparing_device": "cuda",
|
|
||||||
"computation_dtype": torch.bfloat16,
|
|
||||||
"computation_device": "cuda",
|
|
||||||
}
|
|
||||||
pipe = ZImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Control
|
|
||||||
dataset_snapshot_download(
|
|
||||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
|
||||||
local_dir="./data/example_image_dataset",
|
|
||||||
allow_file_pattern="depth/image_1.jpg"
|
|
||||||
)
|
|
||||||
controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1024, 1024))
|
|
||||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
|
||||||
image = pipe(
|
|
||||||
prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)],
|
|
||||||
num_inference_steps=30,
|
|
||||||
)
|
|
||||||
image.save("image_control.jpg")
|
|
||||||
|
|
||||||
# Inpaint
|
|
||||||
dataset_snapshot_download(
|
|
||||||
dataset_id="DiffSynth-Studio/example_image_dataset",
|
|
||||||
local_dir="./data/example_image_dataset",
|
|
||||||
allow_file_pattern="inpaint/*.jpg"
|
|
||||||
)
|
|
||||||
inpaint_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024))
|
|
||||||
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024))
|
|
||||||
prompt = "一只戴着墨镜的猫"
|
|
||||||
image = pipe(
|
|
||||||
prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(inpaint_image=inpaint_image, inpaint_mask=inpaint_mask, scale=0.7)],
|
|
||||||
num_inference_steps=30,
|
|
||||||
)
|
|
||||||
image.save("image_inpaint.jpg")
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
# This example is tested on 8*A100
|
|
||||||
# Text to image training
|
|
||||||
accelerate launch --config_file examples/z_image/model_training/full/accelerate_config.yaml examples/z_image/model_training/train.py \
|
|
||||||
--dataset_base_path data/example_image_dataset \
|
|
||||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
|
||||||
--max_pixels 1048576 \
|
|
||||||
--dataset_repeat 400 \
|
|
||||||
--model_id_with_origin_paths "Tongyi-MAI/Z-Image-Omni-Base:transformer/*.safetensors,Tongyi-MAI/Z-Image-Omni-Base:siglip/model.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
|
|
||||||
--learning_rate 1e-5 \
|
|
||||||
--num_epochs 2 \
|
|
||||||
--remove_prefix_in_ckpt "pipe.dit." \
|
|
||||||
--output_path "./models/train/Z-Image-Omni-Base_full" \
|
|
||||||
--trainable_models "dit" \
|
|
||||||
--use_gradient_checkpointing \
|
|
||||||
--find_unused_parameters \
|
|
||||||
--dataset_num_workers 8
|
|
||||||
|
|
||||||
# Image(s) to image training
|
|
||||||
# accelerate launch --config_file examples/z_image/model_training/full/accelerate_config.yaml examples/z_image/model_training/train.py \
|
|
||||||
# --dataset_base_path data/example_image_dataset \
|
|
||||||
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
|
||||||
# --data_file_keys "image,edit_image" \
|
|
||||||
# --extra_inputs "edit_image" \
|
|
||||||
# --max_pixels 1048576 \
|
|
||||||
# --dataset_repeat 400 \
|
|
||||||
# --model_id_with_origin_paths "Tongyi-MAI/Z-Image-Omni-Base:transformer/*.safetensors,Tongyi-MAI/Z-Image-Omni-Base:siglip/model.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
|
|
||||||
# --learning_rate 1e-5 \
|
|
||||||
# --num_epochs 2 \
|
|
||||||
# --remove_prefix_in_ckpt "pipe.dit." \
|
|
||||||
# --output_path "./models/train/Z-Image-Omni-Base_full_edit" \
|
|
||||||
# --trainable_models "dit" \
|
|
||||||
# --use_gradient_checkpointing \
|
|
||||||
# --find_unused_parameters \
|
|
||||||
# --dataset_num_workers 8
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
accelerate launch examples/z_image/model_training/train.py \
|
|
||||||
--dataset_base_path data/example_image_dataset \
|
|
||||||
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_upscale.csv \
|
|
||||||
--data_file_keys "image,controlnet_image" \
|
|
||||||
--max_pixels 1048576 \
|
|
||||||
--dataset_repeat 100 \
|
|
||||||
--model_id_with_origin_paths "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
|
|
||||||
--learning_rate 1e-5 \
|
|
||||||
--num_epochs 2 \
|
|
||||||
--remove_prefix_in_ckpt "pipe.controlnet." \
|
|
||||||
--output_path "./models/train/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps_full" \
|
|
||||||
--trainable_models "controlnet" \
|
|
||||||
--extra_inputs "controlnet_image" \
|
|
||||||
--use_gradient_checkpointing \
|
|
||||||
--dataset_num_workers 8
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
accelerate launch examples/z_image/model_training/train.py \
|
|
||||||
--dataset_base_path data/example_image_dataset \
|
|
||||||
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_canny.csv \
|
|
||||||
--data_file_keys "image,controlnet_image" \
|
|
||||||
--max_pixels 1048576 \
|
|
||||||
--dataset_repeat 100 \
|
|
||||||
--model_id_with_origin_paths "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
|
|
||||||
--learning_rate 1e-5 \
|
|
||||||
--num_epochs 2 \
|
|
||||||
--remove_prefix_in_ckpt "pipe.controlnet." \
|
|
||||||
--output_path "./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps_full" \
|
|
||||||
--trainable_models "controlnet" \
|
|
||||||
--extra_inputs "controlnet_image" \
|
|
||||||
--use_gradient_checkpointing \
|
|
||||||
--dataset_num_workers 8
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
accelerate launch examples/z_image/model_training/train.py \
|
|
||||||
--dataset_base_path data/example_image_dataset \
|
|
||||||
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_canny.csv \
|
|
||||||
--data_file_keys "image,controlnet_image" \
|
|
||||||
--max_pixels 1048576 \
|
|
||||||
--dataset_repeat 100 \
|
|
||||||
--model_id_with_origin_paths "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
|
|
||||||
--learning_rate 1e-5 \
|
|
||||||
--num_epochs 2 \
|
|
||||||
--remove_prefix_in_ckpt "pipe.controlnet." \
|
|
||||||
--output_path "./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1_full" \
|
|
||||||
--trainable_models "controlnet" \
|
|
||||||
--extra_inputs "controlnet_image" \
|
|
||||||
--use_gradient_checkpointing \
|
|
||||||
--dataset_num_workers 8
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
# Text to image training
|
|
||||||
accelerate launch examples/z_image/model_training/train.py \
|
|
||||||
--dataset_base_path data/example_image_dataset \
|
|
||||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
|
||||||
--max_pixels 1048576 \
|
|
||||||
--dataset_repeat 50 \
|
|
||||||
--model_id_with_origin_paths "Tongyi-MAI/Z-Image-Omni-Base:transformer/*.safetensors,Tongyi-MAI/Z-Image-Omni-Base:siglip/model.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
|
|
||||||
--learning_rate 1e-4 \
|
|
||||||
--num_epochs 5 \
|
|
||||||
--remove_prefix_in_ckpt "pipe.dit." \
|
|
||||||
--output_path "./models/train/Z-Image-Omni-Base_lora" \
|
|
||||||
--lora_base_model "dit" \
|
|
||||||
--lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \
|
|
||||||
--lora_rank 32 \
|
|
||||||
--use_gradient_checkpointing \
|
|
||||||
--find_unused_parameters \
|
|
||||||
--dataset_num_workers 8
|
|
||||||
|
|
||||||
# Image(s) to image training
|
|
||||||
# accelerate launch examples/z_image/model_training/train.py \
|
|
||||||
# --dataset_base_path data/example_image_dataset \
|
|
||||||
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
|
||||||
# --data_file_keys "image,edit_image" \
|
|
||||||
# --extra_inputs "edit_image" \
|
|
||||||
# --max_pixels 1048576 \
|
|
||||||
# --dataset_repeat 50 \
|
|
||||||
# --model_id_with_origin_paths "Tongyi-MAI/Z-Image-Omni-Base:transformer/*.safetensors,Tongyi-MAI/Z-Image-Omni-Base:siglip/model.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
|
|
||||||
# --learning_rate 1e-4 \
|
|
||||||
# --num_epochs 5 \
|
|
||||||
# --remove_prefix_in_ckpt "pipe.dit." \
|
|
||||||
# --output_path "./models/train/Z-Image-Omni-Base_lora_edit" \
|
|
||||||
# --lora_base_model "dit" \
|
|
||||||
# --lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \
|
|
||||||
# --lora_rank 32 \
|
|
||||||
# --use_gradient_checkpointing \
|
|
||||||
# --find_unused_parameters \
|
|
||||||
# --dataset_num_workers 8
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
accelerate launch examples/z_image/model_training/train.py \
|
|
||||||
--dataset_base_path data/example_image_dataset \
|
|
||||||
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_upscale.csv \
|
|
||||||
--data_file_keys "image,controlnet_image" \
|
|
||||||
--max_pixels 1048576 \
|
|
||||||
--dataset_repeat 100 \
|
|
||||||
--model_id_with_origin_paths "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
|
|
||||||
--learning_rate 1e-4 \
|
|
||||||
--num_epochs 5 \
|
|
||||||
--remove_prefix_in_ckpt "pipe.dit." \
|
|
||||||
--output_path "./models/train/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps_lora" \
|
|
||||||
--lora_base_model "dit" \
|
|
||||||
--lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \
|
|
||||||
--lora_rank 32 \
|
|
||||||
--extra_inputs "controlnet_image" \
|
|
||||||
--use_gradient_checkpointing \
|
|
||||||
--dataset_num_workers 8
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
accelerate launch examples/z_image/model_training/train.py \
|
|
||||||
--dataset_base_path data/example_image_dataset \
|
|
||||||
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_canny.csv \
|
|
||||||
--data_file_keys "image,controlnet_image" \
|
|
||||||
--max_pixels 1048576 \
|
|
||||||
--dataset_repeat 100 \
|
|
||||||
--model_id_with_origin_paths "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
|
|
||||||
--learning_rate 1e-4 \
|
|
||||||
--num_epochs 5 \
|
|
||||||
--remove_prefix_in_ckpt "pipe.dit." \
|
|
||||||
--output_path "./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps_lora" \
|
|
||||||
--lora_base_model "dit" \
|
|
||||||
--lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \
|
|
||||||
--lora_rank 32 \
|
|
||||||
--extra_inputs "controlnet_image" \
|
|
||||||
--use_gradient_checkpointing \
|
|
||||||
--dataset_num_workers 8
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
accelerate launch examples/z_image/model_training/train.py \
|
|
||||||
--dataset_base_path data/example_image_dataset \
|
|
||||||
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_canny.csv \
|
|
||||||
--data_file_keys "image,controlnet_image" \
|
|
||||||
--max_pixels 1048576 \
|
|
||||||
--dataset_repeat 100 \
|
|
||||||
--model_id_with_origin_paths "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
|
|
||||||
--learning_rate 1e-4 \
|
|
||||||
--num_epochs 5 \
|
|
||||||
--remove_prefix_in_ckpt "pipe.dit." \
|
|
||||||
--output_path "./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1_lora" \
|
|
||||||
--lora_base_model "dit" \
|
|
||||||
--lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \
|
|
||||||
--lora_rank 32 \
|
|
||||||
--extra_inputs "controlnet_image" \
|
|
||||||
--use_gradient_checkpointing \
|
|
||||||
--dataset_num_workers 8
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
|
||||||
from diffsynth.core import load_state_dict
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
pipe = ZImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
|
|
||||||
state_dict = load_state_dict("./models/train/Z-Image-Omni-Base_full/epoch-1.safetensors", torch_dtype=torch.bfloat16)
|
|
||||||
pipe.dit.load_state_dict(state_dict)
|
|
||||||
prompt = "a dog"
|
|
||||||
image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4)
|
|
||||||
image.save("image.jpg")
|
|
||||||
|
|
||||||
# Edit
|
|
||||||
# state_dict = load_state_dict("./models/train/Z-Image-Omni-Base_full_edit/epoch-1.safetensors", torch_dtype=torch.bfloat16)
|
|
||||||
# pipe.dit.load_state_dict(state_dict)
|
|
||||||
# prompt = "Change the color of the dress in Figure 1 to the color shown in Figure 2."
|
|
||||||
# images = [
|
|
||||||
# Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)),
|
|
||||||
# Image.open("data/example_image_dataset/edit/image_color.jpg").resize((1024, 1024)),
|
|
||||||
# ]
|
|
||||||
# image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4, edit_image=images)
|
|
||||||
# image.save("image.jpg")
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput
|
|
||||||
from diffsynth import load_state_dict
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
pipe = ZImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
state_dict = load_state_dict("./models/train/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps_full/epoch-1.safetensors")
|
|
||||||
pipe.controlnet.load_state_dict(state_dict)
|
|
||||||
|
|
||||||
controlnet_image = Image.open("data/example_image_dataset/upscale/image_1.jpg").resize((1024, 1024))
|
|
||||||
prompt = "a dog"
|
|
||||||
image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=1)])
|
|
||||||
image.save("image_tile.jpg")
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput
|
|
||||||
from diffsynth import load_state_dict
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
pipe = ZImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
state_dict = load_state_dict("./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps_full/epoch-1.safetensors")
|
|
||||||
pipe.controlnet.load_state_dict(state_dict)
|
|
||||||
|
|
||||||
controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1024, 1024))
|
|
||||||
prompt = "a dog"
|
|
||||||
image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)])
|
|
||||||
image.save("image_control.jpg")
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput
|
|
||||||
from diffsynth import load_state_dict
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
pipe = ZImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
state_dict = load_state_dict("./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1_full/epoch-1.safetensors")
|
|
||||||
pipe.controlnet.load_state_dict(state_dict)
|
|
||||||
|
|
||||||
controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1024, 1024))
|
|
||||||
prompt = "a dog"
|
|
||||||
image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)])
|
|
||||||
image.save("image_control.jpg")
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
pipe = ZImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
|
|
||||||
pipe.load_lora(pipe.dit, "./models/train/Z-Image-Omni-Base_lora/epoch-4.safetensors")
|
|
||||||
prompt = "a dog"
|
|
||||||
image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4)
|
|
||||||
image.save("image.jpg")
|
|
||||||
|
|
||||||
# Edit
|
|
||||||
# pipe.load_lora(pipe.dit, "./models/train/Z-Image-Omni-Base_lora_edit/epoch-4.safetensors")
|
|
||||||
# prompt = "Change the color of the dress in Figure 1 to the color shown in Figure 2."
|
|
||||||
# images = [
|
|
||||||
# Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)),
|
|
||||||
# Image.open("data/example_image_dataset/edit/image_color.jpg").resize((1024, 1024)),
|
|
||||||
# ]
|
|
||||||
# image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4, edit_image=images)
|
|
||||||
# image.save("image.jpg")
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput
|
|
||||||
from diffsynth import load_state_dict
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
pipe = ZImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
pipe.load_lora(pipe.dit, "./models/train/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps_lora/epoch-4.safetensors")
|
|
||||||
|
|
||||||
controlnet_image = Image.open("data/example_image_dataset/upscale/image_1.jpg").resize((1024, 1024))
|
|
||||||
prompt = "a dog"
|
|
||||||
image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=1)])
|
|
||||||
image.save("image_tile.jpg")
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput
|
|
||||||
from diffsynth import load_state_dict
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
pipe = ZImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
pipe.load_lora(pipe.dit, "./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps_lora/epoch-4.safetensors")
|
|
||||||
|
|
||||||
controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1024, 1024))
|
|
||||||
prompt = "a dog"
|
|
||||||
image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)])
|
|
||||||
image.save("image_control.jpg")
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput
|
|
||||||
from diffsynth import load_state_dict
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
pipe = ZImagePipeline.from_pretrained(
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device="cuda",
|
|
||||||
model_configs=[
|
|
||||||
ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
|
||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
],
|
|
||||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
|
||||||
)
|
|
||||||
pipe.load_lora(pipe.dit, "./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1_lora/epoch-4.safetensors")
|
|
||||||
|
|
||||||
controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1024, 1024))
|
|
||||||
prompt = "a dog"
|
|
||||||
image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)])
|
|
||||||
image.save("image_control.jpg")
|
|
||||||
@@ -4,11 +4,11 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "diffsynth"
|
name = "diffsynth"
|
||||||
version = "2.0.1"
|
version = "2.0.0"
|
||||||
description = "Enjoy the magic of Diffusion models!"
|
description = "Enjoy the magic of Diffusion models!"
|
||||||
authors = [{name = "ModelScope Team"}]
|
authors = [{name = "ModelScope Team"}]
|
||||||
license = {text = "Apache-2.0"}
|
license = {text = "Apache-2.0"}
|
||||||
requires-python = ">=3.10.1"
|
requires-python = ">=3.10"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"torch>=2.0.0",
|
"torch>=2.0.0",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
@@ -33,8 +33,6 @@ classifiers = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
where = ["./"]
|
|
||||||
include = ["diffsynth"]
|
|
||||||
|
|
||||||
[tool.setuptools]
|
[tool.setuptools]
|
||||||
include-package-data = true
|
include-package-data = true
|
||||||
|
|||||||
Reference in New Issue
Block a user