Compare commits

...

95 Commits

Author SHA1 Message Date
Artiprocher
53fe42af1b update version 2026-01-30 13:49:27 +08:00
Artiprocher
ee9a3b4405 support loading models from state dict 2026-01-30 13:47:36 +08:00
Zhongjie Duan
22695e9be0 Merge pull request #1233 from modelscope/z-image-release
Z-Image and Z-Image-i2L
2026-01-27 18:41:28 +08:00
Artiprocher
98290190ec update z-image-i2L demo 2026-01-27 13:42:48 +08:00
Artiprocher
3f4de2cc7f update z-image-i2L examples 2026-01-27 12:16:48 +08:00
Artiprocher
d12bf71bcc support z-image and z-image-i2L 2026-01-27 10:56:15 +08:00
Zhongjie Duan
ffb7a138f7 Merge pull request #1228 from modelscope/klein-bugfix
change klein image resize to crop
2026-01-22 10:34:17 +08:00
Artiprocher
548304667f change klein image resize to crop 2026-01-22 10:33:29 +08:00
Zhongjie Duan
273143136c Merge pull request #1227 from modelscope/modelscope-service-patch
update to 2.0.3
2026-01-21 20:23:13 +08:00
Artiprocher
030ebe649a update to 2.0.3 2026-01-21 20:22:43 +08:00
Zhongjie Duan
90921d2293 Merge pull request #1226 from modelscope/klein-train-fix
improve flux2 training performance
2026-01-21 15:44:52 +08:00
Artiprocher
b61131c693 improve flux2 training performance 2026-01-21 15:44:15 +08:00
Zhongjie Duan
37fbb3248a Merge pull request #1222 from modelscope/trainer-update
support auto detact lora target modules
2026-01-21 11:06:19 +08:00
Artiprocher
d13f533f42 support auto detact lora target modules 2026-01-21 11:05:05 +08:00
Zhongjie Duan
3743b1307c Merge pull request #1219 from modelscope/klein-edit
support klein edit
2026-01-20 12:59:12 +08:00
Artiprocher
a835df984c support klein edit 2026-01-20 12:58:18 +08:00
Zhongjie Duan
3e4b47e424 Merge pull request #1207 from Feng0w0/cuda_replace
[NPU]:Replace 'cuda' in the project with abstract interfaces
2026-01-20 10:13:04 +08:00
Zhongjie Duan
dd8d902624 Merge branch 'main' into cuda_replace 2026-01-20 10:12:31 +08:00
Zhongjie Duan
a8b340c098 Merge pull request #1191 from Feng0w0/wan_rope
[model][NPU]:Wan model rope use torch.complex64 in NPU
2026-01-20 10:05:22 +08:00
Zhongjie Duan
88497b5c13 Merge pull request #1217 from modelscope/klein-update
support klein base models
2026-01-19 21:14:47 +08:00
Artiprocher
1e90c72d94 support klein base models 2026-01-19 21:11:58 +08:00
Zhongjie Duan
3dd82a738e Merge pull request #1215 from lzws/main
updata learning rate in wan-vace training scripts
2026-01-19 17:48:42 +08:00
Artiprocher
8ad2d9884b update lr in wan-vace training scripts 2026-01-19 17:43:07 +08:00
Artiprocher
70f531b724 update wan-vace training scripts 2026-01-19 17:37:30 +08:00
Zhongjie Duan
37c2868b61 Merge pull request #1214 from modelscope/klein
Support FLUX.2-klein
2026-01-19 17:36:39 +08:00
Artiprocher
a18e6233b5 updata wan-vace training scripts 2026-01-19 17:35:08 +08:00
Artiprocher
2336d5f6b3 update doc 2026-01-19 17:27:32 +08:00
Artiprocher
b6ccb362b9 support flux.2 klein 2026-01-19 16:56:14 +08:00
Artiprocher
ae52d93694 support klein 4b models 2026-01-16 13:09:41 +08:00
feng0w0
ad91d41601 [NPU]:Replace 'cuda' in the project with abstract interfaces 2026-01-16 10:28:24 +08:00
feng0w0
dce77ec4d1 [NPU]:Replace 'cuda' in the project with abstract interfaces 2026-01-15 20:35:41 +08:00
feng0w0
5c0b07d939 [NPU]:Replace 'cuda' in the project with abstract interfaces 2026-01-15 20:34:52 +08:00
feng0w0
19e429d889 Merge remote-tracking branch 'origin/cuda_replace' into cuda_replace 2026-01-15 20:33:21 +08:00
feng0w0
209a350c0f [NPU]:Replace 'cuda' in the project with abstract interfaces 2026-01-15 20:33:01 +08:00
feng0w0
a3c2744a43 [NPU]:Replace 'cuda' in the project with abstract interfaces 2026-01-15 20:04:54 +08:00
Zhongjie Duan
55e8346da3 Blog link (#1202)
* update README
2026-01-15 12:31:55 +08:00
Zhongjie Duan
b7979b2633 Merge pull request #1200 from modelscope/flux-compatibility-fix
fix flux compatibility issues
2026-01-14 20:50:18 +08:00
Artiprocher
c90aaa2798 fix flux compatibility issues 2026-01-14 20:49:36 +08:00
Zhongjie Duan
0c617d5d9e Merge pull request #1194 from lzws/main
wan usp bug fix
2026-01-14 16:34:06 +08:00
lzws
fd87b72754 wan usp bug fix 2026-01-14 16:33:02 +08:00
Zhongjie Duan
db75508ba0 Merge pull request #1199 from modelscope/z-image-bugfix
fix RMSNorm precision
2026-01-14 16:32:33 +08:00
Artiprocher
acba342a63 fix RMSNorm precision 2026-01-14 16:29:43 +08:00
feng0w0
d16877e695 [model][NPU]:Wan model rope use torch.complex64 in NPU 2026-01-13 11:17:51 +08:00
lzws
e99cdcf3b8 wan usp bug fix 2026-01-12 22:08:48 +08:00
Zhongjie Duan
a236a17f17 Merge pull request #1193 from modelscope/qwen-image-layered-control
support qwen-image-layered-control
2026-01-12 17:24:06 +08:00
Artiprocher
03e530dc39 support qwen-image-layered-control 2026-01-12 17:20:01 +08:00
feng0w0
6be244233a [model][NPU]:Wan model rope use torch.complex64 in NPU 2026-01-12 11:34:41 +08:00
feng0w0
544c391936 [model][NPU]:Wan model rope use torch.complex64 in NPU 2026-01-12 11:24:11 +08:00
Feng
f4d06ce3fc Merge branch 'modelscope:main' into wan_rope 2026-01-12 11:21:09 +08:00
Zhongjie Duan
ffedb9eb52 Merge pull request #1187 from jiaqixuac/patch-1
Update package inclusion pattern in pyproject.toml
2026-01-12 10:12:20 +08:00
Zhongjie Duan
381067515c Merge pull request #1176 from Feng0w0/z-image-rope
[model][NPU]: Z-image model support NPU
2026-01-12 10:11:22 +08:00
Zhongjie Duan
00f2d1aa5d Merge pull request #1169 from Feng0w0/sample_add
Docs:Supplement NPU training script samples and documentation instruction
2026-01-12 10:08:38 +08:00
Zhongjie Duan
8cc3bece6d Merge pull request #1167 from Feng0w0/install_env
Docs:Supplement NPU environment installation document
2026-01-12 10:07:30 +08:00
Jiaqi Xu
f4bf592064 Update pyproject.toml
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-10 09:32:35 +08:00
Jiaqi Xu
3235393fb5 Update package inclusion pattern in pyproject.toml
Update to install all the sub-packages inside diffsynth. Otherwise, the installed packages only contain __init__.py
2026-01-10 09:28:45 +08:00
feng0w0
3b662da31e [model][NPU]:Wan model rope use torch.complex64 in NPU 2026-01-09 18:11:40 +08:00
feng0w0
19ce3048c1 [model][NPU]:Wan model rope use torch.complex64 in NPU 2026-01-09 18:06:41 +08:00
Zhongjie Duan
de0aa946f7 Merge pull request #1184 from modelscope/z-image-omni-base-dev
update package version
2026-01-08 17:27:33 +08:00
Artiprocher
f376202a49 update package version 2026-01-08 17:26:29 +08:00
Zhongjie Duan
a13ecfc46b Merge pull request #1183 from modelscope/z-image-omni-base-dev
fix unused parameters in z-image-omni-base
2026-01-08 17:03:20 +08:00
Artiprocher
10a1853eda fix unused parameters in z-image-omni-base 2026-01-08 17:02:41 +08:00
Zhongjie Duan
0efab85674 Support Z-Image-Omni-Base and its related models
Support Z-Image-Omni-Base and its related models.
2026-01-08 13:43:59 +08:00
Artiprocher
f45a0ffd02 support z-image-omni-base vram management 2026-01-08 13:41:00 +08:00
Artiprocher
8ba528a8f6 bugfix 2026-01-08 13:21:33 +08:00
Artiprocher
dd479e5bff support z-image-omni-base-i2L 2026-01-07 20:36:53 +08:00
Artiprocher
bac39b1cd2 support z-image controlnet 2026-01-07 15:56:53 +08:00
feng0w0
c1c9a4853b [model][NPU]:Z-image model support NPU 2026-01-07 11:42:19 +08:00
feng0w0
3ee5f53a36 [model][NPU]:Z-image model support NPU 2026-01-07 11:31:22 +08:00
Artiprocher
32449a6aa0 support z-image-omni-base training 2026-01-05 20:04:00 +08:00
Zhongjie Duan
a6884f6b3a Merge pull request #1171 from YZBPXX/main
Fix issue where LoRa loads on a device different from Dit
2026-01-05 16:39:02 +08:00
Zhongjie Duan
b078666640 Merge pull request #1173 from modelscope/flux-compatibility-patch
flux compatibility patch
2026-01-05 16:20:25 +08:00
Artiprocher
7604ca1e52 flux compatibility patch 2026-01-05 16:04:20 +08:00
feng0w0
62c3d406d9 Docs:Supplement NPU training script samples and documentation instruction 2026-01-05 15:42:55 +08:00
Artiprocher
5745c9f200 support z-image-omni-base 2026-01-05 14:45:01 +08:00
feng0w0
86829120c2 Docs:Supplement NPU training script samples and documentation instruction 2026-01-05 09:59:11 +08:00
yaozhengbing
60ac96525b Fix issue where LoRa loads on a device different from Dit 2025-12-31 21:31:01 +08:00
feng0w0
07b1f5702f Docs:Supplement NPU training script samples and documentation instruction 2025-12-31 10:01:21 +08:00
feng0w0
507e7e5d36 Docs:Supplement NPU training script samples and documentation instruction 2025-12-30 19:58:47 +08:00
Zhongjie Duan
ab8580f77e Merge pull request #1166 from modelscope/qwen-image-2512
support qwen-image-2512
2025-12-30 16:47:07 +08:00
Artiprocher
6454259853 support qwen-image-2512 2025-12-30 16:43:41 +08:00
feng0w0
9cc1697d4d Docs:Supplement NPU environment installation document 2025-12-30 15:57:13 +08:00
Zhongjie Duan
8f1d10fb43 Merge pull request #1150 from modelscope/qwen-image-layered
support qwen-image-layered
2025-12-20 14:05:38 +08:00
Artiprocher
20e1aaf908 bugfix 2025-12-20 14:00:22 +08:00
Artiprocher
c6722b3f56 support qwen-image-layered 2025-12-19 19:06:37 +08:00
Zhongjie Duan
11315d7a40 Merge pull request #1147 from modelscope/qwen-image-edit-2511
Qwen image edit 2511
2025-12-18 19:23:44 +08:00
Artiprocher
68d97a9844 update doc 2025-12-18 19:22:22 +08:00
Artiprocher
4629d4cf9e support qwen-image-edit-2511 2025-12-18 19:16:52 +08:00
Zhongjie Duan
3cb5cec906 Merge pull request #1143 from modelscope/readme-update
update README
2025-12-17 16:32:29 +08:00
Artiprocher
b7e16b9034 update README 2025-12-17 16:30:41 +08:00
Zhongjie Duan
83d1e7361f Merge pull request #1136 from modelscope/bugfix-device
bugfix
2025-12-16 16:12:05 +08:00
Artiprocher
1547c3f786 bugfix 2025-12-16 16:09:29 +08:00
Zhongjie Duan
bfaaf12bf4 Merge pull request #1129 from modelscope/ascend
Support Ascend NPU
2025-12-15 19:13:40 +08:00
Zhongjie Duan
47545e1aab Merge pull request #1126 from Leoooo333/main
Fixed: Wan S2V Long video severe quality downgrade
2025-12-15 19:09:39 +08:00
Junming Chen
a4d34d9f3d Append: set video compress quality as original version. 2025-12-14 20:53:26 +00:00
Junming Chen
127cc9007a Fixed: S2V Long video severe quality downgrade 2025-12-14 20:30:34 +00:00
154 changed files with 5294 additions and 400 deletions

View File

@@ -22,7 +22,7 @@ jobs:
- name: Install wheel - name: Install wheel
run: pip install wheel==0.44.0 && pip install -r requirements.txt run: pip install wheel==0.44.0 && pip install -r requirements.txt
- name: Build DiffSynth - name: Build DiffSynth
run: python setup.py sdist bdist_wheel run: python -m build
- name: Publish package to PyPI - name: Publish package to PyPI
run: | run: |
pip install twine pip install twine

View File

@@ -33,7 +33,15 @@ We believe that a well-developed open-source code framework can lower the thresh
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand. > Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
- **December 9, 2025** We release a wild model based on DiffSynth-Studio 2.0: [Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L) (Image-to-LoRA). This model takes an image as input and outputs a LoRA. Although this version still has significant room for improvement in terms of generalization, detail preservation, and other aspects, we are open-sourcing these models to inspire more innovative research. - **January 27, 2026**: [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) is released, and our [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) model is released concurrently. You can use it in [ModelScope Studios](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L). For details, see the [documentation](/docs/zh/Model_Details/Z-Image.md).
- **January 19, 2026**: Added support for [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) and [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/FLUX2.md) and [example code](/examples/flux2/) are now available.
- **January 12, 2026**: We trained and open-sourced a text-guided image layer separation model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)). Given an input image and a textual description, the model isolates the image layer corresponding to the described content. For more details, please refer to our blog post ([Chinese version](https://modelscope.cn/learn/4938), [English version](https://huggingface.co/blog/kelseye/qwen-image-layered-control)).
- **December 24, 2025**: Based on Qwen-Image-Edit-2511, we trained an In-Context Editing LoRA model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)). This model takes three images as input (Image A, Image B, and Image C), and automatically analyzes the transformation from Image A to Image B, then applies the same transformation to Image C to generate Image D. For more details, please refer to our blog post ([Chinese version](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g), [English version](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora)).
- **December 9, 2025** We release a wild model based on DiffSynth-Studio 2.0: [Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L) (Image-to-LoRA). This model takes an image as input and outputs a LoRA. Although this version still has significant room for improvement in terms of generalization, detail preservation, and other aspects, we are open-sourcing these models to inspire more innovative research. For more details, please refer to our [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l).
- **December 4, 2025** DiffSynth-Studio 2.0 released! Many new features online - **December 4, 2025** DiffSynth-Studio 2.0 released! Many new features online
- [Documentation](/docs/en/README.md) online: Our documentation is still continuously being optimized and updated - [Documentation](/docs/en/README.md) online: Our documentation is still continuously being optimized and updated
@@ -263,9 +271,14 @@ image.save("image.jpg")
Example code for Z-Image is available at: [/examples/z_image/](/examples/z_image/) Example code for Z-Image is available at: [/examples/z_image/](/examples/z_image/)
| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | |Model ID|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
|-|-|-|-|-|-|-| |-|-|-|-|-|-|-|
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)| |[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
</details> </details>
@@ -315,9 +328,13 @@ image.save("image.jpg")
Example code for FLUX.2 is available at: [/examples/flux2/](/examples/flux2/) Example code for FLUX.2 is available at: [/examples/flux2/](/examples/flux2/)
| Model ID | Inference | Low-VRAM Inference | LoRA Training | LoRA Training Validation | | Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-| |-|-|-|-|-|-|-|
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| |[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
</details> </details>
@@ -396,8 +413,12 @@ Example code for Qwen-Image is available at: [/examples/qwen_image/](/examples/q
| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | | Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-| |-|-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)| |[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)| |[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)| |[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| |[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| |[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)| |[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
@@ -766,4 +787,3 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-47
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
</details> </details>

View File

@@ -33,7 +33,15 @@ DiffSynth 目前包括两个开源项目:
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责因此新功能的开发进展会比较缓慢issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。 > 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责因此新功能的开发进展会比较缓慢issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
- **2025年12月9日** 我们基于 DiffSynth-Studio 2.0 训练了一个疯狂的模型:[Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)Image to LoRA。这一模型以图像为输入以 LoRA 为输出。尽管这个版本的模型在泛化能力、细节保持能力等方面还有很大改进空间,我们将这些模型开源,以启发更多创新性的研究工作 - **2026年1月27日** [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) 发布,我们的 [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) 模型同步发布,在[魔搭创空间](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L)可直接体验,详见[文档](/docs/zh/Model_Details/Z-Image.md)
- **2026年1月19日** 新增对 [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 和 [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/FLUX2.md)和[示例代码](/examples/flux2/)现已可用。
- **2026年1月12日** 我们训练并开源了一个文本引导的图层拆分模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)),这一模型输入一张图与一段文本描述,模型会将图像中与文本描述相关的图层拆分出来。更多细节请阅读我们的 blog[中文版](https://modelscope.cn/learn/4938)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-layered-control))。
- **2025年12月24日** 我们基于 Qwen-Image-Edit-2511 训练了一个 In-Context Editing LoRA 模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)这个模型可以输入三张图图A、图B、图C模型会自行分析图A到图B的变化并将这样的变化应用到图C生成图D。更多细节请阅读我们的 blog[中文版](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora))。
- **2025年12月9日** 我们基于 DiffSynth-Studio 2.0 训练了一个疯狂的模型:[Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)Image to LoRA。这一模型以图像为输入以 LoRA 为输出。尽管这个版本的模型在泛化能力、细节保持能力等方面还有很大改进空间,我们将这些模型开源,以启发更多创新性的研究工作。更多细节,请参考我们的 [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l)。
- **2025年12月4日** DiffSynth-Studio 2.0 发布!众多新功能上线 - **2025年12月4日** DiffSynth-Studio 2.0 发布!众多新功能上线
- [文档](/docs/zh/README.md)上线:我们的文档还在持续优化更新中 - [文档](/docs/zh/README.md)上线:我们的文档还在持续优化更新中
@@ -265,7 +273,12 @@ Z-Image 的示例代码位于:[/examples/z_image/](/examples/z_image/)
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-| |-|-|-|-|-|-|-|
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)| |[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
</details> </details>
@@ -315,9 +328,13 @@ image.save("image.jpg")
FLUX.2 的示例代码位于:[/examples/flux2/](/examples/flux2/) FLUX.2 的示例代码位于:[/examples/flux2/](/examples/flux2/)
|模型 ID|推理|低显存推理|LoRA 训练|LoRA 训练后验证| |模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-| |-|-|-|-|-|-|-|
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| |[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
</details> </details>
@@ -396,8 +413,12 @@ Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-| |-|-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)| |[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)| |[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)| |[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| |[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| |[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)| |[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|

View File

@@ -63,6 +63,20 @@ qwen_image_series = [
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel", "model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
"extra_kwargs": {"compress_dim": 64, "use_residual": False} "extra_kwargs": {"compress_dim": 64, "use_residual": False}
}, },
{
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
"model_hash": "8dc8cda05de16c73afa755e2c1ce2839",
"model_name": "qwen_image_dit",
"model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
"extra_kwargs": {"use_layer3d_rope": True, "use_additional_t_cond": True}
},
{
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "44b39ddc499e027cfb24f7878d7416b9",
"model_name": "qwen_image_vae",
"model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
"extra_kwargs": {"image_channels": 4}
},
] ]
wan_series = [ wan_series = [
@@ -303,6 +317,13 @@ flux_series = [
"model_class": "diffsynth.models.flux_dit.FluxDiT", "model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
}, },
{
# Supported due to historical reasons.
"model_hash": "605c56eab23e9e2af863ad8f0813a25d",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverterFromDiffusers",
},
{ {
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors") # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors")
"model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78", "model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
@@ -460,6 +481,13 @@ flux_series = [
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
"extra_kwargs": {"disable_guidance_embedder": True}, "extra_kwargs": {"disable_guidance_embedder": True},
}, },
{
# Example: ModelConfig(model_id="MAILAND/majicflus_v1", origin_file_pattern="majicflus_v134.safetensors")
"model_hash": "3394f306c4cbf04334b712bf5aaed95f",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
},
] ]
flux2_series = [ flux2_series = [
@@ -482,6 +510,28 @@ flux2_series = [
"model_name": "flux2_vae", "model_name": "flux2_vae",
"model_class": "diffsynth.models.flux2_vae.Flux2VAE", "model_class": "diffsynth.models.flux2_vae.Flux2VAE",
}, },
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors")
"model_hash": "3bde7b817fec8143028b6825a63180df",
"model_name": "flux2_dit",
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 7680, "num_attention_heads": 24, "num_layers": 5, "num_single_layers": 20}
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors")
"model_hash": "9195f3ea256fcd0ae6d929c203470754",
"model_name": "z_image_text_encoder",
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
"extra_kwargs": {"model_size": "8B"},
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors")
"model_hash": "39c6fc48f07bebecedbbaa971ff466c8",
"model_name": "flux2_dit",
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 12288, "num_attention_heads": 32, "num_layers": 8, "num_single_layers": 24}
},
] ]
z_image_series = [ z_image_series = [
@@ -513,6 +563,32 @@ z_image_series = [
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers", "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers",
"extra_kwargs": {"use_conv_attention": False}, "extra_kwargs": {"use_conv_attention": False},
}, },
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors")
"model_hash": "aa3563718e5c3ecde3dfbb020ca61180",
"model_name": "z_image_dit",
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
"extra_kwargs": {"siglip_feat_dim": 1152},
},
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors")
"model_hash": "89d48e420f45cff95115a9f3e698d44a",
"model_name": "siglip_vision_model_428m",
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M",
},
{
# Example: ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors")
"model_hash": "1677708d40029ab380a95f6c731a57d7",
"model_name": "z_image_controlnet",
"model_class": "diffsynth.models.z_image_controlnet.ZImageControlNet",
},
{
# Example: ???
"model_hash": "9510cb8cd1dd34ee0e4f111c24905510",
"model_name": "z_image_image2lora_style",
"model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
"extra_kwargs": {"compress_dim": 128},
},
] ]
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series

View File

@@ -13,6 +13,7 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
"diffsynth.models.qwen_image_dit.QwenImageDiT": { "diffsynth.models.qwen_image_dit.QwenImageDiT": {
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", "diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
}, },
"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": { "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
@@ -194,4 +195,19 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", "diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
}, },
"diffsynth.models.z_image_controlnet.ZImageControlNet": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M": {
"transformers.models.siglip2.modeling_siglip2.Siglip2VisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.siglip2.modeling_siglip2.Siglip2MultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
} }

View File

@@ -53,12 +53,14 @@ class ToStr(DataProcessingOperator):
class LoadImage(DataProcessingOperator): class LoadImage(DataProcessingOperator):
def __init__(self, convert_RGB=True): def __init__(self, convert_RGB=True, convert_RGBA=False):
self.convert_RGB = convert_RGB self.convert_RGB = convert_RGB
self.convert_RGBA = convert_RGBA
def __call__(self, data: str): def __call__(self, data: str):
image = Image.open(data) image = Image.open(data)
if self.convert_RGB: image = image.convert("RGB") if self.convert_RGB: image = image.convert("RGB")
if self.convert_RGBA: image = image.convert("RGBA")
return image return image

View File

@@ -10,6 +10,7 @@ class UnifiedDataset(torch.utils.data.Dataset):
data_file_keys=tuple(), data_file_keys=tuple(),
main_data_operator=lambda x: x, main_data_operator=lambda x: x,
special_operator_map=None, special_operator_map=None,
max_data_items=None,
): ):
self.base_path = base_path self.base_path = base_path
self.metadata_path = metadata_path self.metadata_path = metadata_path
@@ -18,6 +19,7 @@ class UnifiedDataset(torch.utils.data.Dataset):
self.main_data_operator = main_data_operator self.main_data_operator = main_data_operator
self.cached_data_operator = LoadTorchPickle() self.cached_data_operator = LoadTorchPickle()
self.special_operator_map = {} if special_operator_map is None else special_operator_map self.special_operator_map = {} if special_operator_map is None else special_operator_map
self.max_data_items = max_data_items
self.data = [] self.data = []
self.cached_data = [] self.cached_data = []
self.load_from_cache = metadata_path is None self.load_from_cache = metadata_path is None
@@ -97,7 +99,9 @@ class UnifiedDataset(torch.utils.data.Dataset):
return data return data
def __len__(self): def __len__(self):
if self.load_from_cache: if self.max_data_items is not None:
return self.max_data_items
elif self.load_from_cache:
return len(self.cached_data) * self.repeat return len(self.cached_data) * self.repeat
else: else:
return len(self.data) * self.repeat return len(self.data) * self.repeat

View File

@@ -1 +1,2 @@
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name
from .npu_compatible_device import IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE

View File

@@ -1,5 +1,5 @@
import torch, glob, os import torch, glob, os
from typing import Optional, Union from typing import Optional, Union, Dict
from dataclasses import dataclass from dataclasses import dataclass
from modelscope import snapshot_download from modelscope import snapshot_download
from huggingface_hub import snapshot_download as hf_snapshot_download from huggingface_hub import snapshot_download as hf_snapshot_download
@@ -23,6 +23,7 @@ class ModelConfig:
computation_device: Optional[Union[str, torch.device]] = None computation_device: Optional[Union[str, torch.device]] = None
computation_dtype: Optional[torch.dtype] = None computation_dtype: Optional[torch.dtype] = None
clear_parameters: bool = False clear_parameters: bool = False
state_dict: Dict[str, torch.Tensor] = None
def check_input(self): def check_input(self):
if self.path is None and self.model_id is None: if self.path is None and self.model_id is None:
@@ -97,6 +98,7 @@ class ModelConfig:
self.reset_local_model_path() self.reset_local_model_path()
if self.require_downloading(): if self.require_downloading():
self.download() self.download()
if self.path is None:
if self.origin_file_pattern is None or self.origin_file_pattern == "": if self.origin_file_pattern is None or self.origin_file_pattern == "":
self.path = os.path.join(self.local_model_path, self.model_id) self.path = os.path.join(self.local_model_path, self.model_id)
else: else:

View File

@@ -2,16 +2,25 @@ from safetensors import safe_open
import torch, hashlib import torch, hashlib
def load_state_dict(file_path, torch_dtype=None, device="cpu"): def load_state_dict(file_path, torch_dtype=None, device="cpu", pin_memory=False, verbose=0):
if isinstance(file_path, list): if isinstance(file_path, list):
state_dict = {} state_dict = {}
for file_path_ in file_path: for file_path_ in file_path:
state_dict.update(load_state_dict(file_path_, torch_dtype, device)) state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose))
return state_dict
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
else: else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device) if verbose >= 1:
print(f"Loading file [started]: {file_path}")
if file_path.endswith(".safetensors"):
state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
else:
state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
# If load state dict in CPU memory, `pin_memory=True` will make `model.to("cuda")` faster.
if pin_memory:
for i in state_dict:
state_dict[i] = state_dict[i].pin_memory()
if verbose >= 1:
print(f"Loading file [done]: {file_path}")
return state_dict
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"): def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):

View File

@@ -5,7 +5,7 @@ from .file import load_state_dict
import torch import torch
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None): def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
config = {} if config is None else config config = {} if config is None else config
# Why do we use `skip_model_initialization`? # Why do we use `skip_model_initialization`?
# It skips the random initialization of model parameters, # It skips the random initialization of model parameters,
@@ -20,7 +20,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]] dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
dtype = [d for d in dtypes if d != "disk"][0] dtype = [d for d in dtypes if d != "disk"][0]
if vram_config["offload_device"] != "disk": if vram_config["offload_device"] != "disk":
state_dict = DiskMap(path, device, torch_dtype=dtype) if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype)
if state_dict_converter is not None: if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict) state_dict = state_dict_converter(state_dict)
else: else:
@@ -35,7 +35,9 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
# Sometimes a model file contains multiple models, # Sometimes a model file contains multiple models,
# and DiskMap can load only the parameters of a single model, # and DiskMap can load only the parameters of a single model,
# avoiding the need to load all parameters in the file. # avoiding the need to load all parameters in the file.
if use_disk_map: if state_dict is not None:
pass
elif use_disk_map:
state_dict = DiskMap(path, device, torch_dtype=torch_dtype) state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
else: else:
state_dict = load_state_dict(path, torch_dtype, device) state_dict = load_state_dict(path, torch_dtype, device)

View File

@@ -2,7 +2,7 @@ import torch, copy
from typing import Union from typing import Union
from .initialization import skip_model_initialization from .initialization import skip_model_initialization
from .disk_map import DiskMap from .disk_map import DiskMap
from ..device import parse_device_type from ..device import parse_device_type, get_device_name, IS_NPU_AVAILABLE
class AutoTorchModule(torch.nn.Module): class AutoTorchModule(torch.nn.Module):
@@ -63,7 +63,8 @@ class AutoTorchModule(torch.nn.Module):
return r return r
def check_free_vram(self): def check_free_vram(self):
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(self.computation_device) device = self.computation_device if not IS_NPU_AVAILABLE else get_device_name()
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3) used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
return used_memory < self.vram_limit return used_memory < self.vram_limit
@@ -309,6 +310,7 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
self.lora_B_weights = [] self.lora_B_weights = []
self.lora_merger = None self.lora_merger = None
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz] self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
self.computation_device_type = parse_device_type(self.computation_device)
if offload_dtype == "disk": if offload_dtype == "disk":
self.disk_map = disk_map self.disk_map = disk_map

View File

@@ -4,9 +4,11 @@ import numpy as np
from einops import repeat, reduce from einops import repeat, reduce
from typing import Union from typing import Union
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
from ..core.device.npu_compatible_device import get_device_type
from ..utils.lora import GeneralLoRALoader from ..utils.lora import GeneralLoRALoader
from ..models.model_loader import ModelPool from ..models.model_loader import ModelPool
from ..utils.controlnet import ControlNetInput from ..utils.controlnet import ControlNetInput
from ..core.device import get_device_name, IS_NPU_AVAILABLE
class PipelineUnit: class PipelineUnit:
@@ -60,7 +62,7 @@ class BasePipeline(torch.nn.Module):
def __init__( def __init__(
self, self,
device="cuda", torch_dtype=torch.float16, device=get_device_type(), torch_dtype=torch.float16,
height_division_factor=64, width_division_factor=64, height_division_factor=64, width_division_factor=64,
time_division_factor=None, time_division_remainder=None, time_division_factor=None, time_division_remainder=None,
): ):
@@ -177,7 +179,8 @@ class BasePipeline(torch.nn.Module):
def get_vram(self): def get_vram(self):
return getattr(torch, self.device_type).mem_get_info(self.device)[1] / (1024 ** 3) device = self.device if not IS_NPU_AVAILABLE else get_device_name()
return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)
def get_module(self, model, name): def get_module(self, model, name):
if "." in name: if "." in name:
@@ -234,6 +237,7 @@ class BasePipeline(torch.nn.Module):
alpha=1, alpha=1,
hotload=None, hotload=None,
state_dict=None, state_dict=None,
verbose=1,
): ):
if state_dict is None: if state_dict is None:
if isinstance(lora_config, str): if isinstance(lora_config, str):
@@ -260,12 +264,13 @@ class BasePipeline(torch.nn.Module):
updated_num += 1 updated_num += 1
module.lora_A_weights.append(lora[lora_a_name] * alpha) module.lora_A_weights.append(lora[lora_a_name] * alpha)
module.lora_B_weights.append(lora[lora_b_name]) module.lora_B_weights.append(lora[lora_b_name])
print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.") if verbose >= 1:
print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.")
else: else:
lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha) lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha)
def clear_lora(self): def clear_lora(self, verbose=1):
cleared_num = 0 cleared_num = 0
for name, module in self.named_modules(): for name, module in self.named_modules():
if isinstance(module, AutoWrappedLinear): if isinstance(module, AutoWrappedLinear):
@@ -275,7 +280,8 @@ class BasePipeline(torch.nn.Module):
module.lora_A_weights.clear() module.lora_A_weights.clear()
if hasattr(module, "lora_B_weights"): if hasattr(module, "lora_B_weights"):
module.lora_B_weights.clear() module.lora_B_weights.clear()
print(f"{cleared_num} LoRA layers are cleared.") if verbose >= 1:
print(f"{cleared_num} LoRA layers are cleared.")
def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None): def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None):
@@ -290,6 +296,7 @@ class BasePipeline(torch.nn.Module):
vram_config=vram_config, vram_config=vram_config,
vram_limit=vram_limit, vram_limit=vram_limit,
clear_parameters=model_config.clear_parameters, clear_parameters=model_config.clear_parameters,
state_dict=model_config.state_dict,
) )
return model_pool return model_pool
@@ -303,8 +310,13 @@ class BasePipeline(torch.nn.Module):
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others): def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others) noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
if cfg_scale != 1.0: if cfg_scale != 1.0:
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others) noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else: else:

View File

@@ -89,13 +89,18 @@ class FlowMatchScheduler():
return float(mu) return float(mu)
@staticmethod @staticmethod
def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=1024//16*1024//16): def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None):
sigma_min = 1 / num_inference_steps sigma_min = 1 / num_inference_steps
sigma_max = 1.0 sigma_max = 1.0
num_train_timesteps = 1000 num_train_timesteps = 1000
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps) sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps) if dynamic_shift_len is None:
# If you ask me why I set mu=0.8,
# I can only say that it yields better training results.
mu = 0.8
else:
mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps)
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1)) sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
timesteps = sigmas * num_train_timesteps timesteps = sigmas * num_train_timesteps
return sigmas, timesteps return sigmas, timesteps

View File

@@ -10,7 +10,7 @@ class ModelLogger:
self.num_steps = 0 self.num_steps = 0
def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None): def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None, **kwargs):
self.num_steps += 1 self.num_steps += 1
if save_steps is not None and self.num_steps % save_steps == 0: if save_steps is not None and self.num_steps % save_steps == 0:
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors") self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")

View File

@@ -40,7 +40,7 @@ def launch_training_task(
loss = model(data) loss = model(data)
accelerator.backward(loss) accelerator.backward(loss)
optimizer.step() optimizer.step()
model_logger.on_step_end(accelerator, model, save_steps) model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
scheduler.step() scheduler.step()
if save_steps is None: if save_steps is None:
model_logger.on_epoch_end(accelerator, model, epoch_id) model_logger.on_epoch_end(accelerator, model, epoch_id)

View File

@@ -1,4 +1,4 @@
import torch, json import torch, json, os
from ..core import ModelConfig, load_state_dict from ..core import ModelConfig, load_state_dict
from ..utils.controlnet import ControlNetInput from ..utils.controlnet import ControlNetInput
from peft import LoraConfig, inject_adapter_in_model from peft import LoraConfig, inject_adapter_in_model
@@ -127,16 +127,67 @@ class DiffusionTrainingModule(torch.nn.Module):
if model_id_with_origin_paths is not None: if model_id_with_origin_paths is not None:
model_id_with_origin_paths = model_id_with_origin_paths.split(",") model_id_with_origin_paths = model_id_with_origin_paths.split(",")
for model_id_with_origin_path in model_id_with_origin_paths: for model_id_with_origin_path in model_id_with_origin_paths:
model_id, origin_file_pattern = model_id_with_origin_path.split(":")
vram_config = self.parse_vram_config( vram_config = self.parse_vram_config(
fp8=model_id_with_origin_path in fp8_models, fp8=model_id_with_origin_path in fp8_models,
offload=model_id_with_origin_path in offload_models, offload=model_id_with_origin_path in offload_models,
device=device device=device
) )
model_configs.append(ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern, **vram_config)) config = self.parse_path_or_model_id(model_id_with_origin_path)
model_configs.append(ModelConfig(model_id=config.model_id, origin_file_pattern=config.origin_file_pattern, **vram_config))
return model_configs return model_configs
def parse_path_or_model_id(self, model_id_with_origin_path, default_value=None):
if model_id_with_origin_path is None:
return default_value
elif os.path.exists(model_id_with_origin_path):
return ModelConfig(path=model_id_with_origin_path)
else:
if ":" not in model_id_with_origin_path:
raise ValueError(f"Failed to parse model config: {model_id_with_origin_path}. This is neither a valid path nor in the format of `model_id/origin_file_pattern`.")
split_id = model_id_with_origin_path.rfind(":")
model_id = model_id_with_origin_path[:split_id]
origin_file_pattern = model_id_with_origin_path[split_id + 1:]
return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern)
def auto_detect_lora_target_modules(
self,
model: torch.nn.Module,
search_for_linear=False,
linear_detector=lambda x: min(x.weight.shape) >= 512,
block_list_detector=lambda x: isinstance(x, torch.nn.ModuleList) and len(x) > 1,
name_prefix="",
):
lora_target_modules = []
if search_for_linear:
for name, module in model.named_modules():
module_name = name_prefix + ["", "."][name_prefix != ""] + name
if isinstance(module, torch.nn.Linear) and linear_detector(module):
lora_target_modules.append(module_name)
else:
for name, module in model.named_children():
module_name = name_prefix + ["", "."][name_prefix != ""] + name
lora_target_modules += self.auto_detect_lora_target_modules(
module,
search_for_linear=block_list_detector(module),
linear_detector=linear_detector,
block_list_detector=block_list_detector,
name_prefix=module_name,
)
return lora_target_modules
def parse_lora_target_modules(self, model, lora_target_modules):
if lora_target_modules == "":
print("No LoRA target modules specified. The framework will automatically search for them.")
lora_target_modules = self.auto_detect_lora_target_modules(model)
print(f"LoRA will be patched at {lora_target_modules}.")
else:
lora_target_modules = lora_target_modules.split(",")
return lora_target_modules
def switch_pipe_to_training_mode( def switch_pipe_to_training_mode(
self, self,
pipe, pipe,
@@ -166,7 +217,7 @@ class DiffusionTrainingModule(torch.nn.Module):
return return
model = self.add_lora_to_model( model = self.add_lora_to_model(
getattr(pipe, lora_base_model), getattr(pipe, lora_base_model),
target_modules=lora_target_modules.split(","), target_modules=self.parse_lora_target_modules(getattr(pipe, lora_base_model), lora_target_modules),
lora_rank=lora_rank, lora_rank=lora_rank,
upcast_dtype=pipe.torch_dtype, upcast_dtype=pipe.torch_dtype,
) )

View File

@@ -2,6 +2,8 @@ from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
import torch import torch
from ..core.device.npu_compatible_device import get_device_type
class DINOv3ImageEncoder(DINOv3ViTModel): class DINOv3ImageEncoder(DINOv3ViTModel):
def __init__(self): def __init__(self):
@@ -70,7 +72,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel):
} }
) )
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
inputs = self.processor(images=image, return_tensors="pt") inputs = self.processor(images=image, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device) pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device)
bool_masked_pos = None bool_masked_pos = None

View File

@@ -823,7 +823,13 @@ class Flux2PosEmbed(nn.Module):
class Flux2TimestepGuidanceEmbeddings(nn.Module): class Flux2TimestepGuidanceEmbeddings(nn.Module):
def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False): def __init__(
self,
in_channels: int = 256,
embedding_dim: int = 6144,
bias: bool = False,
guidance_embeds: bool = True,
):
super().__init__() super().__init__()
self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0) self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
@@ -831,20 +837,24 @@ class Flux2TimestepGuidanceEmbeddings(nn.Module):
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
) )
self.guidance_embedder = TimestepEmbedding( if guidance_embeds:
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias self.guidance_embedder = TimestepEmbedding(
) in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
)
else:
self.guidance_embedder = None
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor: def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
timesteps_proj = self.time_proj(timestep) timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D) timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
guidance_proj = self.time_proj(guidance) if guidance is not None and self.guidance_embedder is not None:
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D) guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
time_guidance_emb = timesteps_emb + guidance_emb time_guidance_emb = timesteps_emb + guidance_emb
return time_guidance_emb
return time_guidance_emb else:
return timesteps_emb
class Flux2Modulation(nn.Module): class Flux2Modulation(nn.Module):
@@ -882,6 +892,7 @@ class Flux2DiT(torch.nn.Module):
axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32), axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
rope_theta: int = 2000, rope_theta: int = 2000,
eps: float = 1e-6, eps: float = 1e-6,
guidance_embeds: bool = True,
): ):
super().__init__() super().__init__()
self.out_channels = out_channels or in_channels self.out_channels = out_channels or in_channels
@@ -892,7 +903,10 @@ class Flux2DiT(torch.nn.Module):
# 2. Combined timestep + guidance embedding # 2. Combined timestep + guidance embedding
self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False in_channels=timestep_guidance_channels,
embedding_dim=self.inner_dim,
bias=False,
guidance_embeds=guidance_embeds,
) )
# 3. Modulation (double stream and single stream blocks share modulation parameters, resp.) # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
@@ -953,34 +967,9 @@ class Flux2DiT(torch.nn.Module):
txt_ids: torch.Tensor = None, txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None, guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
use_gradient_checkpointing=False, use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False, use_gradient_checkpointing_offload=False,
) -> Union[torch.Tensor]: ):
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
Input `hidden_states`.
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
# 0. Handle input arguments # 0. Handle input arguments
if joint_attention_kwargs is not None: if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy() joint_attention_kwargs = joint_attention_kwargs.copy()
@@ -992,7 +981,9 @@ class Flux2DiT(torch.nn.Module):
# 1. Calculate timestep embedding and modulation parameters # 1. Calculate timestep embedding and modulation parameters
timestep = timestep.to(hidden_states.dtype) * 1000 timestep = timestep.to(hidden_states.dtype) * 1000
guidance = guidance.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
temb = self.time_guidance_embed(timestep, guidance) temb = self.time_guidance_embed(timestep, guidance)

View File

@@ -19,7 +19,7 @@ def get_timestep_embedding(
) )
exponent = exponent / (half_dim - downscale_freq_shift) exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent).to(timesteps.device) emb = torch.exp(exponent)
if align_dtype_to_timestep: if align_dtype_to_timestep:
emb = emb.to(timesteps.dtype) emb = emb.to(timesteps.dtype)
emb = timesteps[:, None].float() * emb[None, :] emb = timesteps[:, None].float() * emb[None, :]
@@ -78,7 +78,7 @@ class DiffusersCompatibleTimestepProj(torch.nn.Module):
class TimestepEmbeddings(torch.nn.Module): class TimestepEmbeddings(torch.nn.Module):
def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False): def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False, use_additional_t_cond=False):
super().__init__() super().__init__()
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep) self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep)
if diffusers_compatible_format: if diffusers_compatible_format:
@@ -87,10 +87,17 @@ class TimestepEmbeddings(torch.nn.Module):
self.timestep_embedder = torch.nn.Sequential( self.timestep_embedder = torch.nn.Sequential(
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out) torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
) )
self.use_additional_t_cond = use_additional_t_cond
if use_additional_t_cond:
self.addition_t_embedding = torch.nn.Embedding(2, dim_out)
def forward(self, timestep, dtype): def forward(self, timestep, dtype, addition_t_cond=None):
time_emb = self.time_proj(timestep).to(dtype) time_emb = self.time_proj(timestep).to(dtype)
time_emb = self.timestep_embedder(time_emb) time_emb = self.timestep_embedder(time_emb)
if addition_t_cond is not None:
addition_t_emb = self.addition_t_embedding(addition_t_cond)
addition_t_emb = addition_t_emb.to(dtype=dtype)
time_emb = time_emb + addition_t_emb
return time_emb return time_emb

View File

@@ -9,6 +9,7 @@ import numpy as np
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from .wan_video_dit import flash_attention from .wan_video_dit import flash_attention
from ..core.device.npu_compatible_device import get_device_type
from ..core.gradient import gradient_checkpoint_forward from ..core.gradient import gradient_checkpoint_forward
@@ -373,7 +374,7 @@ class FinalLayer_FP32(nn.Module):
B, N, C = x.shape B, N, C = x.shape
T, _, _ = latent_shape T, _, _ = latent_shape
with amp.autocast('cuda', dtype=torch.float32): with amp.autocast(get_device_type(), dtype=torch.float32):
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C] shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C) x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
x = self.linear(x) x = self.linear(x)
@@ -583,7 +584,7 @@ class LongCatSingleStreamBlock(nn.Module):
T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W. T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.
# compute modulation params in fp32 # compute modulation params in fp32
with amp.autocast(device_type='cuda', dtype=torch.float32): with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
shift_msa, scale_msa, gate_msa, \ shift_msa, scale_msa, gate_msa, \
shift_mlp, scale_mlp, gate_mlp = \ shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C] self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
@@ -602,7 +603,7 @@ class LongCatSingleStreamBlock(nn.Module):
else: else:
x_s = attn_outputs x_s = attn_outputs
with amp.autocast(device_type='cuda', dtype=torch.float32): with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype) x = x.to(x_dtype)
@@ -615,7 +616,7 @@ class LongCatSingleStreamBlock(nn.Module):
# ffn with modulation # ffn with modulation
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C) x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
x_s = self.ffn(x_m) x_s = self.ffn(x_m)
with amp.autocast(device_type='cuda', dtype=torch.float32): with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype) x = x.to(x_dtype)
@@ -797,7 +798,7 @@ class LongCatVideoTransformer3DModel(torch.nn.Module):
hidden_states = self.x_embedder(hidden_states) # [B, N, C] hidden_states = self.x_embedder(hidden_states) # [B, N, C]
with amp.autocast(device_type='cuda', dtype=torch.float32): with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t] t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C] encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]

View File

@@ -29,7 +29,7 @@ class ModelPool:
module_map = None module_map = None
return module_map return module_map
def load_model_file(self, config, path, vram_config, vram_limit=None): def load_model_file(self, config, path, vram_config, vram_limit=None, state_dict=None):
model_class = self.import_model_class(config["model_class"]) model_class = self.import_model_class(config["model_class"])
model_config = config.get("extra_kwargs", {}) model_config = config.get("extra_kwargs", {})
if "state_dict_converter" in config: if "state_dict_converter" in config:
@@ -43,6 +43,7 @@ class ModelPool:
state_dict_converter, state_dict_converter,
use_disk_map=True, use_disk_map=True,
vram_config=vram_config, module_map=module_map, vram_limit=vram_limit, vram_config=vram_config, module_map=module_map, vram_limit=vram_limit,
state_dict=state_dict,
) )
return model return model
@@ -59,7 +60,7 @@ class ModelPool:
} }
return vram_config return vram_config
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False): def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False, state_dict=None):
print(f"Loading models from: {json.dumps(path, indent=4)}") print(f"Loading models from: {json.dumps(path, indent=4)}")
if vram_config is None: if vram_config is None:
vram_config = self.default_vram_config() vram_config = self.default_vram_config()
@@ -67,7 +68,7 @@ class ModelPool:
loaded = False loaded = False
for config in MODEL_CONFIGS: for config in MODEL_CONFIGS:
if config["model_hash"] == model_hash: if config["model_hash"] == model_hash:
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit) model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit, state_dict=state_dict)
if clear_parameters: self.clear_parameters(model) if clear_parameters: self.clear_parameters(model)
self.model.append(model) self.model.append(model)
model_name = config["model_name"] model_name = config["model_name"]

View File

@@ -583,7 +583,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache
is_compileable = is_compileable and not self.generation_config.disable_compile is_compileable = is_compileable and not self.generation_config.disable_compile
if is_compileable and ( if is_compileable and (
self.device.type == "cuda" or generation_config.compile_config._compile_all_devices self.device.type in ["cuda", "npu"] or generation_config.compile_config._compile_all_devices
): ):
os.environ["TOKENIZERS_PARALLELISM"] = "0" os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = self.get_compiled_call(generation_config.compile_config) model_forward = self.get_compiled_call(generation_config.compile_config)

View File

@@ -1,4 +1,4 @@
import torch, math import torch, math, functools
import torch.nn as nn import torch.nn as nn
from typing import Tuple, Optional, Union, List from typing import Tuple, Optional, Union, List
from einops import rearrange from einops import rearrange
@@ -225,6 +225,121 @@ class QwenEmbedRope(nn.Module):
return vid_freqs, txt_freqs return vid_freqs, txt_freqs
class QwenEmbedLayer3DRope(nn.Module):
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
pos_index = torch.arange(4096)
neg_index = torch.arange(4096).flip(0) * -1 - 1
self.pos_freqs = torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
)
self.neg_freqs = torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
],
dim=1,
)
self.scale_rope = scale_rope
def rope_params(self, index, dim, theta=10000):
"""
Args:
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
"""
assert dim % 2 == 0
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
def forward(self, video_fhw, txt_seq_lens, device):
"""
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
txt_length: [bs] a list of 1 integers representing the length of the text
"""
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
video_fhw = [video_fhw]
if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
if not isinstance(video_fhw, list):
video_fhw = [video_fhw]
vid_freqs = []
max_vid_index = 0
layer_num = len(video_fhw) - 1
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
if idx != layer_num:
video_freq = self._compute_video_freqs(frame, height, width, idx)
else:
### For the condition image, we set the layer index to -1
video_freq = self._compute_condition_freqs(frame, height, width)
video_freq = video_freq.to(device)
vid_freqs.append(video_freq)
if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else:
max_vid_index = max(height, width, max_vid_index)
max_vid_index = max(max_vid_index, layer_num)
max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
@functools.lru_cache(maxsize=None)
def _compute_condition_freqs(self, frame, height, width):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
class QwenFeedForward(nn.Module): class QwenFeedForward(nn.Module):
def __init__( def __init__(
self, self,
@@ -352,9 +467,38 @@ class QwenImageTransformerBlock(nn.Module):
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim) self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim)
def _modulate(self, x, mod_params): def _modulate(self, x, mod_params, index=None):
shift, scale, gate = mod_params.chunk(3, dim=-1) shift, scale, gate = mod_params.chunk(3, dim=-1)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) if index is not None:
# Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts)
# So shift, scale, gate have shape [2*actual_batch, d]
actual_batch = shift.size(0) // 2
shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d]
scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]
gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]
# index: [b, l] where b is actual batch size
# Expand to [b, l, 1] to match feature dimension
index_expanded = index.unsqueeze(-1) # [b, l, 1]
# Expand chunks to [b, 1, d] then broadcast to [b, l, d]
shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d]
shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d]
scale_0_exp = scale_0.unsqueeze(1)
scale_1_exp = scale_1.unsqueeze(1)
gate_0_exp = gate_0.unsqueeze(1)
gate_1_exp = gate_1.unsqueeze(1)
# Use torch.where to select based on index
shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp)
scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp)
gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp)
else:
shift_result = shift.unsqueeze(1)
scale_result = scale.unsqueeze(1)
gate_result = gate.unsqueeze(1)
return x * (1 + scale_result) + shift_result, gate_result
def forward( def forward(
self, self,
@@ -364,13 +508,16 @@ class QwenImageTransformerBlock(nn.Module):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
enable_fp8_attention = False, enable_fp8_attention = False,
modulate_index: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
if modulate_index is not None:
temb = torch.chunk(temb, 2, dim=0)[0]
txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
img_normed = self.img_norm1(image) img_normed = self.img_norm1(image)
img_modulated, img_gate = self._modulate(img_normed, img_mod_attn) img_modulated, img_gate = self._modulate(img_normed, img_mod_attn, index=modulate_index)
txt_normed = self.txt_norm1(text) txt_normed = self.txt_norm1(text)
txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn) txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn)
@@ -387,7 +534,7 @@ class QwenImageTransformerBlock(nn.Module):
text = text + txt_gate * txt_attn_out text = text + txt_gate * txt_attn_out
img_normed_2 = self.img_norm2(image) img_normed_2 = self.img_norm2(image)
img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp) img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp, index=modulate_index)
txt_normed_2 = self.txt_norm2(text) txt_normed_2 = self.txt_norm2(text)
txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp) txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp)
@@ -405,12 +552,17 @@ class QwenImageDiT(torch.nn.Module):
def __init__( def __init__(
self, self,
num_layers: int = 60, num_layers: int = 60,
use_layer3d_rope: bool = False,
use_additional_t_cond: bool = False,
): ):
super().__init__() super().__init__()
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True) if not use_layer3d_rope:
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
else:
self.pos_embed = QwenEmbedLayer3DRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True) self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=False, use_additional_t_cond=use_additional_t_cond)
self.txt_norm = RMSNorm(3584, eps=1e-6) self.txt_norm = RMSNorm(3584, eps=1e-6)
self.img_in = nn.Linear(64, 3072) self.img_in = nn.Linear(64, 3072)

View File

@@ -366,6 +366,7 @@ class QwenImageEncoder3d(nn.Module):
temperal_downsample=[True, True, False], temperal_downsample=[True, True, False],
dropout=0.0, dropout=0.0,
non_linearity: str = "silu", non_linearity: str = "silu",
image_channels=3
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@@ -381,7 +382,7 @@ class QwenImageEncoder3d(nn.Module):
scale = 1.0 scale = 1.0
# init block # init block
self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1) self.conv_in = QwenImageCausalConv3d(image_channels, dims[0], 3, padding=1)
# downsample blocks # downsample blocks
self.down_blocks = torch.nn.ModuleList([]) self.down_blocks = torch.nn.ModuleList([])
@@ -544,6 +545,7 @@ class QwenImageDecoder3d(nn.Module):
temperal_upsample=[False, True, True], temperal_upsample=[False, True, True],
dropout=0.0, dropout=0.0,
non_linearity: str = "silu", non_linearity: str = "silu",
image_channels=3,
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@@ -594,7 +596,7 @@ class QwenImageDecoder3d(nn.Module):
# output blocks # output blocks
self.norm_out = QwenImageRMS_norm(out_dim, images=False) self.norm_out = QwenImageRMS_norm(out_dim, images=False)
self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1) self.conv_out = QwenImageCausalConv3d(out_dim, image_channels, 3, padding=1)
self.gradient_checkpointing = False self.gradient_checkpointing = False
@@ -647,6 +649,7 @@ class QwenImageVAE(torch.nn.Module):
attn_scales: List[float] = [], attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True], temperal_downsample: List[bool] = [False, True, True],
dropout: float = 0.0, dropout: float = 0.0,
image_channels: int = 3,
) -> None: ) -> None:
super().__init__() super().__init__()
@@ -655,13 +658,13 @@ class QwenImageVAE(torch.nn.Module):
self.temperal_upsample = temperal_downsample[::-1] self.temperal_upsample = temperal_downsample[::-1]
self.encoder = QwenImageEncoder3d( self.encoder = QwenImageEncoder3d(
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, image_channels=image_channels,
) )
self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1) self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1) self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
self.decoder = QwenImageDecoder3d( self.decoder = QwenImageDecoder3d(
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, image_channels=image_channels,
) )
mean = [ mean = [

View File

@@ -1,7 +1,9 @@
from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig
from transformers import SiglipImageProcessor from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
import torch import torch
from diffsynth.core.device.npu_compatible_device import get_device_type
class Siglip2ImageEncoder(SiglipVisionTransformer): class Siglip2ImageEncoder(SiglipVisionTransformer):
def __init__(self): def __init__(self):
@@ -47,7 +49,7 @@ class Siglip2ImageEncoder(SiglipVisionTransformer):
} }
) )
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"] pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"]
pixel_values = pixel_values.to(device=device, dtype=torch_dtype) pixel_values = pixel_values.to(device=device, dtype=torch_dtype)
output_attentions = False output_attentions = False
@@ -68,3 +70,65 @@ class Siglip2ImageEncoder(SiglipVisionTransformer):
pooler_output = self.head(last_hidden_state) if self.use_head else None pooler_output = self.head(last_hidden_state) if self.use_head else None
return pooler_output return pooler_output
class Siglip2ImageEncoder428M(Siglip2VisionModel):
def __init__(self):
config = Siglip2VisionConfig(
attention_dropout = 0.0,
dtype = "bfloat16",
hidden_act = "gelu_pytorch_tanh",
hidden_size = 1152,
intermediate_size = 4304,
layer_norm_eps = 1e-06,
model_type = "siglip2_vision_model",
num_attention_heads = 16,
num_channels = 3,
num_hidden_layers = 27,
num_patches = 256,
patch_size = 16,
transformers_version = "4.57.1"
)
super().__init__(config)
self.processor = Siglip2ImageProcessorFast(
**{
"data_format": "channels_first",
"default_to_square": True,
"device": None,
"disable_grouping": None,
"do_convert_rgb": None,
"do_normalize": True,
"do_pad": None,
"do_rescale": True,
"do_resize": True,
"image_mean": [
0.5,
0.5,
0.5
],
"image_processor_type": "Siglip2ImageProcessorFast",
"image_std": [
0.5,
0.5,
0.5
],
"input_data_format": None,
"max_num_patches": 256,
"pad_size": None,
"patch_size": 16,
"processor_class": "Siglip2Processor",
"resample": 2,
"rescale_factor": 0.00392156862745098,
"return_tensors": None,
}
)
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
siglip_inputs = self.processor(images=[image], return_tensors="pt").to(device)
shape = siglip_inputs.spatial_shapes[0]
hidden_state = super().forward(**siglip_inputs).last_hidden_state
B, N, C = hidden_state.shape
hidden_state = hidden_state[:, : shape[0] * shape[1]]
hidden_state = hidden_state.view(shape[0], shape[1], C)
hidden_state = hidden_state.to(torch_dtype)
return hidden_state

View File

@@ -1,10 +1,11 @@
import torch import torch
from typing import Optional, Union from typing import Optional, Union
from .qwen_image_text_encoder import QwenImageTextEncoder from .qwen_image_text_encoder import QwenImageTextEncoder
from ..core.device.npu_compatible_device import get_device_type, get_torch_device
class Step1xEditEmbedder(torch.nn.Module): class Step1xEditEmbedder(torch.nn.Module):
def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device="cuda"): def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device=get_device_type()):
super().__init__() super().__init__()
self.max_length = max_length self.max_length = max_length
self.dtype = dtype self.dtype = dtype
@@ -77,13 +78,13 @@ User Prompt:'''
self.max_length, self.max_length,
self.model.config.hidden_size, self.model.config.hidden_size,
dtype=torch.bfloat16, dtype=torch.bfloat16,
device=torch.cuda.current_device(), device=get_torch_device().current_device(),
) )
masks = torch.zeros( masks = torch.zeros(
len(text_list), len(text_list),
self.max_length, self.max_length,
dtype=torch.long, dtype=torch.long,
device=torch.cuda.current_device(), device=get_torch_device().current_device(),
) )
def split_string(s): def split_string(s):
@@ -158,7 +159,7 @@ User Prompt:'''
else: else:
token_list.append(token_each) token_list.append(token_each)
new_txt_ids = torch.cat(token_list, dim=1).to("cuda") new_txt_ids = torch.cat(token_list, dim=1).to(get_device_type())
new_txt_ids = new_txt_ids.to(old_inputs_ids.device) new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
@@ -167,15 +168,15 @@ User Prompt:'''
inputs.input_ids = ( inputs.input_ids = (
torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0) torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
.unsqueeze(0) .unsqueeze(0)
.to("cuda") .to(get_device_type())
) )
inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda") inputs.attention_mask = (inputs.input_ids > 0).long().to(get_device_type())
outputs = self.model_forward( outputs = self.model_forward(
self.model, self.model,
input_ids=inputs.input_ids, input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask, attention_mask=inputs.attention_mask,
pixel_values=inputs.pixel_values.to("cuda"), pixel_values=inputs.pixel_values.to(get_device_type()),
image_grid_thw=inputs.image_grid_thw.to("cuda"), image_grid_thw=inputs.image_grid_thw.to(get_device_type()),
output_hidden_states=True, output_hidden_states=True,
) )
@@ -188,7 +189,7 @@ User Prompt:'''
masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones( masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
(min(self.max_length, emb.shape[1] - 217)), (min(self.max_length, emb.shape[1] - 217)),
dtype=torch.long, dtype=torch.long,
device=torch.cuda.current_device(), device=get_torch_device().current_device(),
) )
return embs, masks return embs, masks

View File

@@ -5,6 +5,7 @@ import math
from typing import Tuple, Optional from typing import Tuple, Optional
from einops import rearrange from einops import rearrange
from .wan_video_camera_controller import SimpleAdapter from .wan_video_camera_controller import SimpleAdapter
try: try:
import flash_attn_interface import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True FLASH_ATTN_3_AVAILABLE = True
@@ -92,6 +93,7 @@ def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
x_out = torch.view_as_complex(x.to(torch.float64).reshape( x_out = torch.view_as_complex(x.to(torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2)) x.shape[0], x.shape[1], x.shape[2], -1, 2))
freqs = freqs.to(torch.complex64) if freqs.device == "npu" else freqs
x_out = torch.view_as_real(x_out * freqs).flatten(2) x_out = torch.view_as_real(x_out * freqs).flatten(2)
return x_out.to(x.dtype) return x_out.to(x.dtype)

View File

@@ -0,0 +1,154 @@
from .z_image_dit import ZImageTransformerBlock
from ..core.gradient import gradient_checkpoint_forward
from torch.nn.utils.rnn import pad_sequence
import torch
from torch import nn
class ZImageControlTransformerBlock(ZImageTransformerBlock):
def __init__(
self,
layer_id: int = 1000,
dim: int = 3840,
n_heads: int = 30,
n_kv_heads: int = 30,
norm_eps: float = 1e-5,
qk_norm: bool = True,
modulation = True,
block_id = 0
):
super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation)
self.block_id = block_id
if block_id == 0:
self.before_proj = nn.Linear(self.dim, self.dim)
self.after_proj = nn.Linear(self.dim, self.dim)
def forward(self, c, x, **kwargs):
if self.block_id == 0:
c = self.before_proj(c) + x
all_c = []
else:
all_c = list(torch.unbind(c))
c = all_c.pop(-1)
c = super().forward(c, **kwargs)
c_skip = self.after_proj(c)
all_c += [c_skip, c]
c = torch.stack(all_c)
return c
class ZImageControlNet(torch.nn.Module):
def __init__(
self,
control_layers_places=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
control_in_dim=33,
dim=3840,
n_refiner_layers=2,
):
super().__init__()
self.control_layers = nn.ModuleList([ZImageControlTransformerBlock(layer_id=i, block_id=i) for i in control_layers_places])
self.control_all_x_embedder = nn.ModuleDict({"2-1": nn.Linear(1 * 2 * 2 * control_in_dim, dim, bias=True)})
self.control_noise_refiner = nn.ModuleList([ZImageControlTransformerBlock(block_id=layer_id) for layer_id in range(n_refiner_layers)])
self.control_layers_mapping = {0: 0, 2: 1, 4: 2, 6: 3, 8: 4, 10: 5, 12: 6, 14: 7, 16: 8, 18: 9, 20: 10, 22: 11, 24: 12, 26: 13, 28: 14}
def forward_layers(
self,
x,
cap_feats,
control_context,
control_context_item_seqlens,
kwargs,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
):
bsz = len(control_context)
# unified
cap_item_seqlens = [len(_) for _ in cap_feats]
control_context_unified = []
for i in range(bsz):
control_context_len = control_context_item_seqlens[i]
cap_len = cap_item_seqlens[i]
control_context_unified.append(torch.cat([control_context[i][:control_context_len], cap_feats[i][:cap_len]]))
c = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0)
# arguments
new_kwargs = dict(x=x)
new_kwargs.update(kwargs)
for layer in self.control_layers:
c = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
c=c, **new_kwargs
)
hints = torch.unbind(c)[:-1]
return hints
def forward_refiner(
self,
dit,
x,
cap_feats,
control_context,
kwargs,
t=None,
patch_size=2,
f_patch_size=1,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
):
# embeddings
bsz = len(control_context)
device = control_context[0].device
(
control_context,
control_context_size,
control_context_pos_ids,
control_context_inner_pad_mask,
) = dit.patchify_controlnet(control_context, patch_size, f_patch_size, cap_feats[0].size(0))
# control_context embed & refine
control_context_item_seqlens = [len(_) for _ in control_context]
assert all(_ % 2 == 0 for _ in control_context_item_seqlens)
control_context_max_item_seqlen = max(control_context_item_seqlens)
control_context = torch.cat(control_context, dim=0)
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context)
# Match t_embedder output dtype to control_context for layerwise casting compatibility
adaln_input = t.type_as(control_context)
control_context[torch.cat(control_context_inner_pad_mask)] = dit.x_pad_token.to(dtype=control_context.dtype, device=control_context.device)
control_context = list(control_context.split(control_context_item_seqlens, dim=0))
control_context_freqs_cis = list(dit.rope_embedder(torch.cat(control_context_pos_ids, dim=0)).split(control_context_item_seqlens, dim=0))
control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0)
control_context_freqs_cis = pad_sequence(control_context_freqs_cis, batch_first=True, padding_value=0.0)
control_context_attn_mask = torch.zeros((bsz, control_context_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(control_context_item_seqlens):
control_context_attn_mask[i, :seq_len] = 1
c = control_context
# arguments
new_kwargs = dict(
x=x,
attn_mask=control_context_attn_mask,
freqs_cis=control_context_freqs_cis,
adaln_input=adaln_input,
)
new_kwargs.update(kwargs)
for layer in self.control_noise_refiner:
c = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
c=c, **new_kwargs
)
hints = torch.unbind(c)[:-1]
control_context = torch.unbind(c)[-1]
return hints, control_context, control_context_item_seqlens

View File

@@ -6,13 +6,15 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from torch.nn import RMSNorm from .general_modules import RMSNorm
from ..core.attention import attention_forward from ..core.attention import attention_forward
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type
from ..core.gradient import gradient_checkpoint_forward from ..core.gradient import gradient_checkpoint_forward
ADALN_EMBED_DIM = 256 ADALN_EMBED_DIM = 256
SEQ_MULTI_OF = 32 SEQ_MULTI_OF = 32
X_PAD_DIM = 64
class TimestepEmbedder(nn.Module): class TimestepEmbedder(nn.Module):
@@ -38,7 +40,7 @@ class TimestepEmbedder(nn.Module):
@staticmethod @staticmethod
def timestep_embedding(t, dim, max_period=10000): def timestep_embedding(t, dim, max_period=10000):
with torch.amp.autocast("cuda", enabled=False): with torch.amp.autocast(get_device_type(), enabled=False):
half = dim // 2 half = dim // 2
freqs = torch.exp( freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
@@ -86,7 +88,7 @@ class Attention(torch.nn.Module):
self.norm_q = RMSNorm(head_dim, eps=1e-5) self.norm_q = RMSNorm(head_dim, eps=1e-5)
self.norm_k = RMSNorm(head_dim, eps=1e-5) self.norm_k = RMSNorm(head_dim, eps=1e-5)
def forward(self, hidden_states, freqs_cis): def forward(self, hidden_states, freqs_cis, attention_mask):
query = self.to_q(hidden_states) query = self.to_q(hidden_states)
key = self.to_k(hidden_states) key = self.to_k(hidden_states)
value = self.to_v(hidden_states) value = self.to_v(hidden_states)
@@ -103,7 +105,7 @@ class Attention(torch.nn.Module):
# Apply RoPE # Apply RoPE
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast("cuda", enabled=False): with torch.amp.autocast(get_device_type(), enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2) freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3) x_out = torch.view_as_real(x * freqs_cis).flatten(3)
@@ -123,6 +125,7 @@ class Attention(torch.nn.Module):
key, key,
value, value,
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d", q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
attn_mask=attention_mask,
) )
# Reshape back # Reshape back
@@ -136,6 +139,20 @@ class Attention(torch.nn.Module):
return output return output
def select_per_token(
value_noisy: torch.Tensor,
value_clean: torch.Tensor,
noise_mask: torch.Tensor,
seq_len: int,
) -> torch.Tensor:
noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1)
return torch.where(
noise_mask_expanded == 1,
value_noisy.unsqueeze(1).expand(-1, seq_len, -1),
value_clean.unsqueeze(1).expand(-1, seq_len, -1),
)
class ZImageTransformerBlock(nn.Module): class ZImageTransformerBlock(nn.Module):
def __init__( def __init__(
self, self,
@@ -180,40 +197,53 @@ class ZImageTransformerBlock(nn.Module):
attn_mask: torch.Tensor, attn_mask: torch.Tensor,
freqs_cis: torch.Tensor, freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor] = None, adaln_input: Optional[torch.Tensor] = None,
noise_mask: Optional[torch.Tensor] = None,
adaln_noisy: Optional[torch.Tensor] = None,
adaln_clean: Optional[torch.Tensor] = None,
): ):
if self.modulation: if self.modulation:
assert adaln_input is not None seq_len = x.shape[1]
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() if noise_mask is not None:
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp # Per-token modulation: different modulation for noisy/clean tokens
mod_noisy = self.adaLN_modulation(adaln_noisy)
mod_clean = self.adaLN_modulation(adaln_clean)
scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1)
scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1)
gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh()
gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh()
scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy
scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean
scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len)
scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len)
gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len)
gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len)
else:
# Global modulation: same modulation for all tokens (avoid double select)
mod = self.adaLN_modulation(adaln_input)
scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2)
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
# Attention block # Attention block
attn_out = self.attention( attn_out = self.attention(
self.attention_norm1(x) * scale_msa, self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis
freqs_cis=freqs_cis,
) )
x = x + gate_msa * self.attention_norm2(attn_out) x = x + gate_msa * self.attention_norm2(attn_out)
# FFN block # FFN block
x = x + gate_mlp * self.ffn_norm2( x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
self.feed_forward(
self.ffn_norm1(x) * scale_mlp,
)
)
else: else:
# Attention block # Attention block
attn_out = self.attention( attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis)
self.attention_norm1(x),
freqs_cis=freqs_cis,
)
x = x + self.attention_norm2(attn_out) x = x + self.attention_norm2(attn_out)
# FFN block # FFN block
x = x + self.ffn_norm2( x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
self.feed_forward(
self.ffn_norm1(x),
)
)
return x return x
@@ -229,9 +259,21 @@ class FinalLayer(nn.Module):
nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
) )
def forward(self, x, c): def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None):
scale = 1.0 + self.adaLN_modulation(c) seq_len = x.shape[1]
x = self.norm_final(x) * scale.unsqueeze(1)
if noise_mask is not None:
# Per-token modulation
scale_noisy = 1.0 + self.adaLN_modulation(c_noisy)
scale_clean = 1.0 + self.adaLN_modulation(c_clean)
scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len)
else:
# Original global modulation
assert c is not None, "Either c or (c_noisy, c_clean) must be provided"
scale = 1.0 + self.adaLN_modulation(c)
scale = scale.unsqueeze(1)
x = self.norm_final(x) * scale
x = self.linear(x) x = self.linear(x)
return x return x
@@ -274,7 +316,10 @@ class RopeEmbedder:
result = [] result = []
for i in range(len(self.axes_dims)): for i in range(len(self.axes_dims)):
index = ids[:, i] index = ids[:, i]
result.append(self.freqs_cis[i][index]) if IS_NPU_AVAILABLE:
result.append(torch.index_select(self.freqs_cis[i], 0, index))
else:
result.append(self.freqs_cis[i][index])
return torch.cat(result, dim=-1) return torch.cat(result, dim=-1)
@@ -299,6 +344,7 @@ class ZImageDiT(nn.Module):
t_scale=1000.0, t_scale=1000.0,
axes_dims=[32, 48, 48], axes_dims=[32, 48, 48],
axes_lens=[1024, 512, 512], axes_lens=[1024, 512, 512],
siglip_feat_dim=None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
@@ -359,6 +405,32 @@ class ZImageDiT(nn.Module):
nn.Linear(cap_feat_dim, dim, bias=True), nn.Linear(cap_feat_dim, dim, bias=True),
) )
# Optional SigLIP components (for Omni variant)
self.siglip_feat_dim = siglip_feat_dim
if siglip_feat_dim is not None:
self.siglip_embedder = nn.Sequential(
RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True)
)
self.siglip_refiner = nn.ModuleList(
[
ZImageTransformerBlock(
2000 + layer_id,
dim,
n_heads,
n_kv_heads,
norm_eps,
qk_norm,
modulation=False,
)
for layer_id in range(n_refiner_layers)
]
)
self.siglip_pad_token = nn.Parameter(torch.empty((1, dim)))
else:
self.siglip_embedder = None
self.siglip_refiner = None
self.siglip_pad_token = None
self.x_pad_token = nn.Parameter(torch.empty((1, dim))) self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
@@ -375,22 +447,57 @@ class ZImageDiT(nn.Module):
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]: def unpatchify(
self,
x: List[torch.Tensor],
size: List[Tuple],
patch_size = 2,
f_patch_size = 1,
x_pos_offsets: Optional[List[Tuple[int, int]]] = None,
) -> List[torch.Tensor]:
pH = pW = patch_size pH = pW = patch_size
pF = f_patch_size pF = f_patch_size
bsz = len(x) bsz = len(x)
assert len(size) == bsz assert len(size) == bsz
for i in range(bsz):
F, H, W = size[i] if x_pos_offsets is not None:
ori_len = (F // pF) * (H // pH) * (W // pW) # Omni: extract target image from unified sequence (cond_images + target)
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" result = []
x[i] = ( for i in range(bsz):
x[i][:ori_len] unified_x = x[i][x_pos_offsets[i][0] : x_pos_offsets[i][1]]
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) cu_len = 0
.permute(6, 0, 3, 1, 4, 2, 5) x_item = None
.reshape(self.out_channels, F, H, W) for j in range(len(size[i])):
) if size[i][j] is None:
return x ori_len = 0
pad_len = SEQ_MULTI_OF
cu_len += pad_len + ori_len
else:
F, H, W = size[i][j]
ori_len = (F // pF) * (H // pH) * (W // pW)
pad_len = (-ori_len) % SEQ_MULTI_OF
x_item = (
unified_x[cu_len : cu_len + ori_len]
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
.permute(6, 0, 3, 1, 4, 2, 5)
.reshape(self.out_channels, F, H, W)
)
cu_len += ori_len + pad_len
result.append(x_item) # Return only the last (target) image
return result
else:
# Original mode: simple unpatchify
for i in range(bsz):
F, H, W = size[i]
ori_len = (F // pF) * (H // pH) * (W // pW)
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
x[i] = (
x[i][:ori_len]
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
.permute(6, 0, 3, 1, 4, 2, 5)
.reshape(self.out_channels, F, H, W)
)
return x
@staticmethod @staticmethod
def create_coordinate_grid(size, start=None, device=None): def create_coordinate_grid(size, start=None, device=None):
@@ -405,8 +512,8 @@ class ZImageDiT(nn.Module):
self, self,
all_image: List[torch.Tensor], all_image: List[torch.Tensor],
all_cap_feats: List[torch.Tensor], all_cap_feats: List[torch.Tensor],
patch_size: int, patch_size: int = 2,
f_patch_size: int, f_patch_size: int = 1,
): ):
pH = pW = patch_size pH = pW = patch_size
pF = f_patch_size pF = f_patch_size
@@ -490,90 +597,487 @@ class ZImageDiT(nn.Module):
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
all_image_out.append(image_padded_feat) all_image_out.append(image_padded_feat)
return all_image_out, all_cap_feats_out, {
"x_size": all_image_size,
"x_pos_ids": all_image_pos_ids,
"cap_pos_ids": all_cap_pos_ids,
"x_pad_mask": all_image_pad_mask,
"cap_pad_mask": all_cap_pad_mask
}
# (
# all_img_out,
# all_cap_out,
# all_img_size,
# all_img_pos_ids,
# all_cap_pos_ids,
# all_img_pad_mask,
# all_cap_pad_mask,
# )
def patchify_controlnet(
self,
all_image: List[torch.Tensor],
patch_size: int = 2,
f_patch_size: int = 1,
cap_padding_len: int = None,
):
pH = pW = patch_size
pF = f_patch_size
device = all_image[0].device
all_image_out = []
all_image_size = []
all_image_pos_ids = []
all_image_pad_mask = []
for i, image in enumerate(all_image):
### Process Image
C, F, H, W = image.size()
all_image_size.append((F, H, W))
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
image_ori_len = len(image)
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
image_ori_pos_ids = self.create_coordinate_grid(
size=(F_tokens, H_tokens, W_tokens),
start=(cap_padding_len + 1, 0, 0),
device=device,
).flatten(0, 2)
image_padding_pos_ids = (
self.create_coordinate_grid(
size=(1, 1, 1),
start=(0, 0, 0),
device=device,
)
.flatten(0, 2)
.repeat(image_padding_len, 1)
)
image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
all_image_pos_ids.append(image_padded_pos_ids)
# pad mask
all_image_pad_mask.append(
torch.cat(
[
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
],
dim=0,
)
)
# padded feature
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
all_image_out.append(image_padded_feat)
return ( return (
all_image_out, all_image_out,
all_cap_feats_out,
all_image_size, all_image_size,
all_image_pos_ids, all_image_pos_ids,
all_cap_pos_ids,
all_image_pad_mask, all_image_pad_mask,
all_cap_pad_mask,
) )
def _prepare_sequence(
self,
feats: List[torch.Tensor],
pos_ids: List[torch.Tensor],
inner_pad_mask: List[torch.Tensor],
pad_token: torch.nn.Parameter,
noise_mask: Optional[List[List[int]]] = None,
device: torch.device = None,
):
"""Prepare sequence: apply pad token, RoPE embed, pad to batch, create attention mask."""
item_seqlens = [len(f) for f in feats]
max_seqlen = max(item_seqlens)
bsz = len(feats)
# Pad token
feats_cat = torch.cat(feats, dim=0)
feats_cat[torch.cat(inner_pad_mask)] = pad_token.to(dtype=feats_cat.dtype, device=feats_cat.device)
feats = list(feats_cat.split(item_seqlens, dim=0))
# RoPE
freqs_cis = list(self.rope_embedder(torch.cat(pos_ids, dim=0)).split([len(p) for p in pos_ids], dim=0))
# Pad to batch
feats = pad_sequence(feats, batch_first=True, padding_value=0.0)
freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]]
# Attention mask
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(item_seqlens):
attn_mask[i, :seq_len] = 1
# Noise mask
noise_mask_tensor = None
if noise_mask is not None:
noise_mask_tensor = pad_sequence(
[torch.tensor(m, dtype=torch.long, device=device) for m in noise_mask],
batch_first=True,
padding_value=0,
)[:, : feats.shape[1]]
return feats, freqs_cis, attn_mask, item_seqlens, noise_mask_tensor
def _build_unified_sequence(
self,
x: torch.Tensor,
x_freqs: torch.Tensor,
x_seqlens: List[int],
x_noise_mask: Optional[List[List[int]]],
cap: torch.Tensor,
cap_freqs: torch.Tensor,
cap_seqlens: List[int],
cap_noise_mask: Optional[List[List[int]]],
siglip: Optional[torch.Tensor],
siglip_freqs: Optional[torch.Tensor],
siglip_seqlens: Optional[List[int]],
siglip_noise_mask: Optional[List[List[int]]],
omni_mode: bool,
device: torch.device,
):
"""Build unified sequence: x, cap, and optionally siglip.
Basic mode order: [x, cap]; Omni mode order: [cap, x, siglip]
"""
bsz = len(x_seqlens)
unified = []
unified_freqs = []
unified_noise_mask = []
for i in range(bsz):
x_len, cap_len = x_seqlens[i], cap_seqlens[i]
if omni_mode:
# Omni: [cap, x, siglip]
if siglip is not None and siglip_seqlens is not None:
sig_len = siglip_seqlens[i]
unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len], siglip[i][:sig_len]]))
unified_freqs.append(
torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len], siglip_freqs[i][:sig_len]])
)
unified_noise_mask.append(
torch.tensor(
cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], dtype=torch.long, device=device
)
)
else:
unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len]]))
unified_freqs.append(torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len]]))
unified_noise_mask.append(
torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device)
)
else:
# Basic: [x, cap]
unified.append(torch.cat([x[i][:x_len], cap[i][:cap_len]]))
unified_freqs.append(torch.cat([x_freqs[i][:x_len], cap_freqs[i][:cap_len]]))
# Compute unified seqlens
if omni_mode:
if siglip is not None and siglip_seqlens is not None:
unified_seqlens = [a + b + c for a, b, c in zip(cap_seqlens, x_seqlens, siglip_seqlens)]
else:
unified_seqlens = [a + b for a, b in zip(cap_seqlens, x_seqlens)]
else:
unified_seqlens = [a + b for a, b in zip(x_seqlens, cap_seqlens)]
max_seqlen = max(unified_seqlens)
# Pad to batch
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0)
# Attention mask
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_seqlens):
attn_mask[i, :seq_len] = 1
# Noise mask
noise_mask_tensor = None
if omni_mode:
noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0)[
:, : unified.shape[1]
]
return unified, unified_freqs, attn_mask, noise_mask_tensor
def _pad_with_ids(
self,
feat: torch.Tensor,
pos_grid_size: Tuple,
pos_start: Tuple,
device: torch.device,
noise_mask_val: Optional[int] = None,
):
"""Pad feature to SEQ_MULTI_OF, create position IDs and pad mask."""
ori_len = len(feat)
pad_len = (-ori_len) % SEQ_MULTI_OF
total_len = ori_len + pad_len
# Pos IDs
ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2)
if pad_len > 0:
pad_pos_ids = (
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
.flatten(0, 2)
.repeat(pad_len, 1)
)
pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0)
padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0)
pad_mask = torch.cat(
[
torch.zeros(ori_len, dtype=torch.bool, device=device),
torch.ones(pad_len, dtype=torch.bool, device=device),
]
)
else:
pos_ids = ori_pos_ids
padded_feat = feat
pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device)
noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level
return padded_feat, pos_ids, pad_mask, total_len, noise_mask
def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int):
"""Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim)."""
pH, pW, pF = patch_size, patch_size, f_patch_size
C, F, H, W = image.size()
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
return image, (F, H, W), (F_tokens, H_tokens, W_tokens)
def patchify_and_embed_omni(
self,
all_x: List[List[torch.Tensor]],
all_cap_feats: List[List[torch.Tensor]],
all_siglip_feats: List[List[torch.Tensor]],
patch_size: int = 2,
f_patch_size: int = 1,
images_noise_mask: List[List[int]] = None,
):
"""Patchify for omni mode: multiple images per batch item with noise masks."""
bsz = len(all_x)
device = all_x[0][-1].device
dtype = all_x[0][-1].dtype
all_x_out, all_x_size, all_x_pos_ids, all_x_pad_mask, all_x_len, all_x_noise_mask = [], [], [], [], [], []
all_cap_out, all_cap_pos_ids, all_cap_pad_mask, all_cap_len, all_cap_noise_mask = [], [], [], [], []
all_sig_out, all_sig_pos_ids, all_sig_pad_mask, all_sig_len, all_sig_noise_mask = [], [], [], [], []
for i in range(bsz):
num_images = len(all_x[i])
cap_feats_list, cap_pos_list, cap_mask_list, cap_lens, cap_noise = [], [], [], [], []
cap_end_pos = []
cap_cu_len = 1
# Process captions
for j, cap_item in enumerate(all_cap_feats[i]):
noise_val = images_noise_mask[i][j] if j < len(images_noise_mask[i]) else 1
cap_out, cap_pos, cap_mask, cap_len, cap_nm = self._pad_with_ids(
cap_item,
(len(cap_item) + (-len(cap_item)) % SEQ_MULTI_OF, 1, 1),
(cap_cu_len, 0, 0),
device,
noise_val,
)
cap_feats_list.append(cap_out)
cap_pos_list.append(cap_pos)
cap_mask_list.append(cap_mask)
cap_lens.append(cap_len)
cap_noise.extend(cap_nm)
cap_cu_len += len(cap_item)
cap_end_pos.append(cap_cu_len)
cap_cu_len += 2 # for image vae and siglip tokens
all_cap_out.append(torch.cat(cap_feats_list, dim=0))
all_cap_pos_ids.append(torch.cat(cap_pos_list, dim=0))
all_cap_pad_mask.append(torch.cat(cap_mask_list, dim=0))
all_cap_len.append(cap_lens)
all_cap_noise_mask.append(cap_noise)
# Process images
x_feats_list, x_pos_list, x_mask_list, x_lens, x_size, x_noise = [], [], [], [], [], []
for j, x_item in enumerate(all_x[i]):
noise_val = images_noise_mask[i][j]
if x_item is not None:
x_patches, size, (F_t, H_t, W_t) = self._patchify_image(x_item, patch_size, f_patch_size)
x_out, x_pos, x_mask, x_len, x_nm = self._pad_with_ids(
x_patches, (F_t, H_t, W_t), (cap_end_pos[j], 0, 0), device, noise_val
)
x_size.append(size)
else:
x_len = SEQ_MULTI_OF
x_out = torch.zeros((x_len, X_PAD_DIM), dtype=dtype, device=device)
x_pos = self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(x_len, 1)
x_mask = torch.ones(x_len, dtype=torch.bool, device=device)
x_nm = [noise_val] * x_len
x_size.append(None)
x_feats_list.append(x_out)
x_pos_list.append(x_pos)
x_mask_list.append(x_mask)
x_lens.append(x_len)
x_noise.extend(x_nm)
all_x_out.append(torch.cat(x_feats_list, dim=0))
all_x_pos_ids.append(torch.cat(x_pos_list, dim=0))
all_x_pad_mask.append(torch.cat(x_mask_list, dim=0))
all_x_size.append(x_size)
all_x_len.append(x_lens)
all_x_noise_mask.append(x_noise)
# Process siglip
if all_siglip_feats[i] is None:
all_sig_len.append([0] * num_images)
all_sig_out.append(None)
else:
sig_feats_list, sig_pos_list, sig_mask_list, sig_lens, sig_noise = [], [], [], [], []
for j, sig_item in enumerate(all_siglip_feats[i]):
noise_val = images_noise_mask[i][j]
if sig_item is not None:
sig_H, sig_W, sig_C = sig_item.size()
sig_flat = sig_item.permute(2, 0, 1).reshape(sig_H * sig_W, sig_C)
sig_out, sig_pos, sig_mask, sig_len, sig_nm = self._pad_with_ids(
sig_flat, (1, sig_H, sig_W), (cap_end_pos[j] + 1, 0, 0), device, noise_val
)
# Scale position IDs to match x resolution
if x_size[j] is not None:
sig_pos = sig_pos.float()
sig_pos[..., 1] = sig_pos[..., 1] / max(sig_H - 1, 1) * (x_size[j][1] - 1)
sig_pos[..., 2] = sig_pos[..., 2] / max(sig_W - 1, 1) * (x_size[j][2] - 1)
sig_pos = sig_pos.to(torch.int32)
else:
sig_len = SEQ_MULTI_OF
sig_out = torch.zeros((sig_len, self.siglip_feat_dim), dtype=dtype, device=device)
sig_pos = (
self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(sig_len, 1)
)
sig_mask = torch.ones(sig_len, dtype=torch.bool, device=device)
sig_nm = [noise_val] * sig_len
sig_feats_list.append(sig_out)
sig_pos_list.append(sig_pos)
sig_mask_list.append(sig_mask)
sig_lens.append(sig_len)
sig_noise.extend(sig_nm)
all_sig_out.append(torch.cat(sig_feats_list, dim=0))
all_sig_pos_ids.append(torch.cat(sig_pos_list, dim=0))
all_sig_pad_mask.append(torch.cat(sig_mask_list, dim=0))
all_sig_len.append(sig_lens)
all_sig_noise_mask.append(sig_noise)
# Compute x position offsets
all_x_pos_offsets = [(sum(all_cap_len[i]), sum(all_cap_len[i]) + sum(all_x_len[i])) for i in range(bsz)]
return (
all_x_out,
all_cap_out,
all_sig_out,
all_x_size,
all_x_pos_ids,
all_cap_pos_ids,
all_sig_pos_ids,
all_x_pad_mask,
all_cap_pad_mask,
all_sig_pad_mask,
all_x_pos_offsets,
all_x_noise_mask,
all_cap_noise_mask,
all_sig_noise_mask,
)
return all_x_out, all_cap_out, all_sig_out, {
"x_size": x_size,
"x_pos_ids": all_x_pos_ids,
"cap_pos_ids": all_cap_pos_ids,
"sig_pos_ids": all_sig_pos_ids,
"x_pad_mask": all_x_pad_mask,
"cap_pad_mask": all_cap_pad_mask,
"sig_pad_mask": all_sig_pad_mask,
"x_pos_offsets": all_x_pos_offsets,
"x_noise_mask": all_x_noise_mask,
"cap_noise_mask": all_cap_noise_mask,
"sig_noise_mask": all_sig_noise_mask,
}
def forward( def forward(
self, self,
x: List[torch.Tensor], x: List[torch.Tensor],
t, t,
cap_feats: List[torch.Tensor], cap_feats: List[torch.Tensor],
siglip_feats = None,
image_noise_mask = None,
patch_size=2, patch_size=2,
f_patch_size=1, f_patch_size=1,
use_gradient_checkpointing=False, use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False, use_gradient_checkpointing_offload=False,
): ):
assert patch_size in self.all_patch_size assert patch_size in self.all_patch_size and f_patch_size in self.all_f_patch_size
assert f_patch_size in self.all_f_patch_size omni_mode = isinstance(x[0], list)
device = x[0][-1].device if omni_mode else x[0].device
bsz = len(x) if omni_mode:
device = x[0].device # Dual embeddings: noisy (t) and clean (t=1)
t = t * self.t_scale t_noisy = self.t_embedder(t * self.t_scale).type_as(x[0][-1])
t = self.t_embedder(t) t_clean = self.t_embedder(torch.ones_like(t) * self.t_scale).type_as(x[0][-1])
adaln_input = None
else:
# Single embedding for all tokens
adaln_input = self.t_embedder(t * self.t_scale).type_as(x[0])
t_noisy = t_clean = None
adaln_input = t # Patchify
if omni_mode:
( (
x, x,
cap_feats, cap_feats,
x_size, siglip_feats,
x_pos_ids, x_size,
cap_pos_ids, x_pos_ids,
x_inner_pad_mask, cap_pos_ids,
cap_inner_pad_mask, siglip_pos_ids,
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) x_pad_mask,
cap_pad_mask,
siglip_pad_mask,
x_pos_offsets,
x_noise_mask,
cap_noise_mask,
siglip_noise_mask,
) = self.patchify_and_embed_omni(x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask)
else:
(
x,
cap_feats,
x_size,
x_pos_ids,
cap_pos_ids,
x_pad_mask,
cap_pad_mask,
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
x_pos_offsets = x_noise_mask = cap_noise_mask = siglip_noise_mask = None
# x embed & refine # x embed & refine
x_item_seqlens = [len(_) for _ in x] x_seqlens = [len(xi) for xi in x]
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](torch.cat(x, dim=0)) # embed
x_max_item_seqlen = max(x_item_seqlens) x, x_freqs, x_mask, _, x_noise_tensor = self._prepare_sequence(
list(x.split(x_seqlens, dim=0)), x_pos_ids, x_pad_mask, self.x_pad_token, x_noise_mask, device
x = torch.cat(x, dim=0) )
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token.to(dtype=x.dtype, device=x.device)
x = list(x.split(x_item_seqlens, dim=0))
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
x = pad_sequence(x, batch_first=True, padding_value=0.0)
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(x_item_seqlens):
x_attn_mask[i, :seq_len] = 1
for layer in self.noise_refiner: for layer in self.noise_refiner:
x = gradient_checkpoint_forward( x = gradient_checkpoint_forward(
layer, layer,
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=x, x=x, attn_mask=x_mask, freqs_cis=x_freqs, adaln_input=adaln_input, noise_mask=x_noise_tensor, adaln_noisy=t_noisy, adaln_clean=t_clean,
attn_mask=x_attn_mask,
freqs_cis=x_freqs_cis,
adaln_input=adaln_input,
) )
# cap embed & refine # Cap embed & refine
cap_item_seqlens = [len(_) for _ in cap_feats] cap_seqlens = [len(ci) for ci in cap_feats]
assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens) cap_feats = self.cap_embedder(torch.cat(cap_feats, dim=0)) # embed
cap_max_item_seqlen = max(cap_item_seqlens) cap_feats, cap_freqs, cap_mask, _, _ = self._prepare_sequence(
list(cap_feats.split(cap_seqlens, dim=0)), cap_pos_ids, cap_pad_mask, self.cap_pad_token, None, device
cap_feats = torch.cat(cap_feats, dim=0) )
cap_feats = self.cap_embedder(cap_feats)
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token.to(dtype=x.dtype, device=x.device)
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(cap_item_seqlens):
cap_attn_mask[i, :seq_len] = 1
for layer in self.context_refiner: for layer in self.context_refiner:
cap_feats = gradient_checkpoint_forward( cap_feats = gradient_checkpoint_forward(
@@ -581,41 +1085,68 @@ class ZImageDiT(nn.Module):
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=cap_feats, x=cap_feats,
attn_mask=cap_attn_mask, attn_mask=cap_mask,
freqs_cis=cap_freqs_cis, freqs_cis=cap_freqs,
) )
# unified # Siglip embed & refine
unified = [] siglip_seqlens = siglip_freqs = None
unified_freqs_cis = [] if omni_mode and siglip_feats[0] is not None and self.siglip_embedder is not None:
for i in range(bsz): siglip_seqlens = [len(si) for si in siglip_feats]
x_len = x_item_seqlens[i] siglip_feats = self.siglip_embedder(torch.cat(siglip_feats, dim=0)) # embed
cap_len = cap_item_seqlens[i] siglip_feats, siglip_freqs, siglip_mask, _, _ = self._prepare_sequence(
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) list(siglip_feats.split(siglip_seqlens, dim=0)),
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) siglip_pos_ids,
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] siglip_pad_mask,
assert unified_item_seqlens == [len(_) for _ in unified] self.siglip_pad_token,
unified_max_item_seqlen = max(unified_item_seqlens) None,
device,
)
unified = pad_sequence(unified, batch_first=True, padding_value=0.0) for layer in self.siglip_refiner:
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) siglip_feats = gradient_checkpoint_forward(
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) layer,
for i, seq_len in enumerate(unified_item_seqlens): use_gradient_checkpointing=use_gradient_checkpointing,
unified_attn_mask[i, :seq_len] = 1 use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=siglip_feats, attn_mask=siglip_mask, freqs_cis=siglip_freqs,
)
for layer in self.layers: # Unified sequence
unified, unified_freqs, unified_mask, unified_noise_tensor = self._build_unified_sequence(
x,
x_freqs,
x_seqlens,
x_noise_mask,
cap_feats,
cap_freqs,
cap_seqlens,
cap_noise_mask,
siglip_feats,
siglip_freqs,
siglip_seqlens,
siglip_noise_mask,
omni_mode,
device,
)
# Main transformer layers
for layer_idx, layer in enumerate(self.layers):
unified = gradient_checkpoint_forward( unified = gradient_checkpoint_forward(
layer, layer,
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=unified, x=unified, attn_mask=unified_mask, freqs_cis=unified_freqs, adaln_input=adaln_input, noise_mask=unified_noise_tensor, adaln_noisy=t_noisy, adaln_clean=t_clean
attn_mask=unified_attn_mask,
freqs_cis=unified_freqs_cis,
adaln_input=adaln_input,
) )
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) unified = (
unified = list(unified.unbind(dim=0)) self.all_final_layer[f"{patch_size}-{f_patch_size}"](
x = self.unpatchify(unified, x_size, patch_size, f_patch_size) unified, noise_mask=unified_noise_tensor, c_noisy=t_noisy, c_clean=t_clean
)
if omni_mode
else self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, c=adaln_input)
)
return x, {} # Unpatchify
x = self.unpatchify(list(unified.unbind(dim=0)), x_size, patch_size, f_patch_size, x_pos_offsets)
return x

View File

@@ -0,0 +1,189 @@
import torch
from .qwen_image_image2lora import ImageEmbeddingToLoraMatrix, SequencialMLP
class LoRATrainerBlock(torch.nn.Module):
def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024, prefix="transformer_blocks"):
super().__init__()
self.prefix = prefix
self.lora_patterns = lora_patterns
self.block_id = block_id
self.layers = []
for name, lora_a_dim, lora_b_dim in self.lora_patterns:
self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank))
self.layers = torch.nn.ModuleList(self.layers)
if use_residual:
self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim)
else:
self.proj_residual = None
def forward(self, x, residual=None):
lora = {}
if self.proj_residual is not None: residual = self.proj_residual(residual)
for lora_pattern, layer in zip(self.lora_patterns, self.layers):
name = lora_pattern[0]
lora_a, lora_b = layer(x, residual=residual)
lora[f"{self.prefix}.{self.block_id}.{name}.lora_A.default.weight"] = lora_a
lora[f"{self.prefix}.{self.block_id}.{name}.lora_B.default.weight"] = lora_b
return lora
class ZImageImage2LoRAComponent(torch.nn.Module):
def __init__(self, lora_patterns, prefix, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024):
super().__init__()
self.lora_patterns = lora_patterns
self.num_blocks = num_blocks
self.blocks = []
for lora_patterns in self.lora_patterns:
for block_id in range(self.num_blocks):
self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim, prefix=prefix))
self.blocks = torch.nn.ModuleList(self.blocks)
self.residual_scale = 0.05
self.use_residual = use_residual
def forward(self, x, residual=None):
if residual is not None:
if self.use_residual:
residual = residual * self.residual_scale
else:
residual = None
lora = {}
for block in self.blocks:
lora.update(block(x, residual))
return lora
class ZImageImage2LoRAModel(torch.nn.Module):
def __init__(self, use_residual=False, compress_dim=64, rank=4, residual_length=64+7, residual_mid_dim=1024):
super().__init__()
lora_patterns = [
[
("attention.to_q", 3840, 3840),
("attention.to_k", 3840, 3840),
("attention.to_v", 3840, 3840),
("attention.to_out.0", 3840, 3840),
],
[
("feed_forward.w1", 3840, 10240),
("feed_forward.w2", 10240, 3840),
("feed_forward.w3", 3840, 10240),
],
]
config = {
"lora_patterns": lora_patterns,
"use_residual": use_residual,
"compress_dim": compress_dim,
"rank": rank,
"residual_length": residual_length,
"residual_mid_dim": residual_mid_dim,
}
self.layers_lora = ZImageImage2LoRAComponent(
prefix="layers",
num_blocks=30,
**config,
)
self.context_refiner_lora = ZImageImage2LoRAComponent(
prefix="context_refiner",
num_blocks=2,
**config,
)
self.noise_refiner_lora = ZImageImage2LoRAComponent(
prefix="noise_refiner",
num_blocks=2,
**config,
)
def forward(self, x, residual=None):
lora = {}
lora.update(self.layers_lora(x, residual=residual))
lora.update(self.context_refiner_lora(x, residual=residual))
lora.update(self.noise_refiner_lora(x, residual=residual))
return lora
def initialize_weights(self):
state_dict = self.state_dict()
for name in state_dict:
if ".proj_a." in name:
state_dict[name] = state_dict[name] * 0.3
elif ".proj_b.proj_out." in name:
state_dict[name] = state_dict[name] * 0
elif ".proj_residual.proj_out." in name:
state_dict[name] = state_dict[name] * 0.3
self.load_state_dict(state_dict)
class ImageEmb2LoRAWeightCompressed(torch.nn.Module):
def __init__(self, in_dim, out_dim, emb_dim, rank):
super().__init__()
self.lora_a = torch.nn.Parameter(torch.randn((rank, in_dim)))
self.lora_b = torch.nn.Parameter(torch.randn((out_dim, rank)))
self.proj = torch.nn.Linear(emb_dim, rank * rank, bias=True)
self.rank = rank
def forward(self, x):
x = self.proj(x).view(self.rank, self.rank)
lora_a = x @ self.lora_a
lora_b = self.lora_b
return lora_a, lora_b
class ZImageImage2LoRAModelCompressed(torch.nn.Module):
def __init__(self, emb_dim=1536+4096, rank=32):
super().__init__()
target_layers = [
("attention.to_q", 3840, 3840),
("attention.to_k", 3840, 3840),
("attention.to_v", 3840, 3840),
("attention.to_out.0", 3840, 3840),
("feed_forward.w1", 3840, 10240),
("feed_forward.w2", 10240, 3840),
("feed_forward.w3", 3840, 10240),
]
self.lora_patterns = [
{
"prefix": "layers",
"num_layers": 30,
"target_layers": target_layers,
},
{
"prefix": "context_refiner",
"num_layers": 2,
"target_layers": target_layers,
},
{
"prefix": "noise_refiner",
"num_layers": 2,
"target_layers": target_layers,
},
]
module_dict = {}
for lora_pattern in self.lora_patterns:
prefix, num_layers, target_layers = lora_pattern["prefix"], lora_pattern["num_layers"], lora_pattern["target_layers"]
for layer_id in range(num_layers):
for layer_name, in_dim, out_dim in target_layers:
name = f"{prefix}.{layer_id}.{layer_name}".replace(".", "___")
model = ImageEmb2LoRAWeightCompressed(in_dim, out_dim, emb_dim, rank)
module_dict[name] = model
self.module_dict = torch.nn.ModuleDict(module_dict)
def forward(self, x, residual=None):
lora = {}
for name, module in self.module_dict.items():
name = name.replace("___", ".")
name_a, name_b = f"{name}.lora_A.default.weight", f"{name}.lora_B.default.weight"
lora_a, lora_b = module(x)
lora[name_a] = lora_a
lora[name_b] = lora_b
return lora
def initialize_weights(self):
state_dict = self.state_dict()
for name in state_dict:
if "lora_b" in name:
state_dict[name] = state_dict[name] * 0
elif "lora_a" in name:
state_dict[name] = state_dict[name] * 0.2
elif "proj.weight" in name:
print(name)
state_dict[name] = state_dict[name] * 0.2
self.load_state_dict(state_dict)

View File

@@ -3,38 +3,71 @@ import torch
class ZImageTextEncoder(torch.nn.Module): class ZImageTextEncoder(torch.nn.Module):
def __init__(self): def __init__(self, model_size="4B"):
super().__init__() super().__init__()
config = Qwen3Config(**{ config_dict = {
"architectures": [ "4B": Qwen3Config(**{
"Qwen3ForCausalLM" "architectures": [
], "Qwen3ForCausalLM"
"attention_bias": False, ],
"attention_dropout": 0.0, "attention_bias": False,
"bos_token_id": 151643, "attention_dropout": 0.0,
"eos_token_id": 151645, "bos_token_id": 151643,
"head_dim": 128, "eos_token_id": 151645,
"hidden_act": "silu", "head_dim": 128,
"hidden_size": 2560, "hidden_act": "silu",
"initializer_range": 0.02, "hidden_size": 2560,
"intermediate_size": 9728, "initializer_range": 0.02,
"max_position_embeddings": 40960, "intermediate_size": 9728,
"max_window_layers": 36, "max_position_embeddings": 40960,
"model_type": "qwen3", "max_window_layers": 36,
"num_attention_heads": 32, "model_type": "qwen3",
"num_hidden_layers": 36, "num_attention_heads": 32,
"num_key_value_heads": 8, "num_hidden_layers": 36,
"rms_norm_eps": 1e-06, "num_key_value_heads": 8,
"rope_scaling": None, "rms_norm_eps": 1e-06,
"rope_theta": 1000000, "rope_scaling": None,
"sliding_window": None, "rope_theta": 1000000,
"tie_word_embeddings": True, "sliding_window": None,
"torch_dtype": "bfloat16", "tie_word_embeddings": True,
"transformers_version": "4.51.0", "torch_dtype": "bfloat16",
"use_cache": True, "transformers_version": "4.51.0",
"use_sliding_window": False, "use_cache": True,
"vocab_size": 151936 "use_sliding_window": False,
}) "vocab_size": 151936
}),
"8B": Qwen3Config(**{
"architectures": [
"Qwen3ForCausalLM"
],
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"dtype": "bfloat16",
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 12288,
"max_position_embeddings": 40960,
"max_window_layers": 36,
"model_type": "qwen3",
"num_attention_heads": 32,
"num_hidden_layers": 36,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_scaling": None,
"rope_theta": 1000000,
"sliding_window": None,
"tie_word_embeddings": False,
"transformers_version": "4.56.1",
"use_cache": True,
"use_sliding_window": False,
"vocab_size": 151936
})
}
config = config_dict[model_size]
self.model = Qwen3Model(config) self.model = Qwen3Model(config)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):

View File

@@ -1,4 +1,4 @@
import torch, math import torch, math, torchvision
from PIL import Image from PIL import Image
from typing import Union from typing import Union
from tqdm import tqdm from tqdm import tqdm
@@ -6,25 +6,28 @@ from einops import rearrange
import numpy as np import numpy as np
from typing import Union, List, Optional, Tuple from typing import Union, List, Optional, Tuple
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
from transformers import AutoProcessor from transformers import AutoProcessor, AutoTokenizer
from ..models.flux2_text_encoder import Flux2TextEncoder from ..models.flux2_text_encoder import Flux2TextEncoder
from ..models.flux2_dit import Flux2DiT from ..models.flux2_dit import Flux2DiT
from ..models.flux2_vae import Flux2VAE from ..models.flux2_vae import Flux2VAE
from ..models.z_image_text_encoder import ZImageTextEncoder
class Flux2ImagePipeline(BasePipeline): class Flux2ImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16): def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__( super().__init__(
device=device, torch_dtype=torch_dtype, device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, height_division_factor=16, width_division_factor=16,
) )
self.scheduler = FlowMatchScheduler("FLUX.2") self.scheduler = FlowMatchScheduler("FLUX.2")
self.text_encoder: Flux2TextEncoder = None self.text_encoder: Flux2TextEncoder = None
self.text_encoder_qwen3: ZImageTextEncoder = None
self.dit: Flux2DiT = None self.dit: Flux2DiT = None
self.vae: Flux2VAE = None self.vae: Flux2VAE = None
self.tokenizer: AutoProcessor = None self.tokenizer: AutoProcessor = None
@@ -32,8 +35,10 @@ class Flux2ImagePipeline(BasePipeline):
self.units = [ self.units = [
Flux2Unit_ShapeChecker(), Flux2Unit_ShapeChecker(),
Flux2Unit_PromptEmbedder(), Flux2Unit_PromptEmbedder(),
Flux2Unit_Qwen3PromptEmbedder(),
Flux2Unit_NoiseInitializer(), Flux2Unit_NoiseInitializer(),
Flux2Unit_InputImageEmbedder(), Flux2Unit_InputImageEmbedder(),
Flux2Unit_EditImageEmbedder(),
Flux2Unit_ImageIDs(), Flux2Unit_ImageIDs(),
] ]
self.model_fn = model_fn_flux2 self.model_fn = model_fn_flux2
@@ -42,7 +47,7 @@ class Flux2ImagePipeline(BasePipeline):
@staticmethod @staticmethod
def from_pretrained( def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16, torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda", device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [], model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"), tokenizer_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
vram_limit: float = None, vram_limit: float = None,
@@ -53,11 +58,12 @@ class Flux2ImagePipeline(BasePipeline):
# Fetch models # Fetch models
pipe.text_encoder = model_pool.fetch_model("flux2_text_encoder") pipe.text_encoder = model_pool.fetch_model("flux2_text_encoder")
pipe.text_encoder_qwen3 = model_pool.fetch_model("z_image_text_encoder")
pipe.dit = model_pool.fetch_model("flux2_dit") pipe.dit = model_pool.fetch_model("flux2_dit")
pipe.vae = model_pool.fetch_model("flux2_vae") pipe.vae = model_pool.fetch_model("flux2_vae")
if tokenizer_config is not None: if tokenizer_config is not None:
tokenizer_config.download_if_necessary() tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoProcessor.from_pretrained(tokenizer_config.path) pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
# VRAM Management # VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state() pipe.vram_management_enabled = pipe.check_vram_management_state()
@@ -75,6 +81,9 @@ class Flux2ImagePipeline(BasePipeline):
# Image # Image
input_image: Image.Image = None, input_image: Image.Image = None,
denoising_strength: float = 1.0, denoising_strength: float = 1.0,
# Edit
edit_image: Union[Image.Image, List[Image.Image]] = None,
edit_image_auto_resize: bool = True,
# Shape # Shape
height: int = 1024, height: int = 1024,
width: int = 1024, width: int = 1024,
@@ -98,6 +107,7 @@ class Flux2ImagePipeline(BasePipeline):
inputs_shared = { inputs_shared = {
"cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance, "cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance,
"input_image": input_image, "denoising_strength": denoising_strength, "input_image": input_image, "denoising_strength": denoising_strength,
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
"height": height, "width": width, "height": height, "width": width,
"seed": seed, "rand_device": rand_device, "seed": seed, "rand_device": rand_device,
"num_inference_steps": num_inference_steps, "num_inference_steps": num_inference_steps,
@@ -275,6 +285,10 @@ class Flux2Unit_PromptEmbedder(PipelineUnit):
return prompt_embeds, text_ids return prompt_embeds, text_ids
def process(self, pipe: Flux2ImagePipeline, prompt): def process(self, pipe: Flux2ImagePipeline, prompt):
# Skip if Qwen3 text encoder is available (handled by Qwen3PromptEmbedder)
if pipe.text_encoder_qwen3 is not None:
return {}
pipe.load_models_to_device(self.onload_model_names) pipe.load_models_to_device(self.onload_model_names)
prompt_embeds, text_ids = self.encode_prompt( prompt_embeds, text_ids = self.encode_prompt(
pipe.text_encoder, pipe.tokenizer, prompt, pipe.text_encoder, pipe.tokenizer, prompt,
@@ -283,6 +297,136 @@ class Flux2Unit_PromptEmbedder(PipelineUnit):
return {"prompt_embeds": prompt_embeds, "text_ids": text_ids} return {"prompt_embeds": prompt_embeds, "text_ids": text_ids}
class Flux2Unit_Qwen3PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_emb", "prompt_emb_mask"),
onload_model_names=("text_encoder_qwen3",)
)
self.hidden_states_layers = (9, 18, 27) # Qwen3 layers
def get_qwen3_prompt_embeds(
self,
text_encoder: ZImageTextEncoder,
tokenizer: AutoTokenizer,
prompt: Union[str, List[str]],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
):
dtype = text_encoder.dtype if dtype is None else dtype
device = text_encoder.device if device is None else device
prompt = [prompt] if isinstance(prompt, str) else prompt
all_input_ids = []
all_attention_masks = []
for single_prompt in prompt:
messages = [{"role": "user", "content": single_prompt}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_sequence_length,
)
all_input_ids.append(inputs["input_ids"])
all_attention_masks.append(inputs["attention_mask"])
input_ids = torch.cat(all_input_ids, dim=0).to(device)
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
# Forward pass through the model
with torch.inference_mode():
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
# Only use outputs from intermediate layers and stack them
out = torch.stack([output.hidden_states[k] for k in self.hidden_states_layers], dim=1)
out = out.to(dtype=dtype, device=device)
batch_size, num_channels, seq_len, hidden_dim = out.shape
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
return prompt_embeds
def prepare_text_ids(
self,
x: torch.Tensor, # (B, L, D) or (L, D)
t_coord: Optional[torch.Tensor] = None,
):
B, L, _ = x.shape
out_ids = []
for i in range(B):
t = torch.arange(1) if t_coord is None else t_coord[i]
h = torch.arange(1)
w = torch.arange(1)
l = torch.arange(L)
coords = torch.cartesian_prod(t, h, w, l)
out_ids.append(coords)
return torch.stack(out_ids)
def encode_prompt(
self,
text_encoder: ZImageTextEncoder,
tokenizer: AutoTokenizer,
prompt: Union[str, List[str]],
dtype = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 512,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_embeds = self.get_qwen3_prompt_embeds(
text_encoder=text_encoder,
tokenizer=tokenizer,
prompt=prompt,
dtype=dtype,
device=device,
max_sequence_length=max_sequence_length,
)
batch_size, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
text_ids = self.prepare_text_ids(prompt_embeds)
text_ids = text_ids.to(device)
return prompt_embeds, text_ids
def process(self, pipe: Flux2ImagePipeline, prompt):
# Check if Qwen3 text encoder is available
if pipe.text_encoder_qwen3 is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
prompt_embeds, text_ids = self.encode_prompt(
pipe.text_encoder_qwen3, pipe.tokenizer, prompt,
dtype=pipe.torch_dtype, device=pipe.device,
)
return {"prompt_embeds": prompt_embeds, "text_ids": text_ids}
class Flux2Unit_NoiseInitializer(PipelineUnit): class Flux2Unit_NoiseInitializer(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
@@ -318,6 +462,75 @@ class Flux2Unit_InputImageEmbedder(PipelineUnit):
return {"latents": latents, "input_latents": input_latents} return {"latents": latents, "input_latents": input_latents}
class Flux2Unit_EditImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image", "edit_image_auto_resize"),
output_params=("edit_latents", "edit_image_ids"),
onload_model_names=("vae",)
)
def calculate_dimensions(self, target_area, ratio):
import math
width = math.sqrt(target_area * ratio)
height = width / ratio
width = round(width / 32) * 32
height = round(height / 32) * 32
return width, height
def crop_and_resize(self, image, target_height, target_width):
width, height = image.size
scale = max(target_width / width, target_height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
return image
def edit_image_auto_resize(self, edit_image):
calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1])
return self.crop_and_resize(edit_image, calculated_height, calculated_width)
def process_image_ids(self, image_latents, scale=10):
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
t_coords = [t.view(-1) for t in t_coords]
image_latent_ids = []
for x, t in zip(image_latents, t_coords):
x = x.squeeze(0)
_, height, width = x.shape
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
image_latent_ids.append(x_ids)
image_latent_ids = torch.cat(image_latent_ids, dim=0)
image_latent_ids = image_latent_ids.unsqueeze(0)
return image_latent_ids
def process(self, pipe: Flux2ImagePipeline, edit_image, edit_image_auto_resize):
if edit_image is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
if isinstance(edit_image, Image.Image):
edit_image = [edit_image]
resized_edit_image, edit_latents = [], []
for image in edit_image:
# Preprocess
if edit_image_auto_resize is None or edit_image_auto_resize:
image = self.edit_image_auto_resize(image)
resized_edit_image.append(image)
# Encode
image = pipe.preprocess_image(image)
latents = pipe.vae.encode(image)
edit_latents.append(latents)
edit_image_ids = self.process_image_ids(edit_latents).to(pipe.device)
edit_latents = torch.concat([rearrange(latents, "B C H W -> B (H W) C") for latents in edit_latents], dim=1)
return {"edit_latents": edit_latents, "edit_image_ids": edit_image_ids}
class Flux2Unit_ImageIDs(PipelineUnit): class Flux2Unit_ImageIDs(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
@@ -352,10 +565,17 @@ def model_fn_flux2(
prompt_embeds=None, prompt_embeds=None,
text_ids=None, text_ids=None,
image_ids=None, image_ids=None,
edit_latents=None,
edit_image_ids=None,
use_gradient_checkpointing=False, use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False, use_gradient_checkpointing_offload=False,
**kwargs, **kwargs,
): ):
image_seq_len = latents.shape[1]
if edit_latents is not None:
image_seq_len = latents.shape[1]
latents = torch.concat([latents, edit_latents], dim=1)
image_ids = torch.concat([image_ids, edit_image_ids], dim=1)
embedded_guidance = torch.tensor([embedded_guidance], device=latents.device) embedded_guidance = torch.tensor([embedded_guidance], device=latents.device)
model_output = dit( model_output = dit(
hidden_states=latents, hidden_states=latents,
@@ -367,4 +587,5 @@ def model_fn_flux2(
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
) )
model_output = model_output[:, :image_seq_len]
return model_output return model_output

View File

@@ -6,6 +6,7 @@ from einops import rearrange, repeat
import numpy as np import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast from transformers import CLIPTokenizer, T5TokenizerFast
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward, load_state_dict from ..core import ModelConfig, gradient_checkpoint_forward, load_state_dict
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
@@ -55,7 +56,7 @@ class MultiControlNet(torch.nn.Module):
class FluxImagePipeline(BasePipeline): class FluxImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16): def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__( super().__init__(
device=device, torch_dtype=torch_dtype, device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, height_division_factor=16, width_division_factor=16,
@@ -117,7 +118,7 @@ class FluxImagePipeline(BasePipeline):
@staticmethod @staticmethod
def from_pretrained( def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16, torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda", device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [], model_configs: list[ModelConfig] = [],
tokenizer_1_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/"), tokenizer_1_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/"),
tokenizer_2_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/"), tokenizer_2_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/"),
@@ -377,7 +378,7 @@ class FluxImageUnit_PromptEmbedder(PipelineUnit):
text_encoder_2, text_encoder_2,
prompt, prompt,
positive=True, positive=True,
device="cuda", device=get_device_type(),
t5_sequence_length=512, t5_sequence_length=512,
): ):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device) pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)
@@ -558,7 +559,7 @@ class FluxImageUnit_EntityControl(PipelineUnit):
text_encoder_2, text_encoder_2,
prompt, prompt,
positive=True, positive=True,
device="cuda", device=get_device_type(),
t5_sequence_length=512, t5_sequence_length=512,
): ):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device) pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)
@@ -793,7 +794,7 @@ class FluxImageUnit_ValueControl(PipelineUnit):
class InfinitYou(torch.nn.Module): class InfinitYou(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16): def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__() super().__init__()
from facexlib.recognition import init_recognition_model from facexlib.recognition import init_recognition_model
from insightface.app import FaceAnalysis from insightface.app import FaceAnalysis

View File

@@ -4,7 +4,9 @@ from typing import Union
from tqdm import tqdm from tqdm import tqdm
from einops import rearrange from einops import rearrange
import numpy as np import numpy as np
from math import prod
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
@@ -21,7 +23,7 @@ from ..models.qwen_image_image2lora import QwenImageImage2LoRAModel
class QwenImagePipeline(BasePipeline): class QwenImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16): def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__( super().__init__(
device=device, torch_dtype=torch_dtype, device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, height_division_factor=16, width_division_factor=16,
@@ -47,6 +49,7 @@ class QwenImagePipeline(BasePipeline):
QwenImageUnit_InputImageEmbedder(), QwenImageUnit_InputImageEmbedder(),
QwenImageUnit_Inpaint(), QwenImageUnit_Inpaint(),
QwenImageUnit_EditImageEmbedder(), QwenImageUnit_EditImageEmbedder(),
QwenImageUnit_LayerInputImageEmbedder(),
QwenImageUnit_ContextImageEmbedder(), QwenImageUnit_ContextImageEmbedder(),
QwenImageUnit_PromptEmbedder(), QwenImageUnit_PromptEmbedder(),
QwenImageUnit_EntityControl(), QwenImageUnit_EntityControl(),
@@ -58,7 +61,7 @@ class QwenImagePipeline(BasePipeline):
@staticmethod @staticmethod
def from_pretrained( def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16, torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda", device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [], model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
processor_config: ModelConfig = None, processor_config: ModelConfig = None,
@@ -125,6 +128,11 @@ class QwenImagePipeline(BasePipeline):
edit_image: Image.Image = None, edit_image: Image.Image = None,
edit_image_auto_resize: bool = True, edit_image_auto_resize: bool = True,
edit_rope_interpolation: bool = False, edit_rope_interpolation: bool = False,
# Qwen-Image-Edit-2511
zero_cond_t: bool = False,
# Qwen-Image-Layered
layer_input_image: Image.Image = None,
layer_num: int = None,
# In-context control # In-context control
context_image: Image.Image = None, context_image: Image.Image = None,
# Tile # Tile
@@ -156,6 +164,9 @@ class QwenImagePipeline(BasePipeline):
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, "eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "edit_rope_interpolation": edit_rope_interpolation, "edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "edit_rope_interpolation": edit_rope_interpolation,
"context_image": context_image, "context_image": context_image,
"zero_cond_t": zero_cond_t,
"layer_input_image": layer_input_image,
"layer_num": layer_num,
} }
for unit in self.units: for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -175,7 +186,10 @@ class QwenImagePipeline(BasePipeline):
# Decode # Decode
self.load_models_to_device(['vae']) self.load_models_to_device(['vae'])
image = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) image = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.vae_output_to_image(image) if layer_num is None:
image = self.vae_output_to_image(image)
else:
image = [self.vae_output_to_image(i, pattern="C H W") for i in image]
self.load_models_to_device([]) self.load_models_to_device([])
return image return image
@@ -226,12 +240,15 @@ class QwenImageUnit_ShapeChecker(PipelineUnit):
class QwenImageUnit_NoiseInitializer(PipelineUnit): class QwenImageUnit_NoiseInitializer(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
input_params=("height", "width", "seed", "rand_device"), input_params=("height", "width", "seed", "rand_device", "layer_num"),
output_params=("noise",), output_params=("noise",),
) )
def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device): def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device, layer_num):
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) if layer_num is None:
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
else:
noise = pipe.generate_noise((layer_num + 1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise} return {"noise": noise}
@@ -248,8 +265,15 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit):
if input_image is None: if input_image is None:
return {"latents": noise, "input_latents": None} return {"latents": noise, "input_latents": None}
pipe.load_models_to_device(['vae']) pipe.load_models_to_device(['vae'])
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype) if isinstance(input_image, list):
input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) input_latents = []
for image in input_image:
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
input_latents.append(pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride))
input_latents = torch.concat(input_latents, dim=0)
else:
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if pipe.scheduler.training: if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents} return {"latents": noise, "input_latents": input_latents}
else: else:
@@ -257,6 +281,22 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit):
return {"latents": latents, "input_latents": input_latents} return {"latents": latents, "input_latents": input_latents}
class QwenImageUnit_LayerInputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("layer_input_image", "tiled", "tile_size", "tile_stride"),
output_params=("layer_input_latents",),
onload_model_names=("vae",)
)
def process(self, pipe: QwenImagePipeline, layer_input_image, tiled, tile_size, tile_stride):
if layer_input_image is None:
return {}
pipe.load_models_to_device(['vae'])
image = pipe.preprocess_image(layer_input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return {"layer_input_latents": latents}
class QwenImageUnit_Inpaint(PipelineUnit): class QwenImageUnit_Inpaint(PipelineUnit):
def __init__(self): def __init__(self):
@@ -673,18 +713,26 @@ def model_fn_qwen_image(
entity_prompt_emb_mask=None, entity_prompt_emb_mask=None,
entity_masks=None, entity_masks=None,
edit_latents=None, edit_latents=None,
layer_input_latents=None,
layer_num=None,
context_latents=None, context_latents=None,
enable_fp8_attention=False, enable_fp8_attention=False,
use_gradient_checkpointing=False, use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False, use_gradient_checkpointing_offload=False,
edit_rope_interpolation=False, edit_rope_interpolation=False,
zero_cond_t=False,
**kwargs **kwargs
): ):
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)] if layer_num is None:
layer_num = 1
img_shapes = [(1, latents.shape[2]//2, latents.shape[3]//2)]
else:
layer_num = layer_num + 1
img_shapes = [(1, latents.shape[2]//2, latents.shape[3]//2)] * layer_num
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist() txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
timestep = timestep / 1000 timestep = timestep / 1000
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2) image = rearrange(latents, "(B N) C (H P) (W Q) -> B (N H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2, N=layer_num)
image_seq_len = image.shape[1] image_seq_len = image.shape[1]
if context_latents is not None: if context_latents is not None:
@@ -696,9 +744,27 @@ def model_fn_qwen_image(
img_shapes += [(e.shape[0], e.shape[2]//2, e.shape[3]//2) for e in edit_latents_list] img_shapes += [(e.shape[0], e.shape[2]//2, e.shape[3]//2) for e in edit_latents_list]
edit_image = [rearrange(e, "B C (H P) (W Q) -> B (H W) (C P Q)", H=e.shape[2]//2, W=e.shape[3]//2, P=2, Q=2) for e in edit_latents_list] edit_image = [rearrange(e, "B C (H P) (W Q) -> B (H W) (C P Q)", H=e.shape[2]//2, W=e.shape[3]//2, P=2, Q=2) for e in edit_latents_list]
image = torch.cat([image] + edit_image, dim=1) image = torch.cat([image] + edit_image, dim=1)
if layer_input_latents is not None:
layer_num = layer_num + 1
img_shapes += [(layer_input_latents.shape[0], layer_input_latents.shape[2]//2, layer_input_latents.shape[3]//2)]
layer_input_latents = rearrange(layer_input_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
image = torch.cat([image, layer_input_latents], dim=1)
image = dit.img_in(image) image = dit.img_in(image)
conditioning = dit.time_text_embed(timestep, image.dtype) if zero_cond_t:
timestep = torch.cat([timestep, timestep * 0], dim=0)
modulate_index = torch.tensor(
[[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in [img_shapes]],
device=timestep.device,
dtype=torch.int,
)
else:
modulate_index = None
conditioning = dit.time_text_embed(
timestep,
image.dtype,
addition_t_cond=None if not dit.time_text_embed.use_additional_t_cond else torch.tensor([0]).to(device=image.device, dtype=torch.long)
)
if entity_prompt_emb is not None: if entity_prompt_emb is not None:
text, image_rotary_emb, attention_mask = dit.process_entity_masks( text, image_rotary_emb, attention_mask = dit.process_entity_masks(
@@ -728,6 +794,7 @@ def model_fn_qwen_image(
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask, attention_mask=attention_mask,
enable_fp8_attention=enable_fp8_attention, enable_fp8_attention=enable_fp8_attention,
modulate_index=modulate_index,
) )
if blockwise_controlnet_conditioning is not None: if blockwise_controlnet_conditioning is not None:
image_slice = image[:, :image_seq_len].clone() image_slice = image[:, :image_seq_len].clone()
@@ -738,9 +805,11 @@ def model_fn_qwen_image(
) )
image[:, :image_seq_len] = image_slice + controlnet_output image[:, :image_seq_len] = image_slice + controlnet_output
if zero_cond_t:
conditioning = conditioning.chunk(2, dim=0)[0]
image = dit.norm_out(image, conditioning) image = dit.norm_out(image, conditioning)
image = dit.proj_out(image) image = dit.proj_out(image)
image = image[:, :image_seq_len] image = image[:, :image_seq_len]
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2) latents = rearrange(image, "B (N H W) (C P Q) -> (B N) C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2, B=1)
return latents return latents

View File

@@ -11,6 +11,7 @@ from typing import Optional
from typing_extensions import Literal from typing_extensions import Literal
from transformers import Wav2Vec2Processor from transformers import Wav2Vec2Processor
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
@@ -30,7 +31,7 @@ from ..models.longcat_video_dit import LongCatVideoTransformer3DModel
class WanVideoPipeline(BasePipeline): class WanVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16): def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__( super().__init__(
device=device, torch_dtype=torch_dtype, device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
@@ -98,7 +99,7 @@ class WanVideoPipeline(BasePipeline):
@staticmethod @staticmethod
def from_pretrained( def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16, torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda", device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [], model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
audio_processor_config: ModelConfig = None, audio_processor_config: ModelConfig = None,
@@ -122,11 +123,15 @@ class WanVideoPipeline(BasePipeline):
model_config.model_id = redirect_dict[model_config.origin_file_pattern][0] model_config.model_id = redirect_dict[model_config.origin_file_pattern][0]
model_config.origin_file_pattern = redirect_dict[model_config.origin_file_pattern][1] model_config.origin_file_pattern = redirect_dict[model_config.origin_file_pattern][1]
# Initialize pipeline
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
if use_usp: if use_usp:
from ..utils.xfuser import initialize_usp from ..utils.xfuser import initialize_usp
initialize_usp(device) initialize_usp(device)
import torch.distributed as dist
from ..core.device.npu_compatible_device import get_device_name
if dist.is_available() and dist.is_initialized():
device = get_device_name()
# Initialize pipeline
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit) model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models # Fetch models
@@ -241,6 +246,7 @@ class WanVideoPipeline(BasePipeline):
tea_cache_model_id: Optional[str] = "", tea_cache_model_id: Optional[str] = "",
# progress_bar # progress_bar
progress_bar_cmd=tqdm, progress_bar_cmd=tqdm,
output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized",
): ):
# Scheduler # Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
@@ -320,9 +326,11 @@ class WanVideoPipeline(BasePipeline):
# Decode # Decode
self.load_models_to_device(['vae']) self.load_models_to_device(['vae'])
video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
video = self.vae_output_to_video(video) if output_type == "quantized":
video = self.vae_output_to_video(video)
elif output_type == "floatpoint":
pass
self.load_models_to_device([]) self.load_models_to_device([])
return video return video
@@ -823,9 +831,9 @@ class WanVideoUnit_S2V(PipelineUnit):
pipe.load_models_to_device(["vae"]) pipe.load_models_to_device(["vae"])
motion_frames = 73 motion_frames = 73
kwargs = {} kwargs = {}
if motion_video is not None and len(motion_video) > 0: if motion_video is not None:
assert len(motion_video) == motion_frames, f"motion video must have {motion_frames} frames, but got {len(motion_video)}" assert motion_video.shape[2] == motion_frames, f"motion video must have {motion_frames} frames, but got {motion_video.shape[2]}"
motion_latents = pipe.preprocess_video(motion_video) motion_latents = motion_video
kwargs["drop_motion_frames"] = False kwargs["drop_motion_frames"] = False
else: else:
motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device) motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device)
@@ -957,7 +965,7 @@ class WanVideoUnit_AnimateInpaint(PipelineUnit):
onload_model_names=("vae",) onload_model_names=("vae",)
) )
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"): def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device=get_device_type()):
if mask_pixel_values is None: if mask_pixel_values is None:
msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device) msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
else: else:

View File

@@ -4,21 +4,29 @@ from typing import Union
from tqdm import tqdm from tqdm import tqdm
from einops import rearrange from einops import rearrange
import numpy as np import numpy as np
from typing import Union, List, Optional, Tuple from typing import Union, List, Optional, Tuple, Iterable, Dict
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward from ..core import ModelConfig, gradient_checkpoint_forward
from ..core.data.operators import ImageCropAndResize
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
from ..utils.lora import merge_lora
from transformers import AutoTokenizer from transformers import AutoTokenizer
from ..models.z_image_text_encoder import ZImageTextEncoder from ..models.z_image_text_encoder import ZImageTextEncoder
from ..models.z_image_dit import ZImageDiT from ..models.z_image_dit import ZImageDiT
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
from ..models.siglip2_image_encoder import Siglip2ImageEncoder428M
from ..models.z_image_controlnet import ZImageControlNet
from ..models.siglip2_image_encoder import Siglip2ImageEncoder
from ..models.dinov3_image_encoder import DINOv3ImageEncoder
from ..models.z_image_image2lora import ZImageImage2LoRAModel
class ZImagePipeline(BasePipeline): class ZImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16): def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__( super().__init__(
device=device, torch_dtype=torch_dtype, device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, height_division_factor=16, width_division_factor=16,
@@ -28,13 +36,22 @@ class ZImagePipeline(BasePipeline):
self.dit: ZImageDiT = None self.dit: ZImageDiT = None
self.vae_encoder: FluxVAEEncoder = None self.vae_encoder: FluxVAEEncoder = None
self.vae_decoder: FluxVAEDecoder = None self.vae_decoder: FluxVAEDecoder = None
self.image_encoder: Siglip2ImageEncoder428M = None
self.controlnet: ZImageControlNet = None
self.siglip2_image_encoder: Siglip2ImageEncoder = None
self.dinov3_image_encoder: DINOv3ImageEncoder = None
self.image2lora_style: ZImageImage2LoRAModel = None
self.tokenizer: AutoTokenizer = None self.tokenizer: AutoTokenizer = None
self.in_iteration_models = ("dit",) self.in_iteration_models = ("dit", "controlnet")
self.units = [ self.units = [
ZImageUnit_ShapeChecker(), ZImageUnit_ShapeChecker(),
ZImageUnit_PromptEmbedder(), ZImageUnit_PromptEmbedder(),
ZImageUnit_NoiseInitializer(), ZImageUnit_NoiseInitializer(),
ZImageUnit_InputImageEmbedder(), ZImageUnit_InputImageEmbedder(),
ZImageUnit_EditImageAutoResize(),
ZImageUnit_EditImageEmbedderVAE(),
ZImageUnit_EditImageEmbedderSiglip(),
ZImageUnit_PAIControlNet(),
] ]
self.model_fn = model_fn_z_image self.model_fn = model_fn_z_image
@@ -42,7 +59,7 @@ class ZImagePipeline(BasePipeline):
@staticmethod @staticmethod
def from_pretrained( def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16, torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda", device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [], model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
vram_limit: float = None, vram_limit: float = None,
@@ -56,6 +73,11 @@ class ZImagePipeline(BasePipeline):
pipe.dit = model_pool.fetch_model("z_image_dit") pipe.dit = model_pool.fetch_model("z_image_dit")
pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder") pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder")
pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder") pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder")
pipe.image_encoder = model_pool.fetch_model("siglip_vision_model_428m")
pipe.controlnet = model_pool.fetch_model("z_image_controlnet")
pipe.siglip2_image_encoder = model_pool.fetch_model("siglip2_image_encoder")
pipe.dinov3_image_encoder = model_pool.fetch_model("dinov3_image_encoder")
pipe.image2lora_style = model_pool.fetch_model("z_image_image2lora_style")
if tokenizer_config is not None: if tokenizer_config is not None:
tokenizer_config.download_if_necessary() tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
@@ -75,6 +97,9 @@ class ZImagePipeline(BasePipeline):
# Image # Image
input_image: Image.Image = None, input_image: Image.Image = None,
denoising_strength: float = 1.0, denoising_strength: float = 1.0,
# Edit
edit_image: Image.Image = None,
edit_image_auto_resize: bool = True,
# Shape # Shape
height: int = 1024, height: int = 1024,
width: int = 1024, width: int = 1024,
@@ -83,11 +108,17 @@ class ZImagePipeline(BasePipeline):
rand_device: str = "cpu", rand_device: str = "cpu",
# Steps # Steps
num_inference_steps: int = 8, num_inference_steps: int = 8,
sigma_shift: float = None,
# ControlNet
controlnet_inputs: List[ControlNetInput] = None,
# Image to LoRA
image2lora_images: List[Image.Image] = None,
positive_only_lora: Dict[str, torch.Tensor] = None,
# Progress bar # Progress bar
progress_bar_cmd = tqdm, progress_bar_cmd = tqdm,
): ):
# Scheduler # Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength) self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
# Parameters # Parameters
inputs_posi = { inputs_posi = {
@@ -102,6 +133,9 @@ class ZImagePipeline(BasePipeline):
"height": height, "width": width, "height": height, "width": width,
"seed": seed, "rand_device": rand_device, "seed": seed, "rand_device": rand_device,
"num_inference_steps": num_inference_steps, "num_inference_steps": num_inference_steps,
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
"controlnet_inputs": controlnet_inputs,
"image2lora_images": image2lora_images, "positive_only_lora": positive_only_lora,
} }
for unit in self.units: for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -143,12 +177,13 @@ class ZImageUnit_PromptEmbedder(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
seperate_cfg=True, seperate_cfg=True,
input_params=("edit_image",),
input_params_posi={"prompt": "prompt"}, input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"}, input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_embeds",), output_params=("prompt_embeds",),
onload_model_names=("text_encoder",) onload_model_names=("text_encoder",)
) )
def encode_prompt( def encode_prompt(
self, self,
pipe, pipe,
@@ -194,10 +229,81 @@ class ZImageUnit_PromptEmbedder(PipelineUnit):
embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
return embeddings_list return embeddings_list
def encode_prompt_omni(
self,
pipe,
prompt: Union[str, List[str]],
edit_image=None,
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
) -> List[torch.FloatTensor]:
if isinstance(prompt, str):
prompt = [prompt]
def process(self, pipe: ZImagePipeline, prompt): if edit_image is None:
num_condition_images = 0
elif isinstance(edit_image, list):
num_condition_images = len(edit_image)
else:
num_condition_images = 1
for i, prompt_item in enumerate(prompt):
if num_condition_images == 0:
prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"]
elif num_condition_images > 0:
prompt_list = ["<|im_start|>user\n<|vision_start|>"]
prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1)
prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"]
prompt_list += ["<|vision_end|><|im_end|>"]
prompt[i] = prompt_list
flattened_prompt = []
prompt_list_lengths = []
for i in range(len(prompt)):
prompt_list_lengths.append(len(prompt[i]))
flattened_prompt.extend(prompt[i])
text_inputs = pipe.tokenizer(
flattened_prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device).bool()
prompt_embeds = pipe.text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
).hidden_states[-2]
embeddings_list = []
start_idx = 0
for i in range(len(prompt_list_lengths)):
batch_embeddings = []
end_idx = start_idx + prompt_list_lengths[i]
for j in range(start_idx, end_idx):
batch_embeddings.append(prompt_embeds[j][prompt_masks[j]])
embeddings_list.append(batch_embeddings)
start_idx = end_idx
return embeddings_list
def process(self, pipe: ZImagePipeline, prompt, edit_image):
pipe.load_models_to_device(self.onload_model_names) pipe.load_models_to_device(self.onload_model_names)
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device) if hasattr(pipe, "dit") and pipe.dit.siglip_embedder is not None:
# Z-Image-Turbo and Z-Image-Omni-Base use different prompt encoding methods.
# We determine which encoding method to use based on the model architecture.
# If you are using two-stage split training,
# please use `--offload_models` instead of skipping the DiT model loading.
prompt_embeds = self.encode_prompt_omni(pipe, prompt, edit_image, pipe.device)
else:
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
return {"prompt_embeds": prompt_embeds} return {"prompt_embeds": prompt_embeds}
@@ -234,24 +340,330 @@ class ZImageUnit_InputImageEmbedder(PipelineUnit):
return {"latents": latents, "input_latents": input_latents} return {"latents": latents, "input_latents": input_latents}
class ZImageUnit_EditImageAutoResize(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image", "edit_image_auto_resize"),
output_params=("edit_image",),
)
def process(self, pipe: ZImagePipeline, edit_image, edit_image_auto_resize):
if edit_image is None:
return {}
if edit_image_auto_resize is None or not edit_image_auto_resize:
return {}
operator = ImageCropAndResize(max_pixels=1024*1024, height_division_factor=16, width_division_factor=16)
if not isinstance(edit_image, list):
edit_image = [edit_image]
edit_image = [operator(i) for i in edit_image]
return {"edit_image": edit_image}
class ZImageUnit_EditImageEmbedderSiglip(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image",),
output_params=("image_embeds",),
onload_model_names=("image_encoder",)
)
def process(self, pipe: ZImagePipeline, edit_image):
if edit_image is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
if not isinstance(edit_image, list):
edit_image = [edit_image]
image_emb = []
for image_ in edit_image:
image_emb.append(pipe.image_encoder(image_, device=pipe.device))
return {"image_embeds": image_emb}
class ZImageUnit_EditImageEmbedderVAE(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image",),
output_params=("image_latents",),
onload_model_names=("vae_encoder",)
)
def process(self, pipe: ZImagePipeline, edit_image):
if edit_image is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
if not isinstance(edit_image, list):
edit_image = [edit_image]
image_latents = []
for image_ in edit_image:
image_ = pipe.preprocess_image(image_)
image_latents.append(pipe.vae_encoder(image_))
return {"image_latents": image_latents}
class ZImageUnit_PAIControlNet(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("controlnet_inputs", "height", "width"),
output_params=("control_context", "control_scale"),
onload_model_names=("vae_encoder",)
)
def process(self, pipe: ZImagePipeline, controlnet_inputs: List[ControlNetInput], height, width):
if controlnet_inputs is None:
return {}
if len(controlnet_inputs) != 1:
print("Z-Image ControlNet doesn't support multi-ControlNet. Only one image will be used.")
controlnet_input = controlnet_inputs[0]
pipe.load_models_to_device(self.onload_model_names)
control_image = controlnet_input.image
if control_image is not None:
control_image = pipe.preprocess_image(control_image)
control_latents = pipe.vae_encoder(control_image)
else:
control_latents = torch.ones((1, 16, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device) * -1
inpaint_mask = controlnet_input.inpaint_mask
if inpaint_mask is not None:
inpaint_mask = pipe.preprocess_image(inpaint_mask, min_value=0, max_value=1)
inpaint_image = controlnet_input.inpaint_image
inpaint_image = pipe.preprocess_image(inpaint_image)
inpaint_image = inpaint_image * (inpaint_mask < 0.5)
inpaint_mask = torch.nn.functional.interpolate(1 - inpaint_mask, (height // 8, width // 8), mode='nearest')[:, :1]
else:
inpaint_mask = torch.zeros((1, 1, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device)
inpaint_image = torch.zeros((1, 3, height, width), dtype=pipe.torch_dtype, device=pipe.device)
inpaint_latent = pipe.vae_encoder(inpaint_image)
control_context = torch.concat([control_latents, inpaint_mask, inpaint_latent], dim=1)
control_context = rearrange(control_context, "B C H W -> B C 1 H W")
return {"control_context": control_context, "control_scale": controlnet_input.scale}
def model_fn_z_image( def model_fn_z_image(
dit: ZImageDiT, dit: ZImageDiT,
controlnet: ZImageControlNet = None,
latents=None, latents=None,
timestep=None, timestep=None,
prompt_embeds=None, prompt_embeds=None,
image_embeds=None,
image_latents=None,
use_gradient_checkpointing=False, use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False, use_gradient_checkpointing_offload=False,
**kwargs, **kwargs,
): ):
# Due to the complex and verbose codebase of Z-Image,
# we are temporarily using this inelegant structure.
# We will refactor this part in the future (if time permits).
if dit.siglip_embedder is None:
return model_fn_z_image_turbo(
dit,
controlnet=controlnet,
latents=latents,
timestep=timestep,
prompt_embeds=prompt_embeds,
image_embeds=image_embeds,
image_latents=image_latents,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
**kwargs,
)
latents = [rearrange(latents, "B C H W -> C B H W")] latents = [rearrange(latents, "B C H W -> C B H W")]
if dit.siglip_embedder is not None:
if image_latents is not None:
image_latents = [rearrange(image_latent, "B C H W -> C B H W") for image_latent in image_latents]
latents = [image_latents + latents]
image_noise_mask = [[0] * len(image_latents) + [1]]
else:
latents = [latents]
image_noise_mask = [[1]]
image_embeds = [image_embeds]
else:
image_noise_mask = None
timestep = (1000 - timestep) / 1000 timestep = (1000 - timestep) / 1000
model_output = dit( model_output = dit(
latents, latents,
timestep, timestep,
prompt_embeds, prompt_embeds,
siglip_feats=image_embeds,
image_noise_mask=image_noise_mask,
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)[0][0] )[0]
model_output = -model_output model_output = -model_output
model_output = rearrange(model_output, "C B H W -> B C H W") model_output = rearrange(model_output, "C B H W -> B C H W")
return model_output return model_output
class ZImageUnit_Image2LoRAEncode(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("image2lora_images",),
output_params=("image2lora_x",),
onload_model_names=("siglip2_image_encoder", "dinov3_image_encoder",),
)
from ..core.data.operators import ImageCropAndResize
self.processor_highres = ImageCropAndResize(height=1024, width=1024)
def encode_images_using_siglip2(self, pipe: ZImagePipeline, images: list[Image.Image]):
pipe.load_models_to_device(["siglip2_image_encoder"])
embs = []
for image in images:
image = self.processor_highres(image)
embs.append(pipe.siglip2_image_encoder(image).to(pipe.torch_dtype))
embs = torch.stack(embs)
return embs
def encode_images_using_dinov3(self, pipe: ZImagePipeline, images: list[Image.Image]):
pipe.load_models_to_device(["dinov3_image_encoder"])
embs = []
for image in images:
image = self.processor_highres(image)
embs.append(pipe.dinov3_image_encoder(image).to(pipe.torch_dtype))
embs = torch.stack(embs)
return embs
def encode_images(self, pipe: ZImagePipeline, images: list[Image.Image]):
if images is None:
return {}
if not isinstance(images, list):
images = [images]
embs_siglip2 = self.encode_images_using_siglip2(pipe, images)
embs_dinov3 = self.encode_images_using_dinov3(pipe, images)
x = torch.concat([embs_siglip2, embs_dinov3], dim=-1)
return x
def process(self, pipe: ZImagePipeline, image2lora_images):
if image2lora_images is None:
return {}
x = self.encode_images(pipe, image2lora_images)
return {"image2lora_x": x}
class ZImageUnit_Image2LoRADecode(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("image2lora_x",),
output_params=("lora",),
onload_model_names=("image2lora_style",),
)
def process(self, pipe: ZImagePipeline, image2lora_x):
if image2lora_x is None:
return {}
loras = []
if pipe.image2lora_style is not None:
pipe.load_models_to_device(["image2lora_style"])
for x in image2lora_x:
loras.append(pipe.image2lora_style(x=x, residual=None))
lora = merge_lora(loras, alpha=1 / len(image2lora_x))
return {"lora": lora}
def model_fn_z_image_turbo(
dit: ZImageDiT,
controlnet: ZImageControlNet = None,
latents=None,
timestep=None,
prompt_embeds=None,
image_embeds=None,
image_latents=None,
control_context=None,
control_scale=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
while isinstance(prompt_embeds, list):
prompt_embeds = prompt_embeds[0]
while isinstance(latents, list):
latents = latents[0]
while isinstance(image_embeds, list):
image_embeds = image_embeds[0]
# Timestep
timestep = 1000 - timestep
t_noisy = dit.t_embedder(timestep)
t_clean = dit.t_embedder(torch.ones_like(timestep) * 1000)
# Patchify
latents = rearrange(latents, "B C H W -> C B H W")
x, cap_feats, patch_metadata = dit.patchify_and_embed([latents], [prompt_embeds])
x = x[0]
cap_feats = cap_feats[0]
# Noise refine
x = dit.all_x_embedder["2-1"](x)
x[torch.cat(patch_metadata.get("x_pad_mask"))] = dit.x_pad_token.to(dtype=x.dtype, device=x.device)
x_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("x_pos_ids"), dim=0))
x = rearrange(x, "L C -> 1 L C")
x_freqs_cis = rearrange(x_freqs_cis, "L C -> 1 L C")
if control_context is not None:
kwargs = dict(attn_mask=None, freqs_cis=x_freqs_cis, adaln_input=t_noisy)
refiner_hints, control_context, control_context_item_seqlens = controlnet.forward_refiner(
dit, x, [cap_feats], control_context, kwargs, t=t_noisy, patch_size=2, f_patch_size=1,
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
for layer_id, layer in enumerate(dit.noise_refiner):
x = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=x,
attn_mask=None,
freqs_cis=x_freqs_cis,
adaln_input=t_noisy,
)
if control_context is not None:
x = x + refiner_hints[layer_id] * control_scale
# Prompt refine
cap_feats = dit.cap_embedder(cap_feats)
cap_feats[torch.cat(patch_metadata.get("cap_pad_mask"))] = dit.cap_pad_token.to(dtype=x.dtype, device=x.device)
cap_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("cap_pos_ids"), dim=0))
cap_feats = rearrange(cap_feats, "L C -> 1 L C")
cap_freqs_cis = rearrange(cap_freqs_cis, "L C -> 1 L C")
for layer in dit.context_refiner:
cap_feats = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=cap_feats,
attn_mask=None,
freqs_cis=cap_freqs_cis,
)
# Unified
unified = torch.cat([x, cap_feats], dim=1)
unified_freqs_cis = torch.cat([x_freqs_cis, cap_freqs_cis], dim=1)
if control_context is not None:
kwargs = dict(attn_mask=None, freqs_cis=unified_freqs_cis, adaln_input=t_noisy)
hints = controlnet.forward_layers(
unified, cap_feats, control_context, control_context_item_seqlens, kwargs,
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
for layer_id, layer in enumerate(dit.layers):
unified = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=unified,
attn_mask=None,
freqs_cis=unified_freqs_cis,
adaln_input=t_noisy,
)
if control_context is not None:
if layer_id in controlnet.control_layers_mapping:
unified = unified + hints[controlnet.control_layers_mapping[layer_id]] * control_scale
# Output
unified = dit.all_final_layer["2-1"](unified, t_noisy)
x = dit.unpatchify([unified[0]], patch_metadata.get("x_size"))[0]
x = rearrange(x, "C B H W -> B C H W")
x = -x
return x

View File

@@ -1,12 +1,13 @@
from typing_extensions import Literal, TypeAlias from typing_extensions import Literal, TypeAlias
from diffsynth.core.device.npu_compatible_device import get_device_type
Processor_id: TypeAlias = Literal[ Processor_id: TypeAlias = Literal[
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint" "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
] ]
class Annotator: class Annotator:
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False): def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device=get_device_type(), skip_processor=False):
if not skip_processor: if not skip_processor:
if processor_id == "canny": if processor_id == "canny":
from controlnet_aux.processor import CannyDetector from controlnet_aux.processor import CannyDetector

View File

@@ -9,5 +9,6 @@ class ControlNetInput:
start: float = 1.0 start: float = 1.0
end: float = 0.0 end: float = 0.0
image: Image.Image = None image: Image.Image = None
inpaint_image: Image.Image = None
inpaint_mask: Image.Image = None inpaint_mask: Image.Image = None
processor_id: str = None processor_id: str = None

View File

@@ -149,6 +149,8 @@ class FluxLoRALoader(GeneralLoRALoader):
dtype=state_dict_[name].dtype) dtype=state_dict_[name].dtype)
else: else:
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn.")) state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
mlp = mlp.to(device=state_dict_[name].device)
if 'lora_A' in name: if 'lora_A' in name:
param = torch.concat([ param = torch.concat([
state_dict_.pop(name), state_dict_.pop(name),

View File

@@ -89,4 +89,109 @@ def FluxDiTStateDictConverter(state_dict):
state_dict_[rename] = state_dict[original_name] state_dict_[rename] = state_dict[original_name]
else: else:
pass pass
return state_dict_
def FluxDiTStateDictConverterFromDiffusers(state_dict):
global_rename_dict = {
"context_embedder": "context_embedder",
"x_embedder": "x_embedder",
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
"norm_out.linear": "final_norm_out.linear",
"proj_out": "final_proj_out",
}
rename_dict = {
"proj_out": "proj_out",
"norm1.linear": "norm1_a.linear",
"norm1_context.linear": "norm1_b.linear",
"attn.to_q": "attn.a_to_q",
"attn.to_k": "attn.a_to_k",
"attn.to_v": "attn.a_to_v",
"attn.to_out.0": "attn.a_to_out",
"attn.add_q_proj": "attn.b_to_q",
"attn.add_k_proj": "attn.b_to_k",
"attn.add_v_proj": "attn.b_to_v",
"attn.to_add_out": "attn.b_to_out",
"ff.net.0.proj": "ff_a.0",
"ff.net.2": "ff_a.2",
"ff_context.net.0.proj": "ff_b.0",
"ff_context.net.2": "ff_b.2",
"attn.norm_q": "attn.norm_q_a",
"attn.norm_k": "attn.norm_k_a",
"attn.norm_added_q": "attn.norm_q_b",
"attn.norm_added_k": "attn.norm_k_b",
}
rename_dict_single = {
"attn.to_q": "a_to_q",
"attn.to_k": "a_to_k",
"attn.to_v": "a_to_v",
"attn.norm_q": "norm_q_a",
"attn.norm_k": "norm_k_a",
"norm.linear": "norm.linear",
"proj_mlp": "proj_in_besides_attn",
"proj_out": "proj_out",
}
state_dict_ = {}
for name in state_dict:
param = state_dict[name]
if name.endswith(".weight") or name.endswith(".bias"):
suffix = ".weight" if name.endswith(".weight") else ".bias"
prefix = name[:-len(suffix)]
if prefix in global_rename_dict:
if global_rename_dict[prefix] == "final_norm_out.linear":
param = torch.concat([param[3072:], param[:3072]], dim=0)
state_dict_[global_rename_dict[prefix] + suffix] = param
elif prefix.startswith("transformer_blocks."):
names = prefix.split(".")
names[0] = "blocks"
middle = ".".join(names[2:])
if middle in rename_dict:
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
state_dict_[name_] = param
elif prefix.startswith("single_transformer_blocks."):
names = prefix.split(".")
names[0] = "single_blocks"
middle = ".".join(names[2:])
if middle in rename_dict_single:
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
state_dict_[name_] = param
else:
pass
else:
pass
for name in list(state_dict_.keys()):
if "single_blocks." in name and ".a_to_q." in name:
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
if mlp is None:
mlp = torch.zeros(4 * state_dict_[name].shape[0],
*state_dict_[name].shape[1:],
dtype=state_dict_[name].dtype)
else:
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
param = torch.concat([
state_dict_.pop(name),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
mlp,
], dim=0)
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
state_dict_[name_] = param
for name in list(state_dict_.keys()):
for component in ["a", "b"]:
if f".{component}_to_q." in name:
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
return state_dict_ return state_dict_

View File

@@ -0,0 +1,6 @@
def ZImageTextEncoderStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name != "lm_head.weight":
state_dict_[name] = state_dict[name]
return state_dict_

View File

@@ -50,7 +50,7 @@ def rope_apply(x, freqs, num_heads):
sp_rank = get_sequence_parallel_rank() sp_rank = get_sequence_parallel_rank()
freqs = pad_freqs(freqs, s_per_rank * sp_size) freqs = pad_freqs(freqs, s_per_rank * sp_size)
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device == "npu" else freqs_rank
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
return x_out.to(x.dtype) return x_out.to(x.dtype)

View File

@@ -2,6 +2,15 @@
FLUX.2 is an image generation model trained and open-sourced by Black Forest Labs. FLUX.2 is an image generation model trained and open-sourced by Black Forest Labs.
## Model Lineage
```mermaid
graph LR;
FLUX.2-Series-->black-forest-labs/FLUX.2-dev;
FLUX.2-Series-->black-forest-labs/FLUX.2-klein-4B;
FLUX.2-Series-->black-forest-labs/FLUX.2-klein-9B;
```
## Installation ## Installation
Before using this project for model inference and training, please install DiffSynth-Studio first. Before using this project for model inference and training, please install DiffSynth-Studio first.
@@ -50,16 +59,20 @@ image.save("image.jpg")
## Model Overview ## Model Overview
| Model ID | Inference | Low VRAM Inference | LoRA Training | Validation After LoRA Training | | Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
| - | - | - | - | - | | - | - | - | - | - | - | - |
| [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) | [code](/examples/flux2/model_inference/FLUX.2-dev.py) | [code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py) | [code](/examples/flux2/model_training/lora/FLUX.2-dev.sh) | [code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py) | |[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
Special Training Scripts: Special Training Scripts:
* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md), [code](/examples/flux/model_training/special/differential_training/) * Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md)
* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/flux/model_training/special/fp8_training/) * FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md)
* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/flux/model_training/special/split_training/) * Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md)
* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh) * End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md)
## Model Inference ## Model Inference
@@ -135,4 +148,4 @@ We have built a sample image dataset for your testing. You can download this dat
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
``` ```
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/). We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/).

View File

@@ -81,8 +81,12 @@ graph LR;
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training | | Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
| - | - | - | - | - | - | - | | - | - | - | - | - | - | - |
| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](/examples/qwen_image/model_inference/Qwen-Image.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) | | [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](/examples/qwen_image/model_inference/Qwen-Image.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) |
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) | | [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) | | [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) | | [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) | | [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) | | [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) |

View File

@@ -50,9 +50,14 @@ image.save("image.jpg")
## Model Overview ## Model Overview
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training | |Model ID|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
| - | - | - | - | - | - | - | |-|-|-|-|-|-|-|
| [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) | [code](/examples/z_image/model_inference/Z-Image-Turbo.py) | [code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py) | [code](/examples/z_image/model_training/full/Z-Image-Turbo.sh) | [code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py) | [code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh) | [code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py) | |[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
Special Training Scripts: Special Training Scripts:
@@ -75,6 +80,9 @@ Input parameters for `ZImagePipeline` inference include:
* `seed`: Random seed. Default is `None`, meaning completely random. * `seed`: Random seed. Default is `None`, meaning completely random.
* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different GPUs will produce different generation results. * `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different GPUs will produce different generation results.
* `num_inference_steps`: Number of inference steps, default value is 8. * `num_inference_steps`: Number of inference steps, default value is 8.
* `controlnet_inputs`: Inputs for ControlNet models.
* `edit_image`: Edit images for image editing models, supporting multiple images.
* `positive_only_lora`: LoRA weights used only in positive prompts.
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above. If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.

View File

@@ -13,7 +13,7 @@ All sample code provided by this project supports NVIDIA GPUs by default, requir
AMD provides PyTorch packages based on ROCm, so most models can run without code changes. A small number of models may not be compatible due to their reliance on CUDA-specific instructions. AMD provides PyTorch packages based on ROCm, so most models can run without code changes. A small number of models may not be compatible due to their reliance on CUDA-specific instructions.
## Ascend NPU ## Ascend NPU
### Inference
When using Ascend NPU, you need to replace `"cuda"` with `"npu"` in your code. When using Ascend NPU, you need to replace `"cuda"` with `"npu"` in your code.
For example, here is the inference code for **Wan2.1-T2V-1.3B**, modified for Ascend NPU: For example, here is the inference code for **Wan2.1-T2V-1.3B**, modified for Ascend NPU:
@@ -22,6 +22,7 @@ For example, here is the inference code for **Wan2.1-T2V-1.3B**, modified for As
import torch import torch
from diffsynth.utils.data import save_video, VideoData from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
from diffsynth.core.device.npu_compatible_device import get_device_name
vram_config = { vram_config = {
"offload_dtype": "disk", "offload_dtype": "disk",
@@ -46,7 +47,7 @@ pipe = WanVideoPipeline.from_pretrained(
], ],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
- vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, - vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
+ vram_limit=torch.npu.mem_get_info("npu")[1] / (1024 ** 3) - 2, + vram_limit=torch.npu.mem_get_info(get_device_name())[1] / (1024 ** 3) - 2,
) )
video = pipe( video = pipe(
@@ -56,3 +57,28 @@ video = pipe(
) )
save_video(video, "video.mp4", fps=15, quality=5) save_video(video, "video.mp4", fps=15, quality=5)
``` ```
### Training
NPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_training`, for example `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`.
In the NPU training scripts, NPU specific environment variables that can optimize performance have been added, and relevant parameters have been enabled for specific models.
#### Environment variables
```shell
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
```
`expandable_segments:<value>`: Enable the memory pool expansion segment function, which is the virtual memory feature.
```shell
export CPU_AFFINITY_CONF=1
```
Set 0 or not set: indicates not enabling the binding function
1: Indicates enabling coarse-grained kernel binding
2: Indicates enabling fine-grained kernel binding
#### Parameters for specific models
| Model | Parameter | Note |
|----------------|---------------------------|-------------------|
| Wan 14B series | --initialize_model_on_cpu | The 14B model needs to be initialized on the CPU |

View File

@@ -30,11 +30,16 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm6
* **Ascend NPU** * **Ascend NPU**
Ascend NPU support is provided via the `torch-npu` package. Taking version `2.1.0.post17` (as of the article update date: December 15, 2025) as an example, run the following command: 1. Install [CANN](https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/softwareinst/instg/instg_quick.html?Mode=PmIns&InstallType=local&OS=openEuler&Software=cannToolKit) through official documentation.
```shell 2. Install from source
pip install torch-npu==2.1.0.post17 ```shell
``` git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
# aarch64/ARM
pip install -e .[npu_aarch64] --extra-index-url "https://download.pytorch.org/whl/cpu"
# x86
pip install -e .[npu]
When using Ascend NPU, please replace `"cuda"` with `"npu"` in your Python code. For details, see [NPU Support](/docs/en/Pipeline_Usage/GPU_support.md#ascend-npu). When using Ascend NPU, please replace `"cuda"` with `"npu"` in your Python code. For details, see [NPU Support](/docs/en/Pipeline_Usage/GPU_support.md#ascend-npu).

View File

@@ -2,6 +2,15 @@
FLUX.2 是由 Black Forest Labs 训练并开源的图像生成模型。 FLUX.2 是由 Black Forest Labs 训练并开源的图像生成模型。
## 模型血缘
```mermaid
graph LR;
FLUX.2-Series-->black-forest-labs/FLUX.2-dev;
FLUX.2-Series-->black-forest-labs/FLUX.2-klein-4B;
FLUX.2-Series-->black-forest-labs/FLUX.2-klein-9B;
```
## 安装 ## 安装
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。 在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
@@ -50,16 +59,20 @@ image.save("image.jpg")
## 模型总览 ## 模型总览
|模型 ID|推理|低显存推理|LoRA 训练|LoRA 训练后验证| |模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-| |-|-|-|-|-|-|-|
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| |[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
特殊训练脚本: 特殊训练脚本:
* 差分 LoRA 训练:[doc](/docs/zh/Training/Differential_LoRA.md)、[code](/examples/flux/model_training/special/differential_training/) * 差分 LoRA 训练:[doc](/docs/zh/Training/Differential_LoRA.md)
* FP8 精度训练:[doc](/docs/zh/Training/FP8_Precision.md)、[code](/examples/flux/model_training/special/fp8_training/) * FP8 精度训练:[doc](/docs/zh/Training/FP8_Precision.md)
* 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)、[code](/examples/flux/model_training/special/split_training/) * 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)
* 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)、[code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh) * 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)
## 模型推理 ## 模型推理
@@ -135,4 +148,4 @@ FLUX.2 系列模型统一通过 [`examples/flux2/model_training/train.py`](/exam
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
``` ```
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。 我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。

View File

@@ -81,8 +81,12 @@ graph LR;
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-| |-|-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)| |[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)| |[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)| |[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| |[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| |[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)| |[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|

View File

@@ -52,7 +52,12 @@ image.save("image.jpg")
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-| |-|-|-|-|-|-|-|
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)| |[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
特殊训练脚本: 特殊训练脚本:
@@ -75,6 +80,9 @@ image.save("image.jpg")
* `seed`: 随机种子。默认为 `None`,即完全随机。 * `seed`: 随机种子。默认为 `None`,即完全随机。
* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。 * `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
* `num_inference_steps`: 推理次数,默认值为 8。 * `num_inference_steps`: 推理次数,默认值为 8。
* `controlnet_inputs`: ControlNet 模型的输入。
* `edit_image`: 编辑模型的待编辑图像,支持多张图像。
* `positive_only_lora`: 仅在正向提示词中使用的 LoRA 权重。
如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。 如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。

View File

@@ -13,7 +13,7 @@
AMD 提供了基于 ROCm 的 torch 包,所以大多数模型无需修改代码即可运行,少数模型由于依赖特定的 cuda 指令无法运行。 AMD 提供了基于 ROCm 的 torch 包,所以大多数模型无需修改代码即可运行,少数模型由于依赖特定的 cuda 指令无法运行。
## Ascend NPU ## Ascend NPU
### 推理
使用 Ascend NPU 时,需把代码中的 `"cuda"` 改为 `"npu"` 使用 Ascend NPU 时,需把代码中的 `"cuda"` 改为 `"npu"`
例如Wan2.1-T2V-1.3B 的推理代码: 例如Wan2.1-T2V-1.3B 的推理代码:
@@ -22,6 +22,7 @@ AMD 提供了基于 ROCm 的 torch 包,所以大多数模型无需修改代码
import torch import torch
from diffsynth.utils.data import save_video, VideoData from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
from diffsynth.core.device.npu_compatible_device import get_device_name
vram_config = { vram_config = {
"offload_dtype": "disk", "offload_dtype": "disk",
@@ -33,7 +34,7 @@ vram_config = {
+ "preparing_device": "npu", + "preparing_device": "npu",
"computation_dtype": torch.bfloat16, "computation_dtype": torch.bfloat16,
- "computation_device": "cuda", - "computation_device": "cuda",
+ "preparing_device": "npu", + "computation_device": "npu",
} }
pipe = WanVideoPipeline.from_pretrained( pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
@@ -46,7 +47,7 @@ pipe = WanVideoPipeline.from_pretrained(
], ],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
- vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, - vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
+ vram_limit=torch.npu.mem_get_info("npu")[1] / (1024 ** 3) - 2, + vram_limit=torch.npu.mem_get_info(get_device_name())[1] / (1024 ** 3) - 2,
) )
video = pipe( video = pipe(
@@ -56,3 +57,28 @@ video = pipe(
) )
save_video(video, "video.mp4", fps=15, quality=5) save_video(video, "video.mp4", fps=15, quality=5)
``` ```
### 训练
当前已为每类模型添加NPU的启动脚本样例脚本存放在`examples/xxx/special/npu_training`目录下,例如 `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`
在NPU训练脚本中添加了可以优化性能的NPU特有环境变量并针对特定模型开启了相关参数。
#### 环境变量
```shell
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
```
`expandable_segments:<value>`: 使能内存池扩展段功能,即虚拟内存特征。
```shell
export CPU_AFFINITY_CONF=1
```
设置0或未设置: 表示不启用绑核功能
1: 表示开启粗粒度绑核
2: 表示开启细粒度绑核
#### 特定模型需要开启的参数
| 模型 | 参数 | 备注 |
|-----------|------|-------------------|
| Wan 14B系列 | --initialize_model_on_cpu | 14B模型需要在cpu上进行初始化 |

View File

@@ -30,11 +30,16 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm6
* Ascend NPU * Ascend NPU
Ascend NPU 通过 `torch-npu` 包提供支持,以 `2.1.0.post17` 版本(本文更新于 2025 年 12 月 15 日)为例,请运行以下命令 1. 通过官方文档安装[CANN](https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/softwareinst/instg/instg_quick.html?Mode=PmIns&InstallType=local&OS=openEuler&Software=cannToolKit)
```shell 2. 从源码安装
pip install torch-npu==2.1.0.post17 ```shell
``` git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
# aarch64/ARM
pip install -e .[npu_aarch64] --extra-index-url "https://download.pytorch.org/whl/cpu"
# x86
pip install -e .[npu]
使用 Ascend NPU 时,请将 Python 代码中的 `"cuda"` 改为 `"npu"`,详见[NPU 支持](/docs/zh/Pipeline_Usage/GPU_support.md#ascend-npu)。 使用 Ascend NPU 时,请将 Python 代码中的 `"cuda"` 改为 `"npu"`,详见[NPU 支持](/docs/zh/Pipeline_Usage/GPU_support.md#ascend-npu)。

View File

@@ -108,7 +108,14 @@ def test_flux():
run_inference("examples/flux/model_training/validate_lora") run_inference("examples/flux/model_training/validate_lora")
def test_z_image():
run_inference("examples/z_image/model_inference")
run_inference("examples/z_image/model_inference_low_vram")
run_train_multi_GPU("examples/z_image/model_training/full")
run_inference("examples/z_image/model_training/validate_full")
run_train_single_GPU("examples/z_image/model_training/lora")
run_inference("examples/z_image/model_training/validate_lora")
if __name__ == "__main__": if __name__ == "__main__":
test_qwen_image() test_z_image()
test_flux()
test_wan()

View File

@@ -0,0 +1,17 @@
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export CPU_AFFINITY_CONF=1
accelerate launch --config_file examples/flux/model_training/full/accelerate_config_zero2offload.yaml examples/flux/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_kontext.csv \
--data_file_keys "image,kontext_images" \
--max_pixels 1048576 \
--dataset_repeat 400 \
--model_id_with_origin_paths "black-forest-labs/FLUX.1-Kontext-dev:flux1-kontext-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors" \
--learning_rate 1e-5 \
--num_epochs 1 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.1-Kontext-dev_full" \
--trainable_models "dit" \
--extra_inputs "kontext_images" \
--use_gradient_checkpointing

View File

@@ -0,0 +1,15 @@
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export CPU_AFFINITY_CONF=1
accelerate launch --config_file examples/flux/model_training/full/accelerate_config_zero2offload.yaml examples/flux/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 400 \
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors" \
--learning_rate 1e-5 \
--num_epochs 1 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.1-dev_full" \
--trainable_models "dit" \
--use_gradient_checkpointing

View File

@@ -0,0 +1,21 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
)
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4)
image.save("image_FLUX.2-klein-4B.jpg")
prompt = "change the color of the clothes to red"
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=4)
image.save("image_edit_FLUX.2-klein-4B.jpg")

View File

@@ -0,0 +1,21 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
)
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4)
image.save("image_FLUX.2-klein-9B.jpg")
prompt = "change the color of the clothes to red"
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=4)
image.save("image_edit_FLUX.2-klein-9B.jpg")

View File

@@ -0,0 +1,21 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
)
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
image.save("image_FLUX.2-klein-base-4B.jpg")
prompt = "change the color of the clothes to red"
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
image.save("image_edit_FLUX.2-klein-base-4B.jpg")

View File

@@ -0,0 +1,21 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-9B", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
)
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
image.save("image_FLUX.2-klein-base-9B.jpg")
prompt = "change the color of the clothes to red"
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
image.save("image_edit_FLUX.2-klein-base-9B.jpg")

View File

@@ -0,0 +1,31 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
)
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4)
image.save("image_FLUX.2-klein-4B.jpg")
prompt = "change the color of the clothes to red"
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=4)
image.save("image_edit_FLUX.2-klein-4B.jpg")

View File

@@ -0,0 +1,31 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
)
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4)
image.save("image_FLUX.2-klein-9B.jpg")
prompt = "change the color of the clothes to red"
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=4)
image.save("image_edit_FLUX.2-klein-9B.jpg")

View File

@@ -0,0 +1,31 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
)
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
image.save("image_FLUX.2-klein-base-4B.jpg")
prompt = "change the color of the clothes to red"
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
image.save("image_edit_FLUX.2-klein-base-4B.jpg")

View File

@@ -0,0 +1,31 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-9B", origin_file_pattern="transformer/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
)
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
image.save("image_FLUX.2-klein-base-9B.jpg")
prompt = "change the color of the clothes to red"
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
image.save("image_edit_FLUX.2-klein-base-9B.jpg")

View File

@@ -0,0 +1,30 @@
accelerate launch examples/flux2/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.2-klein-4B_full" \
--trainable_models "dit" \
--use_gradient_checkpointing
# Edit
# accelerate launch examples/flux2/model_training/train.py \
# --dataset_base_path data/example_image_dataset \
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
# --data_file_keys "image,edit_image" \
# --extra_inputs "edit_image" \
# --max_pixels 1048576 \
# --dataset_repeat 50 \
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
# --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
# --learning_rate 1e-5 \
# --num_epochs 2 \
# --remove_prefix_in_ckpt "pipe.dit." \
# --output_path "./models/train/FLUX.2-klein-4B_full" \
# --trainable_models "dit" \
# --use_gradient_checkpointing

View File

@@ -0,0 +1,31 @@
# This script is tested on 8*A100
accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.2-klein-9B_full" \
--trainable_models "dit" \
--use_gradient_checkpointing
# Edit
# accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \
# --dataset_base_path data/example_image_dataset \
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
# --data_file_keys "image,edit_image" \
# --extra_inputs "edit_image" \
# --max_pixels 1048576 \
# --dataset_repeat 50 \
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
# --learning_rate 1e-5 \
# --num_epochs 2 \
# --remove_prefix_in_ckpt "pipe.dit." \
# --output_path "./models/train/FLUX.2-klein-9B_full" \
# --trainable_models "dit" \
# --use_gradient_checkpointing

View File

@@ -0,0 +1,30 @@
accelerate launch examples/flux2/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.2-klein-base-4B_full" \
--trainable_models "dit" \
--use_gradient_checkpointing
# Edit
# accelerate launch examples/flux2/model_training/train.py \
# --dataset_base_path data/example_image_dataset \
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
# --data_file_keys "image,edit_image" \
# --extra_inputs "edit_image" \
# --max_pixels 1048576 \
# --dataset_repeat 50 \
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
# --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
# --learning_rate 1e-5 \
# --num_epochs 2 \
# --remove_prefix_in_ckpt "pipe.dit." \
# --output_path "./models/train/FLUX.2-klein-base-4B_full" \
# --trainable_models "dit" \
# --use_gradient_checkpointing

View File

@@ -0,0 +1,31 @@
# This script is tested on 8*A100
accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.2-klein-base-9B_full" \
--trainable_models "dit" \
--use_gradient_checkpointing
# Edit
# accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \
# --dataset_base_path data/example_image_dataset \
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
# --data_file_keys "image,edit_image" \
# --extra_inputs "edit_image" \
# --max_pixels 1048576 \
# --dataset_repeat 50 \
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
# --learning_rate 1e-5 \
# --num_epochs 2 \
# --remove_prefix_in_ckpt "pipe.dit." \
# --output_path "./models/train/FLUX.2-klein-base-9B_full" \
# --trainable_models "dit" \
# --use_gradient_checkpointing

View File

@@ -0,0 +1,22 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -0,0 +1,34 @@
accelerate launch examples/flux2/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.2-klein-4B_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out" \
--lora_rank 32 \
--use_gradient_checkpointing
# Edit
# accelerate launch examples/flux2/model_training/train.py \
# --dataset_base_path data/example_image_dataset \
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
# --data_file_keys "image,edit_image" \
# --extra_inputs "edit_image" \
# --max_pixels 1048576 \
# --dataset_repeat 50 \
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
# --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
# --learning_rate 1e-4 \
# --num_epochs 5 \
# --remove_prefix_in_ckpt "pipe.dit." \
# --output_path "./models/train/FLUX.2-klein-4B_lora" \
# --lora_base_model "dit" \
# --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out" \
# --lora_rank 32 \
# --use_gradient_checkpointing

View File

@@ -0,0 +1,34 @@
accelerate launch examples/flux2/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.2-klein-9B_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out" \
--lora_rank 32 \
--use_gradient_checkpointing
# Edit
# accelerate launch examples/flux2/model_training/train.py \
# --dataset_base_path data/example_image_dataset \
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
# --data_file_keys "image,edit_image" \
# --extra_inputs "edit_image" \
# --max_pixels 1048576 \
# --dataset_repeat 50 \
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
# --learning_rate 1e-4 \
# --num_epochs 5 \
# --remove_prefix_in_ckpt "pipe.dit." \
# --output_path "./models/train/FLUX.2-klein-9B_lora" \
# --lora_base_model "dit" \
# --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out" \
# --lora_rank 32 \
# --use_gradient_checkpointing

View File

@@ -0,0 +1,34 @@
accelerate launch examples/flux2/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.2-klein-base-4B_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out" \
--lora_rank 32 \
--use_gradient_checkpointing
# Edit
# accelerate launch examples/flux2/model_training/train.py \
# --dataset_base_path data/example_image_dataset \
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
# --data_file_keys "image,edit_image" \
# --extra_inputs "edit_image" \
# --max_pixels 1048576 \
# --dataset_repeat 50 \
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
# --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
# --learning_rate 1e-4 \
# --num_epochs 5 \
# --remove_prefix_in_ckpt "pipe.dit." \
# --output_path "./models/train/FLUX.2-klein-base-4B_lora" \
# --lora_base_model "dit" \
# --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out" \
# --lora_rank 32 \
# --use_gradient_checkpointing

View File

@@ -0,0 +1,34 @@
accelerate launch examples/flux2/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.2-klein-base-9B_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out" \
--lora_rank 32 \
--use_gradient_checkpointing
# Edit
# accelerate launch examples/flux2/model_training/train.py \
# --dataset_base_path data/example_image_dataset \
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
# --data_file_keys "image,edit_image" \
# --extra_inputs "edit_image" \
# --max_pixels 1048576 \
# --dataset_repeat 50 \
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
# --learning_rate 1e-4 \
# --num_epochs 5 \
# --remove_prefix_in_ckpt "pipe.dit." \
# --output_path "./models/train/FLUX.2-klein-base-9B_lora" \
# --lora_base_model "dit" \
# --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out" \
# --lora_rank 32 \
# --use_gradient_checkpointing

View File

@@ -24,7 +24,7 @@ class Flux2ImageTrainingModule(DiffusionTrainingModule):
super().__init__() super().__init__()
# Load models # Load models
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
tokenizer_config = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"))
self.pipe = Flux2ImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config) self.pipe = Flux2ImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)

View File

@@ -0,0 +1,20 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
from diffsynth.core import load_state_dict
import torch
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
)
state_dict = load_state_dict("./models/train/FLUX.2-klein-4B_full/epoch-1.safetensors", torch_dtype=torch.bfloat16)
pipe.dit.load_state_dict(state_dict)
prompt = "a dog"
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)
image.save("image.jpg")

View File

@@ -0,0 +1,20 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
from diffsynth.core import load_state_dict
import torch
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
)
state_dict = load_state_dict("./models/train/FLUX.2-klein-9B_full/epoch-1.safetensors", torch_dtype=torch.bfloat16)
pipe.dit.load_state_dict(state_dict)
prompt = "a dog"
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)
image.save("image.jpg")

View File

@@ -0,0 +1,20 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
from diffsynth.core import load_state_dict
import torch
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
)
state_dict = load_state_dict("./models/train/FLUX.2-klein-base-4B_full/epoch-1.safetensors", torch_dtype=torch.bfloat16)
pipe.dit.load_state_dict(state_dict)
prompt = "a dog"
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)
image.save("image.jpg")

View File

@@ -0,0 +1,20 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
from diffsynth.core import load_state_dict
import torch
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-9B", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
)
state_dict = load_state_dict("./models/train/FLUX.2-klein-base-9B_full/epoch-1.safetensors", torch_dtype=torch.bfloat16)
pipe.dit.load_state_dict(state_dict)
prompt = "a dog"
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)
image.save("image.jpg")

View File

@@ -0,0 +1,18 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
)
pipe.load_lora(pipe.dit, "./models/train/FLUX.2-klein-4B_lora/epoch-4.safetensors")
prompt = "a dog"
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)
image.save("image.jpg")

View File

@@ -0,0 +1,18 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
)
pipe.load_lora(pipe.dit, "./models/train/FLUX.2-klein-9B_lora/epoch-4.safetensors")
prompt = "a dog"
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)
image.save("image.jpg")

View File

@@ -0,0 +1,18 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="tokenizer/"),
)
pipe.load_lora(pipe.dit, "./models/train/FLUX.2-klein-base-4B_lora/epoch-4.safetensors")
prompt = "a dog"
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)
image.save("image.jpg")

View File

@@ -0,0 +1,18 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-9B", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
)
pipe.load_lora(pipe.dit, "./models/train/FLUX.2-klein-base-9B_lora/epoch-4.safetensors")
prompt = "a dog"
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)
image.save("image.jpg")

View File

@@ -0,0 +1,17 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image-2512", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=40)
image.save("image.jpg")

View File

@@ -0,0 +1,44 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from modelscope import dataset_snapshot_download
from PIL import Image
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2511", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
)
dataset_snapshot_download(
"DiffSynth-Studio/example_image_dataset",
allow_file_pattern="qwen_image_edit/*",
local_dir="data/example_image_dataset",
)
prompt = "生成这两个人的合影"
edit_image = [
Image.open("data/example_image_dataset/qwen_image_edit/image1.jpg"),
Image.open("data/example_image_dataset/qwen_image_edit/image2.jpg"),
]
image = pipe(
prompt,
edit_image=edit_image,
seed=1,
num_inference_steps=40,
height=1152,
width=896,
edit_image_auto_resize=True,
zero_cond_t=True, # This is a special parameter introduced by Qwen-Image-Edit-2511
)
image.save("image.jpg")
# Qwen-Image-Edit-2511 is a multi-image editing model.
# Please use a list to input `edit_image`, even if the input contains only one image.
# edit_image = [Image.open("image.jpg")]
# Please do not input the image directly.
# edit_image = Image.open("image.jpg")

View File

@@ -0,0 +1,34 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from modelscope import snapshot_download
from PIL import Image
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Layered-Control", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
)
snapshot_download(
model_id="DiffSynth-Studio/Qwen-Image-Layered-Control",
allow_file_pattern="assets/image_1_input.png",
local_dir="data/layered_input"
)
prompt = "A cartoon skeleton character wearing a purple hat and holding a gift box"
input_image = Image.open("data/layered_input/assets/image_1_input.png").convert("RGBA").resize((1024, 1024))
images = pipe(
prompt,
seed=0,
num_inference_steps=30, cfg_scale=4,
height=1024, width=1024,
layer_input_image=input_image,
layer_num=0,
)
images[0].save("image.png")

View File

@@ -0,0 +1,36 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from modelscope import dataset_snapshot_download
from PIL import Image
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
)
dataset_snapshot_download(
"DiffSynth-Studio/example_image_dataset",
allow_patterns="layer/image.png",
local_dir="data/example_image_dataset"
)
# Prompt should be provided to the pipeline. Our pipeline will not generate the prompt.
prompt = 'A cheerful child with brown hair is waving enthusiastically under a bright blue sky filled with colorful confetti and balloons. The word "HELLO!" is prominently displayed in bold red letters above the child, while "Have a Great Day!" appears in elegant cursive at the bottom right corner. The scene is vibrant and festive, with a mix of pastel colors and dynamic shapes creating a joyful atmosphere.'
# Height and width should be consistent with input_image and be divided evenly by 16
input_image = Image.open("data/example_image_dataset/layer/image.png").convert("RGBA").resize((864, 480))
images = pipe(
prompt,
seed=1, num_inference_steps=50,
height=480, width=864,
layer_input_image=input_image, layer_num=3,
)
for i, image in enumerate(images):
if i == 0: continue # The first image is the input image.
image.save(f"image_{i}.png")

View File

@@ -0,0 +1,28 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image-2512", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=40)
image.save("image.jpg")

View File

@@ -0,0 +1,54 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from modelscope import dataset_snapshot_download
from PIL import Image
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2511", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
)
dataset_snapshot_download(
"DiffSynth-Studio/example_image_dataset",
allow_file_pattern="qwen_image_edit/*",
local_dir="data/example_image_dataset",
)
prompt = "生成这两个人的合影"
edit_image = [
Image.open("data/example_image_dataset/qwen_image_edit/image1.jpg"),
Image.open("data/example_image_dataset/qwen_image_edit/image2.jpg"),
]
image = pipe(
prompt,
edit_image=edit_image,
seed=1,
num_inference_steps=40,
height=1152,
width=896,
edit_image_auto_resize=True,
zero_cond_t=True, # This is a special parameter introduced by Qwen-Image-Edit-2511
)
image.save("image.jpg")
# Qwen-Image-Edit-2511 is a multi-image editing model.
# Please use a list to input `edit_image`, even if the input contains only one image.
# edit_image = [Image.open("image.jpg")]
# Please do not input the image directly.
# edit_image = Image.open("image.jpg")

View File

@@ -0,0 +1,44 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from modelscope import snapshot_download
from PIL import Image
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Layered-Control", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
)
snapshot_download(
model_id="DiffSynth-Studio/Qwen-Image-Layered-Control",
allow_file_pattern="assets/image_1_input.png",
local_dir="data/layered_input"
)
prompt = "A cartoon skeleton character wearing a purple hat and holding a gift box"
input_image = Image.open("data/layered_input/assets/image_1_input.png").convert("RGBA").resize((1024, 1024))
images = pipe(
prompt,
seed=0,
num_inference_steps=30, cfg_scale=4,
height=1024, width=1024,
layer_input_image=input_image,
layer_num=0,
)
images[0].save("image.png")

View File

@@ -0,0 +1,46 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from modelscope import dataset_snapshot_download
from PIL import Image
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
)
dataset_snapshot_download(
"DiffSynth-Studio/example_image_dataset",
allow_patterns="layer/image.png",
local_dir="data/example_image_dataset"
)
# Prompt should be provided to the pipeline. Our pipeline will not generate the prompt.
prompt = 'A cheerful child with brown hair is waving enthusiastically under a bright blue sky filled with colorful confetti and balloons. The word "HELLO!" is prominently displayed in bold red letters above the child, while "Have a Great Day!" appears in elegant cursive at the bottom right corner. The scene is vibrant and festive, with a mix of pastel colors and dynamic shapes creating a joyful atmosphere.'
# Height and width should be consistent with input_image and be divided evenly by 16
input_image = Image.open("data/example_image_dataset/layer/image.png").convert("RGBA").resize((864, 480))
images = pipe(
prompt,
seed=1, num_inference_steps=50,
height=480, width=864,
layer_input_image=input_image, layer_num=3,
)
for i, image in enumerate(images):
if i == 0: continue # The first image is the input image.
image.save(f"image_{i}.png")

View File

@@ -0,0 +1,13 @@
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen-Image-2512:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-2512_full" \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters

View File

@@ -0,0 +1,16 @@
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
--data_file_keys "image,edit_image" \
--extra_inputs "edit_image" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2511:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-Edit-2511_full" \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters \
--zero_cond_t # This is a special parameter introduced by Qwen-Image-Edit-2511. Please enable it for this model.

View File

@@ -0,0 +1,18 @@
# Example Dataset: https://modelscope.cn/datasets/DiffSynth-Studio/example_image_dataset/tree/master/layer
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset/layer \
--dataset_metadata_path data/example_image_dataset/layer/metadata_layered_control.json \
--data_file_keys "image,layer_input_image" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "DiffSynth-Studio/Qwen-Image-Layered-Control:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image-Layered:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-Layered-Control_full" \
--trainable_models "dit" \
--extra_inputs "layer_num,layer_input_image" \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--find_unused_parameters

View File

@@ -0,0 +1,18 @@
# Example Dataset: https://modelscope.cn/datasets/DiffSynth-Studio/example_image_dataset/tree/master/layer
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset/layer \
--dataset_metadata_path data/example_image_dataset/layer/metadata_layered.json \
--data_file_keys "image,layer_input_image" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen-Image-Layered:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image-Layered:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-Layered_full" \
--trainable_models "dit" \
--extra_inputs "layer_num,layer_input_image" \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--find_unused_parameters

View File

@@ -0,0 +1,16 @@
accelerate launch examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen-Image-2512:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-2512_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
--lora_rank 32 \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--find_unused_parameters

View File

@@ -0,0 +1,19 @@
accelerate launch examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
--data_file_keys "image,edit_image" \
--extra_inputs "edit_image" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2511:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-Edit-2511_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
--lora_rank 32 \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--find_unused_parameters \
--zero_cond_t # This is a special parameter introduced by Qwen-Image-Edit-2511. Please enable it for this model.

View File

@@ -0,0 +1,20 @@
# Example Dataset: https://modelscope.cn/datasets/DiffSynth-Studio/example_image_dataset/tree/master/layer
accelerate launch examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset/layer \
--dataset_metadata_path data/example_image_dataset/layer/metadata_layered_control.json \
--data_file_keys "image,layer_input_image" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "DiffSynth-Studio/Qwen-Image-Layered-Control:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image-Layered:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-Layered-Control_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
--lora_rank 32 \
--extra_inputs "layer_num,layer_input_image" \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--find_unused_parameters

View File

@@ -0,0 +1,20 @@
# Example Dataset: https://modelscope.cn/datasets/DiffSynth-Studio/example_image_dataset/tree/master/layer
accelerate launch examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset/layer \
--dataset_metadata_path data/example_image_dataset/layer/metadata_layered.json \
--data_file_keys "image,layer_input_image" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen-Image-Layered:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image-Layered:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-Layered_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
--lora_rank 32 \
--extra_inputs "layer_num,layer_input_image" \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--find_unused_parameters

View File

@@ -0,0 +1,38 @@
# Due to memory limitations, split training is required to train the model on NPU
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export CPU_AFFINITY_CONF=1
accelerate launch examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 1 \
--model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2509:text_encoder/model*.safetensors,Qwen/Qwen-Image-Edit-2509:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-Edit-2509-LoRA-splited-cache" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
--lora_rank 32 \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--find_unused_parameters \
--task "sft:data_process"
accelerate launch examples/qwen_image/model_training/train.py \
--dataset_base_path "./models/train/Qwen-Image-Edit-2509-LoRA-splited-cache" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2509:transformer/diffusion_pytorch_model*.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-Edit-2509-LoRA-splited" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
--lora_rank 32 \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--find_unused_parameters \
--task "sft:train"

View File

@@ -0,0 +1,38 @@
# Due to memory limitations, split training is required to train the model on NPU
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export CPU_AFFINITY_CONF=1
accelerate launch examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 1 \
--model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-LoRA-splited-cache" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
--lora_rank 32 \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--find_unused_parameters \
--task "sft:data_process"
accelerate launch examples/qwen_image/model_training/train.py \
--dataset_base_path "./models/train/Qwen-Image-LoRA-splited-cache" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-LoRA-splited" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
--lora_rank 32 \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--find_unused_parameters \
--task "sft:train"

Some files were not shown because too many files have changed in this diff Show More