Compare commits

..

45 Commits

Author SHA1 Message Date
Artiprocher
63559a3ad6 support z-image-omni-base 2026-01-05 14:03:15 +08:00
Zhongjie Duan
ab8580f77e Merge pull request #1166 from modelscope/qwen-image-2512
support qwen-image-2512
2025-12-30 16:47:07 +08:00
Artiprocher
6454259853 support qwen-image-2512 2025-12-30 16:43:41 +08:00
Zhongjie Duan
8f1d10fb43 Merge pull request #1150 from modelscope/qwen-image-layered
support qwen-image-layered
2025-12-20 14:05:38 +08:00
Artiprocher
20e1aaf908 bugfix 2025-12-20 14:00:22 +08:00
Artiprocher
c6722b3f56 support qwen-image-layered 2025-12-19 19:06:37 +08:00
Zhongjie Duan
11315d7a40 Merge pull request #1147 from modelscope/qwen-image-edit-2511
Qwen image edit 2511
2025-12-18 19:23:44 +08:00
Artiprocher
68d97a9844 update doc 2025-12-18 19:22:22 +08:00
Artiprocher
4629d4cf9e support qwen-image-edit-2511 2025-12-18 19:16:52 +08:00
Zhongjie Duan
3cb5cec906 Merge pull request #1143 from modelscope/readme-update
update README
2025-12-17 16:32:29 +08:00
Artiprocher
b7e16b9034 update README 2025-12-17 16:30:41 +08:00
Zhongjie Duan
83d1e7361f Merge pull request #1136 from modelscope/bugfix-device
bugfix
2025-12-16 16:12:05 +08:00
Artiprocher
1547c3f786 bugfix 2025-12-16 16:09:29 +08:00
Zhongjie Duan
bfaaf12bf4 Merge pull request #1129 from modelscope/ascend
Support Ascend NPU
2025-12-15 19:13:40 +08:00
Zhongjie Duan
47545e1aab Merge pull request #1126 from Leoooo333/main
Fixed: Wan S2V Long video severe quality downgrade
2025-12-15 19:09:39 +08:00
Artiprocher
7c6905a432 support ascend npu 2025-12-15 15:50:12 +08:00
Artiprocher
2883bc1b76 support ascend npu 2025-12-15 15:48:42 +08:00
Zhongjie Duan
78d8842ddf Merge pull request #1128 from modelscope/amd_install
update installation instructions for AMD
2025-12-15 14:35:50 +08:00
Artiprocher
5821a664a0 update AMD GPU support 2025-12-15 14:30:13 +08:00
Zhongjie Duan
ab9aa1a087 Merge pull request #1124 from lzws/main
add wan usp example
2025-12-15 12:57:58 +08:00
Junming Chen
a4d34d9f3d Append: set video compress quality as original version. 2025-12-14 20:53:26 +00:00
Junming Chen
127cc9007a Fixed: S2V Long video severe quality downgrade 2025-12-14 20:30:34 +00:00
lzws
e1f5db5f5c add wan usp example 2025-12-12 20:24:27 +08:00
Zhongjie Duan
e316fb717f Merge pull request #1122 from modelscope/flux-lora-revert
revert FluxLoRAConverter due to dependency issues
2025-12-12 17:19:48 +08:00
Artiprocher
64c5139502 revert FluxLoRAConverter due to dependency issues 2025-12-12 17:19:13 +08:00
Mahdi-CV
5da9611a74 Update README.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-11 09:57:15 -08:00
Mahdi-CV
733750d01b Update README.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-11 09:57:06 -08:00
Mahdi-CV
edc95359d0 Update README_zh.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-11 09:56:48 -08:00
lzws
f2d0241e26 Update Z-Image.md 2025-12-11 16:43:38 +08:00
lzws
7b5d7f4af5 Update Z-Image.md 2025-12-11 16:41:46 +08:00
Mahdi Ghodsi
1fa9a6c60c updated README both Eng and Ch to reflect the AMD installation 2025-12-10 16:14:56 -08:00
Mahdi Ghodsi
51efa128d3 adding amd requirements file 2025-12-10 14:40:38 -08:00
Zhongjie Duan
421c6a5fce Merge pull request #1109 from modelscope/bugfix1
fix typo
2025-12-09 23:30:15 +09:00
Artiprocher
864080d8f2 fix typo 2025-12-09 22:29:50 +08:00
Zhongjie Duan
ba372dd295 Merge pull request #1108 from modelscope/i2L
Qwen-Image-i2L (Image to LoRA)
2025-12-09 23:10:02 +09:00
Artiprocher
1ceb02f673 update README 2025-12-09 22:08:47 +08:00
Artiprocher
30f93161fb support i2L 2025-12-09 22:07:35 +08:00
Zhongjie Duan
3ee3cc3104 Merge pull request #1093 from modelscope/diffsynth-2.0-patch
DiffSynth-Studio 2.0 major update
2025-12-04 16:38:31 +08:00
root
c2218f5c73 DiffSynth-Studio 2.0 major update 2025-12-04 16:34:24 +08:00
root
72af7122b3 DiffSynth-Studio 2.0 major update 2025-12-04 16:33:07 +08:00
Zhongjie Duan
afd101f345 Merge pull request #1058 from modelscope/download
support downloading resource
2025-11-18 10:30:16 +08:00
Artiprocher
1313f4dd63 support downloading resource 2025-11-18 10:29:07 +08:00
Zhongjie Duan
8332ecebb7 Merge pull request #1034 from modelscope/video_as_prompt
Video as prompt
2025-11-04 17:32:50 +08:00
Zhongjie Duan
401d7d74a5 Merge pull request #1025 from krahets/patch-1
Fix sinusoidal_embedding calculation for bf16 precision.
2025-11-04 15:08:11 +08:00
Yudong Jin
b8d7d55568 Fix dtype issue in time embedding calculation 2025-11-01 03:11:03 +08:00
59 changed files with 2953 additions and 244 deletions

View File

@@ -33,6 +33,8 @@ We believe that a well-developed open-source code framework can lower the thresh
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
- **December 9, 2025** We release a wild model based on DiffSynth-Studio 2.0: [Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L) (Image-to-LoRA). This model takes an image as input and outputs a LoRA. Although this version still has significant room for improvement in terms of generalization, detail preservation, and other aspects, we are open-sourcing these models to inspire more innovative research. For more details, please refer to our [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l).
- **December 4, 2025** DiffSynth-Studio 2.0 released! Many new features online
- [Documentation](/docs/en/README.md) online: Our documentation is still continuously being optimized and updated
- [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md) module upgraded, supporting layer-level disk offload, releasing both memory and VRAM simultaneously
@@ -187,21 +189,7 @@ cd DiffSynth-Studio
pip install -e .
```
<details>
<summary>Other installation methods</summary>
Install from PyPI (version updates may be delayed; for latest features, install from source)
```
pip install diffsynth
```
If you meet problems during installation, they might be caused by upstream dependencies. Please check the docs of these packages:
* [torch](https://pytorch.org/get-started/locally/)
* [sentencepiece](https://github.com/google/sentencepiece)
* [cmake](https://cmake.org)
* [cupy](https://docs.cupy.dev/en/stable/install.html)
For more installation methods and instructions for non-NVIDIA GPUs, please refer to the [Installation Guide](/docs/en/Pipeline_Usage/Setup.md).
</details>
@@ -408,8 +396,11 @@ Example code for Qwen-Image is available at: [/examples/qwen_image/](/examples/q
| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
@@ -420,6 +411,7 @@ Example code for Qwen-Image is available at: [/examples/qwen_image/](/examples/q
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
</details>

View File

@@ -33,6 +33,8 @@ DiffSynth 目前包括两个开源项目:
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责因此新功能的开发进展会比较缓慢issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
- **2025年12月9日** 我们基于 DiffSynth-Studio 2.0 训练了一个疯狂的模型:[Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)Image to LoRA。这一模型以图像为输入以 LoRA 为输出。尽管这个版本的模型在泛化能力、细节保持能力等方面还有很大改进空间,我们将这些模型开源,以启发更多创新性的研究工作。更多细节,请参考我们的 [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l)。
- **2025年12月4日** DiffSynth-Studio 2.0 发布!众多新功能上线
- [文档](/docs/zh/README.md)上线:我们的文档还在持续优化更新中
- [显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)模块升级,支持 Layer 级别的 Disk Offload同时释放内存与显存
@@ -187,21 +189,7 @@ cd DiffSynth-Studio
pip install -e .
```
<details>
<summary>其他安装方式</summary>
从 pypi 安装(存在版本更新延迟,如需使用最新功能,请从源码安装)
```
pip install diffsynth
```
如果在安装过程中遇到问题,可能是由上游依赖包导致的,请参考这些包的文档:
* [torch](https://pytorch.org/get-started/locally/)
* [sentencepiece](https://github.com/google/sentencepiece)
* [cmake](https://cmake.org)
* [cupy](https://docs.cupy.dev/en/stable/install.html)
更多安装方式,以及非 NVIDIA GPU 的安装,请参考[安装文档](/docs/zh/Pipeline_Usage/Setup.md)。
</details>
@@ -408,8 +396,11 @@ Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
@@ -420,6 +411,7 @@ Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
</details>

View File

@@ -31,6 +31,52 @@ qwen_image_series = [
"model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
"extra_kwargs": {"additional_in_dim": 4},
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors")
"model_hash": "469c78b61e3e31bc9eec0d0af3d3f2f8",
"model_name": "siglip2_image_encoder",
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors")
"model_hash": "5722b5c873720009de96422993b15682",
"model_name": "dinov3_image_encoder",
"model_class": "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder",
},
{
# Example:
"model_hash": "a166c33455cdbd89c0888a3645ca5c0f",
"model_name": "qwen_image_image2lora_coarse",
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
},
{
# Example:
"model_hash": "a5476e691767a4da6d3a6634a10f7408",
"model_name": "qwen_image_image2lora_fine",
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
"extra_kwargs": {"residual_length": 37*37+7, "residual_mid_dim": 64}
},
{
# Example:
"model_hash": "0aad514690602ecaff932c701cb4b0bb",
"model_name": "qwen_image_image2lora_style",
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
"extra_kwargs": {"compress_dim": 64, "use_residual": False}
},
{
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
"model_hash": "8dc8cda05de16c73afa755e2c1ce2839",
"model_name": "qwen_image_dit",
"model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
"extra_kwargs": {"use_layer3d_rope": True, "use_additional_t_cond": True}
},
{
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "44b39ddc499e027cfb24f7878d7416b9",
"model_name": "qwen_image_vae",
"model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
"extra_kwargs": {"image_channels": 4}
},
]
wan_series = [
@@ -481,6 +527,19 @@ z_image_series = [
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers",
"extra_kwargs": {"use_conv_attention": False},
},
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors")
"model_hash": "aa3563718e5c3ecde3dfbb020ca61180",
"model_name": "z_image_dit",
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
"extra_kwargs": {"siglip_feat_dim": 1152},
},
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors")
"model_hash": "89d48e420f45cff95115a9f3e698d44a",
"model_name": "siglip_vision_model_428m",
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M",
},
]
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series

View File

@@ -13,6 +13,7 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
"diffsynth.models.qwen_image_dit.QwenImageDiT": {
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
@@ -32,6 +33,25 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder": {
"transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder": {
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTLayerScale": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTRopePositionEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter": {
"diffsynth.models.wan_video_animate_adapter.FaceEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_animate_adapter.EqualLinear": "diffsynth.core.vram.layers.AutoWrappedModule",

View File

@@ -3,3 +3,4 @@ from .data import *
from .gradient import *
from .loader import *
from .vram import *
from .device import *

View File

@@ -53,12 +53,14 @@ class ToStr(DataProcessingOperator):
class LoadImage(DataProcessingOperator):
def __init__(self, convert_RGB=True):
def __init__(self, convert_RGB=True, convert_RGBA=False):
self.convert_RGB = convert_RGB
self.convert_RGBA = convert_RGBA
def __call__(self, data: str):
image = Image.open(data)
if self.convert_RGB: image = image.convert("RGB")
if self.convert_RGBA: image = image.convert("RGBA")
return image

View File

@@ -0,0 +1 @@
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type

View File

@@ -0,0 +1,107 @@
import importlib
import torch
from typing import Any
def is_torch_npu_available():
return importlib.util.find_spec("torch_npu") is not None
IS_CUDA_AVAILABLE = torch.cuda.is_available()
IS_NPU_AVAILABLE = is_torch_npu_available() and torch.npu.is_available()
if IS_NPU_AVAILABLE:
import torch_npu
torch.npu.config.allow_internal_format = False
def get_device_type() -> str:
"""Get device type based on current machine, currently only support CPU, CUDA, NPU."""
if IS_CUDA_AVAILABLE:
device = "cuda"
elif IS_NPU_AVAILABLE:
device = "npu"
else:
device = "cpu"
return device
def get_torch_device() -> Any:
"""Get torch attribute based on device type, e.g. torch.cuda or torch.npu"""
device_name = get_device_type()
try:
return getattr(torch, device_name)
except AttributeError:
print(f"Device namespace '{device_name}' not found in torch, try to load 'torch.cuda'.")
return torch.cuda
def get_device_id() -> int:
"""Get current device id based on device type."""
return get_torch_device().current_device()
def get_device_name() -> str:
"""Get current device name based on device type."""
return f"{get_device_type()}:{get_device_id()}"
def synchronize() -> None:
"""Execute torch synchronize operation."""
get_torch_device().synchronize()
def empty_cache() -> None:
"""Execute torch empty cache operation."""
get_torch_device().empty_cache()
def get_nccl_backend() -> str:
"""Return distributed communication backend type based on device type."""
if IS_CUDA_AVAILABLE:
return "nccl"
elif IS_NPU_AVAILABLE:
return "hccl"
else:
raise RuntimeError(f"No available distributed communication backend found on device type {get_device_type()}.")
def enable_high_precision_for_bf16():
"""
Set high accumulation dtype for matmul and reduction.
"""
if IS_CUDA_AVAILABLE:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
if IS_NPU_AVAILABLE:
torch.npu.matmul.allow_tf32 = False
torch.npu.matmul.allow_bf16_reduced_precision_reduction = False
def parse_device_type(device):
if isinstance(device, str):
if device.startswith("cuda"):
return "cuda"
elif device.startswith("npu"):
return "npu"
else:
return "cpu"
elif isinstance(device, torch.device):
return device.type
def parse_nccl_backend(device_type):
if device_type == "cuda":
return "nccl"
elif device_type == "npu":
return "hccl"
else:
raise RuntimeError(f"No available distributed communication backend found on device type {device_type}.")
def get_available_device_type():
return get_device_type()

View File

@@ -2,6 +2,7 @@ import torch, copy
from typing import Union
from .initialization import skip_model_initialization
from .disk_map import DiskMap
from ..device import parse_device_type
class AutoTorchModule(torch.nn.Module):
@@ -32,6 +33,7 @@ class AutoTorchModule(torch.nn.Module):
)
self.state = 0
self.name = ""
self.computation_device_type = parse_device_type(self.computation_device)
def set_dtype_and_device(
self,
@@ -61,7 +63,8 @@ class AutoTorchModule(torch.nn.Module):
return r
def check_free_vram(self):
gpu_mem_state = torch.cuda.mem_get_info(self.computation_device)
device = self.computation_device if self.computation_device != "npu" else "npu:0"
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
return used_memory < self.vram_limit
@@ -307,6 +310,7 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
self.lora_B_weights = []
self.lora_merger = None
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
self.computation_device_type = parse_device_type(self.computation_device)
if offload_dtype == "disk":
self.disk_map = disk_map

View File

@@ -3,7 +3,7 @@ import torch
import numpy as np
from einops import repeat, reduce
from typing import Union
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
from ..utils.lora import GeneralLoRALoader
from ..models.model_loader import ModelPool
from ..utils.controlnet import ControlNetInput
@@ -68,6 +68,7 @@ class BasePipeline(torch.nn.Module):
# The device and torch_dtype is used for the storage of intermediate variables, not models.
self.device = device
self.torch_dtype = torch_dtype
self.device_type = parse_device_type(device)
# The following parameters are used for shape check.
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
@@ -154,7 +155,7 @@ class BasePipeline(torch.nn.Module):
for module in model.modules():
if hasattr(module, "offload"):
module.offload()
torch.cuda.empty_cache()
getattr(torch, self.device_type).empty_cache()
# onload models
for name, model in self.named_children():
if name in model_names:
@@ -176,7 +177,8 @@ class BasePipeline(torch.nn.Module):
def get_vram(self):
return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
device = self.device if self.device != "npu" else "npu:0"
return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)
def get_module(self, model, name):
if "." in name:

View File

@@ -0,0 +1,94 @@
from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
import torch
class DINOv3ImageEncoder(DINOv3ViTModel):
def __init__(self):
config = DINOv3ViTConfig(
architectures = [
"DINOv3ViTModel"
],
attention_dropout = 0.0,
drop_path_rate = 0.0,
dtype = "float32",
hidden_act = "silu",
hidden_size = 4096,
image_size = 224,
initializer_range = 0.02,
intermediate_size = 8192,
key_bias = False,
layer_norm_eps = 1e-05,
layerscale_value = 1.0,
mlp_bias = True,
model_type = "dinov3_vit",
num_attention_heads = 32,
num_channels = 3,
num_hidden_layers = 40,
num_register_tokens = 4,
patch_size = 16,
pos_embed_jitter = None,
pos_embed_rescale = 2.0,
pos_embed_shift = None,
proj_bias = True,
query_bias = False,
rope_theta = 100.0,
transformers_version = "4.56.1",
use_gated_mlp = True,
value_bias = False
)
super().__init__(config)
self.processor = DINOv3ViTImageProcessorFast(
crop_size = None,
data_format = "channels_first",
default_to_square = True,
device = None,
disable_grouping = None,
do_center_crop = None,
do_convert_rgb = None,
do_normalize = True,
do_rescale = True,
do_resize = True,
image_mean = [
0.485,
0.456,
0.406
],
image_processor_type = "DINOv3ViTImageProcessorFast",
image_std = [
0.229,
0.224,
0.225
],
input_data_format = None,
resample = 2,
rescale_factor = 0.00392156862745098,
return_tensors = None,
size = {
"height": 224,
"width": 224
}
)
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
inputs = self.processor(images=image, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device)
bool_masked_pos = None
head_mask = None
pixel_values = pixel_values.to(torch_dtype)
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
position_embeddings = self.rope_embeddings(pixel_values)
for i, layer_module in enumerate(self.layer):
layer_head_mask = head_mask[i] if head_mask is not None else None
hidden_states = layer_module(
hidden_states,
attention_mask=layer_head_mask,
position_embeddings=position_embeddings,
)
sequence_output = self.norm(hidden_states)
pooled_output = sequence_output[:, 0, :]
return pooled_output

View File

@@ -19,7 +19,7 @@ def get_timestep_embedding(
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent).to(timesteps.device)
emb = torch.exp(exponent)
if align_dtype_to_timestep:
emb = emb.to(timesteps.dtype)
emb = timesteps[:, None].float() * emb[None, :]
@@ -78,7 +78,7 @@ class DiffusersCompatibleTimestepProj(torch.nn.Module):
class TimestepEmbeddings(torch.nn.Module):
def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False):
def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False, use_additional_t_cond=False):
super().__init__()
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep)
if diffusers_compatible_format:
@@ -87,10 +87,17 @@ class TimestepEmbeddings(torch.nn.Module):
self.timestep_embedder = torch.nn.Sequential(
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
)
self.use_additional_t_cond = use_additional_t_cond
if use_additional_t_cond:
self.addition_t_embedding = torch.nn.Embedding(2, dim_out)
def forward(self, timestep, dtype):
def forward(self, timestep, dtype, addition_t_cond=None):
time_emb = self.time_proj(timestep).to(dtype)
time_emb = self.timestep_embedder(time_emb)
if addition_t_cond is not None:
addition_t_emb = self.addition_t_embedding(addition_t_cond)
addition_t_emb = addition_t_emb.to(dtype=dtype)
time_emb = time_emb + addition_t_emb
return time_emb

View File

@@ -1,4 +1,4 @@
import torch, math
import torch, math, functools
import torch.nn as nn
from typing import Tuple, Optional, Union, List
from einops import rearrange
@@ -225,6 +225,121 @@ class QwenEmbedRope(nn.Module):
return vid_freqs, txt_freqs
class QwenEmbedLayer3DRope(nn.Module):
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
pos_index = torch.arange(4096)
neg_index = torch.arange(4096).flip(0) * -1 - 1
self.pos_freqs = torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
)
self.neg_freqs = torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
],
dim=1,
)
self.scale_rope = scale_rope
def rope_params(self, index, dim, theta=10000):
"""
Args:
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
"""
assert dim % 2 == 0
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
def forward(self, video_fhw, txt_seq_lens, device):
"""
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
txt_length: [bs] a list of 1 integers representing the length of the text
"""
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
video_fhw = [video_fhw]
if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
if not isinstance(video_fhw, list):
video_fhw = [video_fhw]
vid_freqs = []
max_vid_index = 0
layer_num = len(video_fhw) - 1
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
if idx != layer_num:
video_freq = self._compute_video_freqs(frame, height, width, idx)
else:
### For the condition image, we set the layer index to -1
video_freq = self._compute_condition_freqs(frame, height, width)
video_freq = video_freq.to(device)
vid_freqs.append(video_freq)
if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else:
max_vid_index = max(height, width, max_vid_index)
max_vid_index = max(max_vid_index, layer_num)
max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
@functools.lru_cache(maxsize=None)
def _compute_condition_freqs(self, frame, height, width):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
class QwenFeedForward(nn.Module):
def __init__(
self,
@@ -352,9 +467,38 @@ class QwenImageTransformerBlock(nn.Module):
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim)
def _modulate(self, x, mod_params):
def _modulate(self, x, mod_params, index=None):
shift, scale, gate = mod_params.chunk(3, dim=-1)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
if index is not None:
# Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts)
# So shift, scale, gate have shape [2*actual_batch, d]
actual_batch = shift.size(0) // 2
shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d]
scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]
gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]
# index: [b, l] where b is actual batch size
# Expand to [b, l, 1] to match feature dimension
index_expanded = index.unsqueeze(-1) # [b, l, 1]
# Expand chunks to [b, 1, d] then broadcast to [b, l, d]
shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d]
shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d]
scale_0_exp = scale_0.unsqueeze(1)
scale_1_exp = scale_1.unsqueeze(1)
gate_0_exp = gate_0.unsqueeze(1)
gate_1_exp = gate_1.unsqueeze(1)
# Use torch.where to select based on index
shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp)
scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp)
gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp)
else:
shift_result = shift.unsqueeze(1)
scale_result = scale.unsqueeze(1)
gate_result = gate.unsqueeze(1)
return x * (1 + scale_result) + shift_result, gate_result
def forward(
self,
@@ -364,13 +508,16 @@ class QwenImageTransformerBlock(nn.Module):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
enable_fp8_attention = False,
modulate_index: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
if modulate_index is not None:
temb = torch.chunk(temb, 2, dim=0)[0]
txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
img_normed = self.img_norm1(image)
img_modulated, img_gate = self._modulate(img_normed, img_mod_attn)
img_modulated, img_gate = self._modulate(img_normed, img_mod_attn, index=modulate_index)
txt_normed = self.txt_norm1(text)
txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn)
@@ -387,7 +534,7 @@ class QwenImageTransformerBlock(nn.Module):
text = text + txt_gate * txt_attn_out
img_normed_2 = self.img_norm2(image)
img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp)
img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp, index=modulate_index)
txt_normed_2 = self.txt_norm2(text)
txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp)
@@ -405,12 +552,17 @@ class QwenImageDiT(torch.nn.Module):
def __init__(
self,
num_layers: int = 60,
use_layer3d_rope: bool = False,
use_additional_t_cond: bool = False,
):
super().__init__()
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
if not use_layer3d_rope:
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
else:
self.pos_embed = QwenEmbedLayer3DRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True)
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=False, use_additional_t_cond=use_additional_t_cond)
self.txt_norm = RMSNorm(3584, eps=1e-6)
self.img_in = nn.Linear(64, 3072)

View File

@@ -0,0 +1,128 @@
import torch
class CompressedMLP(torch.nn.Module):
def __init__(self, in_dim, mid_dim, out_dim, bias=False):
super().__init__()
self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias)
self.proj_out = torch.nn.Linear(mid_dim, out_dim, bias=bias)
def forward(self, x, residual=None):
x = self.proj_in(x)
if residual is not None: x = x + residual
x = self.proj_out(x)
return x
class ImageEmbeddingToLoraMatrix(torch.nn.Module):
def __init__(self, in_dim, compress_dim, lora_a_dim, lora_b_dim, rank):
super().__init__()
self.proj_a = CompressedMLP(in_dim, compress_dim, lora_a_dim * rank)
self.proj_b = CompressedMLP(in_dim, compress_dim, lora_b_dim * rank)
self.lora_a_dim = lora_a_dim
self.lora_b_dim = lora_b_dim
self.rank = rank
def forward(self, x, residual=None):
lora_a = self.proj_a(x, residual).view(self.rank, self.lora_a_dim)
lora_b = self.proj_b(x, residual).view(self.lora_b_dim, self.rank)
return lora_a, lora_b
class SequencialMLP(torch.nn.Module):
def __init__(self, length, in_dim, mid_dim, out_dim, bias=False):
super().__init__()
self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias)
self.proj_out = torch.nn.Linear(length * mid_dim, out_dim, bias=bias)
self.length = length
self.in_dim = in_dim
self.mid_dim = mid_dim
def forward(self, x):
x = x.view(self.length, self.in_dim)
x = self.proj_in(x)
x = x.view(1, self.length * self.mid_dim)
x = self.proj_out(x)
return x
class LoRATrainerBlock(torch.nn.Module):
def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024):
super().__init__()
self.lora_patterns = lora_patterns
self.block_id = block_id
self.layers = []
for name, lora_a_dim, lora_b_dim in self.lora_patterns:
self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank))
self.layers = torch.nn.ModuleList(self.layers)
if use_residual:
self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim)
else:
self.proj_residual = None
def forward(self, x, residual=None):
lora = {}
if self.proj_residual is not None: residual = self.proj_residual(residual)
for lora_pattern, layer in zip(self.lora_patterns, self.layers):
name = lora_pattern[0]
lora_a, lora_b = layer(x, residual=residual)
lora[f"transformer_blocks.{self.block_id}.{name}.lora_A.default.weight"] = lora_a
lora[f"transformer_blocks.{self.block_id}.{name}.lora_B.default.weight"] = lora_b
return lora
class QwenImageImage2LoRAModel(torch.nn.Module):
def __init__(self, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024):
super().__init__()
self.lora_patterns = [
[
("attn.to_q", 3072, 3072),
("attn.to_k", 3072, 3072),
("attn.to_v", 3072, 3072),
("attn.to_out.0", 3072, 3072),
],
[
("img_mlp.net.2", 3072*4, 3072),
("img_mod.1", 3072, 3072*6),
],
[
("attn.add_q_proj", 3072, 3072),
("attn.add_k_proj", 3072, 3072),
("attn.add_v_proj", 3072, 3072),
("attn.to_add_out", 3072, 3072),
],
[
("txt_mlp.net.2", 3072*4, 3072),
("txt_mod.1", 3072, 3072*6),
],
]
self.num_blocks = num_blocks
self.blocks = []
for lora_patterns in self.lora_patterns:
for block_id in range(self.num_blocks):
self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim))
self.blocks = torch.nn.ModuleList(self.blocks)
self.residual_scale = 0.05
self.use_residual = use_residual
def forward(self, x, residual=None):
if residual is not None:
if self.use_residual:
residual = residual * self.residual_scale
else:
residual = None
lora = {}
for block in self.blocks:
lora.update(block(x, residual))
return lora
def initialize_weights(self):
state_dict = self.state_dict()
for name in state_dict:
if ".proj_a." in name:
state_dict[name] = state_dict[name] * 0.3
elif ".proj_b.proj_out." in name:
state_dict[name] = state_dict[name] * 0
elif ".proj_residual.proj_out." in name:
state_dict[name] = state_dict[name] * 0.3
self.load_state_dict(state_dict)

View File

@@ -366,6 +366,7 @@ class QwenImageEncoder3d(nn.Module):
temperal_downsample=[True, True, False],
dropout=0.0,
non_linearity: str = "silu",
image_channels=3
):
super().__init__()
self.dim = dim
@@ -381,7 +382,7 @@ class QwenImageEncoder3d(nn.Module):
scale = 1.0
# init block
self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1)
self.conv_in = QwenImageCausalConv3d(image_channels, dims[0], 3, padding=1)
# downsample blocks
self.down_blocks = torch.nn.ModuleList([])
@@ -544,6 +545,7 @@ class QwenImageDecoder3d(nn.Module):
temperal_upsample=[False, True, True],
dropout=0.0,
non_linearity: str = "silu",
image_channels=3,
):
super().__init__()
self.dim = dim
@@ -594,7 +596,7 @@ class QwenImageDecoder3d(nn.Module):
# output blocks
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1)
self.conv_out = QwenImageCausalConv3d(out_dim, image_channels, 3, padding=1)
self.gradient_checkpointing = False
@@ -647,6 +649,7 @@ class QwenImageVAE(torch.nn.Module):
attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True],
dropout: float = 0.0,
image_channels: int = 3,
) -> None:
super().__init__()
@@ -655,13 +658,13 @@ class QwenImageVAE(torch.nn.Module):
self.temperal_upsample = temperal_downsample[::-1]
self.encoder = QwenImageEncoder3d(
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, image_channels=image_channels,
)
self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
self.decoder = QwenImageDecoder3d(
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, image_channels=image_channels,
)
mean = [

View File

@@ -0,0 +1,135 @@
from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
import torch
class Siglip2ImageEncoder(SiglipVisionTransformer):
def __init__(self):
config = SiglipVisionConfig(
attention_dropout = 0.0,
dtype = "float32",
hidden_act = "gelu_pytorch_tanh",
hidden_size = 1536,
image_size = 384,
intermediate_size = 6144,
layer_norm_eps = 1e-06,
model_type = "siglip_vision_model",
num_attention_heads = 16,
num_channels = 3,
num_hidden_layers = 40,
patch_size = 16,
transformers_version = "4.56.1",
_attn_implementation = "sdpa"
)
super().__init__(config)
self.processor = SiglipImageProcessor(
do_convert_rgb = None,
do_normalize = True,
do_rescale = True,
do_resize = True,
image_mean = [
0.5,
0.5,
0.5
],
image_processor_type = "SiglipImageProcessor",
image_std = [
0.5,
0.5,
0.5
],
processor_class = "SiglipProcessor",
resample = 2,
rescale_factor = 0.00392156862745098,
size = {
"height": 384,
"width": 384
}
)
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"]
pixel_values = pixel_values.to(device=device, dtype=torch_dtype)
output_attentions = False
output_hidden_states = False
interpolate_pos_encoding = False
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
last_hidden_state = encoder_outputs.last_hidden_state
last_hidden_state = self.post_layernorm(last_hidden_state)
pooler_output = self.head(last_hidden_state) if self.use_head else None
return pooler_output
class Siglip2ImageEncoder428M(Siglip2VisionModel):
def __init__(self):
config = Siglip2VisionConfig(
attention_dropout = 0.0,
dtype = "bfloat16",
hidden_act = "gelu_pytorch_tanh",
hidden_size = 1152,
intermediate_size = 4304,
layer_norm_eps = 1e-06,
model_type = "siglip2_vision_model",
num_attention_heads = 16,
num_channels = 3,
num_hidden_layers = 27,
num_patches = 256,
patch_size = 16,
transformers_version = "4.57.1"
)
super().__init__(config)
self.processor = Siglip2ImageProcessorFast(
**{
"crop_size": None,
"data_format": "channels_first",
"default_to_square": True,
"device": None,
"disable_grouping": None,
"do_center_crop": None,
"do_convert_rgb": None,
"do_normalize": True,
"do_pad": None,
"do_rescale": True,
"do_resize": True,
"image_mean": [
0.5,
0.5,
0.5
],
"image_processor_type": "Siglip2ImageProcessorFast",
"image_std": [
0.5,
0.5,
0.5
],
"input_data_format": None,
"max_num_patches": 256,
"pad_size": None,
"patch_size": 16,
"processor_class": "Siglip2Processor",
"resample": 2,
"rescale_factor": 0.00392156862745098,
"return_tensors": None,
"size": None
}
)
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
siglip_inputs = self.processor(images=[image], return_tensors="pt").to(device)
shape = siglip_inputs.spatial_shapes[0]
hidden_state = super().forward(**siglip_inputs).last_hidden_state
B, N, C = hidden_state.shape
hidden_state = hidden_state[:, : shape[0] * shape[1]]
hidden_state = hidden_state.view(shape[0], shape[1], C)
hidden_state = hidden_state.to(torch_dtype)
return hidden_state

View File

@@ -13,6 +13,7 @@ from ..core.gradient import gradient_checkpoint_forward
ADALN_EMBED_DIM = 256
SEQ_MULTI_OF = 32
X_PAD_DIM = 64
class TimestepEmbedder(nn.Module):
@@ -86,7 +87,7 @@ class Attention(torch.nn.Module):
self.norm_q = RMSNorm(head_dim, eps=1e-5)
self.norm_k = RMSNorm(head_dim, eps=1e-5)
def forward(self, hidden_states, freqs_cis):
def forward(self, hidden_states, freqs_cis, attention_mask):
query = self.to_q(hidden_states)
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
@@ -123,6 +124,7 @@ class Attention(torch.nn.Module):
key,
value,
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
attn_mask=attention_mask,
)
# Reshape back
@@ -136,6 +138,20 @@ class Attention(torch.nn.Module):
return output
def select_per_token(
value_noisy: torch.Tensor,
value_clean: torch.Tensor,
noise_mask: torch.Tensor,
seq_len: int,
) -> torch.Tensor:
noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1)
return torch.where(
noise_mask_expanded == 1,
value_noisy.unsqueeze(1).expand(-1, seq_len, -1),
value_clean.unsqueeze(1).expand(-1, seq_len, -1),
)
class ZImageTransformerBlock(nn.Module):
def __init__(
self,
@@ -180,40 +196,53 @@ class ZImageTransformerBlock(nn.Module):
attn_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor] = None,
noise_mask: Optional[torch.Tensor] = None,
adaln_noisy: Optional[torch.Tensor] = None,
adaln_clean: Optional[torch.Tensor] = None,
):
if self.modulation:
assert adaln_input is not None
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
seq_len = x.shape[1]
if noise_mask is not None:
# Per-token modulation: different modulation for noisy/clean tokens
mod_noisy = self.adaLN_modulation(adaln_noisy)
mod_clean = self.adaLN_modulation(adaln_clean)
scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1)
scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1)
gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh()
gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh()
scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy
scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean
scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len)
scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len)
gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len)
gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len)
else:
# Global modulation: same modulation for all tokens (avoid double select)
mod = self.adaLN_modulation(adaln_input)
scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2)
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
# Attention block
attn_out = self.attention(
self.attention_norm1(x) * scale_msa,
freqs_cis=freqs_cis,
self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis
)
x = x + gate_msa * self.attention_norm2(attn_out)
# FFN block
x = x + gate_mlp * self.ffn_norm2(
self.feed_forward(
self.ffn_norm1(x) * scale_mlp,
)
)
x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
else:
# Attention block
attn_out = self.attention(
self.attention_norm1(x),
freqs_cis=freqs_cis,
)
attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis)
x = x + self.attention_norm2(attn_out)
# FFN block
x = x + self.ffn_norm2(
self.feed_forward(
self.ffn_norm1(x),
)
)
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
return x
@@ -229,9 +258,21 @@ class FinalLayer(nn.Module):
nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
)
def forward(self, x, c):
scale = 1.0 + self.adaLN_modulation(c)
x = self.norm_final(x) * scale.unsqueeze(1)
def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None):
seq_len = x.shape[1]
if noise_mask is not None:
# Per-token modulation
scale_noisy = 1.0 + self.adaLN_modulation(c_noisy)
scale_clean = 1.0 + self.adaLN_modulation(c_clean)
scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len)
else:
# Original global modulation
assert c is not None, "Either c or (c_noisy, c_clean) must be provided"
scale = 1.0 + self.adaLN_modulation(c)
scale = scale.unsqueeze(1)
x = self.norm_final(x) * scale
x = self.linear(x)
return x
@@ -299,6 +340,7 @@ class ZImageDiT(nn.Module):
t_scale=1000.0,
axes_dims=[32, 48, 48],
axes_lens=[1024, 512, 512],
siglip_feat_dim=None,
) -> None:
super().__init__()
self.in_channels = in_channels
@@ -359,6 +401,32 @@ class ZImageDiT(nn.Module):
nn.Linear(cap_feat_dim, dim, bias=True),
)
# Optional SigLIP components (for Omni variant)
self.siglip_feat_dim = siglip_feat_dim
if siglip_feat_dim is not None:
self.siglip_embedder = nn.Sequential(
RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True)
)
self.siglip_refiner = nn.ModuleList(
[
ZImageTransformerBlock(
2000 + layer_id,
dim,
n_heads,
n_kv_heads,
norm_eps,
qk_norm,
modulation=False,
)
for layer_id in range(n_refiner_layers)
]
)
self.siglip_pad_token = nn.Parameter(torch.empty((1, dim)))
else:
self.siglip_embedder = None
self.siglip_refiner = None
self.siglip_pad_token = None
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
@@ -375,22 +443,57 @@ class ZImageDiT(nn.Module):
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]:
def unpatchify(
self,
x: List[torch.Tensor],
size: List[Tuple],
patch_size = 2,
f_patch_size = 1,
x_pos_offsets: Optional[List[Tuple[int, int]]] = None,
) -> List[torch.Tensor]:
pH = pW = patch_size
pF = f_patch_size
bsz = len(x)
assert len(size) == bsz
for i in range(bsz):
F, H, W = size[i]
ori_len = (F // pF) * (H // pH) * (W // pW)
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
x[i] = (
x[i][:ori_len]
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
.permute(6, 0, 3, 1, 4, 2, 5)
.reshape(self.out_channels, F, H, W)
)
return x
if x_pos_offsets is not None:
# Omni: extract target image from unified sequence (cond_images + target)
result = []
for i in range(bsz):
unified_x = x[i][x_pos_offsets[i][0] : x_pos_offsets[i][1]]
cu_len = 0
x_item = None
for j in range(len(size[i])):
if size[i][j] is None:
ori_len = 0
pad_len = SEQ_MULTI_OF
cu_len += pad_len + ori_len
else:
F, H, W = size[i][j]
ori_len = (F // pF) * (H // pH) * (W // pW)
pad_len = (-ori_len) % SEQ_MULTI_OF
x_item = (
unified_x[cu_len : cu_len + ori_len]
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
.permute(6, 0, 3, 1, 4, 2, 5)
.reshape(self.out_channels, F, H, W)
)
cu_len += ori_len + pad_len
result.append(x_item) # Return only the last (target) image
return result
else:
# Original mode: simple unpatchify
for i in range(bsz):
F, H, W = size[i]
ori_len = (F // pF) * (H // pH) * (W // pW)
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
x[i] = (
x[i][:ori_len]
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
.permute(6, 0, 3, 1, 4, 2, 5)
.reshape(self.out_channels, F, H, W)
)
return x
@staticmethod
def create_coordinate_grid(size, start=None, device=None):
@@ -405,8 +508,8 @@ class ZImageDiT(nn.Module):
self,
all_image: List[torch.Tensor],
all_cap_feats: List[torch.Tensor],
patch_size: int,
f_patch_size: int,
patch_size: int = 2,
f_patch_size: int = 1,
):
pH = pW = patch_size
pF = f_patch_size
@@ -490,90 +593,421 @@ class ZImageDiT(nn.Module):
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
all_image_out.append(image_padded_feat)
return all_image_out, all_cap_feats_out, {
"x_size": all_image_size,
"x_pos_ids": all_image_pos_ids,
"cap_pos_ids": all_cap_pos_ids,
"x_pad_mask": all_image_pad_mask,
"cap_pad_mask": all_cap_pad_mask
}
# (
# all_img_out,
# all_cap_out,
# all_img_size,
# all_img_pos_ids,
# all_cap_pos_ids,
# all_img_pad_mask,
# all_cap_pad_mask,
# )
def _prepare_sequence(
self,
feats: List[torch.Tensor],
pos_ids: List[torch.Tensor],
inner_pad_mask: List[torch.Tensor],
pad_token: torch.nn.Parameter,
noise_mask: Optional[List[List[int]]] = None,
device: torch.device = None,
):
"""Prepare sequence: apply pad token, RoPE embed, pad to batch, create attention mask."""
item_seqlens = [len(f) for f in feats]
max_seqlen = max(item_seqlens)
bsz = len(feats)
# Pad token
feats_cat = torch.cat(feats, dim=0)
feats_cat[torch.cat(inner_pad_mask)] = pad_token
feats = list(feats_cat.split(item_seqlens, dim=0))
# RoPE
freqs_cis = list(self.rope_embedder(torch.cat(pos_ids, dim=0)).split([len(p) for p in pos_ids], dim=0))
# Pad to batch
feats = pad_sequence(feats, batch_first=True, padding_value=0.0)
freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]]
# Attention mask
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(item_seqlens):
attn_mask[i, :seq_len] = 1
# Noise mask
noise_mask_tensor = None
if noise_mask is not None:
noise_mask_tensor = pad_sequence(
[torch.tensor(m, dtype=torch.long, device=device) for m in noise_mask],
batch_first=True,
padding_value=0,
)[:, : feats.shape[1]]
return feats, freqs_cis, attn_mask, item_seqlens, noise_mask_tensor
def _build_unified_sequence(
self,
x: torch.Tensor,
x_freqs: torch.Tensor,
x_seqlens: List[int],
x_noise_mask: Optional[List[List[int]]],
cap: torch.Tensor,
cap_freqs: torch.Tensor,
cap_seqlens: List[int],
cap_noise_mask: Optional[List[List[int]]],
siglip: Optional[torch.Tensor],
siglip_freqs: Optional[torch.Tensor],
siglip_seqlens: Optional[List[int]],
siglip_noise_mask: Optional[List[List[int]]],
omni_mode: bool,
device: torch.device,
):
"""Build unified sequence: x, cap, and optionally siglip.
Basic mode order: [x, cap]; Omni mode order: [cap, x, siglip]
"""
bsz = len(x_seqlens)
unified = []
unified_freqs = []
unified_noise_mask = []
for i in range(bsz):
x_len, cap_len = x_seqlens[i], cap_seqlens[i]
if omni_mode:
# Omni: [cap, x, siglip]
if siglip is not None and siglip_seqlens is not None:
sig_len = siglip_seqlens[i]
unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len], siglip[i][:sig_len]]))
unified_freqs.append(
torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len], siglip_freqs[i][:sig_len]])
)
unified_noise_mask.append(
torch.tensor(
cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], dtype=torch.long, device=device
)
)
else:
unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len]]))
unified_freqs.append(torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len]]))
unified_noise_mask.append(
torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device)
)
else:
# Basic: [x, cap]
unified.append(torch.cat([x[i][:x_len], cap[i][:cap_len]]))
unified_freqs.append(torch.cat([x_freqs[i][:x_len], cap_freqs[i][:cap_len]]))
# Compute unified seqlens
if omni_mode:
if siglip is not None and siglip_seqlens is not None:
unified_seqlens = [a + b + c for a, b, c in zip(cap_seqlens, x_seqlens, siglip_seqlens)]
else:
unified_seqlens = [a + b for a, b in zip(cap_seqlens, x_seqlens)]
else:
unified_seqlens = [a + b for a, b in zip(x_seqlens, cap_seqlens)]
max_seqlen = max(unified_seqlens)
# Pad to batch
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0)
# Attention mask
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_seqlens):
attn_mask[i, :seq_len] = 1
# Noise mask
noise_mask_tensor = None
if omni_mode:
noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0)[
:, : unified.shape[1]
]
return unified, unified_freqs, attn_mask, noise_mask_tensor
def _pad_with_ids(
self,
feat: torch.Tensor,
pos_grid_size: Tuple,
pos_start: Tuple,
device: torch.device,
noise_mask_val: Optional[int] = None,
):
"""Pad feature to SEQ_MULTI_OF, create position IDs and pad mask."""
ori_len = len(feat)
pad_len = (-ori_len) % SEQ_MULTI_OF
total_len = ori_len + pad_len
# Pos IDs
ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2)
if pad_len > 0:
pad_pos_ids = (
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
.flatten(0, 2)
.repeat(pad_len, 1)
)
pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0)
padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0)
pad_mask = torch.cat(
[
torch.zeros(ori_len, dtype=torch.bool, device=device),
torch.ones(pad_len, dtype=torch.bool, device=device),
]
)
else:
pos_ids = ori_pos_ids
padded_feat = feat
pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device)
noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level
return padded_feat, pos_ids, pad_mask, total_len, noise_mask
def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int):
"""Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim)."""
pH, pW, pF = patch_size, patch_size, f_patch_size
C, F, H, W = image.size()
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
return image, (F, H, W), (F_tokens, H_tokens, W_tokens)
def patchify_and_embed_omni(
self,
all_x: List[List[torch.Tensor]],
all_cap_feats: List[List[torch.Tensor]],
all_siglip_feats: List[List[torch.Tensor]],
patch_size: int = 2,
f_patch_size: int = 1,
images_noise_mask: List[List[int]] = None,
):
"""Patchify for omni mode: multiple images per batch item with noise masks."""
bsz = len(all_x)
device = all_x[0][-1].device
dtype = all_x[0][-1].dtype
all_x_out, all_x_size, all_x_pos_ids, all_x_pad_mask, all_x_len, all_x_noise_mask = [], [], [], [], [], []
all_cap_out, all_cap_pos_ids, all_cap_pad_mask, all_cap_len, all_cap_noise_mask = [], [], [], [], []
all_sig_out, all_sig_pos_ids, all_sig_pad_mask, all_sig_len, all_sig_noise_mask = [], [], [], [], []
for i in range(bsz):
num_images = len(all_x[i])
cap_feats_list, cap_pos_list, cap_mask_list, cap_lens, cap_noise = [], [], [], [], []
cap_end_pos = []
cap_cu_len = 1
# Process captions
for j, cap_item in enumerate(all_cap_feats[i]):
noise_val = images_noise_mask[i][j] if j < len(images_noise_mask[i]) else 1
cap_out, cap_pos, cap_mask, cap_len, cap_nm = self._pad_with_ids(
cap_item,
(len(cap_item) + (-len(cap_item)) % SEQ_MULTI_OF, 1, 1),
(cap_cu_len, 0, 0),
device,
noise_val,
)
cap_feats_list.append(cap_out)
cap_pos_list.append(cap_pos)
cap_mask_list.append(cap_mask)
cap_lens.append(cap_len)
cap_noise.extend(cap_nm)
cap_cu_len += len(cap_item)
cap_end_pos.append(cap_cu_len)
cap_cu_len += 2 # for image vae and siglip tokens
all_cap_out.append(torch.cat(cap_feats_list, dim=0))
all_cap_pos_ids.append(torch.cat(cap_pos_list, dim=0))
all_cap_pad_mask.append(torch.cat(cap_mask_list, dim=0))
all_cap_len.append(cap_lens)
all_cap_noise_mask.append(cap_noise)
# Process images
x_feats_list, x_pos_list, x_mask_list, x_lens, x_size, x_noise = [], [], [], [], [], []
for j, x_item in enumerate(all_x[i]):
noise_val = images_noise_mask[i][j]
if x_item is not None:
x_patches, size, (F_t, H_t, W_t) = self._patchify_image(x_item, patch_size, f_patch_size)
x_out, x_pos, x_mask, x_len, x_nm = self._pad_with_ids(
x_patches, (F_t, H_t, W_t), (cap_end_pos[j], 0, 0), device, noise_val
)
x_size.append(size)
else:
x_len = SEQ_MULTI_OF
x_out = torch.zeros((x_len, X_PAD_DIM), dtype=dtype, device=device)
x_pos = self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(x_len, 1)
x_mask = torch.ones(x_len, dtype=torch.bool, device=device)
x_nm = [noise_val] * x_len
x_size.append(None)
x_feats_list.append(x_out)
x_pos_list.append(x_pos)
x_mask_list.append(x_mask)
x_lens.append(x_len)
x_noise.extend(x_nm)
all_x_out.append(torch.cat(x_feats_list, dim=0))
all_x_pos_ids.append(torch.cat(x_pos_list, dim=0))
all_x_pad_mask.append(torch.cat(x_mask_list, dim=0))
all_x_size.append(x_size)
all_x_len.append(x_lens)
all_x_noise_mask.append(x_noise)
# Process siglip
if all_siglip_feats[i] is None:
all_sig_len.append([0] * num_images)
all_sig_out.append(None)
else:
sig_feats_list, sig_pos_list, sig_mask_list, sig_lens, sig_noise = [], [], [], [], []
for j, sig_item in enumerate(all_siglip_feats[i]):
noise_val = images_noise_mask[i][j]
if sig_item is not None:
sig_H, sig_W, sig_C = sig_item.size()
sig_flat = sig_item.permute(2, 0, 1).reshape(sig_H * sig_W, sig_C)
sig_out, sig_pos, sig_mask, sig_len, sig_nm = self._pad_with_ids(
sig_flat, (1, sig_H, sig_W), (cap_end_pos[j] + 1, 0, 0), device, noise_val
)
# Scale position IDs to match x resolution
if x_size[j] is not None:
sig_pos = sig_pos.float()
sig_pos[..., 1] = sig_pos[..., 1] / max(sig_H - 1, 1) * (x_size[j][1] - 1)
sig_pos[..., 2] = sig_pos[..., 2] / max(sig_W - 1, 1) * (x_size[j][2] - 1)
sig_pos = sig_pos.to(torch.int32)
else:
sig_len = SEQ_MULTI_OF
sig_out = torch.zeros((sig_len, self.siglip_feat_dim), dtype=dtype, device=device)
sig_pos = (
self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(sig_len, 1)
)
sig_mask = torch.ones(sig_len, dtype=torch.bool, device=device)
sig_nm = [noise_val] * sig_len
sig_feats_list.append(sig_out)
sig_pos_list.append(sig_pos)
sig_mask_list.append(sig_mask)
sig_lens.append(sig_len)
sig_noise.extend(sig_nm)
all_sig_out.append(torch.cat(sig_feats_list, dim=0))
all_sig_pos_ids.append(torch.cat(sig_pos_list, dim=0))
all_sig_pad_mask.append(torch.cat(sig_mask_list, dim=0))
all_sig_len.append(sig_lens)
all_sig_noise_mask.append(sig_noise)
# Compute x position offsets
all_x_pos_offsets = [(sum(all_cap_len[i]), sum(all_cap_len[i]) + sum(all_x_len[i])) for i in range(bsz)]
return (
all_image_out,
all_cap_feats_out,
all_image_size,
all_image_pos_ids,
all_x_out,
all_cap_out,
all_sig_out,
all_x_size,
all_x_pos_ids,
all_cap_pos_ids,
all_image_pad_mask,
all_sig_pos_ids,
all_x_pad_mask,
all_cap_pad_mask,
all_sig_pad_mask,
all_x_pos_offsets,
all_x_noise_mask,
all_cap_noise_mask,
all_sig_noise_mask,
)
return all_x_out, all_cap_out, all_sig_out, {
"x_size": x_size,
"x_pos_ids": all_x_pos_ids,
"cap_pos_ids": all_cap_pos_ids,
"sig_pos_ids": all_sig_pos_ids,
"x_pad_mask": all_x_pad_mask,
"cap_pad_mask": all_cap_pad_mask,
"sig_pad_mask": all_sig_pad_mask,
"x_pos_offsets": all_x_pos_offsets,
"x_noise_mask": all_x_noise_mask,
"cap_noise_mask": all_cap_noise_mask,
"sig_noise_mask": all_sig_noise_mask,
}
def forward(
self,
x: List[torch.Tensor],
t,
cap_feats: List[torch.Tensor],
siglip_feats = None,
image_noise_mask = None,
patch_size=2,
f_patch_size=1,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
):
assert patch_size in self.all_patch_size
assert f_patch_size in self.all_f_patch_size
assert patch_size in self.all_patch_size and f_patch_size in self.all_f_patch_size
omni_mode = isinstance(x[0], list)
device = x[0][-1].device if omni_mode else x[0].device
bsz = len(x)
device = x[0].device
t = t * self.t_scale
t = self.t_embedder(t)
if omni_mode:
# Dual embeddings: noisy (t) and clean (t=1)
t_noisy = self.t_embedder(t * self.t_scale).type_as(x[0][-1])
t_clean = self.t_embedder(torch.ones_like(t) * self.t_scale).type_as(x[0][-1])
adaln_input = None
else:
# Single embedding for all tokens
adaln_input = self.t_embedder(t * self.t_scale).type_as(x[0])
t_noisy = t_clean = None
adaln_input = t
(
x,
cap_feats,
x_size,
x_pos_ids,
cap_pos_ids,
x_inner_pad_mask,
cap_inner_pad_mask,
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
# Patchify
if omni_mode:
(
x,
cap_feats,
siglip_feats,
x_size,
x_pos_ids,
cap_pos_ids,
siglip_pos_ids,
x_pad_mask,
cap_pad_mask,
siglip_pad_mask,
x_pos_offsets,
x_noise_mask,
cap_noise_mask,
siglip_noise_mask,
) = self.patchify_and_embed_omni(x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask)
else:
(
x,
cap_feats,
x_size,
x_pos_ids,
cap_pos_ids,
x_pad_mask,
cap_pad_mask,
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
x_pos_offsets = x_noise_mask = cap_noise_mask = siglip_noise_mask = None
# x embed & refine
x_item_seqlens = [len(_) for _ in x]
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
x_max_item_seqlen = max(x_item_seqlens)
x = torch.cat(x, dim=0)
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token.to(dtype=x.dtype, device=x.device)
x = list(x.split(x_item_seqlens, dim=0))
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
x = pad_sequence(x, batch_first=True, padding_value=0.0)
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(x_item_seqlens):
x_attn_mask[i, :seq_len] = 1
x_seqlens = [len(xi) for xi in x]
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](torch.cat(x, dim=0)) # embed
x, x_freqs, x_mask, _, x_noise_tensor = self._prepare_sequence(
list(x.split(x_seqlens, dim=0)), x_pos_ids, x_pad_mask, self.x_pad_token, x_noise_mask, device
)
for layer in self.noise_refiner:
x = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=x,
attn_mask=x_attn_mask,
freqs_cis=x_freqs_cis,
adaln_input=adaln_input,
x=x, attn_mask=x_mask, freqs_cis=x_freqs, adaln_input=adaln_input, noise_mask=x_noise_tensor, adaln_noisy=t_noisy, adaln_clean=t_clean,
)
# cap embed & refine
cap_item_seqlens = [len(_) for _ in cap_feats]
assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
cap_max_item_seqlen = max(cap_item_seqlens)
cap_feats = torch.cat(cap_feats, dim=0)
cap_feats = self.cap_embedder(cap_feats)
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token.to(dtype=x.dtype, device=x.device)
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(cap_item_seqlens):
cap_attn_mask[i, :seq_len] = 1
# Cap embed & refine
cap_seqlens = [len(ci) for ci in cap_feats]
cap_feats = self.cap_embedder(torch.cat(cap_feats, dim=0)) # embed
cap_feats, cap_freqs, cap_mask, _, _ = self._prepare_sequence(
list(cap_feats.split(cap_seqlens, dim=0)), cap_pos_ids, cap_pad_mask, self.cap_pad_token, None, device
)
for layer in self.context_refiner:
cap_feats = gradient_checkpoint_forward(
@@ -581,41 +1015,68 @@ class ZImageDiT(nn.Module):
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=cap_feats,
attn_mask=cap_attn_mask,
freqs_cis=cap_freqs_cis,
attn_mask=cap_mask,
freqs_cis=cap_freqs,
)
# unified
unified = []
unified_freqs_cis = []
for i in range(bsz):
x_len = x_item_seqlens[i]
cap_len = cap_item_seqlens[i]
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
assert unified_item_seqlens == [len(_) for _ in unified]
unified_max_item_seqlen = max(unified_item_seqlens)
# Siglip embed & refine
siglip_seqlens = siglip_freqs = None
if omni_mode and siglip_feats[0] is not None and self.siglip_embedder is not None:
siglip_seqlens = [len(si) for si in siglip_feats]
siglip_feats = self.siglip_embedder(torch.cat(siglip_feats, dim=0)) # embed
siglip_feats, siglip_freqs, siglip_mask, _, _ = self._prepare_sequence(
list(siglip_feats.split(siglip_seqlens, dim=0)),
siglip_pos_ids,
siglip_pad_mask,
self.siglip_pad_token,
None,
device,
)
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_item_seqlens):
unified_attn_mask[i, :seq_len] = 1
for layer in self.siglip_refiner:
siglip_feats = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=siglip_feats, attn_mask=siglip_mask, freqs_cis=siglip_freqs,
)
for layer in self.layers:
# Unified sequence
unified, unified_freqs, unified_mask, unified_noise_tensor = self._build_unified_sequence(
x,
x_freqs,
x_seqlens,
x_noise_mask,
cap_feats,
cap_freqs,
cap_seqlens,
cap_noise_mask,
siglip_feats,
siglip_freqs,
siglip_seqlens,
siglip_noise_mask,
omni_mode,
device,
)
# Main transformer layers
for layer_idx, layer in enumerate(self.layers):
unified = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=unified,
attn_mask=unified_attn_mask,
freqs_cis=unified_freqs_cis,
adaln_input=adaln_input,
x=unified, attn_mask=unified_mask, freqs_cis=unified_freqs, adaln_input=adaln_input, noise_mask=unified_noise_tensor, adaln_noisy=t_noisy, adaln_clean=t_clean
)
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
unified = list(unified.unbind(dim=0))
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
unified = (
self.all_final_layer[f"{patch_size}-{f_patch_size}"](
unified, noise_mask=unified_noise_tensor, c_noisy=t_noisy, c_clean=t_clean
)
if omni_mode
else self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, c=adaln_input)
)
return x, {}
# Unpatchify
x = self.unpatchify(list(unified.unbind(dim=0)), x_size, patch_size, f_patch_size, x_pos_offsets)
return x

View File

@@ -4,15 +4,20 @@ from typing import Union
from tqdm import tqdm
from einops import rearrange
import numpy as np
from math import prod
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
from ..utils.lora.merge import merge_lora
from ..models.qwen_image_dit import QwenImageDiT
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
from ..models.qwen_image_vae import QwenImageVAE
from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet
from ..models.siglip2_image_encoder import Siglip2ImageEncoder
from ..models.dinov3_image_encoder import DINOv3ImageEncoder
from ..models.qwen_image_image2lora import QwenImageImage2LoRAModel
class QwenImagePipeline(BasePipeline):
@@ -30,6 +35,11 @@ class QwenImagePipeline(BasePipeline):
self.vae: QwenImageVAE = None
self.blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None
self.tokenizer: Qwen2Tokenizer = None
self.siglip2_image_encoder: Siglip2ImageEncoder = None
self.dinov3_image_encoder: DINOv3ImageEncoder = None
self.image2lora_style: QwenImageImage2LoRAModel = None
self.image2lora_coarse: QwenImageImage2LoRAModel = None
self.image2lora_fine: QwenImageImage2LoRAModel = None
self.processor: Qwen2VLProcessor = None
self.in_iteration_models = ("dit", "blockwise_controlnet")
self.units = [
@@ -38,6 +48,7 @@ class QwenImagePipeline(BasePipeline):
QwenImageUnit_InputImageEmbedder(),
QwenImageUnit_Inpaint(),
QwenImageUnit_EditImageEmbedder(),
QwenImageUnit_LayerInputImageEmbedder(),
QwenImageUnit_ContextImageEmbedder(),
QwenImageUnit_PromptEmbedder(),
QwenImageUnit_EntityControl(),
@@ -72,6 +83,11 @@ class QwenImagePipeline(BasePipeline):
processor_config.download_if_necessary()
from transformers import Qwen2VLProcessor
pipe.processor = Qwen2VLProcessor.from_pretrained(processor_config.path)
pipe.siglip2_image_encoder = model_pool.fetch_model("siglip2_image_encoder")
pipe.dinov3_image_encoder = model_pool.fetch_model("dinov3_image_encoder")
pipe.image2lora_style = model_pool.fetch_model("qwen_image_image2lora_style")
pipe.image2lora_coarse = model_pool.fetch_model("qwen_image_image2lora_coarse")
pipe.image2lora_fine = model_pool.fetch_model("qwen_image_image2lora_fine")
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
@@ -111,6 +127,11 @@ class QwenImagePipeline(BasePipeline):
edit_image: Image.Image = None,
edit_image_auto_resize: bool = True,
edit_rope_interpolation: bool = False,
# Qwen-Image-Edit-2511
zero_cond_t: bool = False,
# Qwen-Image-Layered
layer_input_image: Image.Image = None,
layer_num: int = None,
# In-context control
context_image: Image.Image = None,
# Tile
@@ -142,6 +163,9 @@ class QwenImagePipeline(BasePipeline):
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "edit_rope_interpolation": edit_rope_interpolation,
"context_image": context_image,
"zero_cond_t": zero_cond_t,
"layer_input_image": layer_input_image,
"layer_num": layer_num,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -161,7 +185,10 @@ class QwenImagePipeline(BasePipeline):
# Decode
self.load_models_to_device(['vae'])
image = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.vae_output_to_image(image)
if layer_num is None:
image = self.vae_output_to_image(image)
else:
image = [self.vae_output_to_image(i, pattern="C H W") for i in image]
self.load_models_to_device([])
return image
@@ -212,12 +239,15 @@ class QwenImageUnit_ShapeChecker(PipelineUnit):
class QwenImageUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
input_params=("height", "width", "seed", "rand_device", "layer_num"),
output_params=("noise",),
)
def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device):
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device, layer_num):
if layer_num is None:
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
else:
noise = pipe.generate_noise((layer_num + 1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
@@ -234,8 +264,15 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit):
if input_image is None:
return {"latents": noise, "input_latents": None}
pipe.load_models_to_device(['vae'])
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if isinstance(input_image, list):
input_latents = []
for image in input_image:
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
input_latents.append(pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride))
input_latents = torch.concat(input_latents, dim=0)
else:
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
@@ -243,6 +280,22 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit):
return {"latents": latents, "input_latents": input_latents}
class QwenImageUnit_LayerInputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("layer_input_image", "tiled", "tile_size", "tile_stride"),
output_params=("layer_input_latents",),
onload_model_names=("vae",)
)
def process(self, pipe: QwenImagePipeline, layer_input_image, tiled, tile_size, tile_stride):
if layer_input_image is None:
return {}
pipe.load_models_to_device(['vae'])
image = pipe.preprocess_image(layer_input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return {"layer_input_latents": latents}
class QwenImageUnit_Inpaint(PipelineUnit):
def __init__(self):
@@ -515,6 +568,116 @@ class QwenImageUnit_EditImageEmbedder(PipelineUnit):
return {"edit_latents": edit_latents, "edit_image": resized_edit_image}
class QwenImageUnit_Image2LoRAEncode(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("image2lora_images",),
output_params=("image2lora_x", "image2lora_residual", "image2lora_residual_highres"),
onload_model_names=("siglip2_image_encoder", "dinov3_image_encoder", "text_encoder"),
)
from ..core.data.operators import ImageCropAndResize
self.processor_lowres = ImageCropAndResize(height=28*8, width=28*8)
self.processor_highres = ImageCropAndResize(height=1024, width=1024)
def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
bool_mask = mask.bool()
valid_lengths = bool_mask.sum(dim=1)
selected = hidden_states[bool_mask]
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
return split_result
def encode_prompt_edit(self, pipe: QwenImagePipeline, prompt, edit_image):
prompt = [prompt]
template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
drop_idx = 64
txt = [template.format(e) for e in prompt]
model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device)
hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1]
split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device)
return prompt_embeds.view(1, -1)
def encode_images_using_siglip2(self, pipe: QwenImagePipeline, images: list[Image.Image]):
pipe.load_models_to_device(["siglip2_image_encoder"])
embs = []
for image in images:
image = self.processor_highres(image)
embs.append(pipe.siglip2_image_encoder(image).to(pipe.torch_dtype))
embs = torch.stack(embs)
return embs
def encode_images_using_dinov3(self, pipe: QwenImagePipeline, images: list[Image.Image]):
pipe.load_models_to_device(["dinov3_image_encoder"])
embs = []
for image in images:
image = self.processor_highres(image)
embs.append(pipe.dinov3_image_encoder(image).to(pipe.torch_dtype))
embs = torch.stack(embs)
return embs
def encode_images_using_qwenvl(self, pipe: QwenImagePipeline, images: list[Image.Image], highres=False):
pipe.load_models_to_device(["text_encoder"])
embs = []
for image in images:
image = self.processor_highres(image) if highres else self.processor_lowres(image)
embs.append(self.encode_prompt_edit(pipe, prompt="", edit_image=image))
embs = torch.stack(embs)
return embs
def encode_images(self, pipe: QwenImagePipeline, images: list[Image.Image]):
if images is None:
return {}
if not isinstance(images, list):
images = [images]
embs_siglip2 = self.encode_images_using_siglip2(pipe, images)
embs_dinov3 = self.encode_images_using_dinov3(pipe, images)
x = torch.concat([embs_siglip2, embs_dinov3], dim=-1)
residual = None
residual_highres = None
if pipe.image2lora_coarse is not None:
residual = self.encode_images_using_qwenvl(pipe, images, highres=False)
if pipe.image2lora_fine is not None:
residual_highres = self.encode_images_using_qwenvl(pipe, images, highres=True)
return x, residual, residual_highres
def process(self, pipe: QwenImagePipeline, image2lora_images):
if image2lora_images is None:
return {}
x, residual, residual_highres = self.encode_images(pipe, image2lora_images)
return {"image2lora_x": x, "image2lora_residual": residual, "image2lora_residual_highres": residual_highres}
class QwenImageUnit_Image2LoRADecode(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("image2lora_x", "image2lora_residual", "image2lora_residual_highres"),
output_params=("lora",),
onload_model_names=("image2lora_coarse", "image2lora_fine", "image2lora_style"),
)
def process(self, pipe: QwenImagePipeline, image2lora_x, image2lora_residual, image2lora_residual_highres):
if image2lora_x is None:
return {}
loras = []
if pipe.image2lora_style is not None:
pipe.load_models_to_device(["image2lora_style"])
for x in image2lora_x:
loras.append(pipe.image2lora_style(x=x, residual=None))
if pipe.image2lora_coarse is not None:
pipe.load_models_to_device(["image2lora_coarse"])
for x, residual in zip(image2lora_x, image2lora_residual):
loras.append(pipe.image2lora_coarse(x=x, residual=residual))
if pipe.image2lora_fine is not None:
pipe.load_models_to_device(["image2lora_fine"])
for x, residual in zip(image2lora_x, image2lora_residual_highres):
loras.append(pipe.image2lora_fine(x=x, residual=residual))
lora = merge_lora(loras, alpha=1 / len(image2lora_x))
return {"lora": lora}
class QwenImageUnit_ContextImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
@@ -549,18 +712,26 @@ def model_fn_qwen_image(
entity_prompt_emb_mask=None,
entity_masks=None,
edit_latents=None,
layer_input_latents=None,
layer_num=None,
context_latents=None,
enable_fp8_attention=False,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
edit_rope_interpolation=False,
zero_cond_t=False,
**kwargs
):
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
if layer_num is None:
layer_num = 1
img_shapes = [(1, latents.shape[2]//2, latents.shape[3]//2)]
else:
layer_num = layer_num + 1
img_shapes = [(1, latents.shape[2]//2, latents.shape[3]//2)] * layer_num
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
timestep = timestep / 1000
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
image = rearrange(latents, "(B N) C (H P) (W Q) -> B (N H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2, N=layer_num)
image_seq_len = image.shape[1]
if context_latents is not None:
@@ -572,9 +743,27 @@ def model_fn_qwen_image(
img_shapes += [(e.shape[0], e.shape[2]//2, e.shape[3]//2) for e in edit_latents_list]
edit_image = [rearrange(e, "B C (H P) (W Q) -> B (H W) (C P Q)", H=e.shape[2]//2, W=e.shape[3]//2, P=2, Q=2) for e in edit_latents_list]
image = torch.cat([image] + edit_image, dim=1)
if layer_input_latents is not None:
layer_num = layer_num + 1
img_shapes += [(layer_input_latents.shape[0], layer_input_latents.shape[2]//2, layer_input_latents.shape[3]//2)]
layer_input_latents = rearrange(layer_input_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
image = torch.cat([image, layer_input_latents], dim=1)
image = dit.img_in(image)
conditioning = dit.time_text_embed(timestep, image.dtype)
if zero_cond_t:
timestep = torch.cat([timestep, timestep * 0], dim=0)
modulate_index = torch.tensor(
[[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in [img_shapes]],
device=timestep.device,
dtype=torch.int,
)
else:
modulate_index = None
conditioning = dit.time_text_embed(
timestep,
image.dtype,
addition_t_cond=None if not dit.time_text_embed.use_additional_t_cond else torch.tensor([0]).to(device=image.device, dtype=torch.long)
)
if entity_prompt_emb is not None:
text, image_rotary_emb, attention_mask = dit.process_entity_masks(
@@ -604,6 +793,7 @@ def model_fn_qwen_image(
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
enable_fp8_attention=enable_fp8_attention,
modulate_index=modulate_index,
)
if blockwise_controlnet_conditioning is not None:
image_slice = image[:, :image_seq_len].clone()
@@ -614,9 +804,11 @@ def model_fn_qwen_image(
)
image[:, :image_seq_len] = image_slice + controlnet_output
if zero_cond_t:
conditioning = conditioning.chunk(2, dim=0)[0]
image = dit.norm_out(image, conditioning)
image = dit.proj_out(image)
image = image[:, :image_seq_len]
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
latents = rearrange(image, "B (N H W) (C P Q) -> (B N) C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2, B=1)
return latents

View File

@@ -126,7 +126,7 @@ class WanVideoPipeline(BasePipeline):
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
if use_usp:
from ..utils.xfuser import initialize_usp
initialize_usp()
initialize_usp(device)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
@@ -241,6 +241,7 @@ class WanVideoPipeline(BasePipeline):
tea_cache_model_id: Optional[str] = "",
# progress_bar
progress_bar_cmd=tqdm,
output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized",
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
@@ -320,9 +321,11 @@ class WanVideoPipeline(BasePipeline):
# Decode
self.load_models_to_device(['vae'])
video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
video = self.vae_output_to_video(video)
if output_type == "quantized":
video = self.vae_output_to_video(video)
elif output_type == "floatpoint":
pass
self.load_models_to_device([])
return video
@@ -823,9 +826,9 @@ class WanVideoUnit_S2V(PipelineUnit):
pipe.load_models_to_device(["vae"])
motion_frames = 73
kwargs = {}
if motion_video is not None and len(motion_video) > 0:
assert len(motion_video) == motion_frames, f"motion video must have {motion_frames} frames, but got {len(motion_video)}"
motion_latents = pipe.preprocess_video(motion_video)
if motion_video is not None:
assert motion_video.shape[2] == motion_frames, f"motion video must have {motion_frames} frames, but got {motion_video.shape[2]}"
motion_latents = motion_video
kwargs["drop_motion_frames"] = False
else:
motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device)

View File

@@ -4,16 +4,18 @@ from typing import Union
from tqdm import tqdm
from einops import rearrange
import numpy as np
from typing import Union, List, Optional, Tuple
from typing import Union, List, Optional, Tuple, Iterable
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..core.data.operators import ImageCropAndResize
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
from transformers import AutoTokenizer
from ..models.z_image_text_encoder import ZImageTextEncoder
from ..models.z_image_dit import ZImageDiT
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
from ..models.siglip2_image_encoder import Siglip2ImageEncoder428M
class ZImagePipeline(BasePipeline):
@@ -28,6 +30,7 @@ class ZImagePipeline(BasePipeline):
self.dit: ZImageDiT = None
self.vae_encoder: FluxVAEEncoder = None
self.vae_decoder: FluxVAEDecoder = None
self.image_encoder: Siglip2ImageEncoder428M = None
self.tokenizer: AutoTokenizer = None
self.in_iteration_models = ("dit",)
self.units = [
@@ -35,6 +38,9 @@ class ZImagePipeline(BasePipeline):
ZImageUnit_PromptEmbedder(),
ZImageUnit_NoiseInitializer(),
ZImageUnit_InputImageEmbedder(),
ZImageUnit_EditImageAutoResize(),
ZImageUnit_EditImageEmbedderVAE(),
ZImageUnit_EditImageEmbedderSiglip(),
]
self.model_fn = model_fn_z_image
@@ -56,6 +62,7 @@ class ZImagePipeline(BasePipeline):
pipe.dit = model_pool.fetch_model("z_image_dit")
pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder")
pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder")
pipe.image_encoder = model_pool.fetch_model("siglip_vision_model_428m")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
@@ -75,6 +82,9 @@ class ZImagePipeline(BasePipeline):
# Image
input_image: Image.Image = None,
denoising_strength: float = 1.0,
# Edit
edit_image: Image.Image = None,
edit_image_auto_resize: bool = True,
# Shape
height: int = 1024,
width: int = 1024,
@@ -83,11 +93,12 @@ class ZImagePipeline(BasePipeline):
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 8,
sigma_shift: float = None,
# Progress bar
progress_bar_cmd = tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
# Parameters
inputs_posi = {
@@ -102,6 +113,7 @@ class ZImagePipeline(BasePipeline):
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"num_inference_steps": num_inference_steps,
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -143,12 +155,13 @@ class ZImageUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params=("edit_image",),
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_embeds",),
onload_model_names=("text_encoder",)
)
def encode_prompt(
self,
pipe,
@@ -194,10 +207,81 @@ class ZImageUnit_PromptEmbedder(PipelineUnit):
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
return embeddings_list
def encode_prompt_omni(
self,
pipe,
prompt: Union[str, List[str]],
edit_image=None,
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
) -> List[torch.FloatTensor]:
if isinstance(prompt, str):
prompt = [prompt]
def process(self, pipe: ZImagePipeline, prompt):
if edit_image is None:
num_condition_images = 0
elif isinstance(edit_image, list):
num_condition_images = len(edit_image)
else:
num_condition_images = 1
for i, prompt_item in enumerate(prompt):
if num_condition_images == 0:
prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"]
elif num_condition_images > 0:
prompt_list = ["<|im_start|>user\n<|vision_start|>"]
prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1)
prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"]
prompt_list += ["<|vision_end|><|im_end|>"]
prompt[i] = prompt_list
flattened_prompt = []
prompt_list_lengths = []
for i in range(len(prompt)):
prompt_list_lengths.append(len(prompt[i]))
flattened_prompt.extend(prompt[i])
text_inputs = pipe.tokenizer(
flattened_prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device).bool()
prompt_embeds = pipe.text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
).hidden_states[-2]
embeddings_list = []
start_idx = 0
for i in range(len(prompt_list_lengths)):
batch_embeddings = []
end_idx = start_idx + prompt_list_lengths[i]
for j in range(start_idx, end_idx):
batch_embeddings.append(prompt_embeds[j][prompt_masks[j]])
embeddings_list.append(batch_embeddings)
start_idx = end_idx
return embeddings_list
def process(self, pipe: ZImagePipeline, prompt, edit_image):
pipe.load_models_to_device(self.onload_model_names)
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
if hasattr(pipe, "dit") and pipe.dit.siglip_embedder is not None:
# Z-Image-Turbo and Z-Image-Omni-Base use different prompt encoding methods.
# We determine which encoding method to use based on the model architecture.
# If you are using two-stage split training,
# please use `--offload_models` instead of skipping the DiT model loading.
prompt_embeds = self.encode_prompt_omni(pipe, prompt, edit_image, pipe.device)
else:
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
return {"prompt_embeds": prompt_embeds}
@@ -234,24 +318,197 @@ class ZImageUnit_InputImageEmbedder(PipelineUnit):
return {"latents": latents, "input_latents": input_latents}
class ZImageUnit_EditImageAutoResize(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image", "edit_image_auto_resize"),
output_params=("edit_image",),
)
def process(self, pipe: ZImagePipeline, edit_image, edit_image_auto_resize):
if edit_image is None:
return {}
if edit_image_auto_resize is None or not edit_image_auto_resize:
return {}
operator = ImageCropAndResize(max_pixels=1024*1024, height_division_factor=16, width_division_factor=16)
edit_image = operator(edit_image)
return {"edit_image": edit_image}
class ZImageUnit_EditImageEmbedderSiglip(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image",),
output_params=("image_embeds",),
onload_model_names=("image_encoder",)
)
def process(self, pipe: ZImagePipeline, edit_image):
if edit_image is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
if not isinstance(edit_image, list):
edit_image = [edit_image]
image_emb = []
for image_ in edit_image:
image_emb.append(pipe.image_encoder(image_, device=pipe.device))
return {"image_embeds": image_emb}
class ZImageUnit_EditImageEmbedderVAE(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image",),
output_params=("image_latents",),
onload_model_names=("vae_encoder",)
)
def process(self, pipe: ZImagePipeline, edit_image):
if edit_image is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
if not isinstance(edit_image, list):
edit_image = [edit_image]
image_latents = []
for image_ in edit_image:
image_ = pipe.preprocess_image(image_)
image_latents.append(pipe.vae_encoder(image_))
return {"image_latents": image_latents}
def model_fn_z_image(
dit: ZImageDiT,
latents=None,
timestep=None,
prompt_embeds=None,
image_embeds=None,
image_latents=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
# Due to the complex and verbose codebase of Z-Image,
# we are temporarily using this inelegant structure.
# We will refactor this part in the future (if time permits).
if dit.siglip_embedder is None:
return model_fn_z_image_turbo(
dit,
latents,
timestep,
prompt_embeds,
image_embeds,
image_latents,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
**kwargs,
)
latents = [rearrange(latents, "B C H W -> C B H W")]
if dit.siglip_embedder is not None:
if image_latents is not None:
image_latents = [rearrange(image_latent, "B C H W -> C B H W") for image_latent in image_latents]
latents = [image_latents + latents]
image_noise_mask = [[0] * len(image_latents) + [1]]
else:
latents = [latents]
image_noise_mask = [[1]]
image_embeds = [image_embeds]
else:
image_noise_mask = None
timestep = (1000 - timestep) / 1000
model_output = dit(
latents,
timestep,
prompt_embeds,
siglip_feats=image_embeds,
image_noise_mask=image_noise_mask,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)[0][0]
)[0]
model_output = -model_output
model_output = rearrange(model_output, "C B H W -> B C H W")
return model_output
def model_fn_z_image_turbo(
dit: ZImageDiT,
latents=None,
timestep=None,
prompt_embeds=None,
image_embeds=None,
image_latents=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
while isinstance(prompt_embeds, list):
prompt_embeds = prompt_embeds[0]
while isinstance(latents, list):
latents = latents[0]
while isinstance(image_embeds, list):
image_embeds = image_embeds[0]
# Timestep
timestep = 1000 - timestep
t_noisy = dit.t_embedder(timestep)
t_clean = dit.t_embedder(torch.ones_like(timestep) * 1000)
# Patchify
latents = rearrange(latents, "B C H W -> C B H W")
x, cap_feats, patch_metadata = dit.patchify_and_embed([latents], [prompt_embeds])
x = x[0]
cap_feats = cap_feats[0]
# Noise refine
x = dit.all_x_embedder["2-1"](x)
x_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("x_pos_ids"), dim=0))
x = rearrange(x, "L C -> 1 L C")
x_freqs_cis = rearrange(x_freqs_cis, "L C -> 1 L C")
for layer in dit.noise_refiner:
x = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=x,
attn_mask=None,
freqs_cis=x_freqs_cis,
adaln_input=t_noisy,
)
# Prompt refine
cap_feats = dit.cap_embedder(cap_feats)
cap_feats[torch.cat(patch_metadata.get("cap_pad_mask"))] = dit.cap_pad_token.to(dtype=x.dtype, device=x.device)
cap_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("cap_pos_ids"), dim=0))
cap_feats = rearrange(cap_feats, "L C -> 1 L C")
cap_freqs_cis = rearrange(cap_freqs_cis, "L C -> 1 L C")
for layer in dit.context_refiner:
cap_feats = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=cap_feats,
attn_mask=None,
freqs_cis=cap_freqs_cis,
)
# Unified
unified = torch.cat([x, cap_feats], dim=1)
unified_freqs_cis = torch.cat([x_freqs_cis, cap_freqs_cis], dim=1)
for layer in dit.layers:
unified = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=unified,
attn_mask=None,
freqs_cis=unified_freqs_cis,
adaln_input=t_noisy,
)
# Output
unified = dit.all_final_layer["2-1"](unified, t_noisy)
x = dit.unpatchify([unified[0]], patch_metadata.get("x_size"))[0]
x = rearrange(x, "C B H W -> B C H W")
x = -x
return x

View File

@@ -1 +1,3 @@
from .general import GeneralLoRALoader
from .merge import merge_lora
from .reset_rank import reset_lora_rank

View File

@@ -202,3 +202,99 @@ class FluxLoRALoader(GeneralLoRALoader):
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
return state_dict_
class FluxLoRAConverter:
def __init__(self):
pass
@staticmethod
def align_to_opensource_format(state_dict, alpha=None):
prefix_rename_dict = {
"single_blocks": "lora_unet_single_blocks",
"blocks": "lora_unet_double_blocks",
}
middle_rename_dict = {
"norm.linear": "modulation_lin",
"to_qkv_mlp": "linear1",
"proj_out": "linear2",
"norm1_a.linear": "img_mod_lin",
"norm1_b.linear": "txt_mod_lin",
"attn.a_to_qkv": "img_attn_qkv",
"attn.b_to_qkv": "txt_attn_qkv",
"attn.a_to_out": "img_attn_proj",
"attn.b_to_out": "txt_attn_proj",
"ff_a.0": "img_mlp_0",
"ff_a.2": "img_mlp_2",
"ff_b.0": "txt_mlp_0",
"ff_b.2": "txt_mlp_2",
}
suffix_rename_dict = {
"lora_B.weight": "lora_up.weight",
"lora_A.weight": "lora_down.weight",
}
state_dict_ = {}
for name, param in state_dict.items():
names = name.split(".")
if names[-2] != "lora_A" and names[-2] != "lora_B":
names.pop(-2)
prefix = names[0]
middle = ".".join(names[2:-2])
suffix = ".".join(names[-2:])
block_id = names[1]
if middle not in middle_rename_dict:
continue
rename = prefix_rename_dict[prefix] + "_" + block_id + "_" + middle_rename_dict[middle] + "." + suffix_rename_dict[suffix]
state_dict_[rename] = param
if rename.endswith("lora_up.weight"):
lora_alpha = alpha if alpha is not None else param.shape[-1]
state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((lora_alpha,))[0]
return state_dict_
@staticmethod
def align_to_diffsynth_format(state_dict):
rename_dict = {
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
"lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
"lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
"lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
"lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
}
def guess_block_id(name):
names = name.split("_")
for i in names:
if i.isdigit():
return i, name.replace(f"_{i}_", "_blockid_")
return None, None
state_dict_ = {}
for name, param in state_dict.items():
block_id, source_name = guess_block_id(name)
if source_name in rename_dict:
target_name = rename_dict[source_name]
target_name = target_name.replace(".blockid.", f".{block_id}.")
state_dict_[target_name] = param
else:
state_dict_[name] = param
return state_dict_

View File

@@ -0,0 +1,20 @@
import torch
def decomposite(tensor_A, tensor_B, rank):
dtype, device = tensor_A.dtype, tensor_A.device
weight = tensor_B @ tensor_A
U, S, V = torch.pca_lowrank(weight.float(), q=rank)
tensor_A = (V.T).to(dtype=dtype, device=device).contiguous()
tensor_B = (U @ torch.diag(S)).to(dtype=dtype, device=device).contiguous()
return tensor_A, tensor_B
def reset_lora_rank(lora, rank):
lora_merged = {}
keys = [i for i in lora.keys() if ".lora_A." in i]
for key in keys:
tensor_A = lora[key]
tensor_B = lora[key.replace(".lora_A.", ".lora_B.")]
tensor_A, tensor_B = decomposite(tensor_A, tensor_B, rank)
lora_merged[key] = tensor_A
lora_merged[key.replace(".lora_A.", ".lora_B.")] = tensor_B
return lora_merged

View File

@@ -5,19 +5,20 @@ from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from ...core.device import parse_nccl_backend, parse_device_type
def initialize_usp():
def initialize_usp(device_type):
import torch.distributed as dist
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
dist.init_process_group(backend="nccl", init_method="env://")
dist.init_process_group(backend=parse_nccl_backend(device_type), init_method="env://")
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=1,
ulysses_degree=dist.get_world_size(),
)
torch.cuda.set_device(dist.get_rank())
getattr(torch, device_type).set_device(dist.get_rank())
def sinusoidal_embedding_1d(dim, position):
@@ -141,5 +142,5 @@ def usp_attn_forward(self, x, freqs):
x = x.flatten(2)
del q, k, v
torch.cuda.empty_cache()
getattr(torch, parse_device_type(x.device)).empty_cache()
return self.o(x)

View File

@@ -81,8 +81,11 @@ graph LR;
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
| - | - | - | - | - | - | - |
| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](/examples/qwen_image/model_inference/Qwen-Image.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) |
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) |
@@ -93,6 +96,7 @@ graph LR;
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) |
| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) |
| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - |
|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
Special Training Scripts:

View File

@@ -138,4 +138,4 @@ Training Tips:
* Differential LoRA Training ([code](/examples/z_image/model_training/special/differential_training/)) + Acceleration Configuration Inference
* An additional LoRA needs to be loaded in differential LoRA training, e.g., [ostris/zimage_turbo_training_adapter](https://www.modelscope.cn/models/ostris/zimage_turbo_training_adapter)
* Standard SFT Training ([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Trajectory Imitation Distillation Training ([code](/examples/z_image/model_training/special/trajectory_imitation/)) + Acceleration Configuration Inference
* Standard SFT Training ([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Load Distillation Acceleration LoRA During Inference ([model](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-Turbo-DistillFix)) + Acceleration Configuration Inference
* Standard SFT Training ([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Load Distillation Acceleration LoRA During Inference ([model](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-Turbo-DistillPatch)) + Acceleration Configuration Inference

View File

@@ -0,0 +1,58 @@
# GPU/NPU Support
`DiffSynth-Studio` supports various GPUs and NPUs. This document explains how to run model inference and training on these devices.
Before you begin, please follow the [Installation Guide](/docs/en/Pipeline_Usage/Setup.md) to install the required GPU/NPU dependencies.
## NVIDIA GPU
All sample code provided by this project supports NVIDIA GPUs by default, requiring no additional modifications.
## AMD GPU
AMD provides PyTorch packages based on ROCm, so most models can run without code changes. A small number of models may not be compatible due to their reliance on CUDA-specific instructions.
## Ascend NPU
When using Ascend NPU, you need to replace `"cuda"` with `"npu"` in your code.
For example, here is the inference code for **Wan2.1-T2V-1.3B**, modified for Ascend NPU:
```diff
import torch
from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
- "preparing_device": "cuda",
+ "preparing_device": "npu",
"computation_dtype": torch.bfloat16,
- "computation_device": "cuda",
+ "computation_device": "npu",
}
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
- device="cuda",
+ device="npu",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
- vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
+ vram_limit=torch.npu.mem_get_info("npu:0")[1] / (1024 ** 3) - 2,
)
video = pipe(
prompt="Documentary-style photography: a lively puppy running swiftly across lush green grass. The puppy has brownish-yellow fur, upright ears, and an alert, joyful expression. Sunlight bathes its body, making the fur appear exceptionally soft and shiny. The background is an open field with occasional wildflowers, and faint blue sky with scattered white clouds in the distance. Strong perspective captures the motion of the running puppy and the vitality of the surrounding grass. Mid-shot, side-moving viewpoint.",
negative_prompt="Overly vibrant colors, overexposed, static, blurry details, subtitles, artistic style, painting, still image, overall grayish tone, worst quality, low quality, JPEG artifacts, ugly, distorted, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, fused fingers, motionless scene, cluttered background, three legs, many people in background, walking backward",
seed=0, tiled=True,
)
save_video(video, "video.mp4", fps=15, quality=5)
```

View File

@@ -14,8 +14,35 @@ Install from PyPI (there may be delays in version updates; for latest features,
pip install diffsynth
```
If you encounter issues during installation, they may be caused by upstream dependency packages. Please refer to the documentation for these packages:
## GPU/NPU Support
* **NVIDIA GPU**
Install as described above.
* **AMD GPU**
You need to install the `torch` package with ROCm support. Taking ROCm 6.4 (as of the article update date: December 15, 2025) on Linux as an example, run the following command:
```shell
pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm6.4
```
* **Ascend NPU**
Ascend NPU support is provided via the `torch-npu` package. Taking version `2.1.0.post17` (as of the article update date: December 15, 2025) as an example, run the following command:
```shell
pip install torch-npu==2.1.0.post17
```
When using Ascend NPU, please replace `"cuda"` with `"npu"` in your Python code. For details, see [NPU Support](/docs/en/Pipeline_Usage/GPU_support.md#ascend-npu).
## Other Installation Issues
If you encounter issues during installation, they may be caused by upstream dependencies. Please refer to the documentation for these packages:
* [torch](https://pytorch.org/get-started/locally/)
* [Ascend/pytorch](https://github.com/Ascend/pytorch)
* [sentencepiece](https://github.com/google/sentencepiece)
* [cmake](https://cmake.org)
* [cmake](https://cmake.org)

View File

@@ -31,6 +31,7 @@ This section introduces the basic usage of `DiffSynth-Studio`, including how to
* [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md)
* [Model Training](/docs/en/Pipeline_Usage/Model_Training.md)
* [Environment Variables](/docs/en/Pipeline_Usage/Environment_Variables.md)
* [GPU/NPU Support](/docs/en/Pipeline_Usage/GPU_support.md)
## Section 2: Model Details

View File

@@ -81,8 +81,11 @@ graph LR;
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
@@ -93,6 +96,7 @@ graph LR;
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
特殊训练脚本:

View File

@@ -138,4 +138,4 @@ modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir
* 差分 LoRA 训练([code](/examples/z_image/model_training/special/differential_training/) + 加速配置推理
* 差分 LoRA 训练中需加载一个额外的 LoRA例如 [ostris/zimage_turbo_training_adapter](https://www.modelscope.cn/models/ostris/zimage_turbo_training_adapter)
* 标准 SFT 训练([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)+ 轨迹模仿蒸馏训练([code](/examples/z_image/model_training/special/trajectory_imitation/)+ 加速配置推理
* 标准 SFT 训练([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)+ 推理时加载蒸馏加速 LoRA[model](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-Turbo-DistillFix) + 加速配置推理
* 标准 SFT 训练([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)+ 推理时加载蒸馏加速 LoRA[model](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-Turbo-DistillPatch) + 加速配置推理

View File

@@ -0,0 +1,58 @@
# GPU/NPU 支持
`DiffSynth-Studio` 支持多种 GPU/NPU本文介绍如何在这些设备上运行模型推理和训练。
在开始前,请参考[安装依赖](/docs/zh/Pipeline_Usage/Setup.md)安装好 GPU/NPU 相关的依赖包。
## NVIDIA GPU
本项目提供的所有样例代码默认支持 NVIDIA GPU无需额外修改。
## AMD GPU
AMD 提供了基于 ROCm 的 torch 包,所以大多数模型无需修改代码即可运行,少数模型由于依赖特定的 cuda 指令无法运行。
## Ascend NPU
使用 Ascend NPU 时,需把代码中的 `"cuda"` 改为 `"npu"`
例如Wan2.1-T2V-1.3B 的推理代码:
```diff
import torch
from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
- "preparing_device": "cuda",
+ "preparing_device": "npu",
"computation_dtype": torch.bfloat16,
- "computation_device": "cuda",
+ "preparing_device": "npu",
}
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
- device="cuda",
+ device="npu",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
- vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
+ vram_limit=torch.npu.mem_get_info("npu:0")[1] / (1024 ** 3) - 2,
)
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=True,
)
save_video(video, "video.mp4", fps=15, quality=5)
```

View File

@@ -14,8 +14,35 @@ pip install -e .
pip install diffsynth
```
## GPU/NPU 支持
* NVIDIA GPU
按照以上方式安装即可。
* AMD GPU
需安装支持 ROCm 的 `torch` 包,以 ROCm 6.4(本文更新于 2025 年 12 月 15 日、Linux 系统为例,请运行以下命令
```shell
pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm6.4
```
* Ascend NPU
Ascend NPU 通过 `torch-npu` 包提供支持,以 `2.1.0.post17` 版本(本文更新于 2025 年 12 月 15 日)为例,请运行以下命令
```shell
pip install torch-npu==2.1.0.post17
```
使用 Ascend NPU 时,请将 Python 代码中的 `"cuda"` 改为 `"npu"`,详见[NPU 支持](/docs/zh/Pipeline_Usage/GPU_support.md#ascend-npu)。
## 其他安装问题
如果在安装过程中遇到问题,可能是由上游依赖包导致的,请参考这些包的文档:
* [torch](https://pytorch.org/get-started/locally/)
* [Ascend/pytorch](https://github.com/Ascend/pytorch)
* [sentencepiece](https://github.com/google/sentencepiece)
* [cmake](https://cmake.org)

View File

@@ -31,6 +31,7 @@ graph LR;
* [显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)
* [模型训练](/docs/zh/Pipeline_Usage/Model_Training.md)
* [环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md)
* [GPU/NPU 支持](/docs/zh/Pipeline_Usage/GPU_support.md)
## Section 2: 模型详解

View File

@@ -0,0 +1,17 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image-2512", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=40)
image.save("image.jpg")

View File

@@ -0,0 +1,44 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from modelscope import dataset_snapshot_download
from PIL import Image
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2511", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
)
dataset_snapshot_download(
"DiffSynth-Studio/example_image_dataset",
allow_file_pattern="qwen_image_edit/*",
local_dir="data/example_image_dataset",
)
prompt = "生成这两个人的合影"
edit_image = [
Image.open("data/example_image_dataset/qwen_image_edit/image1.jpg"),
Image.open("data/example_image_dataset/qwen_image_edit/image2.jpg"),
]
image = pipe(
prompt,
edit_image=edit_image,
seed=1,
num_inference_steps=40,
height=1152,
width=896,
edit_image_auto_resize=True,
zero_cond_t=True, # This is a special parameter introduced by Qwen-Image-Edit-2511
)
image.save("image.jpg")
# Qwen-Image-Edit-2511 is a multi-image editing model.
# Please use a list to input `edit_image`, even if the input contains only one image.
# edit_image = [Image.open("image.jpg")]
# Please do not input the image directly.
# edit_image = Image.open("image.jpg")

View File

@@ -0,0 +1,36 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from modelscope import dataset_snapshot_download
from PIL import Image
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
)
dataset_snapshot_download(
"DiffSynth-Studio/example_image_dataset",
allow_patterns="layer/image.png",
local_dir="data/example_image_dataset"
)
# Prompt should be provided to the pipeline. Our pipeline will not generate the prompt.
prompt = 'A cheerful child with brown hair is waving enthusiastically under a bright blue sky filled with colorful confetti and balloons. The word "HELLO!" is prominently displayed in bold red letters above the child, while "Have a Great Day!" appears in elegant cursive at the bottom right corner. The scene is vibrant and festive, with a mix of pastel colors and dynamic shapes creating a joyful atmosphere.'
# Height and width should be consistent with input_image and be divided evenly by 16
input_image = Image.open("data/example_image_dataset/layer/image.png").convert("RGBA").resize((864, 480))
images = pipe(
prompt,
seed=1, num_inference_steps=50,
height=480, width=864,
layer_input_image=input_image, layer_num=3,
)
for i, image in enumerate(images):
if i == 0: continue # The first image is the input image.
image.save(f"image_{i}.png")

View File

@@ -0,0 +1,110 @@
from diffsynth.pipelines.qwen_image import (
QwenImagePipeline, ModelConfig,
QwenImageUnit_Image2LoRAEncode, QwenImageUnit_Image2LoRADecode
)
from diffsynth.utils.lora import merge_lora
from diffsynth import load_state_dict
from modelscope import snapshot_download
from safetensors.torch import save_file
import torch
from PIL import Image
def demo_style():
# Load models
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors"),
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors"),
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Style.safetensors"),
],
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
)
# Load images
snapshot_download(
model_id="DiffSynth-Studio/Qwen-Image-i2L",
allow_file_pattern="assets/style/1/*",
local_dir="data/examples"
)
images = [
Image.open("data/examples/assets/style/1/0.jpg"),
Image.open("data/examples/assets/style/1/1.jpg"),
Image.open("data/examples/assets/style/1/2.jpg"),
Image.open("data/examples/assets/style/1/3.jpg"),
Image.open("data/examples/assets/style/1/4.jpg"),
]
# Model inference
with torch.no_grad():
embs = QwenImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images)
lora = QwenImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"]
save_file(lora, "model_style.safetensors")
def demo_coarse_fine_bias():
# Load models
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors"),
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors"),
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Coarse.safetensors"),
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Fine.safetensors"),
],
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
)
# Load images
snapshot_download(
model_id="DiffSynth-Studio/Qwen-Image-i2L",
allow_file_pattern="assets/lora/3/*",
local_dir="data/examples"
)
images = [
Image.open("data/examples/assets/lora/3/0.jpg"),
Image.open("data/examples/assets/lora/3/1.jpg"),
Image.open("data/examples/assets/lora/3/2.jpg"),
Image.open("data/examples/assets/lora/3/3.jpg"),
Image.open("data/examples/assets/lora/3/4.jpg"),
Image.open("data/examples/assets/lora/3/5.jpg"),
]
# Model inference
with torch.no_grad():
embs = QwenImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images)
lora = QwenImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"]
lora_bias = ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Bias.safetensors")
lora_bias.download_if_necessary()
lora_bias = load_state_dict(lora_bias.path, torch_dtype=torch.bfloat16, device="cuda")
lora = merge_lora([lora, lora_bias])
save_file(lora, "model_coarse_fine_bias.safetensors")
def generate_image(lora_path, prompt, seed):
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
pipe.load_lora(pipe.dit, lora_path)
image = pipe(prompt, seed=seed, height=1024, width=1024, num_inference_steps=50)
return image
demo_style()
image = generate_image("model_style.safetensors", "a cat", 0)
image.save("image_1.jpg")
demo_coarse_fine_bias()
image = generate_image("model_coarse_fine_bias.safetensors", "bowl", 1)
image.save("image_2.jpg")

View File

@@ -0,0 +1,28 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image-2512", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=40)
image.save("image.jpg")

View File

@@ -0,0 +1,54 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from modelscope import dataset_snapshot_download
from PIL import Image
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2511", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
)
dataset_snapshot_download(
"DiffSynth-Studio/example_image_dataset",
allow_file_pattern="qwen_image_edit/*",
local_dir="data/example_image_dataset",
)
prompt = "生成这两个人的合影"
edit_image = [
Image.open("data/example_image_dataset/qwen_image_edit/image1.jpg"),
Image.open("data/example_image_dataset/qwen_image_edit/image2.jpg"),
]
image = pipe(
prompt,
edit_image=edit_image,
seed=1,
num_inference_steps=40,
height=1152,
width=896,
edit_image_auto_resize=True,
zero_cond_t=True, # This is a special parameter introduced by Qwen-Image-Edit-2511
)
image.save("image.jpg")
# Qwen-Image-Edit-2511 is a multi-image editing model.
# Please use a list to input `edit_image`, even if the input contains only one image.
# edit_image = [Image.open("image.jpg")]
# Please do not input the image directly.
# edit_image = Image.open("image.jpg")

View File

@@ -0,0 +1,46 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from modelscope import dataset_snapshot_download
from PIL import Image
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
)
dataset_snapshot_download(
"DiffSynth-Studio/example_image_dataset",
allow_patterns="layer/image.png",
local_dir="data/example_image_dataset"
)
# Prompt should be provided to the pipeline. Our pipeline will not generate the prompt.
prompt = 'A cheerful child with brown hair is waving enthusiastically under a bright blue sky filled with colorful confetti and balloons. The word "HELLO!" is prominently displayed in bold red letters above the child, while "Have a Great Day!" appears in elegant cursive at the bottom right corner. The scene is vibrant and festive, with a mix of pastel colors and dynamic shapes creating a joyful atmosphere.'
# Height and width should be consistent with input_image and be divided evenly by 16
input_image = Image.open("data/example_image_dataset/layer/image.png").convert("RGBA").resize((864, 480))
images = pipe(
prompt,
seed=1, num_inference_steps=50,
height=480, width=864,
layer_input_image=input_image, layer_num=3,
)
for i, image in enumerate(images):
if i == 0: continue # The first image is the input image.
image.save(f"image_{i}.png")

View File

@@ -0,0 +1,134 @@
from diffsynth.pipelines.qwen_image import (
QwenImagePipeline, ModelConfig,
QwenImageUnit_Image2LoRAEncode, QwenImageUnit_Image2LoRADecode
)
from diffsynth.utils.lora import merge_lora
from diffsynth import load_state_dict
from modelscope import snapshot_download
from safetensors.torch import save_file
import torch
from PIL import Image
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
vram_config_disk_offload = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": "disk",
"onload_device": "disk",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
def demo_style():
# Load models
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors", **vram_config_disk_offload),
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors", **vram_config_disk_offload),
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Style.safetensors", **vram_config_disk_offload),
],
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
# Load images
snapshot_download(
model_id="DiffSynth-Studio/Qwen-Image-i2L",
allow_file_pattern="assets/style/1/*",
local_dir="data/examples"
)
images = [
Image.open("data/examples/assets/style/1/0.jpg"),
Image.open("data/examples/assets/style/1/1.jpg"),
Image.open("data/examples/assets/style/1/2.jpg"),
Image.open("data/examples/assets/style/1/3.jpg"),
Image.open("data/examples/assets/style/1/4.jpg"),
]
# Model inference
with torch.no_grad():
embs = QwenImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images)
lora = QwenImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"]
save_file(lora, "model_style.safetensors")
def demo_coarse_fine_bias():
# Load models
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config_disk_offload),
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors", **vram_config_disk_offload),
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors", **vram_config_disk_offload),
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Coarse.safetensors", **vram_config_disk_offload),
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Fine.safetensors", **vram_config_disk_offload),
],
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
# Load images
snapshot_download(
model_id="DiffSynth-Studio/Qwen-Image-i2L",
allow_file_pattern="assets/lora/3/*",
local_dir="data/examples"
)
images = [
Image.open("data/examples/assets/lora/3/0.jpg"),
Image.open("data/examples/assets/lora/3/1.jpg"),
Image.open("data/examples/assets/lora/3/2.jpg"),
Image.open("data/examples/assets/lora/3/3.jpg"),
Image.open("data/examples/assets/lora/3/4.jpg"),
Image.open("data/examples/assets/lora/3/5.jpg"),
]
# Model inference
with torch.no_grad():
embs = QwenImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images)
lora = QwenImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"]
lora_bias = ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Bias.safetensors")
lora_bias.download_if_necessary()
lora_bias = load_state_dict(lora_bias.path, torch_dtype=torch.bfloat16, device="cuda")
lora = merge_lora([lora, lora_bias])
save_file(lora, "model_coarse_fine_bias.safetensors")
def generate_image(lora_path, prompt, seed):
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
pipe.load_lora(pipe.dit, lora_path)
image = pipe(prompt, seed=seed, height=1024, width=1024, num_inference_steps=50)
return image
demo_style()
image = generate_image("model_style.safetensors", "a cat", 0)
image.save("image_1.jpg")
demo_coarse_fine_bias()
image = generate_image("model_coarse_fine_bias.safetensors", "bowl", 1)
image.save("image_2.jpg")

View File

@@ -0,0 +1,13 @@
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen-Image-2512:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-2512_full" \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters

View File

@@ -0,0 +1,16 @@
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
--data_file_keys "image,edit_image" \
--extra_inputs "edit_image" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2511:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-Edit-2511_full" \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters \
--zero_cond_t # This is a special parameter introduced by Qwen-Image-Edit-2511. Please enable it for this model.

View File

@@ -0,0 +1,18 @@
# Example Dataset: https://modelscope.cn/datasets/DiffSynth-Studio/example_image_dataset/tree/master/layer
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset/layer \
--dataset_metadata_path data/example_image_dataset/layer/metadata_layered.json \
--data_file_keys "image,layer_input_image" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen-Image-Layered:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image-Layered:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-Layered_full" \
--trainable_models "dit" \
--extra_inputs "layer_num,layer_input_image" \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--find_unused_parameters

View File

@@ -0,0 +1,16 @@
accelerate launch examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen-Image-2512:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-2512_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
--lora_rank 32 \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--find_unused_parameters

View File

@@ -0,0 +1,19 @@
accelerate launch examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
--data_file_keys "image,edit_image" \
--extra_inputs "edit_image" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2511:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-Edit-2511_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
--lora_rank 32 \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--find_unused_parameters \
--zero_cond_t # This is a special parameter introduced by Qwen-Image-Edit-2511. Please enable it for this model.

View File

@@ -0,0 +1,20 @@
# Example Dataset: https://modelscope.cn/datasets/DiffSynth-Studio/example_image_dataset/tree/master/layer
accelerate launch examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset/layer \
--dataset_metadata_path data/example_image_dataset/layer/metadata_layered.json \
--data_file_keys "image,layer_input_image" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen-Image-Layered:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image-Layered:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-Layered_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
--lora_rank 32 \
--extra_inputs "layer_num,layer_input_image" \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--find_unused_parameters

View File

@@ -2,6 +2,7 @@ import torch, os, argparse, accelerate
from diffsynth.core import UnifiedDataset
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth.diffusion import *
from diffsynth.core.data.operators import *
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -20,6 +21,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
offload_models=None,
device="cpu",
task="sft",
zero_cond_t=False,
):
super().__init__()
# Load models
@@ -43,6 +45,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
self.fp8_models = fp8_models
self.task = task
self.zero_cond_t = zero_cond_t
self.task_to_loss = {
"sft:data_process": lambda pipe, *args: args,
"direct_distill:data_process": lambda pipe, *args: args,
@@ -56,11 +59,6 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {"negative_prompt": ""}
inputs_shared = {
# Assume you are using this pipeline for inference,
# please fill in the input parameters.
"input_image": data["image"],
"height": data["image"].size[1],
"width": data["image"].size[0],
# Please do not modify the following parameters
# unless you clearly know what this will cause.
"cfg_scale": 1,
@@ -68,7 +66,22 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
"edit_image_auto_resize": True,
"zero_cond_t": self.zero_cond_t,
}
# Assume you are using this pipeline for inference,
# please fill in the input parameters.
if isinstance(data["image"], list):
inputs_shared.update({
"input_image": data["image"],
"height": data["image"][0].size[1],
"width": data["image"][0].size[0],
})
else:
inputs_shared.update({
"input_image": data["image"],
"height": data["image"].size[1],
"width": data["image"].size[0],
})
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
return inputs_shared, inputs_posi, inputs_nega
@@ -87,6 +100,7 @@ def qwen_image_parser():
parser = add_image_size_config(parser)
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
parser.add_argument("--zero_cond_t", default=False, action="store_true", help="A special parameter introduced by Qwen-Image-Edit-2511. Please enable it for this model.")
return parser
@@ -109,7 +123,15 @@ if __name__ == "__main__":
width=args.width,
height_division_factor=16,
width_division_factor=16,
)
),
special_operator_map={
# Qwen-Image-Layered
"layer_input_image": ToAbsolutePath(args.dataset_base_path) >> LoadImage(convert_RGB=False, convert_RGBA=True) >> ImageCropAndResize(args.height, args.width, args.max_pixels, 16, 16),
"image": RouteByType(operator_map=[
(str, ToAbsolutePath(args.dataset_base_path) >> LoadImage() >> ImageCropAndResize(args.height, args.width, args.max_pixels, 16, 16)),
(list, SequencialProcess(ToAbsolutePath(args.dataset_base_path) >> LoadImage(convert_RGB=False, convert_RGBA=True) >> ImageCropAndResize(args.height, args.width, args.max_pixels, 16, 16))),
])
}
)
model = QwenImageTrainingModule(
model_paths=args.model_paths,
@@ -130,6 +152,7 @@ if __name__ == "__main__":
offload_models=args.offload_models,
task=args.task,
device=accelerator.device,
zero_cond_t=args.zero_cond_t,
)
model_logger = ModelLogger(
args.output_path,

View File

@@ -0,0 +1,20 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth import load_state_dict
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image-2512", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
state_dict = load_state_dict("models/train/Qwen-Image-2512_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "a dog"
image = pipe(prompt, seed=0)
image.save("image.jpg")

View File

@@ -0,0 +1,26 @@
import torch
from PIL import Image
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth import load_state_dict
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2511", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=None,
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
)
state_dict = load_state_dict("models/train/Qwen-Image-Edit-2511_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "Change the color of the dress in Figure 1 to the color shown in Figure 2."
images = [
Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)),
Image.open("data/example_image_dataset/edit/image_color.jpg").resize((1024, 1024)),
]
image = pipe(prompt, edit_image=images, seed=123, num_inference_steps=40, height=1024, width=1024, zero_cond_t=True)
image.save("image.jpg")

View File

@@ -0,0 +1,28 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth import load_state_dict
from PIL import Image
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
state_dict = load_state_dict("models/train/Qwen-Image-Layered_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "a poster"
input_image = Image.open("data/example_image_dataset/layer/image.png").convert("RGBA").resize((864, 480))
images = pipe(
prompt, seed=0,
height=480, width=864,
layer_input_image=input_image, layer_num=3,
)
for i, image in enumerate(images):
if i == 0: continue # The first image is the input image.
image.save(f"image_{i}.png")

View File

@@ -0,0 +1,18 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image-2512", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-2512_lora/epoch-4.safetensors")
prompt = "a dog"
image = pipe(prompt, seed=0)
image.save("image.jpg")

View File

@@ -0,0 +1,24 @@
import torch
from PIL import Image
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2511", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=None,
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
)
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Edit-2511_lora/epoch-4.safetensors")
prompt = "Change the color of the dress in Figure 1 to the color shown in Figure 2."
images = [
Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)),
Image.open("data/example_image_dataset/edit/image_color.jpg").resize((1024, 1024)),
]
image = pipe(prompt, edit_image=images, seed=123, num_inference_steps=40, height=1024, width=1024, zero_cond_t=True)
image.save("image.jpg")

View File

@@ -0,0 +1,27 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth import load_state_dict
from PIL import Image
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Layered_lora/epoch-4.safetensors")
prompt = "a poster"
input_image = Image.open("data/example_image_dataset/layer/image.png").convert("RGBA").resize((864, 480))
images = pipe(
prompt, seed=0,
height=480, width=864,
layer_input_image=input_image, layer_num=3,
)
for i, image in enumerate(images):
if i == 0: continue # The first image is the input image.
image.save(f"image_{i}.png")

View File

@@ -0,0 +1,26 @@
import torch
from PIL import Image
from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
import torch.distributed as dist
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
use_usp=True,
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth"),
],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
)
# Text-to-video
video = pipe(
prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=True,
)
if dist.get_rank() == 0:
save_video(video, "video1.mp4", fps=15, quality=5)

View File

@@ -27,23 +27,24 @@ def speech_to_video(
# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps.
pose_video = VideoData(pose_video_path, height=height, width=width) if pose_video_path is not None else None
audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose(
pipe=pipe,
input_audio=input_audio,
audio_sample_rate=sample_rate,
s2v_pose_video=pose_video,
num_frames=infer_frames + 1,
height=height,
width=width,
fps=fps,
)
with torch.no_grad():
audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose(
pipe=pipe,
input_audio=input_audio,
audio_sample_rate=sample_rate,
s2v_pose_video=pose_video,
num_frames=infer_frames + 1,
height=height,
width=width,
fps=fps,
)
num_repeat = min(num_repeat, num_clip) if num_clip is not None else num_repeat
print(f"Generating {num_repeat} video clips...")
motion_videos = []
motion_video = None
video = []
for r in range(num_repeat):
s2v_pose_latents = pose_latents[r] if pose_latents is not None else None
current_clip = pipe(
current_clip_tensor = pipe(
prompt=prompt,
input_image=input_image,
negative_prompt=negative_prompt,
@@ -53,15 +54,21 @@ def speech_to_video(
width=width,
audio_embeds=audio_embeds[r],
s2v_pose_latents=s2v_pose_latents,
motion_video=motion_videos,
motion_video=motion_video,
num_inference_steps=num_inference_steps,
output_type="floatpoint",
)
current_clip = current_clip[-infer_frames:]
# (B, C, T, H, W)
current_clip_tensor = current_clip_tensor[:,:,-infer_frames:,:,:]
if r == 0:
current_clip = current_clip[3:]
overlap_frames_num = min(motion_frames, len(current_clip))
motion_videos = motion_videos[overlap_frames_num:] + current_clip[-overlap_frames_num:]
video.extend(current_clip)
current_clip_tensor = current_clip_tensor[:,:,3:,:,:]
overlap_frames_num = min(motion_frames, current_clip_tensor.shape[2])
motion_video = current_clip_tensor[:,:,-overlap_frames_num:,:,:].clone()
else:
overlap_frames_num = min(motion_frames, current_clip_tensor.shape[2])
motion_video = torch.cat((motion_video[:,:,overlap_frames_num:,:,:], current_clip_tensor[:,:,-overlap_frames_num:,:,:]), dim=2)
current_clip_quantized = pipe.vae_output_to_video(current_clip_tensor)
video.extend(current_clip_quantized)
save_video_with_audio(video, save_path, audio_path, fps=16, quality=5)
print(f"processed the {r+1}th clip of total {num_repeat} clips.")
return video

View File

@@ -27,23 +27,24 @@ def speech_to_video(
# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps.
pose_video = VideoData(pose_video_path, height=height, width=width) if pose_video_path is not None else None
audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose(
pipe=pipe,
input_audio=input_audio,
audio_sample_rate=sample_rate,
s2v_pose_video=pose_video,
num_frames=infer_frames + 1,
height=height,
width=width,
fps=fps,
)
with torch.no_grad():
audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose(
pipe=pipe,
input_audio=input_audio,
audio_sample_rate=sample_rate,
s2v_pose_video=pose_video,
num_frames=infer_frames + 1,
height=height,
width=width,
fps=fps,
)
num_repeat = min(num_repeat, num_clip) if num_clip is not None else num_repeat
print(f"Generating {num_repeat} video clips...")
motion_videos = []
motion_video = None
video = []
for r in range(num_repeat):
s2v_pose_latents = pose_latents[r] if pose_latents is not None else None
current_clip = pipe(
current_clip_tensor = pipe(
prompt=prompt,
input_image=input_image,
negative_prompt=negative_prompt,
@@ -53,20 +54,24 @@ def speech_to_video(
width=width,
audio_embeds=audio_embeds[r],
s2v_pose_latents=s2v_pose_latents,
motion_video=motion_videos,
motion_video=motion_video,
num_inference_steps=num_inference_steps,
output_type="floatpoint",
)
current_clip = current_clip[-infer_frames:]
current_clip_tensor = current_clip_tensor[:,:,-infer_frames:,:,:]
if r == 0:
current_clip = current_clip[3:]
overlap_frames_num = min(motion_frames, len(current_clip))
motion_videos = motion_videos[overlap_frames_num:] + current_clip[-overlap_frames_num:]
video.extend(current_clip)
current_clip_tensor = current_clip_tensor[:,:,3:,:,:]
overlap_frames_num = min(motion_frames, current_clip_tensor.shape[2])
motion_video = current_clip_tensor[:,:,-overlap_frames_num:,:,:].clone()
else:
overlap_frames_num = min(motion_frames, current_clip_tensor.shape[2])
motion_video = torch.cat((motion_video[:,:,overlap_frames_num:,:,:], current_clip_tensor[:,:,-overlap_frames_num:,:,:]), dim=2)
current_clip_quantized = pipe.vae_output_to_video(current_clip_tensor)
video.extend(current_clip_quantized)
save_video_with_audio(video, save_path, audio_path, fps=16, quality=5)
print(f"processed the {r+1}th clip of total {num_repeat} clips.")
return video
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",

View File

@@ -0,0 +1,23 @@
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
from PIL import Image
import torch
pipe = ZImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
)
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4)
image.save("image_Z-Image-Omni-Base.jpg")
image = Image.open("image_Z-Image-Omni-Base.jpg")
prompt = "Change the women's clothes to white cheongsam, keep other content unchanged"
image = pipe(prompt=prompt, edit_image=image, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4)
image.save("image_edit_Z-Image-Omni-Base.jpg")