Compare commits

...

132 Commits

Author SHA1 Message Date
Artiprocher
53890bafa4 update examples 2026-02-05 11:10:55 +08:00
Zhongjie Duan
afd48cd706 Merge pull request #1259 from mi804/multi_controlnet
add example for multiple controlnet
2026-02-04 17:04:11 +08:00
mi804
24b68c2392 add example for multiple controlnet 2026-02-04 16:52:39 +08:00
Zhongjie Duan
280ff7cca6 Merge pull request #1229 from Feng0w0/wan_rope
[bugfix][NPU]:Fix bug that correctly obtains device type
2026-02-04 13:26:00 +08:00
Zhongjie Duan
b4b62e2f7c Merge pull request #1221 from Feng0w0/usp_npu
[NPU]:Support USP feature in NPU
2026-02-04 13:25:24 +08:00
Zhongjie Duan
6d1be405b9 Merge pull request #1242 from mi804/ltx-2
LTX-2
2026-02-03 13:07:41 +08:00
Zhongjie Duan
25c3a3d3e2 Merge branch 'main' into ltx-2 2026-02-03 13:06:44 +08:00
mi804
49bc84f78e add comment for tuple noise_pred 2026-02-03 10:43:25 +08:00
mi804
25a9e75030 final fix for ltx-2 2026-02-03 10:39:35 +08:00
mi804
2a7ac73eb5 minor fix 2026-02-02 20:07:08 +08:00
mi804
f4f991d409 support ltx-2 t2v and i2v 2026-02-02 19:53:07 +08:00
Zhongjie Duan
a781138413 Merge pull request #1245 from modelscope/docs-update
update docs
2026-02-02 17:00:04 +08:00
Artiprocher
91a5623976 update docs 2026-02-02 16:52:12 +08:00
Zhongjie Duan
28cd355aba Merge pull request #1232 from huarzone/main
fix wan i2v/ti2v train bug
2026-02-02 15:26:01 +08:00
Zhongjie Duan
005389fca7 Merge pull request #1244 from modelscope/qwen-image-edit-lightning
Qwen image edit lightning
2026-02-02 15:20:11 +08:00
Artiprocher
a6282056eb fix typo 2026-02-02 15:19:19 +08:00
Zhongjie Duan
21a6eb8e2f Merge pull request #1243 from modelscope/research_tutorial_1
add research tutorial sec 1
2026-02-02 14:29:39 +08:00
Artiprocher
98ab238340 add research tutorial sec 1 2026-02-02 14:28:26 +08:00
mi804
1c8a0f8317 refactor patchify 2026-01-31 13:55:52 +08:00
mi804
9f07d65ebb support ltx2 distilled pipeline 2026-01-30 17:40:30 +08:00
lzws
5f1d5adfce qwen-image-edit-2511-lightning 2026-01-30 17:26:26 +08:00
mi804
4f23caa55f support ltx2 two stage pipeline & vram 2026-01-30 16:55:40 +08:00
Zhongjie Duan
b4f6a4de6c Merge pull request #1240 from modelscope/loader-update
Loader update
2026-01-30 13:51:17 +08:00
Artiprocher
53fe42af1b update version 2026-01-30 13:49:27 +08:00
Artiprocher
ee9a3b4405 support loading models from state dict 2026-01-30 13:47:36 +08:00
mi804
b1a2782ad7 support ltx2 one-stage pipeline 2026-01-29 16:30:15 +08:00
mi804
8d303b47e9 add audio_vae, audio_vocoder, text_encoder, connector and upsampler for ltx2 2026-01-28 16:09:22 +08:00
mi804
00da4b6c4f add video_vae and dit for ltx-2 2026-01-27 19:34:09 +08:00
Zhongjie Duan
22695e9be0 Merge pull request #1233 from modelscope/z-image-release
Z-Image and Z-Image-i2L
2026-01-27 18:41:28 +08:00
Artiprocher
98290190ec update z-image-i2L demo 2026-01-27 13:42:48 +08:00
Artiprocher
3f4de2cc7f update z-image-i2L examples 2026-01-27 12:16:48 +08:00
Kared
8d0df403ca fix wan i2v train bug 2026-01-27 03:55:36 +00:00
Artiprocher
d12bf71bcc support z-image and z-image-i2L 2026-01-27 10:56:15 +08:00
feng0w0
35e0776022 [bugfix][NPU]:Fix bug that correctly obtains device type 2026-01-23 10:45:03 +08:00
Zhongjie Duan
ffb7a138f7 Merge pull request #1228 from modelscope/klein-bugfix
change klein image resize to crop
2026-01-22 10:34:17 +08:00
Artiprocher
548304667f change klein image resize to crop 2026-01-22 10:33:29 +08:00
Zhongjie Duan
273143136c Merge pull request #1227 from modelscope/modelscope-service-patch
update to 2.0.3
2026-01-21 20:23:13 +08:00
Artiprocher
030ebe649a update to 2.0.3 2026-01-21 20:22:43 +08:00
Zhongjie Duan
90921d2293 Merge pull request #1226 from modelscope/klein-train-fix
improve flux2 training performance
2026-01-21 15:44:52 +08:00
Artiprocher
b61131c693 improve flux2 training performance 2026-01-21 15:44:15 +08:00
Zhongjie Duan
37fbb3248a Merge pull request #1222 from modelscope/trainer-update
support auto detact lora target modules
2026-01-21 11:06:19 +08:00
Artiprocher
d13f533f42 support auto detact lora target modules 2026-01-21 11:05:05 +08:00
feng0w0
b3cc652dea [NPU]:Support USP feature in NPU 2026-01-21 10:38:27 +08:00
feng0w0
d879d66c62 [NPU]:Support USP feature in NPU 2026-01-21 10:34:09 +08:00
feng0w0
848bfd6993 [NPU]:Support USP feature in NPU 2026-01-21 10:25:31 +08:00
feng0w0
269da09f6e Merge branch 'main' of https://github.com/modelscope/DiffSynth-Studio into usp_npu 2026-01-21 10:00:08 +08:00
feng0w0
e30514a00c Merge branch 'main' of https://github.com/Feng0w0/DiffSynth-Studio into usp_npu 2026-01-21 09:59:18 +08:00
Zhongjie Duan
3743b1307c Merge pull request #1219 from modelscope/klein-edit
support klein edit
2026-01-20 12:59:12 +08:00
Artiprocher
a835df984c support klein edit 2026-01-20 12:58:18 +08:00
Zhongjie Duan
3e4b47e424 Merge pull request #1207 from Feng0w0/cuda_replace
[NPU]:Replace 'cuda' in the project with abstract interfaces
2026-01-20 10:13:04 +08:00
Zhongjie Duan
dd8d902624 Merge branch 'main' into cuda_replace 2026-01-20 10:12:31 +08:00
Zhongjie Duan
a8b340c098 Merge pull request #1191 from Feng0w0/wan_rope
[model][NPU]:Wan model rope use torch.complex64 in NPU
2026-01-20 10:05:22 +08:00
Zhongjie Duan
88497b5c13 Merge pull request #1217 from modelscope/klein-update
support klein base models
2026-01-19 21:14:47 +08:00
Artiprocher
1e90c72d94 support klein base models 2026-01-19 21:11:58 +08:00
Zhongjie Duan
3dd82a738e Merge pull request #1215 from lzws/main
updata learning rate in wan-vace training scripts
2026-01-19 17:48:42 +08:00
Artiprocher
8ad2d9884b update lr in wan-vace training scripts 2026-01-19 17:43:07 +08:00
Artiprocher
70f531b724 update wan-vace training scripts 2026-01-19 17:37:30 +08:00
Zhongjie Duan
37c2868b61 Merge pull request #1214 from modelscope/klein
Support FLUX.2-klein
2026-01-19 17:36:39 +08:00
Artiprocher
a18e6233b5 updata wan-vace training scripts 2026-01-19 17:35:08 +08:00
Artiprocher
2336d5f6b3 update doc 2026-01-19 17:27:32 +08:00
Artiprocher
b6ccb362b9 support flux.2 klein 2026-01-19 16:56:14 +08:00
Artiprocher
ae52d93694 support klein 4b models 2026-01-16 13:09:41 +08:00
feng0w0
ad91d41601 [NPU]:Replace 'cuda' in the project with abstract interfaces 2026-01-16 10:28:24 +08:00
feng0w0
dce77ec4d1 [NPU]:Replace 'cuda' in the project with abstract interfaces 2026-01-15 20:35:41 +08:00
feng0w0
5c0b07d939 [NPU]:Replace 'cuda' in the project with abstract interfaces 2026-01-15 20:34:52 +08:00
feng0w0
19e429d889 Merge remote-tracking branch 'origin/cuda_replace' into cuda_replace 2026-01-15 20:33:21 +08:00
feng0w0
209a350c0f [NPU]:Replace 'cuda' in the project with abstract interfaces 2026-01-15 20:33:01 +08:00
feng0w0
a3c2744a43 [NPU]:Replace 'cuda' in the project with abstract interfaces 2026-01-15 20:04:54 +08:00
Zhongjie Duan
55e8346da3 Blog link (#1202)
* update README
2026-01-15 12:31:55 +08:00
Zhongjie Duan
b7979b2633 Merge pull request #1200 from modelscope/flux-compatibility-fix
fix flux compatibility issues
2026-01-14 20:50:18 +08:00
Artiprocher
c90aaa2798 fix flux compatibility issues 2026-01-14 20:49:36 +08:00
Zhongjie Duan
0c617d5d9e Merge pull request #1194 from lzws/main
wan usp bug fix
2026-01-14 16:34:06 +08:00
lzws
fd87b72754 wan usp bug fix 2026-01-14 16:33:02 +08:00
Zhongjie Duan
db75508ba0 Merge pull request #1199 from modelscope/z-image-bugfix
fix RMSNorm precision
2026-01-14 16:32:33 +08:00
Artiprocher
acba342a63 fix RMSNorm precision 2026-01-14 16:29:43 +08:00
feng0w0
d16877e695 [model][NPU]:Wan model rope use torch.complex64 in NPU 2026-01-13 11:17:51 +08:00
lzws
e99cdcf3b8 wan usp bug fix 2026-01-12 22:08:48 +08:00
Zhongjie Duan
a236a17f17 Merge pull request #1193 from modelscope/qwen-image-layered-control
support qwen-image-layered-control
2026-01-12 17:24:06 +08:00
Artiprocher
03e530dc39 support qwen-image-layered-control 2026-01-12 17:20:01 +08:00
feng0w0
6be244233a [model][NPU]:Wan model rope use torch.complex64 in NPU 2026-01-12 11:34:41 +08:00
feng0w0
544c391936 [model][NPU]:Wan model rope use torch.complex64 in NPU 2026-01-12 11:24:11 +08:00
Feng
f4d06ce3fc Merge branch 'modelscope:main' into wan_rope 2026-01-12 11:21:09 +08:00
Zhongjie Duan
ffedb9eb52 Merge pull request #1187 from jiaqixuac/patch-1
Update package inclusion pattern in pyproject.toml
2026-01-12 10:12:20 +08:00
Zhongjie Duan
381067515c Merge pull request #1176 from Feng0w0/z-image-rope
[model][NPU]: Z-image model support NPU
2026-01-12 10:11:22 +08:00
Zhongjie Duan
00f2d1aa5d Merge pull request #1169 from Feng0w0/sample_add
Docs:Supplement NPU training script samples and documentation instruction
2026-01-12 10:08:38 +08:00
Zhongjie Duan
8cc3bece6d Merge pull request #1167 from Feng0w0/install_env
Docs:Supplement NPU environment installation document
2026-01-12 10:07:30 +08:00
Jiaqi Xu
f4bf592064 Update pyproject.toml
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-10 09:32:35 +08:00
Jiaqi Xu
3235393fb5 Update package inclusion pattern in pyproject.toml
Update to install all the sub-packages inside diffsynth. Otherwise, the installed packages only contain __init__.py
2026-01-10 09:28:45 +08:00
feng0w0
3b662da31e [model][NPU]:Wan model rope use torch.complex64 in NPU 2026-01-09 18:11:40 +08:00
feng0w0
19ce3048c1 [model][NPU]:Wan model rope use torch.complex64 in NPU 2026-01-09 18:06:41 +08:00
Zhongjie Duan
de0aa946f7 Merge pull request #1184 from modelscope/z-image-omni-base-dev
update package version
2026-01-08 17:27:33 +08:00
Artiprocher
f376202a49 update package version 2026-01-08 17:26:29 +08:00
Zhongjie Duan
a13ecfc46b Merge pull request #1183 from modelscope/z-image-omni-base-dev
fix unused parameters in z-image-omni-base
2026-01-08 17:03:20 +08:00
Artiprocher
10a1853eda fix unused parameters in z-image-omni-base 2026-01-08 17:02:41 +08:00
Zhongjie Duan
0efab85674 Support Z-Image-Omni-Base and its related models
Support Z-Image-Omni-Base and its related models.
2026-01-08 13:43:59 +08:00
Artiprocher
f45a0ffd02 support z-image-omni-base vram management 2026-01-08 13:41:00 +08:00
Artiprocher
8ba528a8f6 bugfix 2026-01-08 13:21:33 +08:00
Artiprocher
dd479e5bff support z-image-omni-base-i2L 2026-01-07 20:36:53 +08:00
Artiprocher
bac39b1cd2 support z-image controlnet 2026-01-07 15:56:53 +08:00
feng0w0
c1c9a4853b [model][NPU]:Z-image model support NPU 2026-01-07 11:42:19 +08:00
feng0w0
3ee5f53a36 [model][NPU]:Z-image model support NPU 2026-01-07 11:31:22 +08:00
Artiprocher
32449a6aa0 support z-image-omni-base training 2026-01-05 20:04:00 +08:00
Zhongjie Duan
a6884f6b3a Merge pull request #1171 from YZBPXX/main
Fix issue where LoRa loads on a device different from Dit
2026-01-05 16:39:02 +08:00
Zhongjie Duan
b078666640 Merge pull request #1173 from modelscope/flux-compatibility-patch
flux compatibility patch
2026-01-05 16:20:25 +08:00
Artiprocher
7604ca1e52 flux compatibility patch 2026-01-05 16:04:20 +08:00
feng0w0
62c3d406d9 Docs:Supplement NPU training script samples and documentation instruction 2026-01-05 15:42:55 +08:00
Artiprocher
5745c9f200 support z-image-omni-base 2026-01-05 14:45:01 +08:00
feng0w0
86829120c2 Docs:Supplement NPU training script samples and documentation instruction 2026-01-05 09:59:11 +08:00
yaozhengbing
60ac96525b Fix issue where LoRa loads on a device different from Dit 2025-12-31 21:31:01 +08:00
feng0w0
07b1f5702f Docs:Supplement NPU training script samples and documentation instruction 2025-12-31 10:01:21 +08:00
feng0w0
507e7e5d36 Docs:Supplement NPU training script samples and documentation instruction 2025-12-30 19:58:47 +08:00
Zhongjie Duan
ab8580f77e Merge pull request #1166 from modelscope/qwen-image-2512
support qwen-image-2512
2025-12-30 16:47:07 +08:00
Artiprocher
6454259853 support qwen-image-2512 2025-12-30 16:43:41 +08:00
feng0w0
9cc1697d4d Docs:Supplement NPU environment installation document 2025-12-30 15:57:13 +08:00
feng0w0
c758769a02 训练快速上手 2025-12-29 09:25:46 +08:00
feng0w0
a5935e973a 训练快速上手 2025-12-29 09:23:59 +08:00
feng0w0
9834d72e4d 文档环境安装上手 2025-12-27 16:11:27 +08:00
feng0w0
01234e59c0 文档环境安装上手 2025-12-27 15:01:10 +08:00
Zhongjie Duan
8f1d10fb43 Merge pull request #1150 from modelscope/qwen-image-layered
support qwen-image-layered
2025-12-20 14:05:38 +08:00
Artiprocher
20e1aaf908 bugfix 2025-12-20 14:00:22 +08:00
Artiprocher
c6722b3f56 support qwen-image-layered 2025-12-19 19:06:37 +08:00
Zhongjie Duan
11315d7a40 Merge pull request #1147 from modelscope/qwen-image-edit-2511
Qwen image edit 2511
2025-12-18 19:23:44 +08:00
Artiprocher
68d97a9844 update doc 2025-12-18 19:22:22 +08:00
Artiprocher
4629d4cf9e support qwen-image-edit-2511 2025-12-18 19:16:52 +08:00
Zhongjie Duan
3cb5cec906 Merge pull request #1143 from modelscope/readme-update
update README
2025-12-17 16:32:29 +08:00
Artiprocher
b7e16b9034 update README 2025-12-17 16:30:41 +08:00
Zhongjie Duan
83d1e7361f Merge pull request #1136 from modelscope/bugfix-device
bugfix
2025-12-16 16:12:05 +08:00
Artiprocher
1547c3f786 bugfix 2025-12-16 16:09:29 +08:00
Zhongjie Duan
bfaaf12bf4 Merge pull request #1129 from modelscope/ascend
Support Ascend NPU
2025-12-15 19:13:40 +08:00
Zhongjie Duan
47545e1aab Merge pull request #1126 from Leoooo333/main
Fixed: Wan S2V Long video severe quality downgrade
2025-12-15 19:09:39 +08:00
Junming Chen
a4d34d9f3d Append: set video compress quality as original version. 2025-12-14 20:53:26 +00:00
Junming Chen
127cc9007a Fixed: S2V Long video severe quality downgrade 2025-12-14 20:30:34 +00:00
194 changed files with 15485 additions and 425 deletions

View File

@@ -22,7 +22,7 @@ jobs:
- name: Install wheel - name: Install wheel
run: pip install wheel==0.44.0 && pip install -r requirements.txt run: pip install wheel==0.44.0 && pip install -r requirements.txt
- name: Build DiffSynth - name: Build DiffSynth
run: python setup.py sdist bdist_wheel run: python -m build
- name: Publish package to PyPI - name: Publish package to PyPI
run: | run: |
pip install twine pip install twine

124
README.md
View File

@@ -33,7 +33,17 @@ We believe that a well-developed open-source code framework can lower the thresh
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand. > Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
- **December 9, 2025** We release a wild model based on DiffSynth-Studio 2.0: [Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L) (Image-to-LoRA). This model takes an image as input and outputs a LoRA. Although this version still has significant room for improvement in terms of generalization, detail preservation, and other aspects, we are open-sourcing these models to inspire more innovative research. - **February 2, 2026** The first document of the Research Tutorial series is now available, guiding you through training a small 0.1B text-to-image model from scratch. For details, see the [documentation](/docs/en/Research_Tutorial/train_from_scratch.md) and [model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel). We hope DiffSynth-Studio can evolve into a more powerful training framework for Diffusion models.
- **January 27, 2026**: [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) is released, and our [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) model is released concurrently. You can use it in [ModelScope Studios](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L). For details, see the [documentation](/docs/zh/Model_Details/Z-Image.md).
- **January 19, 2026**: Added support for [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) and [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/FLUX2.md) and [example code](/examples/flux2/) are now available.
- **January 12, 2026**: We trained and open-sourced a text-guided image layer separation model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)). Given an input image and a textual description, the model isolates the image layer corresponding to the described content. For more details, please refer to our blog post ([Chinese version](https://modelscope.cn/learn/4938), [English version](https://huggingface.co/blog/kelseye/qwen-image-layered-control)).
- **December 24, 2025**: Based on Qwen-Image-Edit-2511, we trained an In-Context Editing LoRA model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)). This model takes three images as input (Image A, Image B, and Image C), and automatically analyzes the transformation from Image A to Image B, then applies the same transformation to Image C to generate Image D. For more details, please refer to our blog post ([Chinese version](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g), [English version](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora)).
- **December 9, 2025** We release a wild model based on DiffSynth-Studio 2.0: [Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L) (Image-to-LoRA). This model takes an image as input and outputs a LoRA. Although this version still has significant room for improvement in terms of generalization, detail preservation, and other aspects, we are open-sourcing these models to inspire more innovative research. For more details, please refer to our [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l).
- **December 4, 2025** DiffSynth-Studio 2.0 released! Many new features online - **December 4, 2025** DiffSynth-Studio 2.0 released! Many new features online
- [Documentation](/docs/en/README.md) online: Our documentation is still continuously being optimized and updated - [Documentation](/docs/en/README.md) online: Our documentation is still continuously being optimized and updated
@@ -263,9 +273,14 @@ image.save("image.jpg")
Example code for Z-Image is available at: [/examples/z_image/](/examples/z_image/) Example code for Z-Image is available at: [/examples/z_image/](/examples/z_image/)
| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | |Model ID|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
|-|-|-|-|-|-|-| |-|-|-|-|-|-|-|
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)| |[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
</details> </details>
@@ -315,9 +330,13 @@ image.save("image.jpg")
Example code for FLUX.2 is available at: [/examples/flux2/](/examples/flux2/) Example code for FLUX.2 is available at: [/examples/flux2/](/examples/flux2/)
| Model ID | Inference | Low-VRAM Inference | LoRA Training | LoRA Training Validation | | Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-| |-|-|-|-|-|-|-|
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| |[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
</details> </details>
@@ -396,8 +415,13 @@ Example code for Qwen-Image is available at: [/examples/qwen_image/](/examples/q
| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | | Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-| |-|-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)| |[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)| |[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)| |[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| |[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| |[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)| |[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
@@ -508,6 +532,95 @@ Example code for FLUX.1 is available at: [/examples/flux/](/examples/flux/)
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314 https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
#### LTX-2: [/docs/en/Model_Details/LTX-2.md](/docs/en/Model_Details/LTX-2.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8GB of VRAM.
```python
import torch
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
vram_config = {
"offload_dtype": torch.float8_e5m2,
"offload_device": "cpu",
"onload_dtype": torch.float8_e5m2,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e5m2,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = LTX2AudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
negative_prompt = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)
height, width, num_frames = 512 * 2, 768 * 2, 121
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
seed=43,
height=height,
width=width,
num_frames=num_frames,
tiled=True,
use_two_stage_pipeline=True,
)
write_video_audio_ltx2(
video=video,
audio=audio,
output_path='ltx2_twostage.mp4',
fps=24,
audio_sample_rate=24000,
)
```
</details>
<details>
<summary>Examples</summary>
Example code for LTX-2 is available at: [/examples/ltx2/](/examples/ltx2/)
| Model ID | Extra Args | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|-|
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-|
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
</details>
#### Wan: [/docs/en/Model_Details/Wan.md](/docs/en/Model_Details/Wan.md) #### Wan: [/docs/en/Model_Details/Wan.md](/docs/en/Model_Details/Wan.md)
<details> <details>
@@ -766,4 +879,3 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-47
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
</details> </details>

View File

@@ -33,7 +33,17 @@ DiffSynth 目前包括两个开源项目:
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责因此新功能的开发进展会比较缓慢issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。 > 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责因此新功能的开发进展会比较缓慢issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
- **2025年12月9日** 我们基于 DiffSynth-Studio 2.0 训练一个疯狂的模型:[Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)Image to LoRA。这一模型以图像为输入以 LoRA 为输出。尽管这个版本的模型在泛化能力、细节保持能力等方面还有很大改进空间,我们将这些模型开源,以启发更多创新性的研究工作 - **2026年2月2日** Research Tutorial 的第一篇文档上线,带你从零开始训练一个 0.1B 的小型文生图模型,详见[文档](/docs/zh/Research_Tutorial/train_from_scratch.md)、[模型](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel),我们希望 DiffSynth-Studio 能够成为一个更强大的 Diffusion 模型训练框架
- **2026年1月27日** [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) 发布,我们的 [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) 模型同步发布,在[魔搭创空间](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L)可直接体验,详见[文档](/docs/zh/Model_Details/Z-Image.md)。
- **2026年1月19日** 新增对 [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 和 [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/FLUX2.md)和[示例代码](/examples/flux2/)现已可用。
- **2026年1月12日** 我们训练并开源了一个文本引导的图层拆分模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)),这一模型输入一张图与一段文本描述,模型会将图像中与文本描述相关的图层拆分出来。更多细节请阅读我们的 blog[中文版](https://modelscope.cn/learn/4938)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-layered-control))。
- **2025年12月24日** 我们基于 Qwen-Image-Edit-2511 训练了一个 In-Context Editing LoRA 模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)这个模型可以输入三张图图A、图B、图C模型会自行分析图A到图B的变化并将这样的变化应用到图C生成图D。更多细节请阅读我们的 blog[中文版](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora))。
- **2025年12月9日** 我们基于 DiffSynth-Studio 2.0 训练了一个疯狂的模型:[Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)Image to LoRA。这一模型以图像为输入以 LoRA 为输出。尽管这个版本的模型在泛化能力、细节保持能力等方面还有很大改进空间,我们将这些模型开源,以启发更多创新性的研究工作。更多细节,请参考我们的 [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l)。
- **2025年12月4日** DiffSynth-Studio 2.0 发布!众多新功能上线 - **2025年12月4日** DiffSynth-Studio 2.0 发布!众多新功能上线
- [文档](/docs/zh/README.md)上线:我们的文档还在持续优化更新中 - [文档](/docs/zh/README.md)上线:我们的文档还在持续优化更新中
@@ -265,7 +275,12 @@ Z-Image 的示例代码位于:[/examples/z_image/](/examples/z_image/)
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-| |-|-|-|-|-|-|-|
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)| |[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
</details> </details>
@@ -315,9 +330,13 @@ image.save("image.jpg")
FLUX.2 的示例代码位于:[/examples/flux2/](/examples/flux2/) FLUX.2 的示例代码位于:[/examples/flux2/](/examples/flux2/)
|模型 ID|推理|低显存推理|LoRA 训练|LoRA 训练后验证| |模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-| |-|-|-|-|-|-|-|
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| |[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
</details> </details>
@@ -396,8 +415,13 @@ Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-| |-|-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)| |[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)| |[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)| |[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| |[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| |[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)| |[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
@@ -508,6 +532,95 @@ FLUX.1 的示例代码位于:[/examples/flux/](/examples/flux/)
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314 https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
#### LTX-2: [/docs/zh/Model_Details/LTX-2.md](/docs/zh/Model_Details/LTX-2.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8GB 显存即可运行。
```python
import torch
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
vram_config = {
"offload_dtype": torch.float8_e5m2,
"offload_device": "cpu",
"onload_dtype": torch.float8_e5m2,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e5m2,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = LTX2AudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
negative_prompt = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)
height, width, num_frames = 512 * 2, 768 * 2, 121
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
seed=43,
height=height,
width=width,
num_frames=num_frames,
tiled=True,
use_two_stage_pipeline=True,
)
write_video_audio_ltx2(
video=video,
audio=audio,
output_path='ltx2_twostage.mp4',
fps=24,
audio_sample_rate=24000,
)
```
</details>
<details>
<summary>示例代码</summary>
LTX-2 的示例代码位于:[/examples/ltx2/](/examples/ltx2/)
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|-|
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-|
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
</details>
#### Wan: [/docs/zh/Model_Details/Wan.md](/docs/zh/Model_Details/Wan.md) #### Wan: [/docs/zh/Model_Details/Wan.md](/docs/zh/Model_Details/Wan.md)
<details> <details>

View File

@@ -63,6 +63,20 @@ qwen_image_series = [
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel", "model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
"extra_kwargs": {"compress_dim": 64, "use_residual": False} "extra_kwargs": {"compress_dim": 64, "use_residual": False}
}, },
{
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
"model_hash": "8dc8cda05de16c73afa755e2c1ce2839",
"model_name": "qwen_image_dit",
"model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
"extra_kwargs": {"use_layer3d_rope": True, "use_additional_t_cond": True}
},
{
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "44b39ddc499e027cfb24f7878d7416b9",
"model_name": "qwen_image_vae",
"model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
"extra_kwargs": {"image_channels": 4}
},
] ]
wan_series = [ wan_series = [
@@ -303,6 +317,13 @@ flux_series = [
"model_class": "diffsynth.models.flux_dit.FluxDiT", "model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
}, },
{
# Supported due to historical reasons.
"model_hash": "605c56eab23e9e2af863ad8f0813a25d",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverterFromDiffusers",
},
{ {
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors") # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors")
"model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78", "model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
@@ -460,6 +481,13 @@ flux_series = [
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
"extra_kwargs": {"disable_guidance_embedder": True}, "extra_kwargs": {"disable_guidance_embedder": True},
}, },
{
# Example: ModelConfig(model_id="MAILAND/majicflus_v1", origin_file_pattern="majicflus_v134.safetensors")
"model_hash": "3394f306c4cbf04334b712bf5aaed95f",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
},
] ]
flux2_series = [ flux2_series = [
@@ -482,6 +510,28 @@ flux2_series = [
"model_name": "flux2_vae", "model_name": "flux2_vae",
"model_class": "diffsynth.models.flux2_vae.Flux2VAE", "model_class": "diffsynth.models.flux2_vae.Flux2VAE",
}, },
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors")
"model_hash": "3bde7b817fec8143028b6825a63180df",
"model_name": "flux2_dit",
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 7680, "num_attention_heads": 24, "num_layers": 5, "num_single_layers": 20}
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors")
"model_hash": "9195f3ea256fcd0ae6d929c203470754",
"model_name": "z_image_text_encoder",
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
"extra_kwargs": {"model_size": "8B"},
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors")
"model_hash": "39c6fc48f07bebecedbbaa971ff466c8",
"model_name": "flux2_dit",
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 12288, "num_attention_heads": 32, "num_layers": 8, "num_single_layers": 24}
},
] ]
z_image_series = [ z_image_series = [
@@ -513,6 +563,104 @@ z_image_series = [
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers", "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers",
"extra_kwargs": {"use_conv_attention": False}, "extra_kwargs": {"use_conv_attention": False},
}, },
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors")
"model_hash": "aa3563718e5c3ecde3dfbb020ca61180",
"model_name": "z_image_dit",
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
"extra_kwargs": {"siglip_feat_dim": 1152},
},
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors")
"model_hash": "89d48e420f45cff95115a9f3e698d44a",
"model_name": "siglip_vision_model_428m",
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M",
},
{
# Example: ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors")
"model_hash": "1677708d40029ab380a95f6c731a57d7",
"model_name": "z_image_controlnet",
"model_class": "diffsynth.models.z_image_controlnet.ZImageControlNet",
},
{
# Example: ???
"model_hash": "9510cb8cd1dd34ee0e4f111c24905510",
"model_name": "z_image_image2lora_style",
"model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
"extra_kwargs": {"compress_dim": 128},
},
{
# Example: ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors")
"model_hash": "1392adecee344136041e70553f875f31",
"model_name": "z_image_text_encoder",
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
"extra_kwargs": {"model_size": "0.6B"},
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
},
] ]
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series ltx2_series = [
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_dit",
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_video_vae_encoder",
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_video_vae_decoder",
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_audio_vae_decoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_audio_vocoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
},
# { # not used currently
# # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
# "model_hash": "aca7b0bbf8415e9c98360750268915fc",
# "model_name": "ltx2_audio_vae_encoder",
# "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
# "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
# },
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_text_encoder_post_modules",
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
},
{
# Example: ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors")
"model_hash": "33917f31c4a79196171154cca39f165e",
"model_name": "ltx2_text_encoder",
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "c79c458c6e99e0e14d47e676761732d2",
"model_name": "ltx2_latent_upsampler",
"model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
},
]
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series

View File

@@ -13,6 +13,7 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
"diffsynth.models.qwen_image_dit.QwenImageDiT": { "diffsynth.models.qwen_image_dit.QwenImageDiT": {
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", "diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
}, },
"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": { "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
@@ -194,4 +195,52 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", "diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
}, },
"diffsynth.models.z_image_controlnet.ZImageControlNet": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M": {
"transformers.models.siglip2.modeling_siglip2.Siglip2VisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.siglip2.modeling_siglip2.Siglip2MultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.ltx2_dit.LTXModel": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler": {
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_video_vae.LTX2VideoEncoder": {
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_video_vae.LTX2VideoDecoder": {
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder": {
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_audio_vae.LTX2Vocoder": {
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.ltx2_text_encoder.Embeddings1DConnector": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_text_encoder.LTX2TextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"transformers.models.gemma3.modeling_gemma3.Gemma3MultiModalProjector": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
} }

View File

@@ -52,7 +52,7 @@ def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="
if k_pattern != required_in_pattern: if k_pattern != required_in_pattern:
k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims) k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims)
if v_pattern != required_in_pattern: if v_pattern != required_in_pattern:
v = rearrange(v, f"{q_pattern} -> {required_in_pattern}", **dims) v = rearrange(v, f"{v_pattern} -> {required_in_pattern}", **dims)
return q, k, v return q, k, v

View File

@@ -53,12 +53,14 @@ class ToStr(DataProcessingOperator):
class LoadImage(DataProcessingOperator): class LoadImage(DataProcessingOperator):
def __init__(self, convert_RGB=True): def __init__(self, convert_RGB=True, convert_RGBA=False):
self.convert_RGB = convert_RGB self.convert_RGB = convert_RGB
self.convert_RGBA = convert_RGBA
def __call__(self, data: str): def __call__(self, data: str):
image = Image.open(data) image = Image.open(data)
if self.convert_RGB: image = image.convert("RGB") if self.convert_RGB: image = image.convert("RGB")
if self.convert_RGBA: image = image.convert("RGBA")
return image return image

View File

@@ -10,6 +10,7 @@ class UnifiedDataset(torch.utils.data.Dataset):
data_file_keys=tuple(), data_file_keys=tuple(),
main_data_operator=lambda x: x, main_data_operator=lambda x: x,
special_operator_map=None, special_operator_map=None,
max_data_items=None,
): ):
self.base_path = base_path self.base_path = base_path
self.metadata_path = metadata_path self.metadata_path = metadata_path
@@ -18,6 +19,7 @@ class UnifiedDataset(torch.utils.data.Dataset):
self.main_data_operator = main_data_operator self.main_data_operator = main_data_operator
self.cached_data_operator = LoadTorchPickle() self.cached_data_operator = LoadTorchPickle()
self.special_operator_map = {} if special_operator_map is None else special_operator_map self.special_operator_map = {} if special_operator_map is None else special_operator_map
self.max_data_items = max_data_items
self.data = [] self.data = []
self.cached_data = [] self.cached_data = []
self.load_from_cache = metadata_path is None self.load_from_cache = metadata_path is None
@@ -97,7 +99,9 @@ class UnifiedDataset(torch.utils.data.Dataset):
return data return data
def __len__(self): def __len__(self):
if self.load_from_cache: if self.max_data_items is not None:
return self.max_data_items
elif self.load_from_cache:
return len(self.cached_data) * self.repeat return len(self.cached_data) * self.repeat
else: else:
return len(self.data) * self.repeat return len(self.data) * self.repeat

View File

@@ -1 +1,2 @@
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name
from .npu_compatible_device import IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE

View File

@@ -1,5 +1,5 @@
import torch, glob, os import torch, glob, os
from typing import Optional, Union from typing import Optional, Union, Dict
from dataclasses import dataclass from dataclasses import dataclass
from modelscope import snapshot_download from modelscope import snapshot_download
from huggingface_hub import snapshot_download as hf_snapshot_download from huggingface_hub import snapshot_download as hf_snapshot_download
@@ -23,13 +23,14 @@ class ModelConfig:
computation_device: Optional[Union[str, torch.device]] = None computation_device: Optional[Union[str, torch.device]] = None
computation_dtype: Optional[torch.dtype] = None computation_dtype: Optional[torch.dtype] = None
clear_parameters: bool = False clear_parameters: bool = False
state_dict: Dict[str, torch.Tensor] = None
def check_input(self): def check_input(self):
if self.path is None and self.model_id is None: if self.path is None and self.model_id is None:
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""") raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
def parse_original_file_pattern(self): def parse_original_file_pattern(self):
if self.origin_file_pattern is None or self.origin_file_pattern == "": if self.origin_file_pattern in [None, "", "./"]:
return "*" return "*"
elif self.origin_file_pattern.endswith("/"): elif self.origin_file_pattern.endswith("/"):
return self.origin_file_pattern + "*" return self.origin_file_pattern + "*"
@@ -97,7 +98,8 @@ class ModelConfig:
self.reset_local_model_path() self.reset_local_model_path()
if self.require_downloading(): if self.require_downloading():
self.download() self.download()
if self.origin_file_pattern is None or self.origin_file_pattern == "": if self.path is None:
if self.origin_file_pattern in [None, "", "./"]:
self.path = os.path.join(self.local_model_path, self.model_id) self.path = os.path.join(self.local_model_path, self.model_id)
else: else:
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern)) self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))

View File

@@ -2,16 +2,25 @@ from safetensors import safe_open
import torch, hashlib import torch, hashlib
def load_state_dict(file_path, torch_dtype=None, device="cpu"): def load_state_dict(file_path, torch_dtype=None, device="cpu", pin_memory=False, verbose=0):
if isinstance(file_path, list): if isinstance(file_path, list):
state_dict = {} state_dict = {}
for file_path_ in file_path: for file_path_ in file_path:
state_dict.update(load_state_dict(file_path_, torch_dtype, device)) state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose))
return state_dict
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
else: else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device) if verbose >= 1:
print(f"Loading file [started]: {file_path}")
if file_path.endswith(".safetensors"):
state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
else:
state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
# If load state dict in CPU memory, `pin_memory=True` will make `model.to("cuda")` faster.
if pin_memory:
for i in state_dict:
state_dict[i] = state_dict[i].pin_memory()
if verbose >= 1:
print(f"Loading file [done]: {file_path}")
return state_dict
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"): def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):

View File

@@ -5,7 +5,7 @@ from .file import load_state_dict
import torch import torch
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None): def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
config = {} if config is None else config config = {} if config is None else config
# Why do we use `skip_model_initialization`? # Why do we use `skip_model_initialization`?
# It skips the random initialization of model parameters, # It skips the random initialization of model parameters,
@@ -20,7 +20,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]] dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
dtype = [d for d in dtypes if d != "disk"][0] dtype = [d for d in dtypes if d != "disk"][0]
if vram_config["offload_device"] != "disk": if vram_config["offload_device"] != "disk":
state_dict = DiskMap(path, device, torch_dtype=dtype) if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype)
if state_dict_converter is not None: if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict) state_dict = state_dict_converter(state_dict)
else: else:
@@ -35,7 +35,9 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
# Sometimes a model file contains multiple models, # Sometimes a model file contains multiple models,
# and DiskMap can load only the parameters of a single model, # and DiskMap can load only the parameters of a single model,
# avoiding the need to load all parameters in the file. # avoiding the need to load all parameters in the file.
if use_disk_map: if state_dict is not None:
pass
elif use_disk_map:
state_dict = DiskMap(path, device, torch_dtype=torch_dtype) state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
else: else:
state_dict = load_state_dict(path, torch_dtype, device) state_dict = load_state_dict(path, torch_dtype, device)

View File

@@ -2,7 +2,7 @@ import torch, copy
from typing import Union from typing import Union
from .initialization import skip_model_initialization from .initialization import skip_model_initialization
from .disk_map import DiskMap from .disk_map import DiskMap
from ..device import parse_device_type from ..device import parse_device_type, get_device_name, IS_NPU_AVAILABLE
class AutoTorchModule(torch.nn.Module): class AutoTorchModule(torch.nn.Module):
@@ -63,7 +63,8 @@ class AutoTorchModule(torch.nn.Module):
return r return r
def check_free_vram(self): def check_free_vram(self):
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(self.computation_device) device = self.computation_device if not IS_NPU_AVAILABLE else get_device_name()
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3) used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
return used_memory < self.vram_limit return used_memory < self.vram_limit
@@ -309,6 +310,7 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
self.lora_B_weights = [] self.lora_B_weights = []
self.lora_merger = None self.lora_merger = None
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz] self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
self.computation_device_type = parse_device_type(self.computation_device)
if offload_dtype == "disk": if offload_dtype == "disk":
self.disk_map = disk_map self.disk_map = disk_map

View File

@@ -4,9 +4,11 @@ import numpy as np
from einops import repeat, reduce from einops import repeat, reduce
from typing import Union from typing import Union
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
from ..core.device.npu_compatible_device import get_device_type
from ..utils.lora import GeneralLoRALoader from ..utils.lora import GeneralLoRALoader
from ..models.model_loader import ModelPool from ..models.model_loader import ModelPool
from ..utils.controlnet import ControlNetInput from ..utils.controlnet import ControlNetInput
from ..core.device import get_device_name, IS_NPU_AVAILABLE
class PipelineUnit: class PipelineUnit:
@@ -60,7 +62,7 @@ class BasePipeline(torch.nn.Module):
def __init__( def __init__(
self, self,
device="cuda", torch_dtype=torch.float16, device=get_device_type(), torch_dtype=torch.float16,
height_division_factor=64, width_division_factor=64, height_division_factor=64, width_division_factor=64,
time_division_factor=None, time_division_remainder=None, time_division_factor=None, time_division_remainder=None,
): ):
@@ -177,7 +179,8 @@ class BasePipeline(torch.nn.Module):
def get_vram(self): def get_vram(self):
return getattr(torch, self.device_type).mem_get_info(self.device)[1] / (1024 ** 3) device = self.device if not IS_NPU_AVAILABLE else get_device_name()
return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)
def get_module(self, model, name): def get_module(self, model, name):
if "." in name: if "." in name:
@@ -234,6 +237,7 @@ class BasePipeline(torch.nn.Module):
alpha=1, alpha=1,
hotload=None, hotload=None,
state_dict=None, state_dict=None,
verbose=1,
): ):
if state_dict is None: if state_dict is None:
if isinstance(lora_config, str): if isinstance(lora_config, str):
@@ -260,12 +264,13 @@ class BasePipeline(torch.nn.Module):
updated_num += 1 updated_num += 1
module.lora_A_weights.append(lora[lora_a_name] * alpha) module.lora_A_weights.append(lora[lora_a_name] * alpha)
module.lora_B_weights.append(lora[lora_b_name]) module.lora_B_weights.append(lora[lora_b_name])
print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.") if verbose >= 1:
print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.")
else: else:
lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha) lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha)
def clear_lora(self): def clear_lora(self, verbose=1):
cleared_num = 0 cleared_num = 0
for name, module in self.named_modules(): for name, module in self.named_modules():
if isinstance(module, AutoWrappedLinear): if isinstance(module, AutoWrappedLinear):
@@ -275,7 +280,8 @@ class BasePipeline(torch.nn.Module):
module.lora_A_weights.clear() module.lora_A_weights.clear()
if hasattr(module, "lora_B_weights"): if hasattr(module, "lora_B_weights"):
module.lora_B_weights.clear() module.lora_B_weights.clear()
print(f"{cleared_num} LoRA layers are cleared.") if verbose >= 1:
print(f"{cleared_num} LoRA layers are cleared.")
def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None): def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None):
@@ -290,6 +296,7 @@ class BasePipeline(torch.nn.Module):
vram_config=vram_config, vram_config=vram_config,
vram_limit=vram_limit, vram_limit=vram_limit,
clear_parameters=model_config.clear_parameters, clear_parameters=model_config.clear_parameters,
state_dict=model_config.state_dict,
) )
return model_pool return model_pool
@@ -303,10 +310,22 @@ class BasePipeline(torch.nn.Module):
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others): def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others) noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
if cfg_scale != 1.0: if cfg_scale != 1.0:
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others) noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) if isinstance(noise_pred_posi, tuple):
# Separately handling different output types of latents, eg. video and audio latents.
noise_pred = tuple(
n_nega + cfg_scale * (n_posi - n_nega)
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
)
else:
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else: else:
noise_pred = noise_pred_posi noise_pred = noise_pred_posi
return noise_pred return noise_pred

View File

@@ -4,13 +4,15 @@ from typing_extensions import Literal
class FlowMatchScheduler(): class FlowMatchScheduler():
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image"] = "FLUX.1"): def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning"] = "FLUX.1"):
self.set_timesteps_fn = { self.set_timesteps_fn = {
"FLUX.1": FlowMatchScheduler.set_timesteps_flux, "FLUX.1": FlowMatchScheduler.set_timesteps_flux,
"Wan": FlowMatchScheduler.set_timesteps_wan, "Wan": FlowMatchScheduler.set_timesteps_wan,
"Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image, "Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image,
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2, "FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
"Z-Image": FlowMatchScheduler.set_timesteps_z_image, "Z-Image": FlowMatchScheduler.set_timesteps_z_image,
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
"Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
}.get(template, FlowMatchScheduler.set_timesteps_flux) }.get(template, FlowMatchScheduler.set_timesteps_flux)
self.num_train_timesteps = 1000 self.num_train_timesteps = 1000
@@ -70,6 +72,28 @@ class FlowMatchScheduler():
timesteps = sigmas * num_train_timesteps timesteps = sigmas * num_train_timesteps
return sigmas, timesteps return sigmas, timesteps
@staticmethod
def set_timesteps_qwen_image_lightning(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
sigma_min = 0.0
sigma_max = 1.0
num_train_timesteps = 1000
base_shift = math.log(3)
max_shift = math.log(3)
# Sigmas
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
# Mu
if exponential_shift_mu is not None:
mu = exponential_shift_mu
elif dynamic_shift_len is not None:
mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len, base_shift=base_shift, max_shift=max_shift)
else:
mu = 0.8
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
# Timesteps
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod @staticmethod
def compute_empirical_mu(image_seq_len, num_steps): def compute_empirical_mu(image_seq_len, num_steps):
a1, b1 = 8.73809524e-05, 1.89833333 a1, b1 = 8.73809524e-05, 1.89833333
@@ -89,13 +113,18 @@ class FlowMatchScheduler():
return float(mu) return float(mu)
@staticmethod @staticmethod
def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=1024//16*1024//16): def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None):
sigma_min = 1 / num_inference_steps sigma_min = 1 / num_inference_steps
sigma_max = 1.0 sigma_max = 1.0
num_train_timesteps = 1000 num_train_timesteps = 1000
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps) sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps) if dynamic_shift_len is None:
# If you ask me why I set mu=0.8,
# I can only say that it yields better training results.
mu = 0.8
else:
mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps)
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1)) sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
timesteps = sigmas * num_train_timesteps timesteps = sigmas * num_train_timesteps
return sigmas, timesteps return sigmas, timesteps
@@ -116,7 +145,35 @@ class FlowMatchScheduler():
timestep_id = torch.argmin((timesteps - timestep).abs()) timestep_id = torch.argmin((timesteps - timestep).abs())
timesteps[timestep_id] = timestep timesteps[timestep_id] = timestep
return sigmas, timesteps return sigmas, timesteps
@staticmethod
def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None):
num_train_timesteps = 1000
if special_case == "stage2":
sigmas = torch.Tensor([0.909375, 0.725, 0.421875])
elif special_case == "ditilled_stage1":
sigmas = torch.Tensor([1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875])
else:
dynamic_shift_len = dynamic_shift_len or 4096
sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image(
image_seq_len=dynamic_shift_len,
base_seq_len=1024,
max_seq_len=4096,
base_shift=0.95,
max_shift=2.05,
)
sigma_min = 0.0
sigma_max = 1.0
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
sigmas = math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1))
# Shift terminal
one_minus_z = 1.0 - sigmas
scale_factor = one_minus_z[-1] / (1 - terminal)
sigmas = 1.0 - (one_minus_z / scale_factor)
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
def set_training_weight(self): def set_training_weight(self):
steps = 1000 steps = 1000
x = self.timesteps x = self.timesteps

View File

@@ -10,7 +10,7 @@ class ModelLogger:
self.num_steps = 0 self.num_steps = 0
def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None): def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None, **kwargs):
self.num_steps += 1 self.num_steps += 1
if save_steps is not None and self.num_steps % save_steps == 0: if save_steps is not None and self.num_steps % save_steps == 0:
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors") self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")

View File

@@ -13,9 +13,16 @@ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep) inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep) training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
if "first_frame_latents" in inputs:
inputs["latents"][:, :, 0:1] = inputs["first_frame_latents"]
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep) noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
if "first_frame_latents" in inputs:
noise_pred = noise_pred[:, :, 1:]
training_target = training_target[:, :, 1:]
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * pipe.scheduler.training_weight(timestep) loss = loss * pipe.scheduler.training_weight(timestep)
return loss return loss

View File

@@ -40,7 +40,7 @@ def launch_training_task(
loss = model(data) loss = model(data)
accelerator.backward(loss) accelerator.backward(loss)
optimizer.step() optimizer.step()
model_logger.on_step_end(accelerator, model, save_steps) model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
scheduler.step() scheduler.step()
if save_steps is None: if save_steps is None:
model_logger.on_epoch_end(accelerator, model, epoch_id) model_logger.on_epoch_end(accelerator, model, epoch_id)

View File

@@ -1,4 +1,4 @@
import torch, json import torch, json, os
from ..core import ModelConfig, load_state_dict from ..core import ModelConfig, load_state_dict
from ..utils.controlnet import ControlNetInput from ..utils.controlnet import ControlNetInput
from peft import LoraConfig, inject_adapter_in_model from peft import LoraConfig, inject_adapter_in_model
@@ -127,16 +127,67 @@ class DiffusionTrainingModule(torch.nn.Module):
if model_id_with_origin_paths is not None: if model_id_with_origin_paths is not None:
model_id_with_origin_paths = model_id_with_origin_paths.split(",") model_id_with_origin_paths = model_id_with_origin_paths.split(",")
for model_id_with_origin_path in model_id_with_origin_paths: for model_id_with_origin_path in model_id_with_origin_paths:
model_id, origin_file_pattern = model_id_with_origin_path.split(":")
vram_config = self.parse_vram_config( vram_config = self.parse_vram_config(
fp8=model_id_with_origin_path in fp8_models, fp8=model_id_with_origin_path in fp8_models,
offload=model_id_with_origin_path in offload_models, offload=model_id_with_origin_path in offload_models,
device=device device=device
) )
model_configs.append(ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern, **vram_config)) config = self.parse_path_or_model_id(model_id_with_origin_path)
model_configs.append(ModelConfig(model_id=config.model_id, origin_file_pattern=config.origin_file_pattern, **vram_config))
return model_configs return model_configs
def parse_path_or_model_id(self, model_id_with_origin_path, default_value=None):
if model_id_with_origin_path is None:
return default_value
elif os.path.exists(model_id_with_origin_path):
return ModelConfig(path=model_id_with_origin_path)
else:
if ":" not in model_id_with_origin_path:
raise ValueError(f"Failed to parse model config: {model_id_with_origin_path}. This is neither a valid path nor in the format of `model_id/origin_file_pattern`.")
split_id = model_id_with_origin_path.rfind(":")
model_id = model_id_with_origin_path[:split_id]
origin_file_pattern = model_id_with_origin_path[split_id + 1:]
return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern)
def auto_detect_lora_target_modules(
self,
model: torch.nn.Module,
search_for_linear=False,
linear_detector=lambda x: min(x.weight.shape) >= 512,
block_list_detector=lambda x: isinstance(x, torch.nn.ModuleList) and len(x) > 1,
name_prefix="",
):
lora_target_modules = []
if search_for_linear:
for name, module in model.named_modules():
module_name = name_prefix + ["", "."][name_prefix != ""] + name
if isinstance(module, torch.nn.Linear) and linear_detector(module):
lora_target_modules.append(module_name)
else:
for name, module in model.named_children():
module_name = name_prefix + ["", "."][name_prefix != ""] + name
lora_target_modules += self.auto_detect_lora_target_modules(
module,
search_for_linear=block_list_detector(module),
linear_detector=linear_detector,
block_list_detector=block_list_detector,
name_prefix=module_name,
)
return lora_target_modules
def parse_lora_target_modules(self, model, lora_target_modules):
if lora_target_modules == "":
print("No LoRA target modules specified. The framework will automatically search for them.")
lora_target_modules = self.auto_detect_lora_target_modules(model)
print(f"LoRA will be patched at {lora_target_modules}.")
else:
lora_target_modules = lora_target_modules.split(",")
return lora_target_modules
def switch_pipe_to_training_mode( def switch_pipe_to_training_mode(
self, self,
pipe, pipe,
@@ -166,7 +217,7 @@ class DiffusionTrainingModule(torch.nn.Module):
return return
model = self.add_lora_to_model( model = self.add_lora_to_model(
getattr(pipe, lora_base_model), getattr(pipe, lora_base_model),
target_modules=lora_target_modules.split(","), target_modules=self.parse_lora_target_modules(getattr(pipe, lora_base_model), lora_target_modules),
lora_rank=lora_rank, lora_rank=lora_rank,
upcast_dtype=pipe.torch_dtype, upcast_dtype=pipe.torch_dtype,
) )

View File

@@ -2,6 +2,8 @@ from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
import torch import torch
from ..core.device.npu_compatible_device import get_device_type
class DINOv3ImageEncoder(DINOv3ViTModel): class DINOv3ImageEncoder(DINOv3ViTModel):
def __init__(self): def __init__(self):
@@ -70,7 +72,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel):
} }
) )
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
inputs = self.processor(images=image, return_tensors="pt") inputs = self.processor(images=image, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device) pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device)
bool_masked_pos = None bool_masked_pos = None

View File

@@ -823,7 +823,13 @@ class Flux2PosEmbed(nn.Module):
class Flux2TimestepGuidanceEmbeddings(nn.Module): class Flux2TimestepGuidanceEmbeddings(nn.Module):
def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False): def __init__(
self,
in_channels: int = 256,
embedding_dim: int = 6144,
bias: bool = False,
guidance_embeds: bool = True,
):
super().__init__() super().__init__()
self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0) self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
@@ -831,20 +837,24 @@ class Flux2TimestepGuidanceEmbeddings(nn.Module):
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
) )
self.guidance_embedder = TimestepEmbedding( if guidance_embeds:
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias self.guidance_embedder = TimestepEmbedding(
) in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
)
else:
self.guidance_embedder = None
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor: def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
timesteps_proj = self.time_proj(timestep) timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D) timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
guidance_proj = self.time_proj(guidance) if guidance is not None and self.guidance_embedder is not None:
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D) guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
time_guidance_emb = timesteps_emb + guidance_emb time_guidance_emb = timesteps_emb + guidance_emb
return time_guidance_emb
return time_guidance_emb else:
return timesteps_emb
class Flux2Modulation(nn.Module): class Flux2Modulation(nn.Module):
@@ -882,6 +892,7 @@ class Flux2DiT(torch.nn.Module):
axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32), axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
rope_theta: int = 2000, rope_theta: int = 2000,
eps: float = 1e-6, eps: float = 1e-6,
guidance_embeds: bool = True,
): ):
super().__init__() super().__init__()
self.out_channels = out_channels or in_channels self.out_channels = out_channels or in_channels
@@ -892,7 +903,10 @@ class Flux2DiT(torch.nn.Module):
# 2. Combined timestep + guidance embedding # 2. Combined timestep + guidance embedding
self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False in_channels=timestep_guidance_channels,
embedding_dim=self.inner_dim,
bias=False,
guidance_embeds=guidance_embeds,
) )
# 3. Modulation (double stream and single stream blocks share modulation parameters, resp.) # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
@@ -953,34 +967,9 @@ class Flux2DiT(torch.nn.Module):
txt_ids: torch.Tensor = None, txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None, guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
use_gradient_checkpointing=False, use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False, use_gradient_checkpointing_offload=False,
) -> Union[torch.Tensor]: ):
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
Input `hidden_states`.
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
# 0. Handle input arguments # 0. Handle input arguments
if joint_attention_kwargs is not None: if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy() joint_attention_kwargs = joint_attention_kwargs.copy()
@@ -992,7 +981,9 @@ class Flux2DiT(torch.nn.Module):
# 1. Calculate timestep embedding and modulation parameters # 1. Calculate timestep embedding and modulation parameters
timestep = timestep.to(hidden_states.dtype) * 1000 timestep = timestep.to(hidden_states.dtype) * 1000
guidance = guidance.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
temb = self.time_guidance_embed(timestep, guidance) temb = self.time_guidance_embed(timestep, guidance)

View File

@@ -19,7 +19,7 @@ def get_timestep_embedding(
) )
exponent = exponent / (half_dim - downscale_freq_shift) exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent).to(timesteps.device) emb = torch.exp(exponent)
if align_dtype_to_timestep: if align_dtype_to_timestep:
emb = emb.to(timesteps.dtype) emb = emb.to(timesteps.dtype)
emb = timesteps[:, None].float() * emb[None, :] emb = timesteps[:, None].float() * emb[None, :]
@@ -78,7 +78,7 @@ class DiffusersCompatibleTimestepProj(torch.nn.Module):
class TimestepEmbeddings(torch.nn.Module): class TimestepEmbeddings(torch.nn.Module):
def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False): def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False, use_additional_t_cond=False):
super().__init__() super().__init__()
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep) self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep)
if diffusers_compatible_format: if diffusers_compatible_format:
@@ -87,10 +87,17 @@ class TimestepEmbeddings(torch.nn.Module):
self.timestep_embedder = torch.nn.Sequential( self.timestep_embedder = torch.nn.Sequential(
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out) torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
) )
self.use_additional_t_cond = use_additional_t_cond
if use_additional_t_cond:
self.addition_t_embedding = torch.nn.Embedding(2, dim_out)
def forward(self, timestep, dtype): def forward(self, timestep, dtype, addition_t_cond=None):
time_emb = self.time_proj(timestep).to(dtype) time_emb = self.time_proj(timestep).to(dtype)
time_emb = self.timestep_embedder(time_emb) time_emb = self.timestep_embedder(time_emb)
if addition_t_cond is not None:
addition_t_emb = self.addition_t_embedding(addition_t_cond)
addition_t_emb = addition_t_emb.to(dtype=dtype)
time_emb = time_emb + addition_t_emb
return time_emb return time_emb

View File

@@ -9,6 +9,7 @@ import numpy as np
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from .wan_video_dit import flash_attention from .wan_video_dit import flash_attention
from ..core.device.npu_compatible_device import get_device_type
from ..core.gradient import gradient_checkpoint_forward from ..core.gradient import gradient_checkpoint_forward
@@ -373,7 +374,7 @@ class FinalLayer_FP32(nn.Module):
B, N, C = x.shape B, N, C = x.shape
T, _, _ = latent_shape T, _, _ = latent_shape
with amp.autocast('cuda', dtype=torch.float32): with amp.autocast(get_device_type(), dtype=torch.float32):
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C] shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C) x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
x = self.linear(x) x = self.linear(x)
@@ -583,7 +584,7 @@ class LongCatSingleStreamBlock(nn.Module):
T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W. T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.
# compute modulation params in fp32 # compute modulation params in fp32
with amp.autocast(device_type='cuda', dtype=torch.float32): with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
shift_msa, scale_msa, gate_msa, \ shift_msa, scale_msa, gate_msa, \
shift_mlp, scale_mlp, gate_mlp = \ shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C] self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
@@ -602,7 +603,7 @@ class LongCatSingleStreamBlock(nn.Module):
else: else:
x_s = attn_outputs x_s = attn_outputs
with amp.autocast(device_type='cuda', dtype=torch.float32): with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype) x = x.to(x_dtype)
@@ -615,7 +616,7 @@ class LongCatSingleStreamBlock(nn.Module):
# ffn with modulation # ffn with modulation
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C) x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
x_s = self.ffn(x_m) x_s = self.ffn(x_m)
with amp.autocast(device_type='cuda', dtype=torch.float32): with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype) x = x.to(x_dtype)
@@ -797,7 +798,7 @@ class LongCatVideoTransformer3DModel(torch.nn.Module):
hidden_states = self.x_embedder(hidden_states) # [B, N, C] hidden_states = self.x_embedder(hidden_states) # [B, N, C]
with amp.autocast(device_type='cuda', dtype=torch.float32): with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t] t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C] encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,371 @@
from dataclasses import dataclass
from typing import NamedTuple, Protocol, Tuple
import torch
from torch import nn
from enum import Enum
class VideoPixelShape(NamedTuple):
"""
Shape of the tensor representing the video pixel array. Assumes BGR channel format.
"""
batch: int
frames: int
height: int
width: int
fps: float
class SpatioTemporalScaleFactors(NamedTuple):
"""
Describes the spatiotemporal downscaling between decoded video space and
the corresponding VAE latent grid.
"""
time: int
width: int
height: int
@classmethod
def default(cls) -> "SpatioTemporalScaleFactors":
return cls(time=8, width=32, height=32)
VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default()
class VideoLatentShape(NamedTuple):
"""
Shape of the tensor representing video in VAE latent space.
The latent representation is a 5D tensor with dimensions ordered as
(batch, channels, frames, height, width). Spatial and temporal dimensions
are downscaled relative to pixel space according to the VAE's scale factors.
"""
batch: int
channels: int
frames: int
height: int
width: int
def to_torch_shape(self) -> torch.Size:
return torch.Size([self.batch, self.channels, self.frames, self.height, self.width])
@staticmethod
def from_torch_shape(shape: torch.Size) -> "VideoLatentShape":
return VideoLatentShape(
batch=shape[0],
channels=shape[1],
frames=shape[2],
height=shape[3],
width=shape[4],
)
def mask_shape(self) -> "VideoLatentShape":
return self._replace(channels=1)
@staticmethod
def from_pixel_shape(
shape: VideoPixelShape,
latent_channels: int = 128,
scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS,
) -> "VideoLatentShape":
frames = (shape.frames - 1) // scale_factors[0] + 1
height = shape.height // scale_factors[1]
width = shape.width // scale_factors[2]
return VideoLatentShape(
batch=shape.batch,
channels=latent_channels,
frames=frames,
height=height,
width=width,
)
def upscale(self, scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS) -> "VideoLatentShape":
return self._replace(
channels=3,
frames=(self.frames - 1) * scale_factors.time + 1,
height=self.height * scale_factors.height,
width=self.width * scale_factors.width,
)
class AudioLatentShape(NamedTuple):
"""
Shape of audio in VAE latent space: (batch, channels, frames, mel_bins).
mel_bins is the number of frequency bins from the mel-spectrogram encoding.
"""
batch: int
channels: int
frames: int
mel_bins: int
def to_torch_shape(self) -> torch.Size:
return torch.Size([self.batch, self.channels, self.frames, self.mel_bins])
def mask_shape(self) -> "AudioLatentShape":
return self._replace(channels=1, mel_bins=1)
@staticmethod
def from_torch_shape(shape: torch.Size) -> "AudioLatentShape":
return AudioLatentShape(
batch=shape[0],
channels=shape[1],
frames=shape[2],
mel_bins=shape[3],
)
@staticmethod
def from_duration(
batch: int,
duration: float,
channels: int = 8,
mel_bins: int = 16,
sample_rate: int = 16000,
hop_length: int = 160,
audio_latent_downsample_factor: int = 4,
) -> "AudioLatentShape":
latents_per_second = float(sample_rate) / float(hop_length) / float(audio_latent_downsample_factor)
return AudioLatentShape(
batch=batch,
channels=channels,
frames=round(duration * latents_per_second),
mel_bins=mel_bins,
)
@staticmethod
def from_video_pixel_shape(
shape: VideoPixelShape,
channels: int = 8,
mel_bins: int = 16,
sample_rate: int = 16000,
hop_length: int = 160,
audio_latent_downsample_factor: int = 4,
) -> "AudioLatentShape":
return AudioLatentShape.from_duration(
batch=shape.batch,
duration=float(shape.frames) / float(shape.fps),
channels=channels,
mel_bins=mel_bins,
sample_rate=sample_rate,
hop_length=hop_length,
audio_latent_downsample_factor=audio_latent_downsample_factor,
)
@dataclass(frozen=True)
class LatentState:
"""
State of latents during the diffusion denoising process.
Attributes:
latent: The current noisy latent tensor being denoised.
denoise_mask: Mask encoding the denoising strength for each token (1 = full denoising, 0 = no denoising).
positions: Positional indices for each latent element, used for positional embeddings.
clean_latent: Initial state of the latent before denoising, may include conditioning latents.
"""
latent: torch.Tensor
denoise_mask: torch.Tensor
positions: torch.Tensor
clean_latent: torch.Tensor
def clone(self) -> "LatentState":
return LatentState(
latent=self.latent.clone(),
denoise_mask=self.denoise_mask.clone(),
positions=self.positions.clone(),
clean_latent=self.clean_latent.clone(),
)
class NormType(Enum):
"""Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm)."""
GROUP = "group"
PIXEL = "pixel"
class PixelNorm(nn.Module):
"""
Per-pixel (per-location) RMS normalization layer.
For each element along the chosen dimension, this layer normalizes the tensor
by the root-mean-square of its values across that dimension:
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
"""
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
"""
Args:
dim: Dimension along which to compute the RMS (typically channels).
eps: Small constant added for numerical stability.
"""
super().__init__()
self.dim = dim
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply RMS normalization along the configured dimension.
"""
# Compute mean of squared values along `dim`, keep dimensions for broadcasting.
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
# Normalize by the root-mean-square (RMS).
rms = torch.sqrt(mean_sq + self.eps)
return x / rms
def build_normalization_layer(
in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP
) -> nn.Module:
"""
Create a normalization layer based on the normalization type.
Args:
in_channels: Number of input channels
num_groups: Number of groups for group normalization
normtype: Type of normalization: "group" or "pixel"
Returns:
A normalization layer
"""
if normtype == NormType.GROUP:
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if normtype == NormType.PIXEL:
return PixelNorm(dim=1, eps=1e-6)
raise ValueError(f"Invalid normalization type: {normtype}")
def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor:
"""Root-mean-square (RMS) normalize `x` over its last dimension.
Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized
shape and forwards `weight` and `eps`.
"""
return torch.nn.functional.rms_norm(x, (x.shape[-1],), weight=weight, eps=eps)
@dataclass(frozen=True)
class Modality:
"""
Input data for a single modality (video or audio) in the transformer.
Bundles the latent tokens, timestep embeddings, positional information,
and text conditioning context for processing by the diffusion transformer.
"""
latent: (
torch.Tensor
) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps
positions: (
torch.Tensor
) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens
context: torch.Tensor
enabled: bool = True
context_mask: torch.Tensor | None = None
def to_denoised(
sample: torch.Tensor,
velocity: torch.Tensor,
sigma: float | torch.Tensor,
calc_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Convert the sample and its denoising velocity to denoised sample.
Returns:
Denoised sample
"""
if isinstance(sigma, torch.Tensor):
sigma = sigma.to(calc_dtype)
return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype)
class Patchifier(Protocol):
"""
Protocol for patchifiers that convert latent tensors into patches and assemble them back.
"""
def patchify(
self,
latents: torch.Tensor,
) -> torch.Tensor:
...
"""
Convert latent tensors into flattened patch tokens.
Args:
latents: Latent tensor to patchify.
Returns:
Flattened patch tokens tensor.
"""
def unpatchify(
self,
latents: torch.Tensor,
output_shape: AudioLatentShape | VideoLatentShape,
) -> torch.Tensor:
"""
Converts latent tensors between spatio-temporal formats and flattened sequence representations.
Args:
latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`.
output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or
VideoLatentShape.
Returns:
Dense latent tensor restored from the flattened representation.
"""
@property
def patch_size(self) -> Tuple[int, int, int]:
...
"""
Returns the patch size as a tuple of (temporal, height, width) dimensions
"""
def get_patch_grid_bounds(
self,
output_shape: AudioLatentShape | VideoLatentShape,
device: torch.device | None = None,
) -> torch.Tensor:
...
"""
Compute metadata describing where each latent patch resides within the
grid specified by `output_shape`.
Args:
output_shape: Target grid layout for the patches.
device: Target device for the returned tensor.
Returns:
Tensor containing patch coordinate metadata such as spatial or temporal intervals.
"""
def get_pixel_coords(
latent_coords: torch.Tensor,
scale_factors: SpatioTemporalScaleFactors,
causal_fix: bool = False,
) -> torch.Tensor:
"""
Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
Args:
latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
per axis.
causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
that treat frame zero differently still yield non-negative timestamps.
"""
# Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
broadcast_shape = [1] * latent_coords.ndim
broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
# Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
pixel_coords = latent_coords * scale_tensor
if causal_fix:
# VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
# Shift and clamp to keep the first-frame timestamps causal and non-negative.
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
return pixel_coords

1451
diffsynth/models/ltx2_dit.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,366 @@
import torch
from transformers import Gemma3ForConditionalGeneration, Gemma3Config, AutoTokenizer
from .ltx2_dit import (LTXRopeType, generate_freq_grid_np, generate_freq_grid_pytorch, precompute_freqs_cis, Attention,
FeedForward)
from .ltx2_common import rms_norm
class LTX2TextEncoder(Gemma3ForConditionalGeneration):
def __init__(self):
config = Gemma3Config(
**{
"architectures": ["Gemma3ForConditionalGeneration"],
"boi_token_index": 255999,
"dtype": "bfloat16",
"eoi_token_index": 256000,
"eos_token_id": [1, 106],
"image_token_index": 262144,
"initializer_range": 0.02,
"mm_tokens_per_image": 256,
"model_type": "gemma3",
"text_config": {
"_sliding_window_pattern": 6,
"attention_bias": False,
"attention_dropout": 0.0,
"attn_logit_softcapping": None,
"cache_implementation": "hybrid",
"dtype": "bfloat16",
"final_logit_softcapping": None,
"head_dim": 256,
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 3840,
"initializer_range": 0.02,
"intermediate_size": 15360,
"layer_types": [
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention"
],
"max_position_embeddings": 131072,
"model_type": "gemma3_text",
"num_attention_heads": 16,
"num_hidden_layers": 48,
"num_key_value_heads": 8,
"query_pre_attn_scalar": 256,
"rms_norm_eps": 1e-06,
"rope_local_base_freq": 10000,
"rope_scaling": {
"factor": 8.0,
"rope_type": "linear"
},
"rope_theta": 1000000,
"sliding_window": 1024,
"sliding_window_pattern": 6,
"use_bidirectional_attention": False,
"use_cache": True,
"vocab_size": 262208
},
"transformers_version": "4.57.3",
"vision_config": {
"attention_dropout": 0.0,
"dtype": "bfloat16",
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": 896,
"intermediate_size": 4304,
"layer_norm_eps": 1e-06,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 27,
"patch_size": 14,
"vision_use_head": False
}
})
super().__init__(config)
class LTXVGemmaTokenizer:
"""
Tokenizer wrapper for Gemma models compatible with LTXV processes.
This class wraps HuggingFace's `AutoTokenizer` for use with Gemma text encoders,
ensuring correct settings and output formatting for downstream consumption.
"""
def __init__(self, tokenizer_path: str, max_length: int = 1024):
"""
Initialize the tokenizer.
Args:
tokenizer_path (str): Path to the pretrained tokenizer files or model directory.
max_length (int, optional): Max sequence length for encoding. Defaults to 256.
"""
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, local_files_only=True, model_max_length=max_length
)
# Gemma expects left padding for chat-style prompts; for plain text it doesn't matter much.
self.tokenizer.padding_side = "left"
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.max_length = max_length
def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> dict[str, list[tuple[int, int]]]:
"""
Tokenize the given text and return token IDs and attention weights.
Args:
text (str): The input string to tokenize.
return_word_ids (bool, optional): If True, includes the token's position (index) in the output tuples.
If False (default), omits the indices.
Returns:
dict[str, list[tuple[int, int]]] OR dict[str, list[tuple[int, int, int]]]:
A dictionary with a "gemma" key mapping to:
- a list of (token_id, attention_mask) tuples if return_word_ids is False;
- a list of (token_id, attention_mask, index) tuples if return_word_ids is True.
Example:
>>> tokenizer = LTXVGemmaTokenizer("path/to/tokenizer", max_length=8)
>>> tokenizer.tokenize_with_weights("hello world")
{'gemma': [(1234, 1), (5678, 1), (2, 0), ...]}
"""
text = text.strip()
encoded = self.tokenizer(
text,
padding="max_length",
max_length=self.max_length,
truncation=True,
return_tensors="pt",
)
input_ids = encoded.input_ids
attention_mask = encoded.attention_mask
tuples = [
(token_id, attn, i) for i, (token_id, attn) in enumerate(zip(input_ids[0], attention_mask[0], strict=True))
]
out = {"gemma": tuples}
if not return_word_ids:
# Return only (token_id, attention_mask) pairs, omitting token position
out = {k: [(t, w) for t, w, _ in v] for k, v in out.items()}
return out
class GemmaFeaturesExtractorProjLinear(torch.nn.Module):
"""
Feature extractor module for Gemma models.
This module applies a single linear projection to the input tensor.
It expects a flattened feature tensor of shape (batch_size, 3840*49).
The linear layer maps this to a (batch_size, 3840) embedding.
Attributes:
aggregate_embed (torch.nn.Linear): Linear projection layer.
"""
def __init__(self) -> None:
"""
Initialize the GemmaFeaturesExtractorProjLinear module.
The input dimension is expected to be 3840 * 49, and the output is 3840.
"""
super().__init__()
self.aggregate_embed = torch.nn.Linear(3840 * 49, 3840, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the feature extractor.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, 3840 * 49).
Returns:
torch.Tensor: Output tensor of shape (batch_size, 3840).
"""
return self.aggregate_embed(x)
class _BasicTransformerBlock1D(torch.nn.Module):
def __init__(
self,
dim: int,
heads: int,
dim_head: int,
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
):
super().__init__()
self.attn1 = Attention(
query_dim=dim,
heads=heads,
dim_head=dim_head,
rope_type=rope_type,
)
self.ff = FeedForward(
dim,
dim_out=dim,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
pe: torch.Tensor | None = None,
) -> torch.Tensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Normalization Before Self-Attention
norm_hidden_states = rms_norm(hidden_states)
norm_hidden_states = norm_hidden_states.squeeze(1)
# 2. Self-Attention
attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe)
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 3. Normalization before Feed-Forward
norm_hidden_states = rms_norm(hidden_states)
# 4. Feed-forward
ff_output = self.ff(norm_hidden_states)
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class Embeddings1DConnector(torch.nn.Module):
"""
Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or
other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can
substitute padded positions with learnable registers. The module is highly configurable for head size, number of
layers, and register usage.
Args:
attention_head_dim (int): Dimension of each attention head (default=128).
num_attention_heads (int): Number of attention heads (default=30).
num_layers (int): Number of transformer layers (default=2).
positional_embedding_theta (float): Scaling factor for position embedding (default=10000.0).
positional_embedding_max_pos (list[int] | None): Max positions for positional embeddings (default=[1]).
causal_temporal_positioning (bool): If True, uses causal attention (default=False).
num_learnable_registers (int | None): Number of learnable registers to replace padded tokens. If None, disables
register replacement. (default=128)
rope_type (LTXRopeType): The RoPE variant to use (default=DEFAULT_ROPE_TYPE).
double_precision_rope (bool): Use double precision rope calculation (default=False).
"""
_supports_gradient_checkpointing = True
def __init__(
self,
attention_head_dim: int = 128,
num_attention_heads: int = 30,
num_layers: int = 2,
positional_embedding_theta: float = 10000.0,
positional_embedding_max_pos: list[int] | None = [4096],
causal_temporal_positioning: bool = False,
num_learnable_registers: int | None = 128,
rope_type: LTXRopeType = LTXRopeType.SPLIT,
double_precision_rope: bool = True,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.inner_dim = num_attention_heads * attention_head_dim
self.causal_temporal_positioning = causal_temporal_positioning
self.positional_embedding_theta = positional_embedding_theta
self.positional_embedding_max_pos = (
positional_embedding_max_pos if positional_embedding_max_pos is not None else [1]
)
self.rope_type = rope_type
self.double_precision_rope = double_precision_rope
self.transformer_1d_blocks = torch.nn.ModuleList(
[
_BasicTransformerBlock1D(
dim=self.inner_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
rope_type=rope_type,
)
for _ in range(num_layers)
]
)
self.num_learnable_registers = num_learnable_registers
if self.num_learnable_registers:
self.learnable_registers = torch.nn.Parameter(
torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0
)
def _replace_padded_with_learnable_registers(
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.shape[1] % self.num_learnable_registers == 0, (
f"Hidden states sequence length {hidden_states.shape[1]} must be divisible by num_learnable_registers "
f"{self.num_learnable_registers}."
)
num_registers_duplications = hidden_states.shape[1] // self.num_learnable_registers
learnable_registers = torch.tile(self.learnable_registers, (num_registers_duplications, 1))
attention_mask_binary = (attention_mask.squeeze(1).squeeze(1).unsqueeze(-1) >= -9000.0).int()
non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :]
non_zero_nums = non_zero_hidden_states.shape[1]
pad_length = hidden_states.shape[1] - non_zero_nums
adjusted_hidden_states = torch.nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0)
flipped_mask = torch.flip(attention_mask_binary, dims=[1])
hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers
attention_mask = torch.full_like(
attention_mask,
0.0,
dtype=attention_mask.dtype,
device=attention_mask.device,
)
return hidden_states, attention_mask
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass of Embeddings1DConnector.
Args:
hidden_states (torch.Tensor): Input tensor of embeddings (shape [batch, seq_len, feature_dim]).
attention_mask (torch.Tensor|None): Optional mask for valid tokens (shape compatible with hidden_states).
Returns:
tuple[torch.Tensor, torch.Tensor]: Processed features and the corresponding (possibly modified) mask.
"""
if self.num_learnable_registers:
hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask)
indices_grid = torch.arange(hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device)
indices_grid = indices_grid[None, None, :]
freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch
freqs_cis = precompute_freqs_cis(
indices_grid=indices_grid,
dim=self.inner_dim,
out_dtype=hidden_states.dtype,
theta=self.positional_embedding_theta,
max_pos=self.positional_embedding_max_pos,
num_attention_heads=self.num_attention_heads,
rope_type=self.rope_type,
freq_grid_generator=freq_grid_generator,
)
for block in self.transformer_1d_blocks:
hidden_states = block(hidden_states, attention_mask=attention_mask, pe=freqs_cis)
hidden_states = rms_norm(hidden_states)
return hidden_states, attention_mask
class LTX2TextEncoderPostModules(torch.nn.Module):
def __init__(self,):
super().__init__()
self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear()
self.embeddings_connector = Embeddings1DConnector()
self.audio_embeddings_connector = Embeddings1DConnector()

View File

@@ -0,0 +1,313 @@
import math
from typing import Optional, Tuple
import torch
from einops import rearrange
import torch.nn.functional as F
from .ltx2_video_vae import LTX2VideoEncoder
class PixelShuffleND(torch.nn.Module):
"""
N-dimensional pixel shuffle operation for upsampling tensors.
Args:
dims (int): Number of dimensions to apply pixel shuffle to.
- 1: Temporal (e.g., frames)
- 2: Spatial (e.g., height and width)
- 3: Spatiotemporal (e.g., depth, height, width)
upscale_factors (tuple[int, int, int], optional): Upscaling factors for each dimension.
For dims=1, only the first value is used.
For dims=2, the first two values are used.
For dims=3, all three values are used.
The input tensor is rearranged so that the channel dimension is split into
smaller channels and upscaling factors, and the upscaling factors are moved
into the corresponding spatial/temporal dimensions.
Note:
This operation is equivalent to the patchifier operation in for the models. Consider
using this class instead.
"""
def __init__(self, dims: int, upscale_factors: tuple[int, int, int] = (2, 2, 2)):
super().__init__()
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
self.dims = dims
self.upscale_factors = upscale_factors
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.dims == 3:
return rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.upscale_factors[0],
p2=self.upscale_factors[1],
p3=self.upscale_factors[2],
)
elif self.dims == 2:
return rearrange(
x,
"b (c p1 p2) h w -> b c (h p1) (w p2)",
p1=self.upscale_factors[0],
p2=self.upscale_factors[1],
)
elif self.dims == 1:
return rearrange(
x,
"b (c p1) f h w -> b c (f p1) h w",
p1=self.upscale_factors[0],
)
else:
raise ValueError(f"Unsupported dims: {self.dims}")
class ResBlock(torch.nn.Module):
"""
Residual block with two convolutional layers, group normalization, and SiLU activation.
Args:
channels (int): Number of input and output channels.
mid_channels (Optional[int]): Number of channels in the intermediate convolution layer. Defaults to `channels`
if not specified.
dims (int): Dimensionality of the convolution (2 for Conv2d, 3 for Conv3d). Defaults to 3.
"""
def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3):
super().__init__()
if mid_channels is None:
mid_channels = channels
conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
self.conv1 = conv(channels, mid_channels, kernel_size=3, padding=1)
self.norm1 = torch.nn.GroupNorm(32, mid_channels)
self.conv2 = conv(mid_channels, channels, kernel_size=3, padding=1)
self.norm2 = torch.nn.GroupNorm(32, channels)
self.activation = torch.nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.conv1(x)
x = self.norm1(x)
x = self.activation(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.activation(x + residual)
return x
class BlurDownsample(torch.nn.Module):
"""
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel.
Applies only on H,W. Works for dims=2 or dims=3 (per-frame).
"""
def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None:
super().__init__()
assert dims in (2, 3)
assert isinstance(stride, int)
assert stride >= 1
assert kernel_size >= 3
assert kernel_size % 2 == 1
self.dims = dims
self.stride = stride
self.kernel_size = kernel_size
# 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from
# the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and
# provides a smooth approximation of a Gaussian filter (often called a "binomial filter").
# The 2D kernel is constructed as the outer product and normalized.
k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)])
k2d = k[:, None] @ k[None, :]
k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size)
self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.stride == 1:
return x
if self.dims == 2:
return self._apply_2d(x)
else:
# dims == 3: apply per-frame on H,W
b, _, f, _, _ = x.shape
x = rearrange(x, "b c f h w -> (b f) c h w")
x = self._apply_2d(x)
h2, w2 = x.shape[-2:]
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2)
return x
def _apply_2d(self, x2d: torch.Tensor) -> torch.Tensor:
c = x2d.shape[1]
weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
x2d = F.conv2d(x2d, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
return x2d
def _rational_for_scale(scale: float) -> Tuple[int, int]:
mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)}
if float(scale) not in mapping:
raise ValueError(f"Unsupported scale {scale}. Choose from {list(mapping.keys())}")
return mapping[float(scale)]
class SpatialRationalResampler(torch.nn.Module):
"""
Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
downsample by 'den' using fixed blur + stride. Operates on H,W only.
For dims==3, work per-frame for spatial scaling (temporal axis untouched).
Args:
mid_channels (`int`): Number of intermediate channels for the convolution layer
scale (`float`): Spatial scaling factor. Supported values are:
- 0.75: Downsample by 3/4 (reduce spatial size)
- 1.5: Upsample by 3/2 (increase spatial size)
- 2.0: Upsample by 2x (double spatial size)
- 4.0: Upsample by 4x (quadruple spatial size)
Any other value will raise a ValueError.
"""
def __init__(self, mid_channels: int, scale: float):
super().__init__()
self.scale = float(scale)
self.num, self.den = _rational_for_scale(self.scale)
self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1)
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
self.blur_down = BlurDownsample(dims=2, stride=self.den)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, _, f, _, _ = x.shape
x = rearrange(x, "b c f h w -> (b f) c h w")
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.blur_down(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
return x
class LTX2LatentUpsampler(torch.nn.Module):
"""
Model to upsample VAE latents spatially and/or temporally.
Args:
in_channels (`int`): Number of channels in the input latent
mid_channels (`int`): Number of channels in the middle layers
num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
dims (`int`): Number of dimensions for convolutions (2 or 3)
spatial_upsample (`bool`): Whether to spatially upsample the latent
temporal_upsample (`bool`): Whether to temporally upsample the latent
spatial_scale (`float`): Scale factor for spatial upsampling
rational_resampler (`bool`): Whether to use a rational resampler for spatial upsampling
"""
def __init__(
self,
in_channels: int = 128,
mid_channels: int = 1024,
num_blocks_per_stage: int = 4,
dims: int = 3,
spatial_upsample: bool = True,
temporal_upsample: bool = False,
spatial_scale: float = 2.0,
rational_resampler: bool = True,
):
super().__init__()
self.in_channels = in_channels
self.mid_channels = mid_channels
self.num_blocks_per_stage = num_blocks_per_stage
self.dims = dims
self.spatial_upsample = spatial_upsample
self.temporal_upsample = temporal_upsample
self.spatial_scale = float(spatial_scale)
self.rational_resampler = rational_resampler
conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
self.initial_conv = conv(in_channels, mid_channels, kernel_size=3, padding=1)
self.initial_norm = torch.nn.GroupNorm(32, mid_channels)
self.initial_activation = torch.nn.SiLU()
self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])
if spatial_upsample and temporal_upsample:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(3),
)
elif spatial_upsample:
if rational_resampler:
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=self.spatial_scale)
else:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(2),
)
elif temporal_upsample:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(1),
)
else:
raise ValueError("Either spatial_upsample or temporal_upsample must be True")
self.post_upsample_res_blocks = torch.nn.ModuleList(
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
)
self.final_conv = conv(mid_channels, in_channels, kernel_size=3, padding=1)
def forward(self, latent: torch.Tensor) -> torch.Tensor:
b, _, f, _, _ = latent.shape
if self.dims == 2:
x = rearrange(latent, "b c f h w -> (b f) c h w")
x = self.initial_conv(x)
x = self.initial_norm(x)
x = self.initial_activation(x)
for block in self.res_blocks:
x = block(x)
x = self.upsampler(x)
for block in self.post_upsample_res_blocks:
x = block(x)
x = self.final_conv(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
else:
x = self.initial_conv(latent)
x = self.initial_norm(x)
x = self.initial_activation(x)
for block in self.res_blocks:
x = block(x)
if self.temporal_upsample:
x = self.upsampler(x)
# remove the first frame after upsampling.
# This is done because the first frame encodes one pixel frame.
x = x[:, :, 1:, :, :]
elif isinstance(self.upsampler, SpatialRationalResampler):
x = self.upsampler(x)
else:
x = rearrange(x, "b c f h w -> (b f) c h w")
x = self.upsampler(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
for block in self.post_upsample_res_blocks:
x = block(x)
x = self.final_conv(x)
return x
def upsample_video(latent: torch.Tensor, video_encoder: LTX2VideoEncoder, upsampler: "LTX2LatentUpsampler") -> torch.Tensor:
"""
Apply upsampling to the latent representation using the provided upsampler,
with normalization and un-normalization based on the video encoder's per-channel statistics.
Args:
latent: Input latent tensor of shape [B, C, F, H, W].
video_encoder: VideoEncoder with per_channel_statistics for normalization.
upsampler: LTX2LatentUpsampler module to perform upsampling.
Returns:
torch.Tensor: Upsampled and re-normalized latent tensor.
"""
latent = video_encoder.per_channel_statistics.un_normalize(latent)
latent = upsampler(latent)
latent = video_encoder.per_channel_statistics.normalize(latent)
return latent

File diff suppressed because it is too large Load Diff

View File

@@ -29,7 +29,7 @@ class ModelPool:
module_map = None module_map = None
return module_map return module_map
def load_model_file(self, config, path, vram_config, vram_limit=None): def load_model_file(self, config, path, vram_config, vram_limit=None, state_dict=None):
model_class = self.import_model_class(config["model_class"]) model_class = self.import_model_class(config["model_class"])
model_config = config.get("extra_kwargs", {}) model_config = config.get("extra_kwargs", {})
if "state_dict_converter" in config: if "state_dict_converter" in config:
@@ -43,6 +43,7 @@ class ModelPool:
state_dict_converter, state_dict_converter,
use_disk_map=True, use_disk_map=True,
vram_config=vram_config, module_map=module_map, vram_limit=vram_limit, vram_config=vram_config, module_map=module_map, vram_limit=vram_limit,
state_dict=state_dict,
) )
return model return model
@@ -59,7 +60,7 @@ class ModelPool:
} }
return vram_config return vram_config
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False): def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False, state_dict=None):
print(f"Loading models from: {json.dumps(path, indent=4)}") print(f"Loading models from: {json.dumps(path, indent=4)}")
if vram_config is None: if vram_config is None:
vram_config = self.default_vram_config() vram_config = self.default_vram_config()
@@ -67,7 +68,7 @@ class ModelPool:
loaded = False loaded = False
for config in MODEL_CONFIGS: for config in MODEL_CONFIGS:
if config["model_hash"] == model_hash: if config["model_hash"] == model_hash:
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit) model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit, state_dict=state_dict)
if clear_parameters: self.clear_parameters(model) if clear_parameters: self.clear_parameters(model)
self.model.append(model) self.model.append(model)
model_name = config["model_name"] model_name = config["model_name"]

View File

@@ -583,7 +583,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache
is_compileable = is_compileable and not self.generation_config.disable_compile is_compileable = is_compileable and not self.generation_config.disable_compile
if is_compileable and ( if is_compileable and (
self.device.type == "cuda" or generation_config.compile_config._compile_all_devices self.device.type in ["cuda", "npu"] or generation_config.compile_config._compile_all_devices
): ):
os.environ["TOKENIZERS_PARALLELISM"] = "0" os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = self.get_compiled_call(generation_config.compile_config) model_forward = self.get_compiled_call(generation_config.compile_config)

View File

@@ -1,4 +1,4 @@
import torch, math import torch, math, functools
import torch.nn as nn import torch.nn as nn
from typing import Tuple, Optional, Union, List from typing import Tuple, Optional, Union, List
from einops import rearrange from einops import rearrange
@@ -225,6 +225,121 @@ class QwenEmbedRope(nn.Module):
return vid_freqs, txt_freqs return vid_freqs, txt_freqs
class QwenEmbedLayer3DRope(nn.Module):
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
pos_index = torch.arange(4096)
neg_index = torch.arange(4096).flip(0) * -1 - 1
self.pos_freqs = torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
)
self.neg_freqs = torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
],
dim=1,
)
self.scale_rope = scale_rope
def rope_params(self, index, dim, theta=10000):
"""
Args:
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
"""
assert dim % 2 == 0
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
def forward(self, video_fhw, txt_seq_lens, device):
"""
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
txt_length: [bs] a list of 1 integers representing the length of the text
"""
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
video_fhw = [video_fhw]
if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
if not isinstance(video_fhw, list):
video_fhw = [video_fhw]
vid_freqs = []
max_vid_index = 0
layer_num = len(video_fhw) - 1
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
if idx != layer_num:
video_freq = self._compute_video_freqs(frame, height, width, idx)
else:
### For the condition image, we set the layer index to -1
video_freq = self._compute_condition_freqs(frame, height, width)
video_freq = video_freq.to(device)
vid_freqs.append(video_freq)
if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else:
max_vid_index = max(height, width, max_vid_index)
max_vid_index = max(max_vid_index, layer_num)
max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
@functools.lru_cache(maxsize=None)
def _compute_condition_freqs(self, frame, height, width):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
class QwenFeedForward(nn.Module): class QwenFeedForward(nn.Module):
def __init__( def __init__(
self, self,
@@ -352,9 +467,38 @@ class QwenImageTransformerBlock(nn.Module):
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim) self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim)
def _modulate(self, x, mod_params): def _modulate(self, x, mod_params, index=None):
shift, scale, gate = mod_params.chunk(3, dim=-1) shift, scale, gate = mod_params.chunk(3, dim=-1)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) if index is not None:
# Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts)
# So shift, scale, gate have shape [2*actual_batch, d]
actual_batch = shift.size(0) // 2
shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d]
scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]
gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]
# index: [b, l] where b is actual batch size
# Expand to [b, l, 1] to match feature dimension
index_expanded = index.unsqueeze(-1) # [b, l, 1]
# Expand chunks to [b, 1, d] then broadcast to [b, l, d]
shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d]
shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d]
scale_0_exp = scale_0.unsqueeze(1)
scale_1_exp = scale_1.unsqueeze(1)
gate_0_exp = gate_0.unsqueeze(1)
gate_1_exp = gate_1.unsqueeze(1)
# Use torch.where to select based on index
shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp)
scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp)
gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp)
else:
shift_result = shift.unsqueeze(1)
scale_result = scale.unsqueeze(1)
gate_result = gate.unsqueeze(1)
return x * (1 + scale_result) + shift_result, gate_result
def forward( def forward(
self, self,
@@ -364,13 +508,16 @@ class QwenImageTransformerBlock(nn.Module):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
enable_fp8_attention = False, enable_fp8_attention = False,
modulate_index: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
if modulate_index is not None:
temb = torch.chunk(temb, 2, dim=0)[0]
txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
img_normed = self.img_norm1(image) img_normed = self.img_norm1(image)
img_modulated, img_gate = self._modulate(img_normed, img_mod_attn) img_modulated, img_gate = self._modulate(img_normed, img_mod_attn, index=modulate_index)
txt_normed = self.txt_norm1(text) txt_normed = self.txt_norm1(text)
txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn) txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn)
@@ -387,7 +534,7 @@ class QwenImageTransformerBlock(nn.Module):
text = text + txt_gate * txt_attn_out text = text + txt_gate * txt_attn_out
img_normed_2 = self.img_norm2(image) img_normed_2 = self.img_norm2(image)
img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp) img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp, index=modulate_index)
txt_normed_2 = self.txt_norm2(text) txt_normed_2 = self.txt_norm2(text)
txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp) txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp)
@@ -405,12 +552,17 @@ class QwenImageDiT(torch.nn.Module):
def __init__( def __init__(
self, self,
num_layers: int = 60, num_layers: int = 60,
use_layer3d_rope: bool = False,
use_additional_t_cond: bool = False,
): ):
super().__init__() super().__init__()
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True) if not use_layer3d_rope:
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
else:
self.pos_embed = QwenEmbedLayer3DRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True) self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=False, use_additional_t_cond=use_additional_t_cond)
self.txt_norm = RMSNorm(3584, eps=1e-6) self.txt_norm = RMSNorm(3584, eps=1e-6)
self.img_in = nn.Linear(64, 3072) self.img_in = nn.Linear(64, 3072)

View File

@@ -366,6 +366,7 @@ class QwenImageEncoder3d(nn.Module):
temperal_downsample=[True, True, False], temperal_downsample=[True, True, False],
dropout=0.0, dropout=0.0,
non_linearity: str = "silu", non_linearity: str = "silu",
image_channels=3
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@@ -381,7 +382,7 @@ class QwenImageEncoder3d(nn.Module):
scale = 1.0 scale = 1.0
# init block # init block
self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1) self.conv_in = QwenImageCausalConv3d(image_channels, dims[0], 3, padding=1)
# downsample blocks # downsample blocks
self.down_blocks = torch.nn.ModuleList([]) self.down_blocks = torch.nn.ModuleList([])
@@ -544,6 +545,7 @@ class QwenImageDecoder3d(nn.Module):
temperal_upsample=[False, True, True], temperal_upsample=[False, True, True],
dropout=0.0, dropout=0.0,
non_linearity: str = "silu", non_linearity: str = "silu",
image_channels=3,
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@@ -594,7 +596,7 @@ class QwenImageDecoder3d(nn.Module):
# output blocks # output blocks
self.norm_out = QwenImageRMS_norm(out_dim, images=False) self.norm_out = QwenImageRMS_norm(out_dim, images=False)
self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1) self.conv_out = QwenImageCausalConv3d(out_dim, image_channels, 3, padding=1)
self.gradient_checkpointing = False self.gradient_checkpointing = False
@@ -647,6 +649,7 @@ class QwenImageVAE(torch.nn.Module):
attn_scales: List[float] = [], attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True], temperal_downsample: List[bool] = [False, True, True],
dropout: float = 0.0, dropout: float = 0.0,
image_channels: int = 3,
) -> None: ) -> None:
super().__init__() super().__init__()
@@ -655,13 +658,13 @@ class QwenImageVAE(torch.nn.Module):
self.temperal_upsample = temperal_downsample[::-1] self.temperal_upsample = temperal_downsample[::-1]
self.encoder = QwenImageEncoder3d( self.encoder = QwenImageEncoder3d(
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, image_channels=image_channels,
) )
self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1) self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1) self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
self.decoder = QwenImageDecoder3d( self.decoder = QwenImageDecoder3d(
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, image_channels=image_channels,
) )
mean = [ mean = [

View File

@@ -1,7 +1,9 @@
from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig
from transformers import SiglipImageProcessor from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
import torch import torch
from diffsynth.core.device.npu_compatible_device import get_device_type
class Siglip2ImageEncoder(SiglipVisionTransformer): class Siglip2ImageEncoder(SiglipVisionTransformer):
def __init__(self): def __init__(self):
@@ -47,7 +49,7 @@ class Siglip2ImageEncoder(SiglipVisionTransformer):
} }
) )
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"] pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"]
pixel_values = pixel_values.to(device=device, dtype=torch_dtype) pixel_values = pixel_values.to(device=device, dtype=torch_dtype)
output_attentions = False output_attentions = False
@@ -68,3 +70,65 @@ class Siglip2ImageEncoder(SiglipVisionTransformer):
pooler_output = self.head(last_hidden_state) if self.use_head else None pooler_output = self.head(last_hidden_state) if self.use_head else None
return pooler_output return pooler_output
class Siglip2ImageEncoder428M(Siglip2VisionModel):
def __init__(self):
config = Siglip2VisionConfig(
attention_dropout = 0.0,
dtype = "bfloat16",
hidden_act = "gelu_pytorch_tanh",
hidden_size = 1152,
intermediate_size = 4304,
layer_norm_eps = 1e-06,
model_type = "siglip2_vision_model",
num_attention_heads = 16,
num_channels = 3,
num_hidden_layers = 27,
num_patches = 256,
patch_size = 16,
transformers_version = "4.57.1"
)
super().__init__(config)
self.processor = Siglip2ImageProcessorFast(
**{
"data_format": "channels_first",
"default_to_square": True,
"device": None,
"disable_grouping": None,
"do_convert_rgb": None,
"do_normalize": True,
"do_pad": None,
"do_rescale": True,
"do_resize": True,
"image_mean": [
0.5,
0.5,
0.5
],
"image_processor_type": "Siglip2ImageProcessorFast",
"image_std": [
0.5,
0.5,
0.5
],
"input_data_format": None,
"max_num_patches": 256,
"pad_size": None,
"patch_size": 16,
"processor_class": "Siglip2Processor",
"resample": 2,
"rescale_factor": 0.00392156862745098,
"return_tensors": None,
}
)
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
siglip_inputs = self.processor(images=[image], return_tensors="pt").to(device)
shape = siglip_inputs.spatial_shapes[0]
hidden_state = super().forward(**siglip_inputs).last_hidden_state
B, N, C = hidden_state.shape
hidden_state = hidden_state[:, : shape[0] * shape[1]]
hidden_state = hidden_state.view(shape[0], shape[1], C)
hidden_state = hidden_state.to(torch_dtype)
return hidden_state

View File

@@ -1,10 +1,11 @@
import torch import torch
from typing import Optional, Union from typing import Optional, Union
from .qwen_image_text_encoder import QwenImageTextEncoder from .qwen_image_text_encoder import QwenImageTextEncoder
from ..core.device.npu_compatible_device import get_device_type, get_torch_device
class Step1xEditEmbedder(torch.nn.Module): class Step1xEditEmbedder(torch.nn.Module):
def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device="cuda"): def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device=get_device_type()):
super().__init__() super().__init__()
self.max_length = max_length self.max_length = max_length
self.dtype = dtype self.dtype = dtype
@@ -77,13 +78,13 @@ User Prompt:'''
self.max_length, self.max_length,
self.model.config.hidden_size, self.model.config.hidden_size,
dtype=torch.bfloat16, dtype=torch.bfloat16,
device=torch.cuda.current_device(), device=get_torch_device().current_device(),
) )
masks = torch.zeros( masks = torch.zeros(
len(text_list), len(text_list),
self.max_length, self.max_length,
dtype=torch.long, dtype=torch.long,
device=torch.cuda.current_device(), device=get_torch_device().current_device(),
) )
def split_string(s): def split_string(s):
@@ -158,7 +159,7 @@ User Prompt:'''
else: else:
token_list.append(token_each) token_list.append(token_each)
new_txt_ids = torch.cat(token_list, dim=1).to("cuda") new_txt_ids = torch.cat(token_list, dim=1).to(get_device_type())
new_txt_ids = new_txt_ids.to(old_inputs_ids.device) new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
@@ -167,15 +168,15 @@ User Prompt:'''
inputs.input_ids = ( inputs.input_ids = (
torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0) torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
.unsqueeze(0) .unsqueeze(0)
.to("cuda") .to(get_device_type())
) )
inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda") inputs.attention_mask = (inputs.input_ids > 0).long().to(get_device_type())
outputs = self.model_forward( outputs = self.model_forward(
self.model, self.model,
input_ids=inputs.input_ids, input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask, attention_mask=inputs.attention_mask,
pixel_values=inputs.pixel_values.to("cuda"), pixel_values=inputs.pixel_values.to(get_device_type()),
image_grid_thw=inputs.image_grid_thw.to("cuda"), image_grid_thw=inputs.image_grid_thw.to(get_device_type()),
output_hidden_states=True, output_hidden_states=True,
) )
@@ -188,7 +189,7 @@ User Prompt:'''
masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones( masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
(min(self.max_length, emb.shape[1] - 217)), (min(self.max_length, emb.shape[1] - 217)),
dtype=torch.long, dtype=torch.long,
device=torch.cuda.current_device(), device=get_torch_device().current_device(),
) )
return embs, masks return embs, masks

View File

@@ -5,6 +5,7 @@ import math
from typing import Tuple, Optional from typing import Tuple, Optional
from einops import rearrange from einops import rearrange
from .wan_video_camera_controller import SimpleAdapter from .wan_video_camera_controller import SimpleAdapter
try: try:
import flash_attn_interface import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True FLASH_ATTN_3_AVAILABLE = True
@@ -92,6 +93,7 @@ def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
x_out = torch.view_as_complex(x.to(torch.float64).reshape( x_out = torch.view_as_complex(x.to(torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2)) x.shape[0], x.shape[1], x.shape[2], -1, 2))
freqs = freqs.to(torch.complex64) if freqs.device.type == "npu" else freqs
x_out = torch.view_as_real(x_out * freqs).flatten(2) x_out = torch.view_as_real(x_out * freqs).flatten(2)
return x_out.to(x.dtype) return x_out.to(x.dtype)

View File

@@ -0,0 +1,154 @@
from .z_image_dit import ZImageTransformerBlock
from ..core.gradient import gradient_checkpoint_forward
from torch.nn.utils.rnn import pad_sequence
import torch
from torch import nn
class ZImageControlTransformerBlock(ZImageTransformerBlock):
def __init__(
self,
layer_id: int = 1000,
dim: int = 3840,
n_heads: int = 30,
n_kv_heads: int = 30,
norm_eps: float = 1e-5,
qk_norm: bool = True,
modulation = True,
block_id = 0
):
super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation)
self.block_id = block_id
if block_id == 0:
self.before_proj = nn.Linear(self.dim, self.dim)
self.after_proj = nn.Linear(self.dim, self.dim)
def forward(self, c, x, **kwargs):
if self.block_id == 0:
c = self.before_proj(c) + x
all_c = []
else:
all_c = list(torch.unbind(c))
c = all_c.pop(-1)
c = super().forward(c, **kwargs)
c_skip = self.after_proj(c)
all_c += [c_skip, c]
c = torch.stack(all_c)
return c
class ZImageControlNet(torch.nn.Module):
def __init__(
self,
control_layers_places=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
control_in_dim=33,
dim=3840,
n_refiner_layers=2,
):
super().__init__()
self.control_layers = nn.ModuleList([ZImageControlTransformerBlock(layer_id=i, block_id=i) for i in control_layers_places])
self.control_all_x_embedder = nn.ModuleDict({"2-1": nn.Linear(1 * 2 * 2 * control_in_dim, dim, bias=True)})
self.control_noise_refiner = nn.ModuleList([ZImageControlTransformerBlock(block_id=layer_id) for layer_id in range(n_refiner_layers)])
self.control_layers_mapping = {0: 0, 2: 1, 4: 2, 6: 3, 8: 4, 10: 5, 12: 6, 14: 7, 16: 8, 18: 9, 20: 10, 22: 11, 24: 12, 26: 13, 28: 14}
def forward_layers(
self,
x,
cap_feats,
control_context,
control_context_item_seqlens,
kwargs,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
):
bsz = len(control_context)
# unified
cap_item_seqlens = [len(_) for _ in cap_feats]
control_context_unified = []
for i in range(bsz):
control_context_len = control_context_item_seqlens[i]
cap_len = cap_item_seqlens[i]
control_context_unified.append(torch.cat([control_context[i][:control_context_len], cap_feats[i][:cap_len]]))
c = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0)
# arguments
new_kwargs = dict(x=x)
new_kwargs.update(kwargs)
for layer in self.control_layers:
c = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
c=c, **new_kwargs
)
hints = torch.unbind(c)[:-1]
return hints
def forward_refiner(
self,
dit,
x,
cap_feats,
control_context,
kwargs,
t=None,
patch_size=2,
f_patch_size=1,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
):
# embeddings
bsz = len(control_context)
device = control_context[0].device
(
control_context,
control_context_size,
control_context_pos_ids,
control_context_inner_pad_mask,
) = dit.patchify_controlnet(control_context, patch_size, f_patch_size, cap_feats[0].size(0))
# control_context embed & refine
control_context_item_seqlens = [len(_) for _ in control_context]
assert all(_ % 2 == 0 for _ in control_context_item_seqlens)
control_context_max_item_seqlen = max(control_context_item_seqlens)
control_context = torch.cat(control_context, dim=0)
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context)
# Match t_embedder output dtype to control_context for layerwise casting compatibility
adaln_input = t.type_as(control_context)
control_context[torch.cat(control_context_inner_pad_mask)] = dit.x_pad_token.to(dtype=control_context.dtype, device=control_context.device)
control_context = list(control_context.split(control_context_item_seqlens, dim=0))
control_context_freqs_cis = list(dit.rope_embedder(torch.cat(control_context_pos_ids, dim=0)).split(control_context_item_seqlens, dim=0))
control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0)
control_context_freqs_cis = pad_sequence(control_context_freqs_cis, batch_first=True, padding_value=0.0)
control_context_attn_mask = torch.zeros((bsz, control_context_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(control_context_item_seqlens):
control_context_attn_mask[i, :seq_len] = 1
c = control_context
# arguments
new_kwargs = dict(
x=x,
attn_mask=control_context_attn_mask,
freqs_cis=control_context_freqs_cis,
adaln_input=adaln_input,
)
new_kwargs.update(kwargs)
for layer in self.control_noise_refiner:
c = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
c=c, **new_kwargs
)
hints = torch.unbind(c)[:-1]
control_context = torch.unbind(c)[-1]
return hints, control_context, control_context_item_seqlens

View File

@@ -6,13 +6,15 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from torch.nn import RMSNorm from .general_modules import RMSNorm
from ..core.attention import attention_forward from ..core.attention import attention_forward
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type
from ..core.gradient import gradient_checkpoint_forward from ..core.gradient import gradient_checkpoint_forward
ADALN_EMBED_DIM = 256 ADALN_EMBED_DIM = 256
SEQ_MULTI_OF = 32 SEQ_MULTI_OF = 32
X_PAD_DIM = 64
class TimestepEmbedder(nn.Module): class TimestepEmbedder(nn.Module):
@@ -38,7 +40,7 @@ class TimestepEmbedder(nn.Module):
@staticmethod @staticmethod
def timestep_embedding(t, dim, max_period=10000): def timestep_embedding(t, dim, max_period=10000):
with torch.amp.autocast("cuda", enabled=False): with torch.amp.autocast(get_device_type(), enabled=False):
half = dim // 2 half = dim // 2
freqs = torch.exp( freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
@@ -86,7 +88,7 @@ class Attention(torch.nn.Module):
self.norm_q = RMSNorm(head_dim, eps=1e-5) self.norm_q = RMSNorm(head_dim, eps=1e-5)
self.norm_k = RMSNorm(head_dim, eps=1e-5) self.norm_k = RMSNorm(head_dim, eps=1e-5)
def forward(self, hidden_states, freqs_cis): def forward(self, hidden_states, freqs_cis, attention_mask):
query = self.to_q(hidden_states) query = self.to_q(hidden_states)
key = self.to_k(hidden_states) key = self.to_k(hidden_states)
value = self.to_v(hidden_states) value = self.to_v(hidden_states)
@@ -103,7 +105,7 @@ class Attention(torch.nn.Module):
# Apply RoPE # Apply RoPE
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast("cuda", enabled=False): with torch.amp.autocast(get_device_type(), enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2) freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3) x_out = torch.view_as_real(x * freqs_cis).flatten(3)
@@ -123,6 +125,7 @@ class Attention(torch.nn.Module):
key, key,
value, value,
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d", q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
attn_mask=attention_mask,
) )
# Reshape back # Reshape back
@@ -136,6 +139,20 @@ class Attention(torch.nn.Module):
return output return output
def select_per_token(
value_noisy: torch.Tensor,
value_clean: torch.Tensor,
noise_mask: torch.Tensor,
seq_len: int,
) -> torch.Tensor:
noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1)
return torch.where(
noise_mask_expanded == 1,
value_noisy.unsqueeze(1).expand(-1, seq_len, -1),
value_clean.unsqueeze(1).expand(-1, seq_len, -1),
)
class ZImageTransformerBlock(nn.Module): class ZImageTransformerBlock(nn.Module):
def __init__( def __init__(
self, self,
@@ -180,40 +197,53 @@ class ZImageTransformerBlock(nn.Module):
attn_mask: torch.Tensor, attn_mask: torch.Tensor,
freqs_cis: torch.Tensor, freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor] = None, adaln_input: Optional[torch.Tensor] = None,
noise_mask: Optional[torch.Tensor] = None,
adaln_noisy: Optional[torch.Tensor] = None,
adaln_clean: Optional[torch.Tensor] = None,
): ):
if self.modulation: if self.modulation:
assert adaln_input is not None seq_len = x.shape[1]
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() if noise_mask is not None:
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp # Per-token modulation: different modulation for noisy/clean tokens
mod_noisy = self.adaLN_modulation(adaln_noisy)
mod_clean = self.adaLN_modulation(adaln_clean)
scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1)
scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1)
gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh()
gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh()
scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy
scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean
scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len)
scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len)
gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len)
gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len)
else:
# Global modulation: same modulation for all tokens (avoid double select)
mod = self.adaLN_modulation(adaln_input)
scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2)
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
# Attention block # Attention block
attn_out = self.attention( attn_out = self.attention(
self.attention_norm1(x) * scale_msa, self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis
freqs_cis=freqs_cis,
) )
x = x + gate_msa * self.attention_norm2(attn_out) x = x + gate_msa * self.attention_norm2(attn_out)
# FFN block # FFN block
x = x + gate_mlp * self.ffn_norm2( x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
self.feed_forward(
self.ffn_norm1(x) * scale_mlp,
)
)
else: else:
# Attention block # Attention block
attn_out = self.attention( attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis)
self.attention_norm1(x),
freqs_cis=freqs_cis,
)
x = x + self.attention_norm2(attn_out) x = x + self.attention_norm2(attn_out)
# FFN block # FFN block
x = x + self.ffn_norm2( x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
self.feed_forward(
self.ffn_norm1(x),
)
)
return x return x
@@ -229,9 +259,21 @@ class FinalLayer(nn.Module):
nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
) )
def forward(self, x, c): def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None):
scale = 1.0 + self.adaLN_modulation(c) seq_len = x.shape[1]
x = self.norm_final(x) * scale.unsqueeze(1)
if noise_mask is not None:
# Per-token modulation
scale_noisy = 1.0 + self.adaLN_modulation(c_noisy)
scale_clean = 1.0 + self.adaLN_modulation(c_clean)
scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len)
else:
# Original global modulation
assert c is not None, "Either c or (c_noisy, c_clean) must be provided"
scale = 1.0 + self.adaLN_modulation(c)
scale = scale.unsqueeze(1)
x = self.norm_final(x) * scale
x = self.linear(x) x = self.linear(x)
return x return x
@@ -274,7 +316,10 @@ class RopeEmbedder:
result = [] result = []
for i in range(len(self.axes_dims)): for i in range(len(self.axes_dims)):
index = ids[:, i] index = ids[:, i]
result.append(self.freqs_cis[i][index]) if IS_NPU_AVAILABLE:
result.append(torch.index_select(self.freqs_cis[i], 0, index))
else:
result.append(self.freqs_cis[i][index])
return torch.cat(result, dim=-1) return torch.cat(result, dim=-1)
@@ -299,6 +344,7 @@ class ZImageDiT(nn.Module):
t_scale=1000.0, t_scale=1000.0,
axes_dims=[32, 48, 48], axes_dims=[32, 48, 48],
axes_lens=[1024, 512, 512], axes_lens=[1024, 512, 512],
siglip_feat_dim=None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
@@ -359,6 +405,32 @@ class ZImageDiT(nn.Module):
nn.Linear(cap_feat_dim, dim, bias=True), nn.Linear(cap_feat_dim, dim, bias=True),
) )
# Optional SigLIP components (for Omni variant)
self.siglip_feat_dim = siglip_feat_dim
if siglip_feat_dim is not None:
self.siglip_embedder = nn.Sequential(
RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True)
)
self.siglip_refiner = nn.ModuleList(
[
ZImageTransformerBlock(
2000 + layer_id,
dim,
n_heads,
n_kv_heads,
norm_eps,
qk_norm,
modulation=False,
)
for layer_id in range(n_refiner_layers)
]
)
self.siglip_pad_token = nn.Parameter(torch.empty((1, dim)))
else:
self.siglip_embedder = None
self.siglip_refiner = None
self.siglip_pad_token = None
self.x_pad_token = nn.Parameter(torch.empty((1, dim))) self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
@@ -375,22 +447,57 @@ class ZImageDiT(nn.Module):
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]: def unpatchify(
self,
x: List[torch.Tensor],
size: List[Tuple],
patch_size = 2,
f_patch_size = 1,
x_pos_offsets: Optional[List[Tuple[int, int]]] = None,
) -> List[torch.Tensor]:
pH = pW = patch_size pH = pW = patch_size
pF = f_patch_size pF = f_patch_size
bsz = len(x) bsz = len(x)
assert len(size) == bsz assert len(size) == bsz
for i in range(bsz):
F, H, W = size[i] if x_pos_offsets is not None:
ori_len = (F // pF) * (H // pH) * (W // pW) # Omni: extract target image from unified sequence (cond_images + target)
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" result = []
x[i] = ( for i in range(bsz):
x[i][:ori_len] unified_x = x[i][x_pos_offsets[i][0] : x_pos_offsets[i][1]]
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) cu_len = 0
.permute(6, 0, 3, 1, 4, 2, 5) x_item = None
.reshape(self.out_channels, F, H, W) for j in range(len(size[i])):
) if size[i][j] is None:
return x ori_len = 0
pad_len = SEQ_MULTI_OF
cu_len += pad_len + ori_len
else:
F, H, W = size[i][j]
ori_len = (F // pF) * (H // pH) * (W // pW)
pad_len = (-ori_len) % SEQ_MULTI_OF
x_item = (
unified_x[cu_len : cu_len + ori_len]
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
.permute(6, 0, 3, 1, 4, 2, 5)
.reshape(self.out_channels, F, H, W)
)
cu_len += ori_len + pad_len
result.append(x_item) # Return only the last (target) image
return result
else:
# Original mode: simple unpatchify
for i in range(bsz):
F, H, W = size[i]
ori_len = (F // pF) * (H // pH) * (W // pW)
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
x[i] = (
x[i][:ori_len]
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
.permute(6, 0, 3, 1, 4, 2, 5)
.reshape(self.out_channels, F, H, W)
)
return x
@staticmethod @staticmethod
def create_coordinate_grid(size, start=None, device=None): def create_coordinate_grid(size, start=None, device=None):
@@ -405,8 +512,8 @@ class ZImageDiT(nn.Module):
self, self,
all_image: List[torch.Tensor], all_image: List[torch.Tensor],
all_cap_feats: List[torch.Tensor], all_cap_feats: List[torch.Tensor],
patch_size: int, patch_size: int = 2,
f_patch_size: int, f_patch_size: int = 1,
): ):
pH = pW = patch_size pH = pW = patch_size
pF = f_patch_size pF = f_patch_size
@@ -490,90 +597,487 @@ class ZImageDiT(nn.Module):
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
all_image_out.append(image_padded_feat) all_image_out.append(image_padded_feat)
return all_image_out, all_cap_feats_out, {
"x_size": all_image_size,
"x_pos_ids": all_image_pos_ids,
"cap_pos_ids": all_cap_pos_ids,
"x_pad_mask": all_image_pad_mask,
"cap_pad_mask": all_cap_pad_mask
}
# (
# all_img_out,
# all_cap_out,
# all_img_size,
# all_img_pos_ids,
# all_cap_pos_ids,
# all_img_pad_mask,
# all_cap_pad_mask,
# )
def patchify_controlnet(
self,
all_image: List[torch.Tensor],
patch_size: int = 2,
f_patch_size: int = 1,
cap_padding_len: int = None,
):
pH = pW = patch_size
pF = f_patch_size
device = all_image[0].device
all_image_out = []
all_image_size = []
all_image_pos_ids = []
all_image_pad_mask = []
for i, image in enumerate(all_image):
### Process Image
C, F, H, W = image.size()
all_image_size.append((F, H, W))
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
image_ori_len = len(image)
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
image_ori_pos_ids = self.create_coordinate_grid(
size=(F_tokens, H_tokens, W_tokens),
start=(cap_padding_len + 1, 0, 0),
device=device,
).flatten(0, 2)
image_padding_pos_ids = (
self.create_coordinate_grid(
size=(1, 1, 1),
start=(0, 0, 0),
device=device,
)
.flatten(0, 2)
.repeat(image_padding_len, 1)
)
image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
all_image_pos_ids.append(image_padded_pos_ids)
# pad mask
all_image_pad_mask.append(
torch.cat(
[
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
],
dim=0,
)
)
# padded feature
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
all_image_out.append(image_padded_feat)
return ( return (
all_image_out, all_image_out,
all_cap_feats_out,
all_image_size, all_image_size,
all_image_pos_ids, all_image_pos_ids,
all_cap_pos_ids,
all_image_pad_mask, all_image_pad_mask,
all_cap_pad_mask,
) )
def _prepare_sequence(
self,
feats: List[torch.Tensor],
pos_ids: List[torch.Tensor],
inner_pad_mask: List[torch.Tensor],
pad_token: torch.nn.Parameter,
noise_mask: Optional[List[List[int]]] = None,
device: torch.device = None,
):
"""Prepare sequence: apply pad token, RoPE embed, pad to batch, create attention mask."""
item_seqlens = [len(f) for f in feats]
max_seqlen = max(item_seqlens)
bsz = len(feats)
# Pad token
feats_cat = torch.cat(feats, dim=0)
feats_cat[torch.cat(inner_pad_mask)] = pad_token.to(dtype=feats_cat.dtype, device=feats_cat.device)
feats = list(feats_cat.split(item_seqlens, dim=0))
# RoPE
freqs_cis = list(self.rope_embedder(torch.cat(pos_ids, dim=0)).split([len(p) for p in pos_ids], dim=0))
# Pad to batch
feats = pad_sequence(feats, batch_first=True, padding_value=0.0)
freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]]
# Attention mask
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(item_seqlens):
attn_mask[i, :seq_len] = 1
# Noise mask
noise_mask_tensor = None
if noise_mask is not None:
noise_mask_tensor = pad_sequence(
[torch.tensor(m, dtype=torch.long, device=device) for m in noise_mask],
batch_first=True,
padding_value=0,
)[:, : feats.shape[1]]
return feats, freqs_cis, attn_mask, item_seqlens, noise_mask_tensor
def _build_unified_sequence(
self,
x: torch.Tensor,
x_freqs: torch.Tensor,
x_seqlens: List[int],
x_noise_mask: Optional[List[List[int]]],
cap: torch.Tensor,
cap_freqs: torch.Tensor,
cap_seqlens: List[int],
cap_noise_mask: Optional[List[List[int]]],
siglip: Optional[torch.Tensor],
siglip_freqs: Optional[torch.Tensor],
siglip_seqlens: Optional[List[int]],
siglip_noise_mask: Optional[List[List[int]]],
omni_mode: bool,
device: torch.device,
):
"""Build unified sequence: x, cap, and optionally siglip.
Basic mode order: [x, cap]; Omni mode order: [cap, x, siglip]
"""
bsz = len(x_seqlens)
unified = []
unified_freqs = []
unified_noise_mask = []
for i in range(bsz):
x_len, cap_len = x_seqlens[i], cap_seqlens[i]
if omni_mode:
# Omni: [cap, x, siglip]
if siglip is not None and siglip_seqlens is not None:
sig_len = siglip_seqlens[i]
unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len], siglip[i][:sig_len]]))
unified_freqs.append(
torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len], siglip_freqs[i][:sig_len]])
)
unified_noise_mask.append(
torch.tensor(
cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], dtype=torch.long, device=device
)
)
else:
unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len]]))
unified_freqs.append(torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len]]))
unified_noise_mask.append(
torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device)
)
else:
# Basic: [x, cap]
unified.append(torch.cat([x[i][:x_len], cap[i][:cap_len]]))
unified_freqs.append(torch.cat([x_freqs[i][:x_len], cap_freqs[i][:cap_len]]))
# Compute unified seqlens
if omni_mode:
if siglip is not None and siglip_seqlens is not None:
unified_seqlens = [a + b + c for a, b, c in zip(cap_seqlens, x_seqlens, siglip_seqlens)]
else:
unified_seqlens = [a + b for a, b in zip(cap_seqlens, x_seqlens)]
else:
unified_seqlens = [a + b for a, b in zip(x_seqlens, cap_seqlens)]
max_seqlen = max(unified_seqlens)
# Pad to batch
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0)
# Attention mask
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_seqlens):
attn_mask[i, :seq_len] = 1
# Noise mask
noise_mask_tensor = None
if omni_mode:
noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0)[
:, : unified.shape[1]
]
return unified, unified_freqs, attn_mask, noise_mask_tensor
def _pad_with_ids(
self,
feat: torch.Tensor,
pos_grid_size: Tuple,
pos_start: Tuple,
device: torch.device,
noise_mask_val: Optional[int] = None,
):
"""Pad feature to SEQ_MULTI_OF, create position IDs and pad mask."""
ori_len = len(feat)
pad_len = (-ori_len) % SEQ_MULTI_OF
total_len = ori_len + pad_len
# Pos IDs
ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2)
if pad_len > 0:
pad_pos_ids = (
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
.flatten(0, 2)
.repeat(pad_len, 1)
)
pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0)
padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0)
pad_mask = torch.cat(
[
torch.zeros(ori_len, dtype=torch.bool, device=device),
torch.ones(pad_len, dtype=torch.bool, device=device),
]
)
else:
pos_ids = ori_pos_ids
padded_feat = feat
pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device)
noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level
return padded_feat, pos_ids, pad_mask, total_len, noise_mask
def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int):
"""Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim)."""
pH, pW, pF = patch_size, patch_size, f_patch_size
C, F, H, W = image.size()
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
return image, (F, H, W), (F_tokens, H_tokens, W_tokens)
def patchify_and_embed_omni(
self,
all_x: List[List[torch.Tensor]],
all_cap_feats: List[List[torch.Tensor]],
all_siglip_feats: List[List[torch.Tensor]],
patch_size: int = 2,
f_patch_size: int = 1,
images_noise_mask: List[List[int]] = None,
):
"""Patchify for omni mode: multiple images per batch item with noise masks."""
bsz = len(all_x)
device = all_x[0][-1].device
dtype = all_x[0][-1].dtype
all_x_out, all_x_size, all_x_pos_ids, all_x_pad_mask, all_x_len, all_x_noise_mask = [], [], [], [], [], []
all_cap_out, all_cap_pos_ids, all_cap_pad_mask, all_cap_len, all_cap_noise_mask = [], [], [], [], []
all_sig_out, all_sig_pos_ids, all_sig_pad_mask, all_sig_len, all_sig_noise_mask = [], [], [], [], []
for i in range(bsz):
num_images = len(all_x[i])
cap_feats_list, cap_pos_list, cap_mask_list, cap_lens, cap_noise = [], [], [], [], []
cap_end_pos = []
cap_cu_len = 1
# Process captions
for j, cap_item in enumerate(all_cap_feats[i]):
noise_val = images_noise_mask[i][j] if j < len(images_noise_mask[i]) else 1
cap_out, cap_pos, cap_mask, cap_len, cap_nm = self._pad_with_ids(
cap_item,
(len(cap_item) + (-len(cap_item)) % SEQ_MULTI_OF, 1, 1),
(cap_cu_len, 0, 0),
device,
noise_val,
)
cap_feats_list.append(cap_out)
cap_pos_list.append(cap_pos)
cap_mask_list.append(cap_mask)
cap_lens.append(cap_len)
cap_noise.extend(cap_nm)
cap_cu_len += len(cap_item)
cap_end_pos.append(cap_cu_len)
cap_cu_len += 2 # for image vae and siglip tokens
all_cap_out.append(torch.cat(cap_feats_list, dim=0))
all_cap_pos_ids.append(torch.cat(cap_pos_list, dim=0))
all_cap_pad_mask.append(torch.cat(cap_mask_list, dim=0))
all_cap_len.append(cap_lens)
all_cap_noise_mask.append(cap_noise)
# Process images
x_feats_list, x_pos_list, x_mask_list, x_lens, x_size, x_noise = [], [], [], [], [], []
for j, x_item in enumerate(all_x[i]):
noise_val = images_noise_mask[i][j]
if x_item is not None:
x_patches, size, (F_t, H_t, W_t) = self._patchify_image(x_item, patch_size, f_patch_size)
x_out, x_pos, x_mask, x_len, x_nm = self._pad_with_ids(
x_patches, (F_t, H_t, W_t), (cap_end_pos[j], 0, 0), device, noise_val
)
x_size.append(size)
else:
x_len = SEQ_MULTI_OF
x_out = torch.zeros((x_len, X_PAD_DIM), dtype=dtype, device=device)
x_pos = self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(x_len, 1)
x_mask = torch.ones(x_len, dtype=torch.bool, device=device)
x_nm = [noise_val] * x_len
x_size.append(None)
x_feats_list.append(x_out)
x_pos_list.append(x_pos)
x_mask_list.append(x_mask)
x_lens.append(x_len)
x_noise.extend(x_nm)
all_x_out.append(torch.cat(x_feats_list, dim=0))
all_x_pos_ids.append(torch.cat(x_pos_list, dim=0))
all_x_pad_mask.append(torch.cat(x_mask_list, dim=0))
all_x_size.append(x_size)
all_x_len.append(x_lens)
all_x_noise_mask.append(x_noise)
# Process siglip
if all_siglip_feats[i] is None:
all_sig_len.append([0] * num_images)
all_sig_out.append(None)
else:
sig_feats_list, sig_pos_list, sig_mask_list, sig_lens, sig_noise = [], [], [], [], []
for j, sig_item in enumerate(all_siglip_feats[i]):
noise_val = images_noise_mask[i][j]
if sig_item is not None:
sig_H, sig_W, sig_C = sig_item.size()
sig_flat = sig_item.permute(2, 0, 1).reshape(sig_H * sig_W, sig_C)
sig_out, sig_pos, sig_mask, sig_len, sig_nm = self._pad_with_ids(
sig_flat, (1, sig_H, sig_W), (cap_end_pos[j] + 1, 0, 0), device, noise_val
)
# Scale position IDs to match x resolution
if x_size[j] is not None:
sig_pos = sig_pos.float()
sig_pos[..., 1] = sig_pos[..., 1] / max(sig_H - 1, 1) * (x_size[j][1] - 1)
sig_pos[..., 2] = sig_pos[..., 2] / max(sig_W - 1, 1) * (x_size[j][2] - 1)
sig_pos = sig_pos.to(torch.int32)
else:
sig_len = SEQ_MULTI_OF
sig_out = torch.zeros((sig_len, self.siglip_feat_dim), dtype=dtype, device=device)
sig_pos = (
self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(sig_len, 1)
)
sig_mask = torch.ones(sig_len, dtype=torch.bool, device=device)
sig_nm = [noise_val] * sig_len
sig_feats_list.append(sig_out)
sig_pos_list.append(sig_pos)
sig_mask_list.append(sig_mask)
sig_lens.append(sig_len)
sig_noise.extend(sig_nm)
all_sig_out.append(torch.cat(sig_feats_list, dim=0))
all_sig_pos_ids.append(torch.cat(sig_pos_list, dim=0))
all_sig_pad_mask.append(torch.cat(sig_mask_list, dim=0))
all_sig_len.append(sig_lens)
all_sig_noise_mask.append(sig_noise)
# Compute x position offsets
all_x_pos_offsets = [(sum(all_cap_len[i]), sum(all_cap_len[i]) + sum(all_x_len[i])) for i in range(bsz)]
return (
all_x_out,
all_cap_out,
all_sig_out,
all_x_size,
all_x_pos_ids,
all_cap_pos_ids,
all_sig_pos_ids,
all_x_pad_mask,
all_cap_pad_mask,
all_sig_pad_mask,
all_x_pos_offsets,
all_x_noise_mask,
all_cap_noise_mask,
all_sig_noise_mask,
)
return all_x_out, all_cap_out, all_sig_out, {
"x_size": x_size,
"x_pos_ids": all_x_pos_ids,
"cap_pos_ids": all_cap_pos_ids,
"sig_pos_ids": all_sig_pos_ids,
"x_pad_mask": all_x_pad_mask,
"cap_pad_mask": all_cap_pad_mask,
"sig_pad_mask": all_sig_pad_mask,
"x_pos_offsets": all_x_pos_offsets,
"x_noise_mask": all_x_noise_mask,
"cap_noise_mask": all_cap_noise_mask,
"sig_noise_mask": all_sig_noise_mask,
}
def forward( def forward(
self, self,
x: List[torch.Tensor], x: List[torch.Tensor],
t, t,
cap_feats: List[torch.Tensor], cap_feats: List[torch.Tensor],
siglip_feats = None,
image_noise_mask = None,
patch_size=2, patch_size=2,
f_patch_size=1, f_patch_size=1,
use_gradient_checkpointing=False, use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False, use_gradient_checkpointing_offload=False,
): ):
assert patch_size in self.all_patch_size assert patch_size in self.all_patch_size and f_patch_size in self.all_f_patch_size
assert f_patch_size in self.all_f_patch_size omni_mode = isinstance(x[0], list)
device = x[0][-1].device if omni_mode else x[0].device
bsz = len(x) if omni_mode:
device = x[0].device # Dual embeddings: noisy (t) and clean (t=1)
t = t * self.t_scale t_noisy = self.t_embedder(t * self.t_scale).type_as(x[0][-1])
t = self.t_embedder(t) t_clean = self.t_embedder(torch.ones_like(t) * self.t_scale).type_as(x[0][-1])
adaln_input = None
else:
# Single embedding for all tokens
adaln_input = self.t_embedder(t * self.t_scale).type_as(x[0])
t_noisy = t_clean = None
adaln_input = t # Patchify
if omni_mode:
( (
x, x,
cap_feats, cap_feats,
x_size, siglip_feats,
x_pos_ids, x_size,
cap_pos_ids, x_pos_ids,
x_inner_pad_mask, cap_pos_ids,
cap_inner_pad_mask, siglip_pos_ids,
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) x_pad_mask,
cap_pad_mask,
siglip_pad_mask,
x_pos_offsets,
x_noise_mask,
cap_noise_mask,
siglip_noise_mask,
) = self.patchify_and_embed_omni(x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask)
else:
(
x,
cap_feats,
x_size,
x_pos_ids,
cap_pos_ids,
x_pad_mask,
cap_pad_mask,
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
x_pos_offsets = x_noise_mask = cap_noise_mask = siglip_noise_mask = None
# x embed & refine # x embed & refine
x_item_seqlens = [len(_) for _ in x] x_seqlens = [len(xi) for xi in x]
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](torch.cat(x, dim=0)) # embed
x_max_item_seqlen = max(x_item_seqlens) x, x_freqs, x_mask, _, x_noise_tensor = self._prepare_sequence(
list(x.split(x_seqlens, dim=0)), x_pos_ids, x_pad_mask, self.x_pad_token, x_noise_mask, device
x = torch.cat(x, dim=0) )
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token.to(dtype=x.dtype, device=x.device)
x = list(x.split(x_item_seqlens, dim=0))
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
x = pad_sequence(x, batch_first=True, padding_value=0.0)
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(x_item_seqlens):
x_attn_mask[i, :seq_len] = 1
for layer in self.noise_refiner: for layer in self.noise_refiner:
x = gradient_checkpoint_forward( x = gradient_checkpoint_forward(
layer, layer,
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=x, x=x, attn_mask=x_mask, freqs_cis=x_freqs, adaln_input=adaln_input, noise_mask=x_noise_tensor, adaln_noisy=t_noisy, adaln_clean=t_clean,
attn_mask=x_attn_mask,
freqs_cis=x_freqs_cis,
adaln_input=adaln_input,
) )
# cap embed & refine # Cap embed & refine
cap_item_seqlens = [len(_) for _ in cap_feats] cap_seqlens = [len(ci) for ci in cap_feats]
assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens) cap_feats = self.cap_embedder(torch.cat(cap_feats, dim=0)) # embed
cap_max_item_seqlen = max(cap_item_seqlens) cap_feats, cap_freqs, cap_mask, _, _ = self._prepare_sequence(
list(cap_feats.split(cap_seqlens, dim=0)), cap_pos_ids, cap_pad_mask, self.cap_pad_token, None, device
cap_feats = torch.cat(cap_feats, dim=0) )
cap_feats = self.cap_embedder(cap_feats)
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token.to(dtype=x.dtype, device=x.device)
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(cap_item_seqlens):
cap_attn_mask[i, :seq_len] = 1
for layer in self.context_refiner: for layer in self.context_refiner:
cap_feats = gradient_checkpoint_forward( cap_feats = gradient_checkpoint_forward(
@@ -581,41 +1085,68 @@ class ZImageDiT(nn.Module):
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=cap_feats, x=cap_feats,
attn_mask=cap_attn_mask, attn_mask=cap_mask,
freqs_cis=cap_freqs_cis, freqs_cis=cap_freqs,
) )
# unified # Siglip embed & refine
unified = [] siglip_seqlens = siglip_freqs = None
unified_freqs_cis = [] if omni_mode and siglip_feats[0] is not None and self.siglip_embedder is not None:
for i in range(bsz): siglip_seqlens = [len(si) for si in siglip_feats]
x_len = x_item_seqlens[i] siglip_feats = self.siglip_embedder(torch.cat(siglip_feats, dim=0)) # embed
cap_len = cap_item_seqlens[i] siglip_feats, siglip_freqs, siglip_mask, _, _ = self._prepare_sequence(
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) list(siglip_feats.split(siglip_seqlens, dim=0)),
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) siglip_pos_ids,
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] siglip_pad_mask,
assert unified_item_seqlens == [len(_) for _ in unified] self.siglip_pad_token,
unified_max_item_seqlen = max(unified_item_seqlens) None,
device,
)
unified = pad_sequence(unified, batch_first=True, padding_value=0.0) for layer in self.siglip_refiner:
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) siglip_feats = gradient_checkpoint_forward(
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) layer,
for i, seq_len in enumerate(unified_item_seqlens): use_gradient_checkpointing=use_gradient_checkpointing,
unified_attn_mask[i, :seq_len] = 1 use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=siglip_feats, attn_mask=siglip_mask, freqs_cis=siglip_freqs,
)
for layer in self.layers: # Unified sequence
unified, unified_freqs, unified_mask, unified_noise_tensor = self._build_unified_sequence(
x,
x_freqs,
x_seqlens,
x_noise_mask,
cap_feats,
cap_freqs,
cap_seqlens,
cap_noise_mask,
siglip_feats,
siglip_freqs,
siglip_seqlens,
siglip_noise_mask,
omni_mode,
device,
)
# Main transformer layers
for layer_idx, layer in enumerate(self.layers):
unified = gradient_checkpoint_forward( unified = gradient_checkpoint_forward(
layer, layer,
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=unified, x=unified, attn_mask=unified_mask, freqs_cis=unified_freqs, adaln_input=adaln_input, noise_mask=unified_noise_tensor, adaln_noisy=t_noisy, adaln_clean=t_clean
attn_mask=unified_attn_mask,
freqs_cis=unified_freqs_cis,
adaln_input=adaln_input,
) )
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) unified = (
unified = list(unified.unbind(dim=0)) self.all_final_layer[f"{patch_size}-{f_patch_size}"](
x = self.unpatchify(unified, x_size, patch_size, f_patch_size) unified, noise_mask=unified_noise_tensor, c_noisy=t_noisy, c_clean=t_clean
)
if omni_mode
else self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, c=adaln_input)
)
return x, {} # Unpatchify
x = self.unpatchify(list(unified.unbind(dim=0)), x_size, patch_size, f_patch_size, x_pos_offsets)
return x

View File

@@ -0,0 +1,189 @@
import torch
from .qwen_image_image2lora import ImageEmbeddingToLoraMatrix, SequencialMLP
class LoRATrainerBlock(torch.nn.Module):
def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024, prefix="transformer_blocks"):
super().__init__()
self.prefix = prefix
self.lora_patterns = lora_patterns
self.block_id = block_id
self.layers = []
for name, lora_a_dim, lora_b_dim in self.lora_patterns:
self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank))
self.layers = torch.nn.ModuleList(self.layers)
if use_residual:
self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim)
else:
self.proj_residual = None
def forward(self, x, residual=None):
lora = {}
if self.proj_residual is not None: residual = self.proj_residual(residual)
for lora_pattern, layer in zip(self.lora_patterns, self.layers):
name = lora_pattern[0]
lora_a, lora_b = layer(x, residual=residual)
lora[f"{self.prefix}.{self.block_id}.{name}.lora_A.default.weight"] = lora_a
lora[f"{self.prefix}.{self.block_id}.{name}.lora_B.default.weight"] = lora_b
return lora
class ZImageImage2LoRAComponent(torch.nn.Module):
def __init__(self, lora_patterns, prefix, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024):
super().__init__()
self.lora_patterns = lora_patterns
self.num_blocks = num_blocks
self.blocks = []
for lora_patterns in self.lora_patterns:
for block_id in range(self.num_blocks):
self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim, prefix=prefix))
self.blocks = torch.nn.ModuleList(self.blocks)
self.residual_scale = 0.05
self.use_residual = use_residual
def forward(self, x, residual=None):
if residual is not None:
if self.use_residual:
residual = residual * self.residual_scale
else:
residual = None
lora = {}
for block in self.blocks:
lora.update(block(x, residual))
return lora
class ZImageImage2LoRAModel(torch.nn.Module):
def __init__(self, use_residual=False, compress_dim=64, rank=4, residual_length=64+7, residual_mid_dim=1024):
super().__init__()
lora_patterns = [
[
("attention.to_q", 3840, 3840),
("attention.to_k", 3840, 3840),
("attention.to_v", 3840, 3840),
("attention.to_out.0", 3840, 3840),
],
[
("feed_forward.w1", 3840, 10240),
("feed_forward.w2", 10240, 3840),
("feed_forward.w3", 3840, 10240),
],
]
config = {
"lora_patterns": lora_patterns,
"use_residual": use_residual,
"compress_dim": compress_dim,
"rank": rank,
"residual_length": residual_length,
"residual_mid_dim": residual_mid_dim,
}
self.layers_lora = ZImageImage2LoRAComponent(
prefix="layers",
num_blocks=30,
**config,
)
self.context_refiner_lora = ZImageImage2LoRAComponent(
prefix="context_refiner",
num_blocks=2,
**config,
)
self.noise_refiner_lora = ZImageImage2LoRAComponent(
prefix="noise_refiner",
num_blocks=2,
**config,
)
def forward(self, x, residual=None):
lora = {}
lora.update(self.layers_lora(x, residual=residual))
lora.update(self.context_refiner_lora(x, residual=residual))
lora.update(self.noise_refiner_lora(x, residual=residual))
return lora
def initialize_weights(self):
state_dict = self.state_dict()
for name in state_dict:
if ".proj_a." in name:
state_dict[name] = state_dict[name] * 0.3
elif ".proj_b.proj_out." in name:
state_dict[name] = state_dict[name] * 0
elif ".proj_residual.proj_out." in name:
state_dict[name] = state_dict[name] * 0.3
self.load_state_dict(state_dict)
class ImageEmb2LoRAWeightCompressed(torch.nn.Module):
def __init__(self, in_dim, out_dim, emb_dim, rank):
super().__init__()
self.lora_a = torch.nn.Parameter(torch.randn((rank, in_dim)))
self.lora_b = torch.nn.Parameter(torch.randn((out_dim, rank)))
self.proj = torch.nn.Linear(emb_dim, rank * rank, bias=True)
self.rank = rank
def forward(self, x):
x = self.proj(x).view(self.rank, self.rank)
lora_a = x @ self.lora_a
lora_b = self.lora_b
return lora_a, lora_b
class ZImageImage2LoRAModelCompressed(torch.nn.Module):
def __init__(self, emb_dim=1536+4096, rank=32):
super().__init__()
target_layers = [
("attention.to_q", 3840, 3840),
("attention.to_k", 3840, 3840),
("attention.to_v", 3840, 3840),
("attention.to_out.0", 3840, 3840),
("feed_forward.w1", 3840, 10240),
("feed_forward.w2", 10240, 3840),
("feed_forward.w3", 3840, 10240),
]
self.lora_patterns = [
{
"prefix": "layers",
"num_layers": 30,
"target_layers": target_layers,
},
{
"prefix": "context_refiner",
"num_layers": 2,
"target_layers": target_layers,
},
{
"prefix": "noise_refiner",
"num_layers": 2,
"target_layers": target_layers,
},
]
module_dict = {}
for lora_pattern in self.lora_patterns:
prefix, num_layers, target_layers = lora_pattern["prefix"], lora_pattern["num_layers"], lora_pattern["target_layers"]
for layer_id in range(num_layers):
for layer_name, in_dim, out_dim in target_layers:
name = f"{prefix}.{layer_id}.{layer_name}".replace(".", "___")
model = ImageEmb2LoRAWeightCompressed(in_dim, out_dim, emb_dim, rank)
module_dict[name] = model
self.module_dict = torch.nn.ModuleDict(module_dict)
def forward(self, x, residual=None):
lora = {}
for name, module in self.module_dict.items():
name = name.replace("___", ".")
name_a, name_b = f"{name}.lora_A.default.weight", f"{name}.lora_B.default.weight"
lora_a, lora_b = module(x)
lora[name_a] = lora_a
lora[name_b] = lora_b
return lora
def initialize_weights(self):
state_dict = self.state_dict()
for name in state_dict:
if "lora_b" in name:
state_dict[name] = state_dict[name] * 0
elif "lora_a" in name:
state_dict[name] = state_dict[name] * 0.2
elif "proj.weight" in name:
print(name)
state_dict[name] = state_dict[name] * 0.2
self.load_state_dict(state_dict)

View File

@@ -3,38 +3,101 @@ import torch
class ZImageTextEncoder(torch.nn.Module): class ZImageTextEncoder(torch.nn.Module):
def __init__(self): def __init__(self, model_size="4B"):
super().__init__() super().__init__()
config = Qwen3Config(**{ config_dict = {
"architectures": [ "0.6B": Qwen3Config(**{
"Qwen3ForCausalLM" "architectures": [
], "Qwen3ForCausalLM"
"attention_bias": False, ],
"attention_dropout": 0.0, "attention_bias": False,
"bos_token_id": 151643, "attention_dropout": 0.0,
"eos_token_id": 151645, "bos_token_id": 151643,
"head_dim": 128, "eos_token_id": 151645,
"hidden_act": "silu", "head_dim": 128,
"hidden_size": 2560, "hidden_act": "silu",
"initializer_range": 0.02, "hidden_size": 1024,
"intermediate_size": 9728, "initializer_range": 0.02,
"max_position_embeddings": 40960, "intermediate_size": 3072,
"max_window_layers": 36, "max_position_embeddings": 40960,
"model_type": "qwen3", "max_window_layers": 28,
"num_attention_heads": 32, "model_type": "qwen3",
"num_hidden_layers": 36, "num_attention_heads": 16,
"num_key_value_heads": 8, "num_hidden_layers": 28,
"rms_norm_eps": 1e-06, "num_key_value_heads": 8,
"rope_scaling": None, "rms_norm_eps": 1e-06,
"rope_theta": 1000000, "rope_scaling": None,
"sliding_window": None, "rope_theta": 1000000,
"tie_word_embeddings": True, "sliding_window": None,
"torch_dtype": "bfloat16", "tie_word_embeddings": True,
"transformers_version": "4.51.0", "torch_dtype": "bfloat16",
"use_cache": True, "transformers_version": "4.51.0",
"use_sliding_window": False, "use_cache": True,
"vocab_size": 151936 "use_sliding_window": False,
}) "vocab_size": 151936
}),
"4B": Qwen3Config(**{
"architectures": [
"Qwen3ForCausalLM"
],
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 2560,
"initializer_range": 0.02,
"intermediate_size": 9728,
"max_position_embeddings": 40960,
"max_window_layers": 36,
"model_type": "qwen3",
"num_attention_heads": 32,
"num_hidden_layers": 36,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_scaling": None,
"rope_theta": 1000000,
"sliding_window": None,
"tie_word_embeddings": True,
"torch_dtype": "bfloat16",
"transformers_version": "4.51.0",
"use_cache": True,
"use_sliding_window": False,
"vocab_size": 151936
}),
"8B": Qwen3Config(**{
"architectures": [
"Qwen3ForCausalLM"
],
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"dtype": "bfloat16",
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 12288,
"max_position_embeddings": 40960,
"max_window_layers": 36,
"model_type": "qwen3",
"num_attention_heads": 32,
"num_hidden_layers": 36,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_scaling": None,
"rope_theta": 1000000,
"sliding_window": None,
"tie_word_embeddings": False,
"transformers_version": "4.56.1",
"use_cache": True,
"use_sliding_window": False,
"vocab_size": 151936
})
}
config = config_dict[model_size]
self.model = Qwen3Model(config) self.model = Qwen3Model(config)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):

View File

@@ -1,4 +1,4 @@
import torch, math import torch, math, torchvision
from PIL import Image from PIL import Image
from typing import Union from typing import Union
from tqdm import tqdm from tqdm import tqdm
@@ -6,25 +6,28 @@ from einops import rearrange
import numpy as np import numpy as np
from typing import Union, List, Optional, Tuple from typing import Union, List, Optional, Tuple
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
from transformers import AutoProcessor from transformers import AutoProcessor, AutoTokenizer
from ..models.flux2_text_encoder import Flux2TextEncoder from ..models.flux2_text_encoder import Flux2TextEncoder
from ..models.flux2_dit import Flux2DiT from ..models.flux2_dit import Flux2DiT
from ..models.flux2_vae import Flux2VAE from ..models.flux2_vae import Flux2VAE
from ..models.z_image_text_encoder import ZImageTextEncoder
class Flux2ImagePipeline(BasePipeline): class Flux2ImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16): def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__( super().__init__(
device=device, torch_dtype=torch_dtype, device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, height_division_factor=16, width_division_factor=16,
) )
self.scheduler = FlowMatchScheduler("FLUX.2") self.scheduler = FlowMatchScheduler("FLUX.2")
self.text_encoder: Flux2TextEncoder = None self.text_encoder: Flux2TextEncoder = None
self.text_encoder_qwen3: ZImageTextEncoder = None
self.dit: Flux2DiT = None self.dit: Flux2DiT = None
self.vae: Flux2VAE = None self.vae: Flux2VAE = None
self.tokenizer: AutoProcessor = None self.tokenizer: AutoProcessor = None
@@ -32,8 +35,10 @@ class Flux2ImagePipeline(BasePipeline):
self.units = [ self.units = [
Flux2Unit_ShapeChecker(), Flux2Unit_ShapeChecker(),
Flux2Unit_PromptEmbedder(), Flux2Unit_PromptEmbedder(),
Flux2Unit_Qwen3PromptEmbedder(),
Flux2Unit_NoiseInitializer(), Flux2Unit_NoiseInitializer(),
Flux2Unit_InputImageEmbedder(), Flux2Unit_InputImageEmbedder(),
Flux2Unit_EditImageEmbedder(),
Flux2Unit_ImageIDs(), Flux2Unit_ImageIDs(),
] ]
self.model_fn = model_fn_flux2 self.model_fn = model_fn_flux2
@@ -42,7 +47,7 @@ class Flux2ImagePipeline(BasePipeline):
@staticmethod @staticmethod
def from_pretrained( def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16, torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda", device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [], model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"), tokenizer_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
vram_limit: float = None, vram_limit: float = None,
@@ -53,11 +58,12 @@ class Flux2ImagePipeline(BasePipeline):
# Fetch models # Fetch models
pipe.text_encoder = model_pool.fetch_model("flux2_text_encoder") pipe.text_encoder = model_pool.fetch_model("flux2_text_encoder")
pipe.text_encoder_qwen3 = model_pool.fetch_model("z_image_text_encoder")
pipe.dit = model_pool.fetch_model("flux2_dit") pipe.dit = model_pool.fetch_model("flux2_dit")
pipe.vae = model_pool.fetch_model("flux2_vae") pipe.vae = model_pool.fetch_model("flux2_vae")
if tokenizer_config is not None: if tokenizer_config is not None:
tokenizer_config.download_if_necessary() tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoProcessor.from_pretrained(tokenizer_config.path) pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
# VRAM Management # VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state() pipe.vram_management_enabled = pipe.check_vram_management_state()
@@ -75,6 +81,9 @@ class Flux2ImagePipeline(BasePipeline):
# Image # Image
input_image: Image.Image = None, input_image: Image.Image = None,
denoising_strength: float = 1.0, denoising_strength: float = 1.0,
# Edit
edit_image: Union[Image.Image, List[Image.Image]] = None,
edit_image_auto_resize: bool = True,
# Shape # Shape
height: int = 1024, height: int = 1024,
width: int = 1024, width: int = 1024,
@@ -98,6 +107,7 @@ class Flux2ImagePipeline(BasePipeline):
inputs_shared = { inputs_shared = {
"cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance, "cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance,
"input_image": input_image, "denoising_strength": denoising_strength, "input_image": input_image, "denoising_strength": denoising_strength,
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
"height": height, "width": width, "height": height, "width": width,
"seed": seed, "rand_device": rand_device, "seed": seed, "rand_device": rand_device,
"num_inference_steps": num_inference_steps, "num_inference_steps": num_inference_steps,
@@ -275,6 +285,10 @@ class Flux2Unit_PromptEmbedder(PipelineUnit):
return prompt_embeds, text_ids return prompt_embeds, text_ids
def process(self, pipe: Flux2ImagePipeline, prompt): def process(self, pipe: Flux2ImagePipeline, prompt):
# Skip if Qwen3 text encoder is available (handled by Qwen3PromptEmbedder)
if pipe.text_encoder_qwen3 is not None:
return {}
pipe.load_models_to_device(self.onload_model_names) pipe.load_models_to_device(self.onload_model_names)
prompt_embeds, text_ids = self.encode_prompt( prompt_embeds, text_ids = self.encode_prompt(
pipe.text_encoder, pipe.tokenizer, prompt, pipe.text_encoder, pipe.tokenizer, prompt,
@@ -283,6 +297,136 @@ class Flux2Unit_PromptEmbedder(PipelineUnit):
return {"prompt_embeds": prompt_embeds, "text_ids": text_ids} return {"prompt_embeds": prompt_embeds, "text_ids": text_ids}
class Flux2Unit_Qwen3PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_emb", "prompt_emb_mask"),
onload_model_names=("text_encoder_qwen3",)
)
self.hidden_states_layers = (9, 18, 27) # Qwen3 layers
def get_qwen3_prompt_embeds(
self,
text_encoder: ZImageTextEncoder,
tokenizer: AutoTokenizer,
prompt: Union[str, List[str]],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
):
dtype = text_encoder.dtype if dtype is None else dtype
device = text_encoder.device if device is None else device
prompt = [prompt] if isinstance(prompt, str) else prompt
all_input_ids = []
all_attention_masks = []
for single_prompt in prompt:
messages = [{"role": "user", "content": single_prompt}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_sequence_length,
)
all_input_ids.append(inputs["input_ids"])
all_attention_masks.append(inputs["attention_mask"])
input_ids = torch.cat(all_input_ids, dim=0).to(device)
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
# Forward pass through the model
with torch.inference_mode():
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
# Only use outputs from intermediate layers and stack them
out = torch.stack([output.hidden_states[k] for k in self.hidden_states_layers], dim=1)
out = out.to(dtype=dtype, device=device)
batch_size, num_channels, seq_len, hidden_dim = out.shape
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
return prompt_embeds
def prepare_text_ids(
self,
x: torch.Tensor, # (B, L, D) or (L, D)
t_coord: Optional[torch.Tensor] = None,
):
B, L, _ = x.shape
out_ids = []
for i in range(B):
t = torch.arange(1) if t_coord is None else t_coord[i]
h = torch.arange(1)
w = torch.arange(1)
l = torch.arange(L)
coords = torch.cartesian_prod(t, h, w, l)
out_ids.append(coords)
return torch.stack(out_ids)
def encode_prompt(
self,
text_encoder: ZImageTextEncoder,
tokenizer: AutoTokenizer,
prompt: Union[str, List[str]],
dtype = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 512,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_embeds = self.get_qwen3_prompt_embeds(
text_encoder=text_encoder,
tokenizer=tokenizer,
prompt=prompt,
dtype=dtype,
device=device,
max_sequence_length=max_sequence_length,
)
batch_size, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
text_ids = self.prepare_text_ids(prompt_embeds)
text_ids = text_ids.to(device)
return prompt_embeds, text_ids
def process(self, pipe: Flux2ImagePipeline, prompt):
# Check if Qwen3 text encoder is available
if pipe.text_encoder_qwen3 is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
prompt_embeds, text_ids = self.encode_prompt(
pipe.text_encoder_qwen3, pipe.tokenizer, prompt,
dtype=pipe.torch_dtype, device=pipe.device,
)
return {"prompt_embeds": prompt_embeds, "text_ids": text_ids}
class Flux2Unit_NoiseInitializer(PipelineUnit): class Flux2Unit_NoiseInitializer(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
@@ -318,6 +462,75 @@ class Flux2Unit_InputImageEmbedder(PipelineUnit):
return {"latents": latents, "input_latents": input_latents} return {"latents": latents, "input_latents": input_latents}
class Flux2Unit_EditImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image", "edit_image_auto_resize"),
output_params=("edit_latents", "edit_image_ids"),
onload_model_names=("vae",)
)
def calculate_dimensions(self, target_area, ratio):
import math
width = math.sqrt(target_area * ratio)
height = width / ratio
width = round(width / 32) * 32
height = round(height / 32) * 32
return width, height
def crop_and_resize(self, image, target_height, target_width):
width, height = image.size
scale = max(target_width / width, target_height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
return image
def edit_image_auto_resize(self, edit_image):
calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1])
return self.crop_and_resize(edit_image, calculated_height, calculated_width)
def process_image_ids(self, image_latents, scale=10):
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
t_coords = [t.view(-1) for t in t_coords]
image_latent_ids = []
for x, t in zip(image_latents, t_coords):
x = x.squeeze(0)
_, height, width = x.shape
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
image_latent_ids.append(x_ids)
image_latent_ids = torch.cat(image_latent_ids, dim=0)
image_latent_ids = image_latent_ids.unsqueeze(0)
return image_latent_ids
def process(self, pipe: Flux2ImagePipeline, edit_image, edit_image_auto_resize):
if edit_image is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
if isinstance(edit_image, Image.Image):
edit_image = [edit_image]
resized_edit_image, edit_latents = [], []
for image in edit_image:
# Preprocess
if edit_image_auto_resize is None or edit_image_auto_resize:
image = self.edit_image_auto_resize(image)
resized_edit_image.append(image)
# Encode
image = pipe.preprocess_image(image)
latents = pipe.vae.encode(image)
edit_latents.append(latents)
edit_image_ids = self.process_image_ids(edit_latents).to(pipe.device)
edit_latents = torch.concat([rearrange(latents, "B C H W -> B (H W) C") for latents in edit_latents], dim=1)
return {"edit_latents": edit_latents, "edit_image_ids": edit_image_ids}
class Flux2Unit_ImageIDs(PipelineUnit): class Flux2Unit_ImageIDs(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
@@ -352,10 +565,17 @@ def model_fn_flux2(
prompt_embeds=None, prompt_embeds=None,
text_ids=None, text_ids=None,
image_ids=None, image_ids=None,
edit_latents=None,
edit_image_ids=None,
use_gradient_checkpointing=False, use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False, use_gradient_checkpointing_offload=False,
**kwargs, **kwargs,
): ):
image_seq_len = latents.shape[1]
if edit_latents is not None:
image_seq_len = latents.shape[1]
latents = torch.concat([latents, edit_latents], dim=1)
image_ids = torch.concat([image_ids, edit_image_ids], dim=1)
embedded_guidance = torch.tensor([embedded_guidance], device=latents.device) embedded_guidance = torch.tensor([embedded_guidance], device=latents.device)
model_output = dit( model_output = dit(
hidden_states=latents, hidden_states=latents,
@@ -367,4 +587,5 @@ def model_fn_flux2(
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
) )
model_output = model_output[:, :image_seq_len]
return model_output return model_output

View File

@@ -6,6 +6,7 @@ from einops import rearrange, repeat
import numpy as np import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast from transformers import CLIPTokenizer, T5TokenizerFast
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward, load_state_dict from ..core import ModelConfig, gradient_checkpoint_forward, load_state_dict
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
@@ -55,7 +56,7 @@ class MultiControlNet(torch.nn.Module):
class FluxImagePipeline(BasePipeline): class FluxImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16): def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__( super().__init__(
device=device, torch_dtype=torch_dtype, device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, height_division_factor=16, width_division_factor=16,
@@ -117,7 +118,7 @@ class FluxImagePipeline(BasePipeline):
@staticmethod @staticmethod
def from_pretrained( def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16, torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda", device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [], model_configs: list[ModelConfig] = [],
tokenizer_1_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/"), tokenizer_1_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/"),
tokenizer_2_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/"), tokenizer_2_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/"),
@@ -377,7 +378,7 @@ class FluxImageUnit_PromptEmbedder(PipelineUnit):
text_encoder_2, text_encoder_2,
prompt, prompt,
positive=True, positive=True,
device="cuda", device=get_device_type(),
t5_sequence_length=512, t5_sequence_length=512,
): ):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device) pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)
@@ -558,7 +559,7 @@ class FluxImageUnit_EntityControl(PipelineUnit):
text_encoder_2, text_encoder_2,
prompt, prompt,
positive=True, positive=True,
device="cuda", device=get_device_type(),
t5_sequence_length=512, t5_sequence_length=512,
): ):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device) pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)
@@ -793,7 +794,7 @@ class FluxImageUnit_ValueControl(PipelineUnit):
class InfinitYou(torch.nn.Module): class InfinitYou(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16): def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__() super().__init__()
from facexlib.recognition import init_recognition_model from facexlib.recognition import init_recognition_model
from insightface.app import FaceAnalysis from insightface.app import FaceAnalysis

View File

@@ -0,0 +1,550 @@
import torch, types
import numpy as np
from PIL import Image
from einops import repeat
from typing import Optional, Union
from einops import rearrange
import numpy as np
from PIL import Image
from tqdm import tqdm
from typing import Optional
from transformers import AutoImageProcessor, Gemma3Processor
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer
from ..models.ltx2_dit import LTXModel
from ..models.ltx2_video_vae import LTX2VideoEncoder, LTX2VideoDecoder, VideoLatentPatchifier
from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier
from ..models.ltx2_upsampler import LTX2LatentUpsampler
from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS
from ..utils.data.media_io_ltx2 import ltx2_preprocess
class LTX2AudioVideoPipeline(BasePipeline):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device,
torch_dtype=torch_dtype,
height_division_factor=32,
width_division_factor=32,
time_division_factor=8,
time_division_remainder=1,
)
self.scheduler = FlowMatchScheduler("LTX-2")
self.text_encoder: LTX2TextEncoder = None
self.tokenizer: LTXVGemmaTokenizer = None
self.processor: Gemma3Processor = None
self.text_encoder_post_modules: LTX2TextEncoderPostModules = None
self.dit: LTXModel = None
self.video_vae_encoder: LTX2VideoEncoder = None
self.video_vae_decoder: LTX2VideoDecoder = None
self.audio_vae_encoder: LTX2AudioEncoder = None
self.audio_vae_decoder: LTX2AudioDecoder = None
self.audio_vocoder: LTX2Vocoder = None
self.upsampler: LTX2LatentUpsampler = None
self.video_patchifier: VideoLatentPatchifier = VideoLatentPatchifier(patch_size=1)
self.audio_patchifier: AudioPatchifier = AudioPatchifier(patch_size=1)
self.in_iteration_models = ("dit",)
self.units = [
LTX2AudioVideoUnit_PipelineChecker(),
LTX2AudioVideoUnit_ShapeChecker(),
LTX2AudioVideoUnit_PromptEmbedder(),
LTX2AudioVideoUnit_NoiseInitializer(),
LTX2AudioVideoUnit_InputVideoEmbedder(),
LTX2AudioVideoUnit_InputImagesEmbedder(),
]
self.model_fn = model_fn_ltx2
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
stage2_lora_config: Optional[ModelConfig] = None,
vram_limit: float = None,
):
# Initialize pipeline
pipe = LTX2AudioVideoPipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
pipe.text_encoder = model_pool.fetch_model("ltx2_text_encoder")
tokenizer_config.download_if_necessary()
pipe.tokenizer = LTXVGemmaTokenizer(tokenizer_path=tokenizer_config.path)
image_processor = AutoImageProcessor.from_pretrained(tokenizer_config.path, local_files_only=True)
pipe.processor = Gemma3Processor(image_processor=image_processor, tokenizer=pipe.tokenizer.tokenizer)
pipe.text_encoder_post_modules = model_pool.fetch_model("ltx2_text_encoder_post_modules")
pipe.dit = model_pool.fetch_model("ltx2_dit")
pipe.video_vae_encoder = model_pool.fetch_model("ltx2_video_vae_encoder")
pipe.video_vae_decoder = model_pool.fetch_model("ltx2_video_vae_decoder")
pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder")
pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder")
pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler")
# Stage 2
if stage2_lora_config is not None:
stage2_lora_config.download_if_necessary()
pipe.stage2_lora_path = stage2_lora_config.path
# Optional, currently not used
# pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd=tqdm):
if inputs_shared["use_two_stage_pipeline"]:
latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
self.load_models_to_device('upsampler',)
latent = self.upsampler(latent)
latent = self.video_vae_encoder.per_channel_statistics.normalize(latent)
self.scheduler.set_timesteps(special_case="stage2")
inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")})
denoise_mask_video = 1.0
if inputs_shared.get("input_images", None) is not None:
latent, denoise_mask_video, initial_latents = self.apply_input_images_to_latents(
latent, inputs_shared.pop("input_latents"), inputs_shared["input_images_indexes"],
inputs_shared["input_images_strength"], latent.clone())
inputs_shared.update({"input_latents_video": initial_latents, "denoise_mask_video": denoise_mask_video})
inputs_shared["video_latents"] = self.scheduler.sigmas[0] * denoise_mask_video * inputs_shared[
"video_noise"] + (1 - self.scheduler.sigmas[0] * denoise_mask_video) * latent
inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (
1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"]
self.load_models_to_device(self.in_iteration_models)
if not inputs_shared["use_distilled_pipeline"]:
self.load_lora(self.dit, self.stage2_lora_path, alpha=0.8)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
self.model_fn, 1.0, inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id,
noise_pred=noise_pred_video, inpaint_mask=inputs_shared.get("denoise_mask_video", None),
input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
noise_pred=noise_pred_audio, **inputs_shared)
return inputs_shared
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: Optional[str] = "",
# Image-to-video
denoising_strength: float = 1.0,
input_images: Optional[list[Image.Image]] = None,
input_images_indexes: Optional[list[int]] = None,
input_images_strength: Optional[float] = 1.0,
# Randomness
seed: Optional[int] = None,
rand_device: Optional[str] = "cpu",
# Shape
height: Optional[int] = 512,
width: Optional[int] = 768,
num_frames=121,
# Classifier-free guidance
cfg_scale: Optional[float] = 3.0,
cfg_merge: Optional[bool] = False,
# Scheduler
num_inference_steps: Optional[int] = 40,
# VAE tiling
tiled: Optional[bool] = True,
tile_size_in_pixels: Optional[int] = 512,
tile_overlap_in_pixels: Optional[int] = 128,
tile_size_in_frames: Optional[int] = 128,
tile_overlap_in_frames: Optional[int] = 24,
# Special Pipelines
use_two_stage_pipeline: Optional[bool] = False,
use_distilled_pipeline: Optional[bool] = False,
# progress_bar
progress_bar_cmd=tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength,
special_case="ditilled_stage1" if use_distilled_pipeline else None)
# Inputs
inputs_posi = {
"prompt": prompt,
}
inputs_nega = {
"negative_prompt": negative_prompt,
}
inputs_shared = {
"input_images": input_images, "input_images_indexes": input_images_indexes, "input_images_strength": input_images_strength,
"seed": seed, "rand_device": rand_device,
"height": height, "width": width, "num_frames": num_frames,
"cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline,
"video_patchifier": self.video_patchifier, "audio_patchifier": self.audio_patchifier,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Denoise Stage 1
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video,
inpaint_mask=inputs_shared.get("denoise_mask_video", None), input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
noise_pred=noise_pred_audio, **inputs_shared)
# Denoise Stage 2
inputs_shared = self.stage2_denoise(inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd)
# Decode
self.load_models_to_device(['video_vae_decoder'])
video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels,
tile_overlap_in_pixels, tile_size_in_frames, tile_overlap_in_frames)
video = self.vae_output_to_video(video)
self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder'])
decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"])
decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float()
return video, decoded_audio
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength, initial_latents=None, num_frames=121):
b, _, f, h, w = latents.shape
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device)
initial_latents = torch.zeros_like(latents) if initial_latents is None else initial_latents
for idx, input_latent in zip(input_indexes, input_latents):
idx = min(max(1 + (idx-1) // 8, 0), f - 1)
input_latent = input_latent.to(dtype=latents.dtype, device=latents.device)
initial_latents[:, :, idx:idx + input_latent.shape[2], :, :] = input_latent
denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength
latents = latents * denoise_mask + initial_latents * (1.0 - denoise_mask)
return latents, denoise_mask, initial_latents
class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
input_params=("use_distilled_pipeline", "use_two_stage_pipeline"),
output_params=("use_two_stage_pipeline", "cfg_scale")
)
def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
if inputs_shared.get("use_distilled_pipeline", False):
inputs_shared["use_two_stage_pipeline"] = True
inputs_shared["cfg_scale"] = 1.0
print(f"Distilled pipeline requested, setting use_two_stage_pipeline to True, disable CFG by setting cfg_scale to 1.0.")
if inputs_shared.get("use_two_stage_pipeline", False):
# distill pipeline also uses two-stage, but it does not needs lora
if not inputs_shared.get("use_distilled_pipeline", False):
if not (hasattr(pipe, "stage2_lora_path") and pipe.stage2_lora_path is not None):
raise ValueError("Two-stage pipeline requested, but stage2_lora_path is not set in the pipeline.")
if not (hasattr(pipe, "upsampler") and pipe.upsampler is not None):
raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.")
return inputs_shared, inputs_posi, inputs_nega
class LTX2AudioVideoUnit_ShapeChecker(PipelineUnit):
"""
For two-stage pipelines, the resolution must be divisible by 64.
For one-stage pipelines, the resolution must be divisible by 32.
"""
def __init__(self):
super().__init__(
input_params=("height", "width", "num_frames"),
output_params=("height", "width", "num_frames"),
)
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, use_two_stage_pipeline=False):
if use_two_stage_pipeline:
self.width_division_factor = 64
self.height_division_factor = 64
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
if use_two_stage_pipeline:
self.width_division_factor = 32
self.height_division_factor = 32
return {"height": height, "width": width, "num_frames": num_frames}
class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("video_context", "audio_context"),
onload_model_names=("text_encoder", "text_encoder_post_modules"),
)
def _convert_to_additive_mask(self, attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return (attention_mask - 1).to(dtype).reshape(
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(dtype).max
def _run_connectors(self, pipe, encoded_input: torch.Tensor,
attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
connector_attention_mask = self._convert_to_additive_mask(attention_mask, encoded_input.dtype)
encoded, encoded_connector_attention_mask = pipe.text_encoder_post_modules.embeddings_connector(
encoded_input,
connector_attention_mask,
)
# restore the mask values to int64
attention_mask = (encoded_connector_attention_mask < 0.000001).to(torch.int64)
attention_mask = attention_mask.reshape([encoded.shape[0], encoded.shape[1], 1])
encoded = encoded * attention_mask
encoded_for_audio, _ = pipe.text_encoder_post_modules.audio_embeddings_connector(
encoded_input, connector_attention_mask)
return encoded, encoded_for_audio, attention_mask.squeeze(-1)
def _norm_and_concat_padded_batch(
self,
encoded_text: torch.Tensor,
sequence_lengths: torch.Tensor,
padding_side: str = "right",
) -> torch.Tensor:
"""Normalize and flatten multi-layer hidden states, respecting padding.
Performs per-batch, per-layer normalization using masked mean and range,
then concatenates across the layer dimension.
Args:
encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].
sequence_lengths: Number of valid (non-padded) tokens per batch item.
padding_side: Whether padding is on "left" or "right".
Returns:
Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],
with padded positions zeroed out.
"""
b, t, d, l = encoded_text.shape # noqa: E741
device = encoded_text.device
# Build mask: [B, T, 1, 1]
token_indices = torch.arange(t, device=device)[None, :] # [1, T]
if padding_side == "right":
# For right padding, valid tokens are from 0 to sequence_length-1
mask = token_indices < sequence_lengths[:, None] # [B, T]
elif padding_side == "left":
# For left padding, valid tokens are from (T - sequence_length) to T-1
start_indices = t - sequence_lengths[:, None] # [B, 1]
mask = token_indices >= start_indices # [B, T]
else:
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
mask = rearrange(mask, "b t -> b t 1 1")
eps = 1e-6
# Compute masked mean: [B, 1, 1, L]
masked = encoded_text.masked_fill(~mask, 0.0)
denom = (sequence_lengths * d).view(b, 1, 1, 1)
mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps)
# Compute masked min/max: [B, 1, 1, L]
x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
range_ = x_max - x_min
# Normalize only the valid tokens
normed = 8 * (encoded_text - mean) / (range_ + eps)
# concat to be [Batch, T, D * L] - this preserves the original structure
normed = normed.reshape(b, t, -1) # [B, T, D * L]
# Apply mask to preserve original padding (set padded positions to 0)
mask_flattened = rearrange(mask, "b t 1 1 -> b t 1").expand(-1, -1, d * l)
normed = normed.masked_fill(~mask_flattened, 0.0)
return normed
def _run_feature_extractor(self,
pipe,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
padding_side: str = "right") -> torch.Tensor:
encoded_text_features = torch.stack(hidden_states, dim=-1)
encoded_text_features_dtype = encoded_text_features.dtype
sequence_lengths = attention_mask.sum(dim=-1)
normed_concated_encoded_text_features = self._norm_and_concat_padded_batch(encoded_text_features,
sequence_lengths,
padding_side=padding_side)
return pipe.text_encoder_post_modules.feature_extractor_linear(
normed_concated_encoded_text_features.to(encoded_text_features_dtype))
def _preprocess_text(
self,
pipe,
text: str,
padding_side: str = "left",
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""
Encode a given string into feature tensors suitable for downstream tasks.
Args:
text (str): Input string to encode.
Returns:
tuple[torch.Tensor, dict[str, torch.Tensor]]: Encoded features and a dictionary with attention mask.
"""
token_pairs = pipe.tokenizer.tokenize_with_weights(text)["gemma"]
input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.device)
attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.device)
outputs = pipe.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
projected = self._run_feature_extractor(pipe,
hidden_states=outputs.hidden_states,
attention_mask=attention_mask,
padding_side=padding_side)
return projected, attention_mask
def encode_prompt(self, pipe, text, padding_side="left"):
encoded_inputs, attention_mask = self._preprocess_text(pipe, text, padding_side)
video_encoding, audio_encoding, attention_mask = self._run_connectors(pipe, encoded_inputs, attention_mask)
return video_encoding, audio_encoding, attention_mask
def process(self, pipe: LTX2AudioVideoPipeline, prompt: str):
pipe.load_models_to_device(self.onload_model_names)
video_context, audio_context, _ = self.encode_prompt(pipe, prompt)
return {"video_context": video_context, "audio_context": audio_context}
class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "num_frames", "seed", "rand_device", "use_two_stage_pipeline"),
output_params=("video_noise", "audio_noise",),
)
def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=pipe.video_vae_encoder.latent_channels)
video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device)
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float()
video_positions[:, 0, ...] = video_positions[:, 0, ...] / frame_rate
video_positions = video_positions.to(pipe.torch_dtype)
audio_latent_shape = AudioLatentShape.from_video_pixel_shape(video_pixel_shape)
audio_noise = pipe.generate_noise(audio_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
return {
"video_noise": video_noise,
"audio_noise": audio_noise,
"video_positions": video_positions,
"audio_positions": audio_positions,
"video_latent_shape": video_latent_shape,
"audio_latent_shape": audio_latent_shape
}
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0, use_two_stage_pipeline=False):
if use_two_stage_pipeline:
stage1_dict = self.process_stage(pipe, height // 2, width // 2, num_frames, seed, rand_device, frame_rate)
stage2_dict = self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
initial_dict = stage1_dict
initial_dict.update({"stage2_" + k: v for k, v in stage2_dict.items()})
return initial_dict
else:
return self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_video", "video_noise", "audio_noise", "tiled", "tile_size", "tile_stride"),
output_params=("video_latents", "audio_latents"),
onload_model_names=("video_vae_encoder")
)
def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, audio_noise, tiled, tile_size, tile_stride):
if input_video is None:
return {"video_latents": video_noise, "audio_latents": audio_noise}
else:
# TODO: implement video-to-video
raise NotImplementedError("Video-to-video not implemented yet.")
class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "num_frames", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"),
output_params=("video_latents"),
onload_model_names=("video_vae_encoder")
)
def get_image_latent(self, pipe, input_image, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels):
image = ltx2_preprocess(np.array(input_image.resize((width, height))))
image = torch.Tensor(np.array(image, dtype=np.float32)).to(dtype=pipe.torch_dtype, device=pipe.device)
image = image / 127.5 - 1.0
image = repeat(image, f"H W C -> B C F H W", B=1, F=1)
latent = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device)
return latent
def process(self, pipe: LTX2AudioVideoPipeline, input_images, input_images_indexes, input_images_strength, video_latents, height, width, num_frames, tiled, tile_size_in_pixels, tile_overlap_in_pixels, use_two_stage_pipeline=False):
if input_images is None or len(input_images) == 0:
return {"video_latents": video_latents}
else:
pipe.load_models_to_device(self.onload_model_names)
output_dicts = {}
stage1_height = height // 2 if use_two_stage_pipeline else height
stage1_width = width // 2 if use_two_stage_pipeline else width
stage1_latents = [
self.get_image_latent(pipe, img, stage1_height, stage1_width, tiled, tile_size_in_pixels,
tile_overlap_in_pixels) for img in input_images
]
video_latents, denoise_mask_video, initial_latents = pipe.apply_input_images_to_latents(video_latents, stage1_latents, input_images_indexes, input_images_strength, num_frames=num_frames)
output_dicts.update({"video_latents": video_latents, "denoise_mask_video": denoise_mask_video, "input_latents_video": initial_latents})
if use_two_stage_pipeline:
stage2_latents = [
self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels,
tile_overlap_in_pixels) for img in input_images
]
output_dicts.update({"stage2_input_latents": stage2_latents})
return output_dicts
def model_fn_ltx2(
dit: LTXModel,
video_latents=None,
video_context=None,
video_positions=None,
video_patchifier=None,
audio_latents=None,
audio_context=None,
audio_positions=None,
audio_patchifier=None,
timestep=None,
denoise_mask_video=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
timestep = timestep.float() / 1000.
# patchify
b, c_v, f, h, w = video_latents.shape
video_latents = video_patchifier.patchify(video_latents)
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
if denoise_mask_video is not None:
video_timesteps = video_patchifier.patchify(denoise_mask_video) * video_timesteps
_, c_a, _, mel_bins = audio_latents.shape
audio_latents = audio_patchifier.patchify(audio_latents)
audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1)
#TODO: support gradient checkpointing in training
vx, ax = dit(
video_latents=video_latents,
video_positions=video_positions,
video_context=video_context,
video_timesteps=video_timesteps,
audio_latents=audio_latents,
audio_positions=audio_positions,
audio_context=audio_context,
audio_timesteps=audio_timesteps,
)
# unpatchify
vx = video_patchifier.unpatchify_video(vx, f, h, w)
ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins)
return vx, ax

View File

@@ -4,7 +4,9 @@ from typing import Union
from tqdm import tqdm from tqdm import tqdm
from einops import rearrange from einops import rearrange
import numpy as np import numpy as np
from math import prod
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
@@ -21,7 +23,7 @@ from ..models.qwen_image_image2lora import QwenImageImage2LoRAModel
class QwenImagePipeline(BasePipeline): class QwenImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16): def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__( super().__init__(
device=device, torch_dtype=torch_dtype, device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, height_division_factor=16, width_division_factor=16,
@@ -47,6 +49,7 @@ class QwenImagePipeline(BasePipeline):
QwenImageUnit_InputImageEmbedder(), QwenImageUnit_InputImageEmbedder(),
QwenImageUnit_Inpaint(), QwenImageUnit_Inpaint(),
QwenImageUnit_EditImageEmbedder(), QwenImageUnit_EditImageEmbedder(),
QwenImageUnit_LayerInputImageEmbedder(),
QwenImageUnit_ContextImageEmbedder(), QwenImageUnit_ContextImageEmbedder(),
QwenImageUnit_PromptEmbedder(), QwenImageUnit_PromptEmbedder(),
QwenImageUnit_EntityControl(), QwenImageUnit_EntityControl(),
@@ -58,7 +61,7 @@ class QwenImagePipeline(BasePipeline):
@staticmethod @staticmethod
def from_pretrained( def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16, torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda", device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [], model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
processor_config: ModelConfig = None, processor_config: ModelConfig = None,
@@ -125,6 +128,11 @@ class QwenImagePipeline(BasePipeline):
edit_image: Image.Image = None, edit_image: Image.Image = None,
edit_image_auto_resize: bool = True, edit_image_auto_resize: bool = True,
edit_rope_interpolation: bool = False, edit_rope_interpolation: bool = False,
# Qwen-Image-Edit-2511
zero_cond_t: bool = False,
# Qwen-Image-Layered
layer_input_image: Image.Image = None,
layer_num: int = None,
# In-context control # In-context control
context_image: Image.Image = None, context_image: Image.Image = None,
# Tile # Tile
@@ -156,6 +164,9 @@ class QwenImagePipeline(BasePipeline):
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, "eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "edit_rope_interpolation": edit_rope_interpolation, "edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "edit_rope_interpolation": edit_rope_interpolation,
"context_image": context_image, "context_image": context_image,
"zero_cond_t": zero_cond_t,
"layer_input_image": layer_input_image,
"layer_num": layer_num,
} }
for unit in self.units: for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -175,7 +186,10 @@ class QwenImagePipeline(BasePipeline):
# Decode # Decode
self.load_models_to_device(['vae']) self.load_models_to_device(['vae'])
image = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) image = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.vae_output_to_image(image) if layer_num is None:
image = self.vae_output_to_image(image)
else:
image = [self.vae_output_to_image(i, pattern="C H W") for i in image]
self.load_models_to_device([]) self.load_models_to_device([])
return image return image
@@ -226,12 +240,15 @@ class QwenImageUnit_ShapeChecker(PipelineUnit):
class QwenImageUnit_NoiseInitializer(PipelineUnit): class QwenImageUnit_NoiseInitializer(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
input_params=("height", "width", "seed", "rand_device"), input_params=("height", "width", "seed", "rand_device", "layer_num"),
output_params=("noise",), output_params=("noise",),
) )
def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device): def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device, layer_num):
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) if layer_num is None:
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
else:
noise = pipe.generate_noise((layer_num + 1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise} return {"noise": noise}
@@ -248,8 +265,15 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit):
if input_image is None: if input_image is None:
return {"latents": noise, "input_latents": None} return {"latents": noise, "input_latents": None}
pipe.load_models_to_device(['vae']) pipe.load_models_to_device(['vae'])
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype) if isinstance(input_image, list):
input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) input_latents = []
for image in input_image:
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
input_latents.append(pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride))
input_latents = torch.concat(input_latents, dim=0)
else:
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if pipe.scheduler.training: if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents} return {"latents": noise, "input_latents": input_latents}
else: else:
@@ -257,6 +281,22 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit):
return {"latents": latents, "input_latents": input_latents} return {"latents": latents, "input_latents": input_latents}
class QwenImageUnit_LayerInputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("layer_input_image", "tiled", "tile_size", "tile_stride"),
output_params=("layer_input_latents",),
onload_model_names=("vae",)
)
def process(self, pipe: QwenImagePipeline, layer_input_image, tiled, tile_size, tile_stride):
if layer_input_image is None:
return {}
pipe.load_models_to_device(['vae'])
image = pipe.preprocess_image(layer_input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return {"layer_input_latents": latents}
class QwenImageUnit_Inpaint(PipelineUnit): class QwenImageUnit_Inpaint(PipelineUnit):
def __init__(self): def __init__(self):
@@ -673,18 +713,26 @@ def model_fn_qwen_image(
entity_prompt_emb_mask=None, entity_prompt_emb_mask=None,
entity_masks=None, entity_masks=None,
edit_latents=None, edit_latents=None,
layer_input_latents=None,
layer_num=None,
context_latents=None, context_latents=None,
enable_fp8_attention=False, enable_fp8_attention=False,
use_gradient_checkpointing=False, use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False, use_gradient_checkpointing_offload=False,
edit_rope_interpolation=False, edit_rope_interpolation=False,
zero_cond_t=False,
**kwargs **kwargs
): ):
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)] if layer_num is None:
layer_num = 1
img_shapes = [(1, latents.shape[2]//2, latents.shape[3]//2)]
else:
layer_num = layer_num + 1
img_shapes = [(1, latents.shape[2]//2, latents.shape[3]//2)] * layer_num
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist() txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
timestep = timestep / 1000 timestep = timestep / 1000
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2) image = rearrange(latents, "(B N) C (H P) (W Q) -> B (N H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2, N=layer_num)
image_seq_len = image.shape[1] image_seq_len = image.shape[1]
if context_latents is not None: if context_latents is not None:
@@ -696,9 +744,27 @@ def model_fn_qwen_image(
img_shapes += [(e.shape[0], e.shape[2]//2, e.shape[3]//2) for e in edit_latents_list] img_shapes += [(e.shape[0], e.shape[2]//2, e.shape[3]//2) for e in edit_latents_list]
edit_image = [rearrange(e, "B C (H P) (W Q) -> B (H W) (C P Q)", H=e.shape[2]//2, W=e.shape[3]//2, P=2, Q=2) for e in edit_latents_list] edit_image = [rearrange(e, "B C (H P) (W Q) -> B (H W) (C P Q)", H=e.shape[2]//2, W=e.shape[3]//2, P=2, Q=2) for e in edit_latents_list]
image = torch.cat([image] + edit_image, dim=1) image = torch.cat([image] + edit_image, dim=1)
if layer_input_latents is not None:
layer_num = layer_num + 1
img_shapes += [(layer_input_latents.shape[0], layer_input_latents.shape[2]//2, layer_input_latents.shape[3]//2)]
layer_input_latents = rearrange(layer_input_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
image = torch.cat([image, layer_input_latents], dim=1)
image = dit.img_in(image) image = dit.img_in(image)
conditioning = dit.time_text_embed(timestep, image.dtype) if zero_cond_t:
timestep = torch.cat([timestep, timestep * 0], dim=0)
modulate_index = torch.tensor(
[[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in [img_shapes]],
device=timestep.device,
dtype=torch.int,
)
else:
modulate_index = None
conditioning = dit.time_text_embed(
timestep,
image.dtype,
addition_t_cond=None if not dit.time_text_embed.use_additional_t_cond else torch.tensor([0]).to(device=image.device, dtype=torch.long)
)
if entity_prompt_emb is not None: if entity_prompt_emb is not None:
text, image_rotary_emb, attention_mask = dit.process_entity_masks( text, image_rotary_emb, attention_mask = dit.process_entity_masks(
@@ -728,6 +794,7 @@ def model_fn_qwen_image(
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask, attention_mask=attention_mask,
enable_fp8_attention=enable_fp8_attention, enable_fp8_attention=enable_fp8_attention,
modulate_index=modulate_index,
) )
if blockwise_controlnet_conditioning is not None: if blockwise_controlnet_conditioning is not None:
image_slice = image[:, :image_seq_len].clone() image_slice = image[:, :image_seq_len].clone()
@@ -738,9 +805,11 @@ def model_fn_qwen_image(
) )
image[:, :image_seq_len] = image_slice + controlnet_output image[:, :image_seq_len] = image_slice + controlnet_output
if zero_cond_t:
conditioning = conditioning.chunk(2, dim=0)[0]
image = dit.norm_out(image, conditioning) image = dit.norm_out(image, conditioning)
image = dit.proj_out(image) image = dit.proj_out(image)
image = image[:, :image_seq_len] image = image[:, :image_seq_len]
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2) latents = rearrange(image, "B (N H W) (C P Q) -> (B N) C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2, B=1)
return latents return latents

View File

@@ -11,6 +11,7 @@ from typing import Optional
from typing_extensions import Literal from typing_extensions import Literal
from transformers import Wav2Vec2Processor from transformers import Wav2Vec2Processor
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
@@ -30,7 +31,7 @@ from ..models.longcat_video_dit import LongCatVideoTransformer3DModel
class WanVideoPipeline(BasePipeline): class WanVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16): def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__( super().__init__(
device=device, torch_dtype=torch_dtype, device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
@@ -98,7 +99,7 @@ class WanVideoPipeline(BasePipeline):
@staticmethod @staticmethod
def from_pretrained( def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16, torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda", device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [], model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
audio_processor_config: ModelConfig = None, audio_processor_config: ModelConfig = None,
@@ -122,11 +123,15 @@ class WanVideoPipeline(BasePipeline):
model_config.model_id = redirect_dict[model_config.origin_file_pattern][0] model_config.model_id = redirect_dict[model_config.origin_file_pattern][0]
model_config.origin_file_pattern = redirect_dict[model_config.origin_file_pattern][1] model_config.origin_file_pattern = redirect_dict[model_config.origin_file_pattern][1]
# Initialize pipeline
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
if use_usp: if use_usp:
from ..utils.xfuser import initialize_usp from ..utils.xfuser import initialize_usp
initialize_usp(device) initialize_usp(device)
import torch.distributed as dist
from ..core.device.npu_compatible_device import get_device_name
if dist.is_available() and dist.is_initialized():
device = get_device_name()
# Initialize pipeline
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit) model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models # Fetch models
@@ -241,6 +246,7 @@ class WanVideoPipeline(BasePipeline):
tea_cache_model_id: Optional[str] = "", tea_cache_model_id: Optional[str] = "",
# progress_bar # progress_bar
progress_bar_cmd=tqdm, progress_bar_cmd=tqdm,
output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized",
): ):
# Scheduler # Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
@@ -320,9 +326,11 @@ class WanVideoPipeline(BasePipeline):
# Decode # Decode
self.load_models_to_device(['vae']) self.load_models_to_device(['vae'])
video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
video = self.vae_output_to_video(video) if output_type == "quantized":
video = self.vae_output_to_video(video)
elif output_type == "floatpoint":
pass
self.load_models_to_device([]) self.load_models_to_device([])
return video return video
@@ -823,9 +831,9 @@ class WanVideoUnit_S2V(PipelineUnit):
pipe.load_models_to_device(["vae"]) pipe.load_models_to_device(["vae"])
motion_frames = 73 motion_frames = 73
kwargs = {} kwargs = {}
if motion_video is not None and len(motion_video) > 0: if motion_video is not None:
assert len(motion_video) == motion_frames, f"motion video must have {motion_frames} frames, but got {len(motion_video)}" assert motion_video.shape[2] == motion_frames, f"motion video must have {motion_frames} frames, but got {motion_video.shape[2]}"
motion_latents = pipe.preprocess_video(motion_video) motion_latents = motion_video
kwargs["drop_motion_frames"] = False kwargs["drop_motion_frames"] = False
else: else:
motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device) motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device)
@@ -957,7 +965,7 @@ class WanVideoUnit_AnimateInpaint(PipelineUnit):
onload_model_names=("vae",) onload_model_names=("vae",)
) )
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"): def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device=get_device_type()):
if mask_pixel_values is None: if mask_pixel_values is None:
msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device) msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
else: else:

View File

@@ -4,21 +4,29 @@ from typing import Union
from tqdm import tqdm from tqdm import tqdm
from einops import rearrange from einops import rearrange
import numpy as np import numpy as np
from typing import Union, List, Optional, Tuple from typing import Union, List, Optional, Tuple, Iterable, Dict
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward from ..core import ModelConfig, gradient_checkpoint_forward
from ..core.data.operators import ImageCropAndResize
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
from ..utils.lora import merge_lora
from transformers import AutoTokenizer from transformers import AutoTokenizer
from ..models.z_image_text_encoder import ZImageTextEncoder from ..models.z_image_text_encoder import ZImageTextEncoder
from ..models.z_image_dit import ZImageDiT from ..models.z_image_dit import ZImageDiT
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
from ..models.siglip2_image_encoder import Siglip2ImageEncoder428M
from ..models.z_image_controlnet import ZImageControlNet
from ..models.siglip2_image_encoder import Siglip2ImageEncoder
from ..models.dinov3_image_encoder import DINOv3ImageEncoder
from ..models.z_image_image2lora import ZImageImage2LoRAModel
class ZImagePipeline(BasePipeline): class ZImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16): def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__( super().__init__(
device=device, torch_dtype=torch_dtype, device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, height_division_factor=16, width_division_factor=16,
@@ -28,13 +36,22 @@ class ZImagePipeline(BasePipeline):
self.dit: ZImageDiT = None self.dit: ZImageDiT = None
self.vae_encoder: FluxVAEEncoder = None self.vae_encoder: FluxVAEEncoder = None
self.vae_decoder: FluxVAEDecoder = None self.vae_decoder: FluxVAEDecoder = None
self.image_encoder: Siglip2ImageEncoder428M = None
self.controlnet: ZImageControlNet = None
self.siglip2_image_encoder: Siglip2ImageEncoder = None
self.dinov3_image_encoder: DINOv3ImageEncoder = None
self.image2lora_style: ZImageImage2LoRAModel = None
self.tokenizer: AutoTokenizer = None self.tokenizer: AutoTokenizer = None
self.in_iteration_models = ("dit",) self.in_iteration_models = ("dit", "controlnet")
self.units = [ self.units = [
ZImageUnit_ShapeChecker(), ZImageUnit_ShapeChecker(),
ZImageUnit_PromptEmbedder(), ZImageUnit_PromptEmbedder(),
ZImageUnit_NoiseInitializer(), ZImageUnit_NoiseInitializer(),
ZImageUnit_InputImageEmbedder(), ZImageUnit_InputImageEmbedder(),
ZImageUnit_EditImageAutoResize(),
ZImageUnit_EditImageEmbedderVAE(),
ZImageUnit_EditImageEmbedderSiglip(),
ZImageUnit_PAIControlNet(),
] ]
self.model_fn = model_fn_z_image self.model_fn = model_fn_z_image
@@ -42,7 +59,7 @@ class ZImagePipeline(BasePipeline):
@staticmethod @staticmethod
def from_pretrained( def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16, torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda", device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [], model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
vram_limit: float = None, vram_limit: float = None,
@@ -56,6 +73,11 @@ class ZImagePipeline(BasePipeline):
pipe.dit = model_pool.fetch_model("z_image_dit") pipe.dit = model_pool.fetch_model("z_image_dit")
pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder") pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder")
pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder") pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder")
pipe.image_encoder = model_pool.fetch_model("siglip_vision_model_428m")
pipe.controlnet = model_pool.fetch_model("z_image_controlnet")
pipe.siglip2_image_encoder = model_pool.fetch_model("siglip2_image_encoder")
pipe.dinov3_image_encoder = model_pool.fetch_model("dinov3_image_encoder")
pipe.image2lora_style = model_pool.fetch_model("z_image_image2lora_style")
if tokenizer_config is not None: if tokenizer_config is not None:
tokenizer_config.download_if_necessary() tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
@@ -75,6 +97,9 @@ class ZImagePipeline(BasePipeline):
# Image # Image
input_image: Image.Image = None, input_image: Image.Image = None,
denoising_strength: float = 1.0, denoising_strength: float = 1.0,
# Edit
edit_image: Image.Image = None,
edit_image_auto_resize: bool = True,
# Shape # Shape
height: int = 1024, height: int = 1024,
width: int = 1024, width: int = 1024,
@@ -83,11 +108,17 @@ class ZImagePipeline(BasePipeline):
rand_device: str = "cpu", rand_device: str = "cpu",
# Steps # Steps
num_inference_steps: int = 8, num_inference_steps: int = 8,
sigma_shift: float = None,
# ControlNet
controlnet_inputs: List[ControlNetInput] = None,
# Image to LoRA
image2lora_images: List[Image.Image] = None,
positive_only_lora: Dict[str, torch.Tensor] = None,
# Progress bar # Progress bar
progress_bar_cmd = tqdm, progress_bar_cmd = tqdm,
): ):
# Scheduler # Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength) self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
# Parameters # Parameters
inputs_posi = { inputs_posi = {
@@ -102,6 +133,9 @@ class ZImagePipeline(BasePipeline):
"height": height, "width": width, "height": height, "width": width,
"seed": seed, "rand_device": rand_device, "seed": seed, "rand_device": rand_device,
"num_inference_steps": num_inference_steps, "num_inference_steps": num_inference_steps,
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
"controlnet_inputs": controlnet_inputs,
"image2lora_images": image2lora_images, "positive_only_lora": positive_only_lora,
} }
for unit in self.units: for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -143,12 +177,13 @@ class ZImageUnit_PromptEmbedder(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
seperate_cfg=True, seperate_cfg=True,
input_params=("edit_image",),
input_params_posi={"prompt": "prompt"}, input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"}, input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_embeds",), output_params=("prompt_embeds",),
onload_model_names=("text_encoder",) onload_model_names=("text_encoder",)
) )
def encode_prompt( def encode_prompt(
self, self,
pipe, pipe,
@@ -194,10 +229,81 @@ class ZImageUnit_PromptEmbedder(PipelineUnit):
embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
return embeddings_list return embeddings_list
def encode_prompt_omni(
self,
pipe,
prompt: Union[str, List[str]],
edit_image=None,
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
) -> List[torch.FloatTensor]:
if isinstance(prompt, str):
prompt = [prompt]
def process(self, pipe: ZImagePipeline, prompt): if edit_image is None:
num_condition_images = 0
elif isinstance(edit_image, list):
num_condition_images = len(edit_image)
else:
num_condition_images = 1
for i, prompt_item in enumerate(prompt):
if num_condition_images == 0:
prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"]
elif num_condition_images > 0:
prompt_list = ["<|im_start|>user\n<|vision_start|>"]
prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1)
prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"]
prompt_list += ["<|vision_end|><|im_end|>"]
prompt[i] = prompt_list
flattened_prompt = []
prompt_list_lengths = []
for i in range(len(prompt)):
prompt_list_lengths.append(len(prompt[i]))
flattened_prompt.extend(prompt[i])
text_inputs = pipe.tokenizer(
flattened_prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device).bool()
prompt_embeds = pipe.text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
).hidden_states[-2]
embeddings_list = []
start_idx = 0
for i in range(len(prompt_list_lengths)):
batch_embeddings = []
end_idx = start_idx + prompt_list_lengths[i]
for j in range(start_idx, end_idx):
batch_embeddings.append(prompt_embeds[j][prompt_masks[j]])
embeddings_list.append(batch_embeddings)
start_idx = end_idx
return embeddings_list
def process(self, pipe: ZImagePipeline, prompt, edit_image):
pipe.load_models_to_device(self.onload_model_names) pipe.load_models_to_device(self.onload_model_names)
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device) if hasattr(pipe, "dit") and pipe.dit.siglip_embedder is not None:
# Z-Image-Turbo and Z-Image-Omni-Base use different prompt encoding methods.
# We determine which encoding method to use based on the model architecture.
# If you are using two-stage split training,
# please use `--offload_models` instead of skipping the DiT model loading.
prompt_embeds = self.encode_prompt_omni(pipe, prompt, edit_image, pipe.device)
else:
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
return {"prompt_embeds": prompt_embeds} return {"prompt_embeds": prompt_embeds}
@@ -234,24 +340,330 @@ class ZImageUnit_InputImageEmbedder(PipelineUnit):
return {"latents": latents, "input_latents": input_latents} return {"latents": latents, "input_latents": input_latents}
class ZImageUnit_EditImageAutoResize(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image", "edit_image_auto_resize"),
output_params=("edit_image",),
)
def process(self, pipe: ZImagePipeline, edit_image, edit_image_auto_resize):
if edit_image is None:
return {}
if edit_image_auto_resize is None or not edit_image_auto_resize:
return {}
operator = ImageCropAndResize(max_pixels=1024*1024, height_division_factor=16, width_division_factor=16)
if not isinstance(edit_image, list):
edit_image = [edit_image]
edit_image = [operator(i) for i in edit_image]
return {"edit_image": edit_image}
class ZImageUnit_EditImageEmbedderSiglip(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image",),
output_params=("image_embeds",),
onload_model_names=("image_encoder",)
)
def process(self, pipe: ZImagePipeline, edit_image):
if edit_image is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
if not isinstance(edit_image, list):
edit_image = [edit_image]
image_emb = []
for image_ in edit_image:
image_emb.append(pipe.image_encoder(image_, device=pipe.device))
return {"image_embeds": image_emb}
class ZImageUnit_EditImageEmbedderVAE(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image",),
output_params=("image_latents",),
onload_model_names=("vae_encoder",)
)
def process(self, pipe: ZImagePipeline, edit_image):
if edit_image is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
if not isinstance(edit_image, list):
edit_image = [edit_image]
image_latents = []
for image_ in edit_image:
image_ = pipe.preprocess_image(image_)
image_latents.append(pipe.vae_encoder(image_))
return {"image_latents": image_latents}
class ZImageUnit_PAIControlNet(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("controlnet_inputs", "height", "width"),
output_params=("control_context", "control_scale"),
onload_model_names=("vae_encoder",)
)
def process(self, pipe: ZImagePipeline, controlnet_inputs: List[ControlNetInput], height, width):
if controlnet_inputs is None:
return {}
if len(controlnet_inputs) != 1:
print("Z-Image ControlNet doesn't support multi-ControlNet. Only one image will be used.")
controlnet_input = controlnet_inputs[0]
pipe.load_models_to_device(self.onload_model_names)
control_image = controlnet_input.image
if control_image is not None:
control_image = pipe.preprocess_image(control_image)
control_latents = pipe.vae_encoder(control_image)
else:
control_latents = torch.ones((1, 16, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device) * -1
inpaint_mask = controlnet_input.inpaint_mask
if inpaint_mask is not None:
inpaint_mask = pipe.preprocess_image(inpaint_mask, min_value=0, max_value=1)
inpaint_image = controlnet_input.inpaint_image
inpaint_image = pipe.preprocess_image(inpaint_image)
inpaint_image = inpaint_image * (inpaint_mask < 0.5)
inpaint_mask = torch.nn.functional.interpolate(1 - inpaint_mask, (height // 8, width // 8), mode='nearest')[:, :1]
else:
inpaint_mask = torch.zeros((1, 1, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device)
inpaint_image = torch.zeros((1, 3, height, width), dtype=pipe.torch_dtype, device=pipe.device)
inpaint_latent = pipe.vae_encoder(inpaint_image)
control_context = torch.concat([control_latents, inpaint_mask, inpaint_latent], dim=1)
control_context = rearrange(control_context, "B C H W -> B C 1 H W")
return {"control_context": control_context, "control_scale": controlnet_input.scale}
def model_fn_z_image( def model_fn_z_image(
dit: ZImageDiT, dit: ZImageDiT,
controlnet: ZImageControlNet = None,
latents=None, latents=None,
timestep=None, timestep=None,
prompt_embeds=None, prompt_embeds=None,
image_embeds=None,
image_latents=None,
use_gradient_checkpointing=False, use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False, use_gradient_checkpointing_offload=False,
**kwargs, **kwargs,
): ):
# Due to the complex and verbose codebase of Z-Image,
# we are temporarily using this inelegant structure.
# We will refactor this part in the future (if time permits).
if dit.siglip_embedder is None:
return model_fn_z_image_turbo(
dit,
controlnet=controlnet,
latents=latents,
timestep=timestep,
prompt_embeds=prompt_embeds,
image_embeds=image_embeds,
image_latents=image_latents,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
**kwargs,
)
latents = [rearrange(latents, "B C H W -> C B H W")] latents = [rearrange(latents, "B C H W -> C B H W")]
if dit.siglip_embedder is not None:
if image_latents is not None:
image_latents = [rearrange(image_latent, "B C H W -> C B H W") for image_latent in image_latents]
latents = [image_latents + latents]
image_noise_mask = [[0] * len(image_latents) + [1]]
else:
latents = [latents]
image_noise_mask = [[1]]
image_embeds = [image_embeds]
else:
image_noise_mask = None
timestep = (1000 - timestep) / 1000 timestep = (1000 - timestep) / 1000
model_output = dit( model_output = dit(
latents, latents,
timestep, timestep,
prompt_embeds, prompt_embeds,
siglip_feats=image_embeds,
image_noise_mask=image_noise_mask,
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)[0][0] )[0]
model_output = -model_output model_output = -model_output
model_output = rearrange(model_output, "C B H W -> B C H W") model_output = rearrange(model_output, "C B H W -> B C H W")
return model_output return model_output
class ZImageUnit_Image2LoRAEncode(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("image2lora_images",),
output_params=("image2lora_x",),
onload_model_names=("siglip2_image_encoder", "dinov3_image_encoder",),
)
from ..core.data.operators import ImageCropAndResize
self.processor_highres = ImageCropAndResize(height=1024, width=1024)
def encode_images_using_siglip2(self, pipe: ZImagePipeline, images: list[Image.Image]):
pipe.load_models_to_device(["siglip2_image_encoder"])
embs = []
for image in images:
image = self.processor_highres(image)
embs.append(pipe.siglip2_image_encoder(image).to(pipe.torch_dtype))
embs = torch.stack(embs)
return embs
def encode_images_using_dinov3(self, pipe: ZImagePipeline, images: list[Image.Image]):
pipe.load_models_to_device(["dinov3_image_encoder"])
embs = []
for image in images:
image = self.processor_highres(image)
embs.append(pipe.dinov3_image_encoder(image).to(pipe.torch_dtype))
embs = torch.stack(embs)
return embs
def encode_images(self, pipe: ZImagePipeline, images: list[Image.Image]):
if images is None:
return {}
if not isinstance(images, list):
images = [images]
embs_siglip2 = self.encode_images_using_siglip2(pipe, images)
embs_dinov3 = self.encode_images_using_dinov3(pipe, images)
x = torch.concat([embs_siglip2, embs_dinov3], dim=-1)
return x
def process(self, pipe: ZImagePipeline, image2lora_images):
if image2lora_images is None:
return {}
x = self.encode_images(pipe, image2lora_images)
return {"image2lora_x": x}
class ZImageUnit_Image2LoRADecode(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("image2lora_x",),
output_params=("lora",),
onload_model_names=("image2lora_style",),
)
def process(self, pipe: ZImagePipeline, image2lora_x):
if image2lora_x is None:
return {}
loras = []
if pipe.image2lora_style is not None:
pipe.load_models_to_device(["image2lora_style"])
for x in image2lora_x:
loras.append(pipe.image2lora_style(x=x, residual=None))
lora = merge_lora(loras, alpha=1 / len(image2lora_x))
return {"lora": lora}
def model_fn_z_image_turbo(
dit: ZImageDiT,
controlnet: ZImageControlNet = None,
latents=None,
timestep=None,
prompt_embeds=None,
image_embeds=None,
image_latents=None,
control_context=None,
control_scale=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
while isinstance(prompt_embeds, list):
prompt_embeds = prompt_embeds[0]
while isinstance(latents, list):
latents = latents[0]
while isinstance(image_embeds, list):
image_embeds = image_embeds[0]
# Timestep
timestep = 1000 - timestep
t_noisy = dit.t_embedder(timestep)
t_clean = dit.t_embedder(torch.ones_like(timestep) * 1000)
# Patchify
latents = rearrange(latents, "B C H W -> C B H W")
x, cap_feats, patch_metadata = dit.patchify_and_embed([latents], [prompt_embeds])
x = x[0]
cap_feats = cap_feats[0]
# Noise refine
x = dit.all_x_embedder["2-1"](x)
x[torch.cat(patch_metadata.get("x_pad_mask"))] = dit.x_pad_token.to(dtype=x.dtype, device=x.device)
x_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("x_pos_ids"), dim=0))
x = rearrange(x, "L C -> 1 L C")
x_freqs_cis = rearrange(x_freqs_cis, "L C -> 1 L C")
if control_context is not None:
kwargs = dict(attn_mask=None, freqs_cis=x_freqs_cis, adaln_input=t_noisy)
refiner_hints, control_context, control_context_item_seqlens = controlnet.forward_refiner(
dit, x, [cap_feats], control_context, kwargs, t=t_noisy, patch_size=2, f_patch_size=1,
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
for layer_id, layer in enumerate(dit.noise_refiner):
x = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=x,
attn_mask=None,
freqs_cis=x_freqs_cis,
adaln_input=t_noisy,
)
if control_context is not None:
x = x + refiner_hints[layer_id] * control_scale
# Prompt refine
cap_feats = dit.cap_embedder(cap_feats)
cap_feats[torch.cat(patch_metadata.get("cap_pad_mask"))] = dit.cap_pad_token.to(dtype=x.dtype, device=x.device)
cap_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("cap_pos_ids"), dim=0))
cap_feats = rearrange(cap_feats, "L C -> 1 L C")
cap_freqs_cis = rearrange(cap_freqs_cis, "L C -> 1 L C")
for layer in dit.context_refiner:
cap_feats = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=cap_feats,
attn_mask=None,
freqs_cis=cap_freqs_cis,
)
# Unified
unified = torch.cat([x, cap_feats], dim=1)
unified_freqs_cis = torch.cat([x_freqs_cis, cap_freqs_cis], dim=1)
if control_context is not None:
kwargs = dict(attn_mask=None, freqs_cis=unified_freqs_cis, adaln_input=t_noisy)
hints = controlnet.forward_layers(
unified, cap_feats, control_context, control_context_item_seqlens, kwargs,
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
for layer_id, layer in enumerate(dit.layers):
unified = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=unified,
attn_mask=None,
freqs_cis=unified_freqs_cis,
adaln_input=t_noisy,
)
if control_context is not None:
if layer_id in controlnet.control_layers_mapping:
unified = unified + hints[controlnet.control_layers_mapping[layer_id]] * control_scale
# Output
unified = dit.all_final_layer["2-1"](unified, t_noisy)
x = dit.unpatchify([unified[0]], patch_metadata.get("x_size"))[0]
x = rearrange(x, "C B H W -> B C H W")
x = -x
return x

View File

@@ -1,12 +1,13 @@
from typing_extensions import Literal, TypeAlias from typing_extensions import Literal, TypeAlias
from diffsynth.core.device.npu_compatible_device import get_device_type
Processor_id: TypeAlias = Literal[ Processor_id: TypeAlias = Literal[
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint" "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
] ]
class Annotator: class Annotator:
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False): def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device=get_device_type(), skip_processor=False):
if not skip_processor: if not skip_processor:
if processor_id == "canny": if processor_id == "canny":
from controlnet_aux.processor import CannyDetector from controlnet_aux.processor import CannyDetector

View File

@@ -9,5 +9,6 @@ class ControlNetInput:
start: float = 1.0 start: float = 1.0
end: float = 0.0 end: float = 0.0
image: Image.Image = None image: Image.Image = None
inpaint_image: Image.Image = None
inpaint_mask: Image.Image = None inpaint_mask: Image.Image = None
processor_id: str = None processor_id: str = None

View File

@@ -0,0 +1,149 @@
from fractions import Fraction
import torch
import av
from tqdm import tqdm
from PIL import Image
import numpy as np
from io import BytesIO
from collections.abc import Generator, Iterator
def _resample_audio(
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
) -> None:
cc = audio_stream.codec_context
# Use the encoder's format/layout/rate as the *target*
target_format = cc.format or "fltp" # AAC → usually fltp
target_layout = cc.layout or "stereo"
target_rate = cc.sample_rate or frame_in.sample_rate
audio_resampler = av.audio.resampler.AudioResampler(
format=target_format,
layout=target_layout,
rate=target_rate,
)
audio_next_pts = 0
for rframe in audio_resampler.resample(frame_in):
if rframe.pts is None:
rframe.pts = audio_next_pts
audio_next_pts += rframe.samples
rframe.sample_rate = frame_in.sample_rate
container.mux(audio_stream.encode(rframe))
# flush audio encoder
for packet in audio_stream.encode():
container.mux(packet)
def _write_audio(
container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int
) -> None:
if samples.ndim == 1:
samples = samples[:, None]
if samples.shape[1] != 2 and samples.shape[0] == 2:
samples = samples.T
if samples.shape[1] != 2:
raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
if samples.dtype != torch.int16:
samples = torch.clip(samples, -1.0, 1.0)
samples = (samples * 32767.0).to(torch.int16)
frame_in = av.AudioFrame.from_ndarray(
samples.contiguous().reshape(1, -1).cpu().numpy(),
format="s16",
layout="stereo",
)
frame_in.sample_rate = audio_sample_rate
_resample_audio(container, audio_stream, frame_in)
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
"""
Prepare the audio stream for writing.
"""
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
audio_stream.codec_context.sample_rate = audio_sample_rate
audio_stream.codec_context.layout = "stereo"
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
return audio_stream
def write_video_audio_ltx2(
video: list[Image.Image],
audio: torch.Tensor | None,
output_path: str,
fps: int = 24,
audio_sample_rate: int | None = 24000,
) -> None:
width, height = video[0].size
container = av.open(output_path, mode="w")
stream = container.add_stream("libx264", rate=int(fps))
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
if audio is not None:
if audio_sample_rate is None:
raise ValueError("audio_sample_rate is required when audio is provided")
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
for frame in tqdm(video, total=len(video)):
frame = av.VideoFrame.from_image(frame)
for packet in stream.encode(frame):
container.mux(packet)
# Flush encoder
for packet in stream.encode():
container.mux(packet)
if audio is not None:
_write_audio(container, audio_stream, audio, audio_sample_rate)
container.close()
def encode_single_frame(output_file: str, image_array: np.ndarray, crf: float) -> None:
container = av.open(output_file, "w", format="mp4")
try:
stream = container.add_stream("libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"})
# Round to nearest multiple of 2 for compatibility with video codecs
height = image_array.shape[0] // 2 * 2
width = image_array.shape[1] // 2 * 2
image_array = image_array[:height, :width]
stream.height = height
stream.width = width
av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(format="yuv420p")
container.mux(stream.encode(av_frame))
container.mux(stream.encode())
finally:
container.close()
def decode_single_frame(video_file: str) -> np.array:
container = av.open(video_file)
try:
stream = next(s for s in container.streams if s.type == "video")
frame = next(container.decode(stream))
finally:
container.close()
return frame.to_ndarray(format="rgb24")
def ltx2_preprocess(image: np.array, crf: float = 33) -> np.array:
if crf == 0:
return image
with BytesIO() as output_file:
encode_single_frame(output_file, image, crf)
video_bytes = output_file.getvalue()
with BytesIO(video_bytes) as video_file:
image_array = decode_single_frame(video_file)
return image_array

View File

@@ -149,6 +149,8 @@ class FluxLoRALoader(GeneralLoRALoader):
dtype=state_dict_[name].dtype) dtype=state_dict_[name].dtype)
else: else:
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn.")) state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
mlp = mlp.to(device=state_dict_[name].device)
if 'lora_A' in name: if 'lora_A' in name:
param = torch.concat([ param = torch.concat([
state_dict_.pop(name), state_dict_.pop(name),

View File

@@ -89,4 +89,109 @@ def FluxDiTStateDictConverter(state_dict):
state_dict_[rename] = state_dict[original_name] state_dict_[rename] = state_dict[original_name]
else: else:
pass pass
return state_dict_
def FluxDiTStateDictConverterFromDiffusers(state_dict):
global_rename_dict = {
"context_embedder": "context_embedder",
"x_embedder": "x_embedder",
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
"norm_out.linear": "final_norm_out.linear",
"proj_out": "final_proj_out",
}
rename_dict = {
"proj_out": "proj_out",
"norm1.linear": "norm1_a.linear",
"norm1_context.linear": "norm1_b.linear",
"attn.to_q": "attn.a_to_q",
"attn.to_k": "attn.a_to_k",
"attn.to_v": "attn.a_to_v",
"attn.to_out.0": "attn.a_to_out",
"attn.add_q_proj": "attn.b_to_q",
"attn.add_k_proj": "attn.b_to_k",
"attn.add_v_proj": "attn.b_to_v",
"attn.to_add_out": "attn.b_to_out",
"ff.net.0.proj": "ff_a.0",
"ff.net.2": "ff_a.2",
"ff_context.net.0.proj": "ff_b.0",
"ff_context.net.2": "ff_b.2",
"attn.norm_q": "attn.norm_q_a",
"attn.norm_k": "attn.norm_k_a",
"attn.norm_added_q": "attn.norm_q_b",
"attn.norm_added_k": "attn.norm_k_b",
}
rename_dict_single = {
"attn.to_q": "a_to_q",
"attn.to_k": "a_to_k",
"attn.to_v": "a_to_v",
"attn.norm_q": "norm_q_a",
"attn.norm_k": "norm_k_a",
"norm.linear": "norm.linear",
"proj_mlp": "proj_in_besides_attn",
"proj_out": "proj_out",
}
state_dict_ = {}
for name in state_dict:
param = state_dict[name]
if name.endswith(".weight") or name.endswith(".bias"):
suffix = ".weight" if name.endswith(".weight") else ".bias"
prefix = name[:-len(suffix)]
if prefix in global_rename_dict:
if global_rename_dict[prefix] == "final_norm_out.linear":
param = torch.concat([param[3072:], param[:3072]], dim=0)
state_dict_[global_rename_dict[prefix] + suffix] = param
elif prefix.startswith("transformer_blocks."):
names = prefix.split(".")
names[0] = "blocks"
middle = ".".join(names[2:])
if middle in rename_dict:
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
state_dict_[name_] = param
elif prefix.startswith("single_transformer_blocks."):
names = prefix.split(".")
names[0] = "single_blocks"
middle = ".".join(names[2:])
if middle in rename_dict_single:
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
state_dict_[name_] = param
else:
pass
else:
pass
for name in list(state_dict_.keys()):
if "single_blocks." in name and ".a_to_q." in name:
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
if mlp is None:
mlp = torch.zeros(4 * state_dict_[name].shape[0],
*state_dict_[name].shape[1:],
dtype=state_dict_[name].dtype)
else:
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
param = torch.concat([
state_dict_.pop(name),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
mlp,
], dim=0)
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
state_dict_[name_] = param
for name in list(state_dict_.keys()):
for component in ["a", "b"]:
if f".{component}_to_q." in name:
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
return state_dict_ return state_dict_

View File

@@ -0,0 +1,32 @@
def LTX2AudioEncoderStateDictConverter(state_dict):
# Not used
state_dict_ = {}
for name in state_dict:
if name.startswith("audio_vae.encoder."):
new_name = name.replace("audio_vae.encoder.", "")
state_dict_[new_name] = state_dict[name]
elif name.startswith("audio_vae.per_channel_statistics."):
new_name = name.replace("audio_vae.per_channel_statistics.", "per_channel_statistics.")
state_dict_[new_name] = state_dict[name]
return state_dict_
def LTX2AudioDecoderStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith("audio_vae.decoder."):
new_name = name.replace("audio_vae.decoder.", "")
state_dict_[new_name] = state_dict[name]
elif name.startswith("audio_vae.per_channel_statistics."):
new_name = name.replace("audio_vae.per_channel_statistics.", "per_channel_statistics.")
state_dict_[new_name] = state_dict[name]
return state_dict_
def LTX2VocoderStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith("vocoder."):
new_name = name.replace("vocoder.", "")
state_dict_[new_name] = state_dict[name]
return state_dict_

View File

@@ -0,0 +1,9 @@
def LTXModelStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith("model.diffusion_model."):
new_name = name.replace("model.diffusion_model.", "")
if new_name.startswith("audio_embeddings_connector.") or new_name.startswith("video_embeddings_connector."):
continue
state_dict_[new_name] = state_dict[name]
return state_dict_

View File

@@ -0,0 +1,31 @@
def LTX2TextEncoderStateDictConverter(state_dict):
state_dict_ = {}
for key in state_dict:
if key.startswith("language_model.model."):
new_key = key.replace("language_model.model.", "model.language_model.")
elif key.startswith("vision_tower."):
new_key = key.replace("vision_tower.", "model.vision_tower.")
elif key.startswith("multi_modal_projector."):
new_key = key.replace("multi_modal_projector.", "model.multi_modal_projector.")
elif key.startswith("language_model.lm_head."):
new_key = key.replace("language_model.lm_head.", "lm_head.")
else:
continue
state_dict_[new_key] = state_dict[key]
state_dict_["lm_head.weight"] = state_dict_.get("model.language_model.embed_tokens.weight")
return state_dict_
def LTX2TextEncoderPostModulesStateDictConverter(state_dict):
state_dict_ = {}
for key in state_dict:
if key.startswith("text_embedding_projection."):
new_key = key.replace("text_embedding_projection.", "feature_extractor_linear.")
elif key.startswith("model.diffusion_model.video_embeddings_connector."):
new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "embeddings_connector.")
elif key.startswith("model.diffusion_model.audio_embeddings_connector."):
new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "audio_embeddings_connector.")
else:
continue
state_dict_[new_key] = state_dict[key]
return state_dict_

View File

@@ -0,0 +1,22 @@
def LTX2VideoEncoderStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith("vae.encoder."):
new_name = name.replace("vae.encoder.", "")
state_dict_[new_name] = state_dict[name]
elif name.startswith("vae.per_channel_statistics."):
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
state_dict_[new_name] = state_dict[name]
return state_dict_
def LTX2VideoDecoderStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith("vae.decoder."):
new_name = name.replace("vae.decoder.", "")
state_dict_[new_name] = state_dict[name]
elif name.startswith("vae.per_channel_statistics."):
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
state_dict_[new_name] = state_dict[name]
return state_dict_

View File

@@ -0,0 +1,6 @@
def ZImageTextEncoderStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name != "lm_head.weight":
state_dict_[name] = state_dict[name]
return state_dict_

View File

@@ -1,10 +1,13 @@
import torch import torch
from typing import Optional from typing import Optional
from einops import rearrange from einops import rearrange
from yunchang.kernels import AttnType
from xfuser.core.distributed import (get_sequence_parallel_rank, from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size, get_sequence_parallel_world_size,
get_sp_group) get_sp_group)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from ... import IS_NPU_AVAILABLE
from ...core.device import parse_nccl_backend, parse_device_type from ...core.device import parse_nccl_backend, parse_device_type
@@ -30,13 +33,16 @@ def sinusoidal_embedding_1d(dim, position):
def pad_freqs(original_tensor, target_len): def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len pad_size = target_len - seq_len
original_tensor_device = original_tensor.device
if original_tensor.device == "npu":
original_tensor = original_tensor.cpu()
padding_tensor = torch.ones( padding_tensor = torch.ones(
pad_size, pad_size,
s1, s1,
s2, s2,
dtype=original_tensor.dtype, dtype=original_tensor.dtype,
device=original_tensor.device) device=original_tensor.device)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0).to(device=original_tensor_device)
return padded_tensor return padded_tensor
def rope_apply(x, freqs, num_heads): def rope_apply(x, freqs, num_heads):
@@ -50,7 +56,7 @@ def rope_apply(x, freqs, num_heads):
sp_rank = get_sequence_parallel_rank() sp_rank = get_sequence_parallel_rank()
freqs = pad_freqs(freqs, s_per_rank * sp_size) freqs = pad_freqs(freqs, s_per_rank * sp_size)
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device.type == "npu" else freqs_rank
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
return x_out.to(x.dtype) return x_out.to(x.dtype)
@@ -133,7 +139,12 @@ def usp_attn_forward(self, x, freqs):
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads) k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads) v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
x = xFuserLongContextAttention()( attn_type = AttnType.FA
ring_impl_type = "basic"
if IS_NPU_AVAILABLE:
attn_type = AttnType.NPU
ring_impl_type = "basic_npu"
x = xFuserLongContextAttention(attn_type=attn_type, ring_impl_type=ring_impl_type)(
None, None,
query=q, query=q,
key=k, key=k,

View File

@@ -2,6 +2,15 @@
FLUX.2 is an image generation model trained and open-sourced by Black Forest Labs. FLUX.2 is an image generation model trained and open-sourced by Black Forest Labs.
## Model Lineage
```mermaid
graph LR;
FLUX.2-Series-->black-forest-labs/FLUX.2-dev;
FLUX.2-Series-->black-forest-labs/FLUX.2-klein-4B;
FLUX.2-Series-->black-forest-labs/FLUX.2-klein-9B;
```
## Installation ## Installation
Before using this project for model inference and training, please install DiffSynth-Studio first. Before using this project for model inference and training, please install DiffSynth-Studio first.
@@ -50,16 +59,20 @@ image.save("image.jpg")
## Model Overview ## Model Overview
| Model ID | Inference | Low VRAM Inference | LoRA Training | Validation After LoRA Training | | Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
| - | - | - | - | - | | - | - | - | - | - | - | - |
| [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) | [code](/examples/flux2/model_inference/FLUX.2-dev.py) | [code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py) | [code](/examples/flux2/model_training/lora/FLUX.2-dev.sh) | [code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py) | |[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
Special Training Scripts: Special Training Scripts:
* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md), [code](/examples/flux/model_training/special/differential_training/) * Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md)
* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/flux/model_training/special/fp8_training/) * FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md)
* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/flux/model_training/special/split_training/) * Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md)
* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh) * End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md)
## Model Inference ## Model Inference
@@ -135,4 +148,4 @@ We have built a sample image dataset for your testing. You can download this dat
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
``` ```
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/). We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/).

View File

@@ -0,0 +1,109 @@
# LTX-2
LTX-2 is a series of audio-video generation models developed by Lightricks.
## Installation
Before using this project for model inference and training, please install DiffSynth-Studio first.
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
For more information about installation, please refer to [Installation Dependencies](/docs/en/Pipeline_Usage/Setup.md).
## Quick Start
Run the following code to quickly load the [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) model and perform inference. VRAM management has been enabled, and the framework will automatically control model parameter loading based on remaining VRAM. It can run with a minimum of 8GB VRAM.
```python
import torch
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
vram_config = {
"offload_dtype": torch.float8_e5m2,
"offload_device": "cpu",
"onload_dtype": torch.float8_e5m2,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e5m2,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = LTX2AudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
height, width, num_frames = 512, 768, 121
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
seed=43,
height=height,
width=width,
num_frames=num_frames,
tiled=True,
)
write_video_audio_ltx2(
video=video,
audio=audio,
output_path='ltx2_onestage.mp4',
fps=24,
audio_sample_rate=24000,
)
```
## Model Overview
|Model ID|Additional Parameters|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
|-|-|-|-|-|-|-|-|
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-|
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
## Model Inference
Models are loaded through `LTX2AudioVideoPipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models) for details.
Input parameters for `LTX2AudioVideoPipeline` inference include:
* `prompt`: Prompt describing the content appearing in the video.
* `negative_prompt`: Negative prompt describing content that should not appear in the video, default value is `""`.
* `cfg_scale`: Classifier-free guidance parameter, default value is 3.0.
* `input_images`: List of input images for image-to-video generation.
* `input_images_indexes`: Frame index list of input images in the video.
* `input_images_strength`: Strength of input images, default value is 1.0.
* `denoising_strength`: Denoising strength, range is 01, default value is 1.0.
* `seed`: Random seed. Default is `None`, which means completely random.
* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different results will be generated on different GPUs.
* `height`: Video height, must be a multiple of 32 (single-stage) or 64 (two-stage).
* `width`: Video width, must be a multiple of 32 (single-stage) or 64 (two-stage).
* `num_frames`: Number of video frames, default value is 121, must be a multiple of 8 + 1.
* `num_inference_steps`: Number of inference steps, default value is 40.
* `tiled`: Whether to enable VAE tiling inference, default is `True`. When set to `True`, it can significantly reduce VRAM usage during VAE encoding/decoding stages, with slight errors and minor inference time extension.
* `tile_size_in_pixels`: Pixel tiling size during VAE encoding/decoding stages, default is 512.
* `tile_overlap_in_pixels`: Pixel tiling overlap size during VAE encoding/decoding stages, default is 128.
* `tile_size_in_frames`: Frame tiling size during VAE encoding/decoding stages, default is 128.
* `tile_overlap_in_frames`: Frame tiling overlap size during VAE encoding/decoding stages, default is 24.
* `use_two_stage_pipeline`: Whether to use two-stage pipeline, default is `False`.
* `use_distilled_pipeline`: Whether to use distilled pipeline, default is `False`.
* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be set to `lambda x:x` to hide the progress bar.
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the previous "Supported Inference Scripts" section.
## Model Training
The LTX-2 series models currently do not support training functionality. We will add related support as soon as possible.

View File

@@ -81,8 +81,13 @@ graph LR;
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training | | Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
| - | - | - | - | - | - | - | | - | - | - | - | - | - | - |
| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](/examples/qwen_image/model_inference/Qwen-Image.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) | | [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](/examples/qwen_image/model_inference/Qwen-Image.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) |
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) | | [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) | | [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) | | [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) | | [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) | | [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) |

View File

@@ -50,9 +50,14 @@ image.save("image.jpg")
## Model Overview ## Model Overview
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training | |Model ID|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
| - | - | - | - | - | - | - | |-|-|-|-|-|-|-|
| [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) | [code](/examples/z_image/model_inference/Z-Image-Turbo.py) | [code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py) | [code](/examples/z_image/model_training/full/Z-Image-Turbo.sh) | [code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py) | [code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh) | [code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py) | |[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
Special Training Scripts: Special Training Scripts:
@@ -75,6 +80,9 @@ Input parameters for `ZImagePipeline` inference include:
* `seed`: Random seed. Default is `None`, meaning completely random. * `seed`: Random seed. Default is `None`, meaning completely random.
* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different GPUs will produce different generation results. * `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different GPUs will produce different generation results.
* `num_inference_steps`: Number of inference steps, default value is 8. * `num_inference_steps`: Number of inference steps, default value is 8.
* `controlnet_inputs`: Inputs for ControlNet models.
* `edit_image`: Edit images for image editing models, supporting multiple images.
* `positive_only_lora`: LoRA weights used only in positive prompts.
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above. If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.

View File

@@ -13,7 +13,7 @@ All sample code provided by this project supports NVIDIA GPUs by default, requir
AMD provides PyTorch packages based on ROCm, so most models can run without code changes. A small number of models may not be compatible due to their reliance on CUDA-specific instructions. AMD provides PyTorch packages based on ROCm, so most models can run without code changes. A small number of models may not be compatible due to their reliance on CUDA-specific instructions.
## Ascend NPU ## Ascend NPU
### Inference
When using Ascend NPU, you need to replace `"cuda"` with `"npu"` in your code. When using Ascend NPU, you need to replace `"cuda"` with `"npu"` in your code.
For example, here is the inference code for **Wan2.1-T2V-1.3B**, modified for Ascend NPU: For example, here is the inference code for **Wan2.1-T2V-1.3B**, modified for Ascend NPU:
@@ -22,6 +22,7 @@ For example, here is the inference code for **Wan2.1-T2V-1.3B**, modified for As
import torch import torch
from diffsynth.utils.data import save_video, VideoData from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
from diffsynth.core.device.npu_compatible_device import get_device_name
vram_config = { vram_config = {
"offload_dtype": "disk", "offload_dtype": "disk",
@@ -46,7 +47,7 @@ pipe = WanVideoPipeline.from_pretrained(
], ],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
- vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, - vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
+ vram_limit=torch.npu.mem_get_info("npu")[1] / (1024 ** 3) - 2, + vram_limit=torch.npu.mem_get_info(get_device_name())[1] / (1024 ** 3) - 2,
) )
video = pipe( video = pipe(
@@ -56,3 +57,36 @@ video = pipe(
) )
save_video(video, "video.mp4", fps=15, quality=5) save_video(video, "video.mp4", fps=15, quality=5)
``` ```
#### USP(Unified Sequence Parallel)
If you want to use this feature on NPU, please install additional third-party libraries as follows:
```shell
pip install git+https://github.com/feifeibear/long-context-attention.git
pip install git+https://github.com/xdit-project/xDiT.git
```
### Training
NPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_training`, for example `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`.
In the NPU training scripts, NPU specific environment variables that can optimize performance have been added, and relevant parameters have been enabled for specific models.
#### Environment variables
```shell
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
```
`expandable_segments:<value>`: Enable the memory pool expansion segment function, which is the virtual memory feature.
```shell
export CPU_AFFINITY_CONF=1
```
Set 0 or not set: indicates not enabling the binding function
1: Indicates enabling coarse-grained kernel binding
2: Indicates enabling fine-grained kernel binding
#### Parameters for specific models
| Model | Parameter | Note |
|----------------|---------------------------|-------------------|
| Wan 14B series | --initialize_model_on_cpu | The 14B model needs to be initialized on the CPU |

View File

@@ -30,11 +30,16 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm6
* **Ascend NPU** * **Ascend NPU**
Ascend NPU support is provided via the `torch-npu` package. Taking version `2.1.0.post17` (as of the article update date: December 15, 2025) as an example, run the following command: 1. Install [CANN](https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/softwareinst/instg/instg_quick.html?Mode=PmIns&InstallType=local&OS=openEuler&Software=cannToolKit) through official documentation.
```shell 2. Install from source
pip install torch-npu==2.1.0.post17 ```shell
``` git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
# aarch64/ARM
pip install -e .[npu_aarch64] --extra-index-url "https://download.pytorch.org/whl/cpu"
# x86
pip install -e .[npu]
When using Ascend NPU, please replace `"cuda"` with `"npu"` in your Python code. For details, see [NPU Support](/docs/en/Pipeline_Usage/GPU_support.md#ascend-npu). When using Ascend NPU, please replace `"cuda"` with `"npu"` in your Python code. For details, see [NPU Support](/docs/en/Pipeline_Usage/GPU_support.md#ascend-npu).

View File

@@ -77,7 +77,7 @@ This section introduces the independent core module `diffsynth.core` in `DiffSyn
This section introduces how to use `DiffSynth-Studio` to train new models, helping researchers explore new model technologies. This section introduces how to use `DiffSynth-Studio` to train new models, helping researchers explore new model technologies.
* Training models from scratch 【coming soon】 * [Training models from scratch](/docs/en/Research_Tutorial/train_from_scratch.md)
* Inference improvement techniques 【coming soon】 * Inference improvement techniques 【coming soon】
* Designing controllable generation models 【coming soon】 * Designing controllable generation models 【coming soon】
* Creating new training paradigms 【coming soon】 * Creating new training paradigms 【coming soon】

View File

@@ -0,0 +1,476 @@
# Training Models from Scratch
DiffSynth-Studio's training engine supports training foundation models from scratch. This article introduces how to train a small text-to-image model with only 0.1B parameters from scratch.
## 1. Building Model Architecture
### 1.1 Diffusion Model
From UNet [[1]](https://arxiv.org/abs/1505.04597) [[2]](https://arxiv.org/abs/2112.10752) to DiT [[3]](https://arxiv.org/abs/2212.09748) [[4]](https://arxiv.org/abs/2403.03206), the mainstream model architectures of Diffusion have undergone multiple evolutions. Typically, a Diffusion model's inputs include:
* Image tensor (`latents`): The encoding of images, generated by the VAE model, containing partial noise
* Text tensor (`prompt_embeds`): The encoding of text, generated by the text encoder
* Timestep (`timestep`): A scalar used to mark which stage of the Diffusion process we are currently at
The model's output is a tensor with the same shape as the image tensor, representing the denoising direction predicted by the model. For details about Diffusion model theory, please refer to [Basic Principles of Diffusion Models](/docs/en/Training/Understanding_Diffusion_models.md). In this article, we build a DiT model with only 0.1B parameters: `AAADiT`.
<details>
<summary>Model Architecture Code</summary>
```python
import torch, accelerate
from PIL import Image
from typing import Union
from tqdm import tqdm
from einops import rearrange, repeat
from transformers import AutoProcessor, AutoTokenizer
from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
from diffsynth.models.general_modules import TimestepEmbeddings
from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
from diffsynth.models.flux2_vae import Flux2VAE
class AAAPositionalEmbedding(torch.nn.Module):
def __init__(self, height=16, width=16, dim=1024):
super().__init__()
self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
def forward(self, image, text):
height, width = image.shape[-2:]
image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
emb = torch.concat([image_emb, text_emb], dim=1)
return emb
class AAABlock(torch.nn.Module):
def __init__(self, dim=1024, num_heads=32):
super().__init__()
self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
self.to_q = torch.nn.Linear(dim, dim)
self.to_k = torch.nn.Linear(dim, dim)
self.to_v = torch.nn.Linear(dim, dim)
self.to_out = torch.nn.Linear(dim, dim)
self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
self.ff = torch.nn.Sequential(
torch.nn.Linear(dim, dim*3),
torch.nn.SiLU(),
torch.nn.Linear(dim*3, dim),
)
self.to_gate = torch.nn.Linear(dim, dim * 2)
self.num_heads = num_heads
def attention(self, emb, pos_emb):
emb = self.norm_attn(emb + pos_emb)
q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
emb = attention_forward(
q, k, v,
q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
dims={"n": self.num_heads},
)
emb = self.to_out(emb)
return emb
def feed_forward(self, emb, pos_emb):
emb = self.norm_mlp(emb + pos_emb)
emb = self.ff(emb)
return emb
def forward(self, emb, pos_emb, t_emb):
gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
return emb
class AAADiT(torch.nn.Module):
def __init__(self, dim=1024):
super().__init__()
self.pos_embedder = AAAPositionalEmbedding(dim=dim)
self.timestep_embedder = TimestepEmbeddings(256, dim)
self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
self.proj_out = torch.nn.Linear(dim, 128)
def forward(
self,
latents,
prompt_embeds,
timestep,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
):
pos_emb = self.pos_embedder(latents, prompt_embeds)
t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
text = self.text_embedder(prompt_embeds)
emb = torch.concat([image, text], dim=1)
for block_id, block in enumerate(self.blocks):
emb = gradient_checkpoint_forward(
block,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
emb=emb,
pos_emb=pos_emb,
t_emb=t_emb,
)
emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
emb = self.proj_out(emb)
emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
return emb
```
</details>
### 1.2 Encoder-Decoder Models
Besides the Diffusion model used for denoising, we also need two other models:
* Text Encoder: Used to encode text into tensors. We adopt the [Qwen/Qwen3-0.6B](https://modelscope.cn/models/Qwen/Qwen3-0.6B) model.
* VAE Encoder-Decoder: The encoder part is used to encode images into tensors, and the decoder part is used to decode image tensors into images. We adopt the VAE model from [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B).
The architectures of these two models are already integrated in DiffSynth-Studio, located at [/diffsynth/models/z_image_text_encoder.py](/diffsynth/models/z_image_text_encoder.py) and [/diffsynth/models/flux2_vae.py](/diffsynth/models/flux2_vae.py), so we don't need to modify any code.
## 2. Building Pipeline
We introduced how to build a model Pipeline in the document [Integrating Pipeline](/docs/en/Developer_Guide/Building_a_Pipeline.md). For the model in this article, we also need to build a Pipeline to connect the text encoder, Diffusion model, and VAE encoder-decoder.
<details>
<summary>Pipeline Code</summary>
```python
class AAAImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
)
self.scheduler = FlowMatchScheduler("FLUX.2")
self.text_encoder: ZImageTextEncoder = None
self.dit: AAADiT = None
self.vae: Flux2VAE = None
self.tokenizer: AutoProcessor = None
self.in_iteration_models = ("dit",)
self.units = [
AAAUnit_PromptEmbedder(),
AAAUnit_NoiseInitializer(),
AAAUnit_InputImageEmbedder(),
]
self.model_fn = model_fn_aaa
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = None,
vram_limit: float = None,
):
# Initialize pipeline
pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
pipe.dit = model_pool.fetch_model("aaa_dit")
pipe.vae = model_pool.fetch_model("flux2_vae")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 1.0,
# Image
input_image: Image.Image = None,
denoising_strength: float = 1.0,
# Shape
height: int = 1024,
width: int = 1024,
# Randomness
seed: int = None,
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 30,
# Progress bar
progress_bar_cmd = tqdm,
):
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
# Parameters
inputs_posi = {"prompt": prompt}
inputs_nega = {"negative_prompt": negative_prompt}
inputs_shared = {
"cfg_scale": cfg_scale,
"input_image": input_image, "denoising_strength": denoising_strength,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"num_inference_steps": num_inference_steps,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
# Decode
self.load_models_to_device(['vae'])
image = self.vae.decode(inputs_shared["latents"])
image = self.vae_output_to_image(image)
self.load_models_to_device([])
return image
class AAAUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_embeds",),
onload_model_names=("text_encoder",)
)
self.hidden_states_layers = (-1,)
def process(self, pipe: AAAImagePipeline, prompt):
pipe.load_models_to_device(self.onload_model_names)
text = pipe.tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
return {"prompt_embeds": prompt_embeds}
class AAAUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
class AAAUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",)
)
def process(self, pipe: AAAImagePipeline, input_image, noise):
if input_image is None:
return {"latents": noise, "input_latents": None}
pipe.load_models_to_device(['vae'])
image = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(image)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents, "input_latents": input_latents}
def model_fn_aaa(
dit: AAADiT,
latents=None,
prompt_embeds=None,
timestep=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
model_output = dit(
latents,
prompt_embeds,
timestep,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
return model_output
```
</details>
## 3. Preparing Dataset
To quickly verify training effectiveness, we use the dataset [Pokemon-First Generation](https://modelscope.cn/datasets/DiffSynth-Studio/pokemon-gen1), which is reproduced from the open-source project [pokemon-dataset-zh](https://github.com/42arch/pokemon-dataset-zh), containing 151 first-generation Pokemon from Bulbasaur to Mew. If you want to use other datasets, please refer to the document [Preparing Datasets](/docs/en/Pipeline_Usage/Model_Training.md#preparing-datasets) and [`diffsynth.core.data`](/docs/en/API_Reference/core/data.md).
```shell
modelscope download --dataset DiffSynth-Studio/pokemon-gen1 --local_dir ./data
```
### 4. Start Training
The training process can be quickly implemented using Pipeline. We have placed the complete code at [/docs/en/Research_Tutorial/train_from_scratch.py](/docs/en/Research_Tutorial/train_from_scratch.py), which can be directly started with `python docs/en/Research_Tutorial/train_from_scratch.py` for single GPU training.
To enable multi-GPU parallel training, please run `accelerate config` to set relevant parameters, then use the command `accelerate launch docs/en/Research_Tutorial/train_from_scratch.py` to start training.
This training script has no stopping condition, please manually close it when needed. The model converges after training approximately 60,000 steps, requiring 10-20 hours for single GPU training.
<details>
<summary>Training Code</summary>
```python
class AAATrainingModule(DiffusionTrainingModule):
def __init__(self, device):
super().__init__()
self.pipe = AAAImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device=device,
model_configs=[
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
)
self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
self.pipe.freeze_except(["dit"])
self.pipe.scheduler.set_timesteps(1000, training=True)
def forward(self, data):
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {"negative_prompt": ""}
inputs_shared = {
"input_image": data["image"],
"height": data["image"].size[1],
"width": data["image"].size[0],
"cfg_scale": 1,
"use_gradient_checkpointing": False,
"use_gradient_checkpointing_offload": False,
}
for unit in self.pipe.units:
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
return loss
if __name__ == "__main__":
accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
dataset = UnifiedDataset(
base_path="data/images",
metadata_path="data/metadata_merged.csv",
max_data_items=10000000,
data_file_keys=("image",),
main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
)
model = AAATrainingModule(device=accelerator.device)
model_logger = ModelLogger(
"models/AAA/v1",
remove_prefix_in_ckpt="pipe.dit.",
)
launch_training_task(
accelerator, dataset, model, model_logger,
learning_rate=2e-4,
num_workers=4,
save_steps=50000,
num_epochs=999999,
)
```
</details>
## 5. Verifying Training Results
If you don't want to wait for the model training to complete, you can directly download [our pre-trained model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel).
```shell
modelscope download --model DiffSynth-Studio/AAAMyModel step-600000.safetensors --local_dir models/DiffSynth-Studio/AAAMyModel
```
Loading the model
```python
from diffsynth import load_model
pipe = AAAImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
)
pipe.dit = load_model(AAADiT, "models/DiffSynth-Studio/AAAMyModel/step-600000.safetensors", torch_dtype=torch.bfloat16, device="cuda")
```
Model inference, generating the first-generation Pokemon "starter trio". At this point, the images generated by the model basically match the training data.
```python
for seed, prompt in enumerate([
"green, lizard, plant, Grass, Poison, seed on back, red eyes, smiling expression, short stout limbs, sharp claws",
"orange, cream, lizard, Fire, flame on tail tip, large eyes, smiling expression, cream-colored belly patch, sharp claws",
"blue, beige, brown, turtle, water type, shell, big eyes, short limbs, curled tail",
]):
image = pipe(
prompt=prompt,
negative_prompt=" ",
num_inference_steps=30,
cfg_scale=10,
seed=seed,
height=256, width=256,
)
image.save(f"image_{seed}.jpg")
```
|![Image](https://github.com/user-attachments/assets/3c620fbf-5d28-4a1a-b887-519d85ac7d1c)|![Image](https://github.com/user-attachments/assets/909efd4c-9e61-4b33-9321-39da0e499b00)|![Image](https://github.com/user-attachments/assets/f3474bcd-b474-4a90-a1ea-579f67e161e3)|
|-|-|-|
Model inference, generating Pokemon with "sharp claws". At this point, different random seeds can produce different image results.
```python
for seed, prompt in enumerate([
"sharp claws",
"sharp claws",
"sharp claws",
]):
image = pipe(
prompt=prompt,
negative_prompt=" ",
num_inference_steps=30,
cfg_scale=10,
seed=seed+4,
height=256, width=256,
)
image.save(f"image_sharp_claws_{seed}.jpg")
```
|![Image](https://github.com/user-attachments/assets/94862edd-96ae-4276-a38f-795249f11a13)|![Image](https://github.com/user-attachments/assets/b2291f23-20ba-42de-8bfd-76cb4afc6eea)|![Image](https://github.com/user-attachments/assets/f2aab9a4-85ec-498e-8039-648b1289796e)|
|-|-|-|
Now, we have obtained a 0.1B small text-to-image model. This model can already generate 151 Pokemon, but cannot generate other image content. If you increase the amount of data, model parameters, and number of GPUs based on this, you can train a more powerful text-to-image model!

View File

@@ -0,0 +1,341 @@
import torch, accelerate
from PIL import Image
from typing import Union
from tqdm import tqdm
from einops import rearrange, repeat
from transformers import AutoProcessor, AutoTokenizer
from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
from diffsynth.models.general_modules import TimestepEmbeddings
from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
from diffsynth.models.flux2_vae import Flux2VAE
class AAAPositionalEmbedding(torch.nn.Module):
def __init__(self, height=16, width=16, dim=1024):
super().__init__()
self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
def forward(self, image, text):
height, width = image.shape[-2:]
image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
emb = torch.concat([image_emb, text_emb], dim=1)
return emb
class AAABlock(torch.nn.Module):
def __init__(self, dim=1024, num_heads=32):
super().__init__()
self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
self.to_q = torch.nn.Linear(dim, dim)
self.to_k = torch.nn.Linear(dim, dim)
self.to_v = torch.nn.Linear(dim, dim)
self.to_out = torch.nn.Linear(dim, dim)
self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
self.ff = torch.nn.Sequential(
torch.nn.Linear(dim, dim*3),
torch.nn.SiLU(),
torch.nn.Linear(dim*3, dim),
)
self.to_gate = torch.nn.Linear(dim, dim * 2)
self.num_heads = num_heads
def attention(self, emb, pos_emb):
emb = self.norm_attn(emb + pos_emb)
q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
emb = attention_forward(
q, k, v,
q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
dims={"n": self.num_heads},
)
emb = self.to_out(emb)
return emb
def feed_forward(self, emb, pos_emb):
emb = self.norm_mlp(emb + pos_emb)
emb = self.ff(emb)
return emb
def forward(self, emb, pos_emb, t_emb):
gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
return emb
class AAADiT(torch.nn.Module):
def __init__(self, dim=1024):
super().__init__()
self.pos_embedder = AAAPositionalEmbedding(dim=dim)
self.timestep_embedder = TimestepEmbeddings(256, dim)
self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
self.proj_out = torch.nn.Linear(dim, 128)
def forward(
self,
latents,
prompt_embeds,
timestep,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
):
pos_emb = self.pos_embedder(latents, prompt_embeds)
t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
text = self.text_embedder(prompt_embeds)
emb = torch.concat([image, text], dim=1)
for block_id, block in enumerate(self.blocks):
emb = gradient_checkpoint_forward(
block,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
emb=emb,
pos_emb=pos_emb,
t_emb=t_emb,
)
emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
emb = self.proj_out(emb)
emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
return emb
class AAAImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
)
self.scheduler = FlowMatchScheduler("FLUX.2")
self.text_encoder: ZImageTextEncoder = None
self.dit: AAADiT = None
self.vae: Flux2VAE = None
self.tokenizer: AutoProcessor = None
self.in_iteration_models = ("dit",)
self.units = [
AAAUnit_PromptEmbedder(),
AAAUnit_NoiseInitializer(),
AAAUnit_InputImageEmbedder(),
]
self.model_fn = model_fn_aaa
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = None,
vram_limit: float = None,
):
# Initialize pipeline
pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
pipe.dit = model_pool.fetch_model("aaa_dit")
pipe.vae = model_pool.fetch_model("flux2_vae")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 1.0,
# Image
input_image: Image.Image = None,
denoising_strength: float = 1.0,
# Shape
height: int = 1024,
width: int = 1024,
# Randomness
seed: int = None,
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 30,
# Progress bar
progress_bar_cmd = tqdm,
):
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
# Parameters
inputs_posi = {"prompt": prompt}
inputs_nega = {"negative_prompt": negative_prompt}
inputs_shared = {
"cfg_scale": cfg_scale,
"input_image": input_image, "denoising_strength": denoising_strength,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"num_inference_steps": num_inference_steps,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
# Decode
self.load_models_to_device(['vae'])
image = self.vae.decode(inputs_shared["latents"])
image = self.vae_output_to_image(image)
self.load_models_to_device([])
return image
class AAAUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_embeds",),
onload_model_names=("text_encoder",)
)
self.hidden_states_layers = (-1,)
def process(self, pipe: AAAImagePipeline, prompt):
pipe.load_models_to_device(self.onload_model_names)
text = pipe.tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
return {"prompt_embeds": prompt_embeds}
class AAAUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
class AAAUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",)
)
def process(self, pipe: AAAImagePipeline, input_image, noise):
if input_image is None:
return {"latents": noise, "input_latents": None}
pipe.load_models_to_device(['vae'])
image = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(image)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents, "input_latents": input_latents}
def model_fn_aaa(
dit: AAADiT,
latents=None,
prompt_embeds=None,
timestep=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
model_output = dit(
latents,
prompt_embeds,
timestep,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
return model_output
class AAATrainingModule(DiffusionTrainingModule):
def __init__(self, device):
super().__init__()
self.pipe = AAAImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device=device,
model_configs=[
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
)
self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
self.pipe.freeze_except(["dit"])
self.pipe.scheduler.set_timesteps(1000, training=True)
def forward(self, data):
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {"negative_prompt": ""}
inputs_shared = {
"input_image": data["image"],
"height": data["image"].size[1],
"width": data["image"].size[0],
"cfg_scale": 1,
"use_gradient_checkpointing": False,
"use_gradient_checkpointing_offload": False,
}
for unit in self.pipe.units:
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
return loss
if __name__ == "__main__":
accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
dataset = UnifiedDataset(
base_path="data/images",
metadata_path="data/metadata_merged.csv",
max_data_items=10000000,
data_file_keys=("image",),
main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
)
model = AAATrainingModule(device=accelerator.device)
model_logger = ModelLogger(
"models/AAA/v1",
remove_prefix_in_ckpt="pipe.dit.",
)
launch_training_task(
accelerator, dataset, model, model_logger,
learning_rate=2e-4,
num_workers=4,
save_steps=50000,
num_epochs=999999,
)

View File

@@ -6,7 +6,7 @@ This document introduces the basic principles of Diffusion models to help you un
Diffusion models generate clear images or video content through iterative denoising. We start by explaining the generation process of a data sample $x_0$. Intuitively, in a complete round of denoising, we start from random Gaussian noise $x_T$ and iteratively obtain $x_{T-1}$, $x_{T-2}$, $x_{T-3}$, $\cdots$, gradually reducing the noise content at each step until we finally obtain the noise-free data sample $x_0$. Diffusion models generate clear images or video content through iterative denoising. We start by explaining the generation process of a data sample $x_0$. Intuitively, in a complete round of denoising, we start from random Gaussian noise $x_T$ and iteratively obtain $x_{T-1}$, $x_{T-2}$, $x_{T-3}$, $\cdots$, gradually reducing the noise content at each step until we finally obtain the noise-free data sample $x_0$.
(Figure) ![Image](https://github.com/user-attachments/assets/6471ae4c-a635-4924-8b36-b0bd4d42043d)
This process is intuitive, but to understand the details, we need to answer several questions: This process is intuitive, but to understand the details, we need to answer several questions:
@@ -28,7 +28,7 @@ As for the intermediate values $\sigma_{T-1}$, $\sigma_{T-2}$, $\cdots$, $\sigma
At an intermediate step, we can directly synthesize noisy data samples $x_t=(1-\sigma_t)x_0+\sigma_t x_T$. At an intermediate step, we can directly synthesize noisy data samples $x_t=(1-\sigma_t)x_0+\sigma_t x_T$.
(Figure) ![Image](https://github.com/user-attachments/assets/e25a2f71-123c-4e18-8b34-3a066af15667)
## How is the iterative denoising computation performed? ## How is the iterative denoising computation performed?
@@ -40,8 +40,6 @@ Before understanding the iterative denoising computation, we need to clarify wha
Among these, the guidance condition $c$ is a newly introduced parameter that is input by the user. It can be text describing the image content or a sketch outlining the image structure. Among these, the guidance condition $c$ is a newly introduced parameter that is input by the user. It can be text describing the image content or a sketch outlining the image structure.
(Figure)
The model's output $\hat \epsilon(x_t,c,t)$ approximately equals $x_T-x_0$, which is the direction of the entire diffusion process (the reverse process of denoising). The model's output $\hat \epsilon(x_t,c,t)$ approximately equals $x_T-x_0$, which is the direction of the entire diffusion process (the reverse process of denoising).
Next, we analyze the computation occurring in one iteration. At time step $t$, after the model computes an approximation of $x_T-x_0$, we calculate the next $x_{t-1}$: Next, we analyze the computation occurring in one iteration. At time step $t$, after the model computes an approximation of $x_T-x_0$, we calculate the next $x_{t-1}$:
@@ -91,8 +89,6 @@ After understanding the iterative denoising process, we next consider how to tra
The training process differs from the generation process. If we retain multi-step iterations during training, the gradient would need to backpropagate through multiple steps, bringing catastrophic time and space complexity. To improve computational efficiency, we randomly select a time step $t$ for training. The training process differs from the generation process. If we retain multi-step iterations during training, the gradient would need to backpropagate through multiple steps, bringing catastrophic time and space complexity. To improve computational efficiency, we randomly select a time step $t$ for training.
(Figure)
The following is pseudocode for the training process: The following is pseudocode for the training process:
> Obtain data sample $x_0$ and guidance condition $c$ from the dataset > Obtain data sample $x_0$ and guidance condition $c$ from the dataset
@@ -113,7 +109,7 @@ The following is pseudocode for the training process:
From theory to practice, more details need to be filled in. Modern Diffusion model architectures have matured, with mainstream architectures following the "three-stage" architecture proposed by Latent Diffusion, including data encoder-decoder, guidance condition encoder, and denoising model. From theory to practice, more details need to be filled in. Modern Diffusion model architectures have matured, with mainstream architectures following the "three-stage" architecture proposed by Latent Diffusion, including data encoder-decoder, guidance condition encoder, and denoising model.
(Figure) ![Image](https://github.com/user-attachments/assets/43855430-6427-4aca-83a0-f684e01438b1)
### Data Encoder-Decoder ### Data Encoder-Decoder

View File

@@ -2,6 +2,15 @@
FLUX.2 是由 Black Forest Labs 训练并开源的图像生成模型。 FLUX.2 是由 Black Forest Labs 训练并开源的图像生成模型。
## 模型血缘
```mermaid
graph LR;
FLUX.2-Series-->black-forest-labs/FLUX.2-dev;
FLUX.2-Series-->black-forest-labs/FLUX.2-klein-4B;
FLUX.2-Series-->black-forest-labs/FLUX.2-klein-9B;
```
## 安装 ## 安装
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。 在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
@@ -50,16 +59,20 @@ image.save("image.jpg")
## 模型总览 ## 模型总览
|模型 ID|推理|低显存推理|LoRA 训练|LoRA 训练后验证| |模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-| |-|-|-|-|-|-|-|
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| |[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
特殊训练脚本: 特殊训练脚本:
* 差分 LoRA 训练:[doc](/docs/zh/Training/Differential_LoRA.md)、[code](/examples/flux/model_training/special/differential_training/) * 差分 LoRA 训练:[doc](/docs/zh/Training/Differential_LoRA.md)
* FP8 精度训练:[doc](/docs/zh/Training/FP8_Precision.md)、[code](/examples/flux/model_training/special/fp8_training/) * FP8 精度训练:[doc](/docs/zh/Training/FP8_Precision.md)
* 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)、[code](/examples/flux/model_training/special/split_training/) * 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)
* 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)、[code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh) * 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)
## 模型推理 ## 模型推理
@@ -135,4 +148,4 @@ FLUX.2 系列模型统一通过 [`examples/flux2/model_training/train.py`](/exam
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
``` ```
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。 我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。

View File

@@ -0,0 +1,109 @@
# LTX-2
LTX-2 是由 Lightricks 开发的音视频生成模型系列。
## 安装
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
更多关于安装的信息,请参考[安装依赖](/docs/zh/Pipeline_Usage/Setup.md)。
## 快速开始
运行以下代码可以快速加载 [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8GB 显存即可运行。
```python
import torch
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
vram_config = {
"offload_dtype": torch.float8_e5m2,
"offload_device": "cpu",
"onload_dtype": torch.float8_e5m2,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e5m2,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = LTX2AudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
height, width, num_frames = 512, 768, 121
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
seed=43,
height=height,
width=width,
num_frames=num_frames,
tiled=True,
)
write_video_audio_ltx2(
video=video,
audio=audio,
output_path='ltx2_onestage.mp4',
fps=24,
audio_sample_rate=24000,
)
```
## 模型总览
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|-|
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-|
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
## 模型推理
模型通过 `LTX2AudioVideoPipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。
`LTX2AudioVideoPipeline` 推理的输入参数包括:
* `prompt`: 提示词,描述视频中出现的内容。
* `negative_prompt`: 负向提示词,描述视频中不应该出现的内容,默认值为 `""`
* `cfg_scale`: Classifier-free guidance 的参数,默认值为 3.0。
* `input_images`: 输入图像列表,用于图生视频。
* `input_images_indexes`: 输入图像在视频中的帧索引列表。
* `input_images_strength`: 输入图像的强度,默认值为 1.0。
* `denoising_strength`: 去噪强度,范围是 01默认值为 1.0。
* `seed`: 随机种子。默认为 `None`,即完全随机。
* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
* `height`: 视频高度,需保证高度为 32 的倍数(单阶段)或 64 的倍数(两阶段)。
* `width`: 视频宽度,需保证宽度为 32 的倍数(单阶段)或 64 的倍数(两阶段)。
* `num_frames`: 视频帧数,默认值为 121需保证为 8 的倍数 + 1。
* `num_inference_steps`: 推理次数,默认值为 40。
* `tiled`: 是否启用 VAE 分块推理,默认为 `True`。设置为 `True` 时可显著减少 VAE 编解码阶段的显存占用,会产生少许误差,以及少量推理时间延长。
* `tile_size_in_pixels`: VAE 编解码阶段的像素分块大小,默认为 512。
* `tile_overlap_in_pixels`: VAE 编解码阶段的像素分块重叠大小,默认为 128。
* `tile_size_in_frames`: VAE 编解码阶段的帧分块大小,默认为 128。
* `tile_overlap_in_frames`: VAE 编解码阶段的帧分块重叠大小,默认为 24。
* `use_two_stage_pipeline`: 是否使用两阶段管道,默认为 `False`
* `use_distilled_pipeline`: 是否使用蒸馏管道,默认为 `False`
* `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。
如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"支持的推理脚本"中的表格。
## 模型训练
LTX-2 系列模型目前暂不支持训练功能。我们将尽快添加相关支持。

View File

@@ -81,8 +81,13 @@ graph LR;
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-| |-|-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)| |[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)| |[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)| |[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| |[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| |[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)| |[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|

View File

@@ -52,7 +52,12 @@ image.save("image.jpg")
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-| |-|-|-|-|-|-|-|
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)| |[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
特殊训练脚本: 特殊训练脚本:
@@ -75,6 +80,9 @@ image.save("image.jpg")
* `seed`: 随机种子。默认为 `None`,即完全随机。 * `seed`: 随机种子。默认为 `None`,即完全随机。
* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。 * `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
* `num_inference_steps`: 推理次数,默认值为 8。 * `num_inference_steps`: 推理次数,默认值为 8。
* `controlnet_inputs`: ControlNet 模型的输入。
* `edit_image`: 编辑模型的待编辑图像,支持多张图像。
* `positive_only_lora`: 仅在正向提示词中使用的 LoRA 权重。
如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。 如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。

View File

@@ -13,7 +13,7 @@
AMD 提供了基于 ROCm 的 torch 包,所以大多数模型无需修改代码即可运行,少数模型由于依赖特定的 cuda 指令无法运行。 AMD 提供了基于 ROCm 的 torch 包,所以大多数模型无需修改代码即可运行,少数模型由于依赖特定的 cuda 指令无法运行。
## Ascend NPU ## Ascend NPU
### 推理
使用 Ascend NPU 时,需把代码中的 `"cuda"` 改为 `"npu"` 使用 Ascend NPU 时,需把代码中的 `"cuda"` 改为 `"npu"`
例如Wan2.1-T2V-1.3B 的推理代码: 例如Wan2.1-T2V-1.3B 的推理代码:
@@ -22,6 +22,7 @@ AMD 提供了基于 ROCm 的 torch 包,所以大多数模型无需修改代码
import torch import torch
from diffsynth.utils.data import save_video, VideoData from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
from diffsynth.core.device.npu_compatible_device import get_device_name
vram_config = { vram_config = {
"offload_dtype": "disk", "offload_dtype": "disk",
@@ -33,7 +34,7 @@ vram_config = {
+ "preparing_device": "npu", + "preparing_device": "npu",
"computation_dtype": torch.bfloat16, "computation_dtype": torch.bfloat16,
- "computation_device": "cuda", - "computation_device": "cuda",
+ "preparing_device": "npu", + "computation_device": "npu",
} }
pipe = WanVideoPipeline.from_pretrained( pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
@@ -46,7 +47,7 @@ pipe = WanVideoPipeline.from_pretrained(
], ],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
- vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, - vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
+ vram_limit=torch.npu.mem_get_info("npu")[1] / (1024 ** 3) - 2, + vram_limit=torch.npu.mem_get_info(get_device_name())[1] / (1024 ** 3) - 2,
) )
video = pipe( video = pipe(
@@ -56,3 +57,35 @@ video = pipe(
) )
save_video(video, "video.mp4", fps=15, quality=5) save_video(video, "video.mp4", fps=15, quality=5)
``` ```
#### USP(Unified Sequence Parallel)
如果想要在NPU上使用该特性请通过如下方式安装额外的第三方库
```shell
pip install git+https://github.com/feifeibear/long-context-attention.git
pip install git+https://github.com/xdit-project/xDiT.git
```
### 训练
当前已为每类模型添加NPU的启动脚本样例脚本存放在`examples/xxx/special/npu_training`目录下,例如 `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`
在NPU训练脚本中添加了可以优化性能的NPU特有环境变量并针对特定模型开启了相关参数。
#### 环境变量
```shell
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
```
`expandable_segments:<value>`: 使能内存池扩展段功能,即虚拟内存特征。
```shell
export CPU_AFFINITY_CONF=1
```
设置0或未设置: 表示不启用绑核功能
1: 表示开启粗粒度绑核
2: 表示开启细粒度绑核
#### 特定模型需要开启的参数
| 模型 | 参数 | 备注 |
|-----------|------|-------------------|
| Wan 14B系列 | --initialize_model_on_cpu | 14B模型需要在cpu上进行初始化 |

View File

@@ -30,11 +30,16 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm6
* Ascend NPU * Ascend NPU
Ascend NPU 通过 `torch-npu` 包提供支持,以 `2.1.0.post17` 版本(本文更新于 2025 年 12 月 15 日)为例,请运行以下命令 1. 通过官方文档安装[CANN](https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/softwareinst/instg/instg_quick.html?Mode=PmIns&InstallType=local&OS=openEuler&Software=cannToolKit)
```shell 2. 从源码安装
pip install torch-npu==2.1.0.post17 ```shell
``` git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
# aarch64/ARM
pip install -e .[npu_aarch64] --extra-index-url "https://download.pytorch.org/whl/cpu"
# x86
pip install -e .[npu]
使用 Ascend NPU 时,请将 Python 代码中的 `"cuda"` 改为 `"npu"`,详见[NPU 支持](/docs/zh/Pipeline_Usage/GPU_support.md#ascend-npu)。 使用 Ascend NPU 时,请将 Python 代码中的 `"cuda"` 改为 `"npu"`,详见[NPU 支持](/docs/zh/Pipeline_Usage/GPU_support.md#ascend-npu)。

View File

@@ -77,7 +77,7 @@ graph LR;
本节介绍如何利用 `DiffSynth-Studio` 训练新的模型,帮助科研工作者探索新的模型技术。 本节介绍如何利用 `DiffSynth-Studio` 训练新的模型,帮助科研工作者探索新的模型技术。
* 从零开始训练模型【coming soon】 * [从零开始训练模型](/docs/zh/Research_Tutorial/train_from_scratch.md)
* 推理改进优化技术【coming soon】 * 推理改进优化技术【coming soon】
* 设计可控生成模型【coming soon】 * 设计可控生成模型【coming soon】
* 创建新的训练范式【coming soon】 * 创建新的训练范式【coming soon】

View File

@@ -0,0 +1,477 @@
# 从零开始训练模型
DiffSynth-Studio 的训练引擎支持从零开始训练基础模型,本文介绍如何从零开始训练一个参数量仅为 0.1B 的小型文生图模型。
## 1. 构建模型结构
### 1.1 Diffusion 模型
从 UNet [[1]](https://arxiv.org/abs/1505.04597) [[2]](https://arxiv.org/abs/2112.10752) 到 DiT [[3]](https://arxiv.org/abs/2212.09748) [[4]](https://arxiv.org/abs/2403.03206)Diffusion 的主流模型结构经历了多次演变。通常,一个 Diffusion 模型的输入包括:
* 图像张量(`latents`):图像的编码,由 VAE 模型产生,含有部分噪声
* 文本张量(`prompt_embeds`):文本的编码,由文本编码器产生
* 时间步(`timestep`):标量,用于标记当前处于 Diffusion 过程的哪个阶段
模型的输出是与图像张量形状相同的张量,表示模型预测的去噪方向,关于 Diffusion 模型理论的细节,请参考 [Diffusion 模型基本原理](/docs/zh/Training/Understanding_Diffusion_models.md)。在本文中,我们构建一个仅含 0.1B 参数的 DiT 模型:`AAADiT`
<details>
<summary>模型结构代码</summary>
```python
import torch, accelerate
from PIL import Image
from typing import Union
from tqdm import tqdm
from einops import rearrange, repeat
from transformers import AutoProcessor, AutoTokenizer
from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
from diffsynth.models.general_modules import TimestepEmbeddings
from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
from diffsynth.models.flux2_vae import Flux2VAE
class AAAPositionalEmbedding(torch.nn.Module):
def __init__(self, height=16, width=16, dim=1024):
super().__init__()
self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
def forward(self, image, text):
height, width = image.shape[-2:]
image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
emb = torch.concat([image_emb, text_emb], dim=1)
return emb
class AAABlock(torch.nn.Module):
def __init__(self, dim=1024, num_heads=32):
super().__init__()
self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
self.to_q = torch.nn.Linear(dim, dim)
self.to_k = torch.nn.Linear(dim, dim)
self.to_v = torch.nn.Linear(dim, dim)
self.to_out = torch.nn.Linear(dim, dim)
self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
self.ff = torch.nn.Sequential(
torch.nn.Linear(dim, dim*3),
torch.nn.SiLU(),
torch.nn.Linear(dim*3, dim),
)
self.to_gate = torch.nn.Linear(dim, dim * 2)
self.num_heads = num_heads
def attention(self, emb, pos_emb):
emb = self.norm_attn(emb + pos_emb)
q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
emb = attention_forward(
q, k, v,
q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
dims={"n": self.num_heads},
)
emb = self.to_out(emb)
return emb
def feed_forward(self, emb, pos_emb):
emb = self.norm_mlp(emb + pos_emb)
emb = self.ff(emb)
return emb
def forward(self, emb, pos_emb, t_emb):
gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
return emb
class AAADiT(torch.nn.Module):
def __init__(self, dim=1024):
super().__init__()
self.pos_embedder = AAAPositionalEmbedding(dim=dim)
self.timestep_embedder = TimestepEmbeddings(256, dim)
self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
self.proj_out = torch.nn.Linear(dim, 128)
def forward(
self,
latents,
prompt_embeds,
timestep,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
):
pos_emb = self.pos_embedder(latents, prompt_embeds)
t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
text = self.text_embedder(prompt_embeds)
emb = torch.concat([image, text], dim=1)
for block_id, block in enumerate(self.blocks):
emb = gradient_checkpoint_forward(
block,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
emb=emb,
pos_emb=pos_emb,
t_emb=t_emb,
)
emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
emb = self.proj_out(emb)
emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
return emb
```
</details>
### 1.2 编解码器模型
除了用于去噪的 Diffusion 模型以外,我们还需要另外两个模型:
* 文本编码器:用于将文本编码为张量。我们采用 [Qwen/Qwen3-0.6B](https://modelscope.cn/models/Qwen/Qwen3-0.6B) 模型。
* VAE 编解码器:编码器部分用于将图像编码为张量,解码器部分用于将图像张量解码为图像。我们采用 [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 中的 VAE 模型。
这两个模型的结构都已集成在 DiffSynth-Studio 中,分别位于 [/diffsynth/models/z_image_text_encoder.py](/diffsynth/models/z_image_text_encoder.py) 和 [/diffsynth/models/flux2_vae.py](/diffsynth/models/flux2_vae.py),因此我们不需要修改任何代码。
## 2. 构建 Pipeline
我们在文档 [接入 Pipeline](/docs/zh/Developer_Guide/Building_a_Pipeline.md) 中介绍了如何构建一个模型 Pipeline对于本文中的模型我们也需要构建一个 Pipeline连接文本编码器、Diffusion 模型、VAE 编解码器。
<details>
<summary>Pipeline 代码</summary>
```python
class AAAImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
)
self.scheduler = FlowMatchScheduler("FLUX.2")
self.text_encoder: ZImageTextEncoder = None
self.dit: AAADiT = None
self.vae: Flux2VAE = None
self.tokenizer: AutoProcessor = None
self.in_iteration_models = ("dit",)
self.units = [
AAAUnit_PromptEmbedder(),
AAAUnit_NoiseInitializer(),
AAAUnit_InputImageEmbedder(),
]
self.model_fn = model_fn_aaa
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = None,
vram_limit: float = None,
):
# Initialize pipeline
pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
pipe.dit = model_pool.fetch_model("aaa_dit")
pipe.vae = model_pool.fetch_model("flux2_vae")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 1.0,
# Image
input_image: Image.Image = None,
denoising_strength: float = 1.0,
# Shape
height: int = 1024,
width: int = 1024,
# Randomness
seed: int = None,
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 30,
# Progress bar
progress_bar_cmd = tqdm,
):
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
# Parameters
inputs_posi = {"prompt": prompt}
inputs_nega = {"negative_prompt": negative_prompt}
inputs_shared = {
"cfg_scale": cfg_scale,
"input_image": input_image, "denoising_strength": denoising_strength,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"num_inference_steps": num_inference_steps,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
# Decode
self.load_models_to_device(['vae'])
image = self.vae.decode(inputs_shared["latents"])
image = self.vae_output_to_image(image)
self.load_models_to_device([])
return image
class AAAUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_embeds",),
onload_model_names=("text_encoder",)
)
self.hidden_states_layers = (-1,)
def process(self, pipe: AAAImagePipeline, prompt):
pipe.load_models_to_device(self.onload_model_names)
text = pipe.tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
return {"prompt_embeds": prompt_embeds}
class AAAUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
class AAAUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",)
)
def process(self, pipe: AAAImagePipeline, input_image, noise):
if input_image is None:
return {"latents": noise, "input_latents": None}
pipe.load_models_to_device(['vae'])
image = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(image)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents, "input_latents": input_latents}
def model_fn_aaa(
dit: AAADiT,
latents=None,
prompt_embeds=None,
timestep=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
model_output = dit(
latents,
prompt_embeds,
timestep,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
return model_output
```
</details>
## 3. 准备数据集
为了快速验证训练效果,我们使用数据集 [宝可梦-第一世代](https://modelscope.cn/datasets/DiffSynth-Studio/pokemon-gen1),这个数据集转载自开源项目 [pokemon-dataset-zh](https://github.com/42arch/pokemon-dataset-zh),包含从妙蛙种子到梦幻的 151 个第一世代宝可梦。如果你想使用其他数据集,请参考文档 [准备数据集](/docs/zh/Pipeline_Usage/Model_Training.md#准备数据集) 和 [`diffsynth.core.data`](/docs/zh/API_Reference/core/data.md)。
```shell
modelscope download --dataset DiffSynth-Studio/pokemon-gen1 --local_dir ./data
```
### 4. 开始训练
训练过程可使用 Pipeline 快速实现,我们已将完整的代码放在 [/docs/zh/Research_Tutorial/train_from_scratch.py](/docs/zh/Research_Tutorial/train_from_scratch.py),可直接通过 `python docs/zh/Research_Tutorial/train_from_scratch.py` 开始单 GPU 训练。
如需开启多 GPU 并行训练,请运行 `accelerate config` 设置相关参数,然后使用命令 `accelerate launch docs/zh/Research_Tutorial/train_from_scratch.py` 开始训练。
这个训练脚本没有设置停止条件,请在需要时手动关闭。模型在训练大约 6 万步后收敛,单 GPU 训练需要 1020 小时。
<details>
<summary>训练代码</summary>
```python
class AAATrainingModule(DiffusionTrainingModule):
def __init__(self, device):
super().__init__()
self.pipe = AAAImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device=device,
model_configs=[
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
)
self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
self.pipe.freeze_except(["dit"])
self.pipe.scheduler.set_timesteps(1000, training=True)
def forward(self, data):
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {"negative_prompt": ""}
inputs_shared = {
"input_image": data["image"],
"height": data["image"].size[1],
"width": data["image"].size[0],
"cfg_scale": 1,
"use_gradient_checkpointing": False,
"use_gradient_checkpointing_offload": False,
}
for unit in self.pipe.units:
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
return loss
if __name__ == "__main__":
accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
dataset = UnifiedDataset(
base_path="data/images",
metadata_path="data/metadata_merged.csv",
max_data_items=10000000,
data_file_keys=("image",),
main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
)
model = AAATrainingModule(device=accelerator.device)
model_logger = ModelLogger(
"models/AAA/v1",
remove_prefix_in_ckpt="pipe.dit.",
)
launch_training_task(
accelerator, dataset, model, model_logger,
learning_rate=2e-4,
num_workers=4,
save_steps=50000,
num_epochs=999999,
)
```
</details>
## 5. 验证训练效果
如果你不想等待模型训练完成,可以直接下载[我们预先训练好的模型](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel)。
```shell
modelscope download --model DiffSynth-Studio/AAAMyModel step-600000.safetensors --local_dir models/DiffSynth-Studio/AAAMyModel
```
加载模型
```python
from diffsynth import load_model
pipe = AAAImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
)
pipe.dit = load_model(AAADiT, "models/DiffSynth-Studio/AAAMyModel/step-600000.safetensors", torch_dtype=torch.bfloat16, device="cuda")
```
模型推理,生成第一世代宝可梦“御三家”,此时模型生成的图像内容与训练数据基本一致。
```python
for seed, prompt in enumerate([
"green, lizard, plant, Grass, Poison, seed on back, red eyes, smiling expression, short stout limbs, sharp claws",
"orange, cream, lizard, Fire, flame on tail tip, large eyes, smiling expression, cream-colored belly patch, sharp claws",
"蓝色,米色,棕色,乌龟,水系,龟壳,大眼睛,短四肢,卷曲尾巴",
]):
image = pipe(
prompt=prompt,
negative_prompt=" ",
num_inference_steps=30,
cfg_scale=10,
seed=seed,
height=256, width=256,
)
image.save(f"image_{seed}.jpg")
```
|![Image](https://github.com/user-attachments/assets/3c620fbf-5d28-4a1a-b887-519d85ac7d1c)|![Image](https://github.com/user-attachments/assets/909efd4c-9e61-4b33-9321-39da0e499b00)|![Image](https://github.com/user-attachments/assets/f3474bcd-b474-4a90-a1ea-579f67e161e3)|
|-|-|-|
模型推理,生成具有“锐利爪子”的宝可梦,此时不同的随机种子能够产生不同的图像结果。
```python
for seed, prompt in enumerate([
"sharp claws",
"sharp claws",
"sharp claws",
]):
image = pipe(
prompt=prompt,
negative_prompt=" ",
num_inference_steps=30,
cfg_scale=10,
seed=seed+4,
height=256, width=256,
)
image.save(f"image_sharp_claws_{seed}.jpg")
```
|![Image](https://github.com/user-attachments/assets/94862edd-96ae-4276-a38f-795249f11a13)|![Image](https://github.com/user-attachments/assets/b2291f23-20ba-42de-8bfd-76cb4afc6eea)|![Image](https://github.com/user-attachments/assets/f2aab9a4-85ec-498e-8039-648b1289796e)|
|-|-|-|
现在,我们获得了一个 0.1B 的小型文生图模型,这个模型已经能够生成 151 个宝可梦但无法生成其他图像内容。如果在此基础上增加数据量、模型参数量、GPU 数量,你就可以训练出一个更强大的文生图模型!

View File

@@ -0,0 +1,341 @@
import torch, accelerate
from PIL import Image
from typing import Union
from tqdm import tqdm
from einops import rearrange, repeat
from transformers import AutoProcessor, AutoTokenizer
from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
from diffsynth.models.general_modules import TimestepEmbeddings
from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
from diffsynth.models.flux2_vae import Flux2VAE
class AAAPositionalEmbedding(torch.nn.Module):
def __init__(self, height=16, width=16, dim=1024):
super().__init__()
self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
def forward(self, image, text):
height, width = image.shape[-2:]
image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
emb = torch.concat([image_emb, text_emb], dim=1)
return emb
class AAABlock(torch.nn.Module):
def __init__(self, dim=1024, num_heads=32):
super().__init__()
self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
self.to_q = torch.nn.Linear(dim, dim)
self.to_k = torch.nn.Linear(dim, dim)
self.to_v = torch.nn.Linear(dim, dim)
self.to_out = torch.nn.Linear(dim, dim)
self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
self.ff = torch.nn.Sequential(
torch.nn.Linear(dim, dim*3),
torch.nn.SiLU(),
torch.nn.Linear(dim*3, dim),
)
self.to_gate = torch.nn.Linear(dim, dim * 2)
self.num_heads = num_heads
def attention(self, emb, pos_emb):
emb = self.norm_attn(emb + pos_emb)
q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
emb = attention_forward(
q, k, v,
q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
dims={"n": self.num_heads},
)
emb = self.to_out(emb)
return emb
def feed_forward(self, emb, pos_emb):
emb = self.norm_mlp(emb + pos_emb)
emb = self.ff(emb)
return emb
def forward(self, emb, pos_emb, t_emb):
gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
return emb
class AAADiT(torch.nn.Module):
def __init__(self, dim=1024):
super().__init__()
self.pos_embedder = AAAPositionalEmbedding(dim=dim)
self.timestep_embedder = TimestepEmbeddings(256, dim)
self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
self.proj_out = torch.nn.Linear(dim, 128)
def forward(
self,
latents,
prompt_embeds,
timestep,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
):
pos_emb = self.pos_embedder(latents, prompt_embeds)
t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
text = self.text_embedder(prompt_embeds)
emb = torch.concat([image, text], dim=1)
for block_id, block in enumerate(self.blocks):
emb = gradient_checkpoint_forward(
block,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
emb=emb,
pos_emb=pos_emb,
t_emb=t_emb,
)
emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
emb = self.proj_out(emb)
emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
return emb
class AAAImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
)
self.scheduler = FlowMatchScheduler("FLUX.2")
self.text_encoder: ZImageTextEncoder = None
self.dit: AAADiT = None
self.vae: Flux2VAE = None
self.tokenizer: AutoProcessor = None
self.in_iteration_models = ("dit",)
self.units = [
AAAUnit_PromptEmbedder(),
AAAUnit_NoiseInitializer(),
AAAUnit_InputImageEmbedder(),
]
self.model_fn = model_fn_aaa
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = None,
vram_limit: float = None,
):
# Initialize pipeline
pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
pipe.dit = model_pool.fetch_model("aaa_dit")
pipe.vae = model_pool.fetch_model("flux2_vae")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 1.0,
# Image
input_image: Image.Image = None,
denoising_strength: float = 1.0,
# Shape
height: int = 1024,
width: int = 1024,
# Randomness
seed: int = None,
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 30,
# Progress bar
progress_bar_cmd = tqdm,
):
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
# Parameters
inputs_posi = {"prompt": prompt}
inputs_nega = {"negative_prompt": negative_prompt}
inputs_shared = {
"cfg_scale": cfg_scale,
"input_image": input_image, "denoising_strength": denoising_strength,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"num_inference_steps": num_inference_steps,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
# Decode
self.load_models_to_device(['vae'])
image = self.vae.decode(inputs_shared["latents"])
image = self.vae_output_to_image(image)
self.load_models_to_device([])
return image
class AAAUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_embeds",),
onload_model_names=("text_encoder",)
)
self.hidden_states_layers = (-1,)
def process(self, pipe: AAAImagePipeline, prompt):
pipe.load_models_to_device(self.onload_model_names)
text = pipe.tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
return {"prompt_embeds": prompt_embeds}
class AAAUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
class AAAUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",)
)
def process(self, pipe: AAAImagePipeline, input_image, noise):
if input_image is None:
return {"latents": noise, "input_latents": None}
pipe.load_models_to_device(['vae'])
image = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(image)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents, "input_latents": input_latents}
def model_fn_aaa(
dit: AAADiT,
latents=None,
prompt_embeds=None,
timestep=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
model_output = dit(
latents,
prompt_embeds,
timestep,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
return model_output
class AAATrainingModule(DiffusionTrainingModule):
def __init__(self, device):
super().__init__()
self.pipe = AAAImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device=device,
model_configs=[
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
)
self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
self.pipe.freeze_except(["dit"])
self.pipe.scheduler.set_timesteps(1000, training=True)
def forward(self, data):
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {"negative_prompt": ""}
inputs_shared = {
"input_image": data["image"],
"height": data["image"].size[1],
"width": data["image"].size[0],
"cfg_scale": 1,
"use_gradient_checkpointing": False,
"use_gradient_checkpointing_offload": False,
}
for unit in self.pipe.units:
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
return loss
if __name__ == "__main__":
accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
dataset = UnifiedDataset(
base_path="data/images",
metadata_path="data/metadata_merged.csv",
max_data_items=10000000,
data_file_keys=("image",),
main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
)
model = AAATrainingModule(device=accelerator.device)
model_logger = ModelLogger(
"models/AAA/v1",
remove_prefix_in_ckpt="pipe.dit.",
)
launch_training_task(
accelerator, dataset, model, model_logger,
learning_rate=2e-4,
num_workers=4,
save_steps=50000,
num_epochs=999999,
)

View File

@@ -6,7 +6,7 @@
Diffusion 模型通过多步迭代式地去噪denoise生成清晰的图像或视频内容我们从一个数据样本 $x_0$ 的生成过程开始讲起。直观地,在完整的一轮 denoise 过程中,我们从随机高斯噪声 $x_T$ 开始,通过迭代依次得到 $x_{T-1}$、$x_{T-2}$、$x_{T-3}$、$\cdots$,在每一步中逐渐减少噪声含量,最终得到不含噪声的数据样本 $x_0$。 Diffusion 模型通过多步迭代式地去噪denoise生成清晰的图像或视频内容我们从一个数据样本 $x_0$ 的生成过程开始讲起。直观地,在完整的一轮 denoise 过程中,我们从随机高斯噪声 $x_T$ 开始,通过迭代依次得到 $x_{T-1}$、$x_{T-2}$、$x_{T-3}$、$\cdots$,在每一步中逐渐减少噪声含量,最终得到不含噪声的数据样本 $x_0$。
(图) ![Image](https://github.com/user-attachments/assets/6471ae4c-a635-4924-8b36-b0bd4d42043d)
这个过程是很直观的,但如果要理解其中的细节,我们就需要回答这几个问题: 这个过程是很直观的,但如果要理解其中的细节,我们就需要回答这几个问题:
@@ -28,7 +28,7 @@ Diffusion 模型通过多步迭代式地去噪denoise生成清晰的图像
那么在中间的某一步,我们可以直接合成含噪声的数据样本 $x_t=(1-\sigma_t)x_0+\sigma_t x_T$。 那么在中间的某一步,我们可以直接合成含噪声的数据样本 $x_t=(1-\sigma_t)x_0+\sigma_t x_T$。
(图) ![Image](https://github.com/user-attachments/assets/e25a2f71-123c-4e18-8b34-3a066af15667)
## 迭代去噪的计算是如何进行的? ## 迭代去噪的计算是如何进行的?
@@ -40,8 +40,6 @@ Diffusion 模型通过多步迭代式地去噪denoise生成清晰的图像
其中,引导条件 $c$ 是新引入的参数,它是由用户输入的,可以是用于描述图像内容的文本,也可以是用于勾勒图像结构的线稿图。 其中,引导条件 $c$ 是新引入的参数,它是由用户输入的,可以是用于描述图像内容的文本,也可以是用于勾勒图像结构的线稿图。
(图)
而模型的输出 $\hat \epsilon(x_t,c,t)$,则近似地等于 $x_T-x_0$,也就是整个扩散过程(去噪过程的反向过程)的方向。 而模型的输出 $\hat \epsilon(x_t,c,t)$,则近似地等于 $x_T-x_0$,也就是整个扩散过程(去噪过程的反向过程)的方向。
接下来我们分析一步迭代中发生的计算,在时间步 $t$,模型通过计算得到近似的 $x_T-x_0$ 后,我们计算下一步的 $x_{t-1}$ 接下来我们分析一步迭代中发生的计算,在时间步 $t$,模型通过计算得到近似的 $x_T-x_0$ 后,我们计算下一步的 $x_{t-1}$
@@ -89,8 +87,6 @@ $$
训练过程不同于生成过程,如果我们在训练过程中保留多步迭代,那么梯度需经过多步回传,带来的时间和空间复杂度是灾难性的。为了提高计算效率,我们在训练中随机选择某一时间步 $t$ 进行训练。 训练过程不同于生成过程,如果我们在训练过程中保留多步迭代,那么梯度需经过多步回传,带来的时间和空间复杂度是灾难性的。为了提高计算效率,我们在训练中随机选择某一时间步 $t$ 进行训练。
(图)
以下是训练过程的伪代码 以下是训练过程的伪代码
> 从数据集获取数据样本 $x_0$ 和引导条件 $c$ > 从数据集获取数据样本 $x_0$ 和引导条件 $c$
@@ -111,7 +107,7 @@ $$
从理论到实践,还需要填充更多细节。现代 Diffusion 模型架构已经发展成熟,主流的架构沿用了 Latent Diffusion 所提出的“三段式”架构,包括数据编解码器、引导条件编码器、去噪模型三部分。 从理论到实践,还需要填充更多细节。现代 Diffusion 模型架构已经发展成熟,主流的架构沿用了 Latent Diffusion 所提出的“三段式”架构,包括数据编解码器、引导条件编码器、去噪模型三部分。
(图) ![Image](https://github.com/user-attachments/assets/43855430-6427-4aca-83a0-f684e01438b1)
### 数据编解码器 ### 数据编解码器

View File

@@ -108,7 +108,14 @@ def test_flux():
run_inference("examples/flux/model_training/validate_lora") run_inference("examples/flux/model_training/validate_lora")
def test_z_image():
run_inference("examples/z_image/model_inference")
run_inference("examples/z_image/model_inference_low_vram")
run_train_multi_GPU("examples/z_image/model_training/full")
run_inference("examples/z_image/model_training/validate_full")
run_train_single_GPU("examples/z_image/model_training/lora")
run_inference("examples/z_image/model_training/validate_lora")
if __name__ == "__main__": if __name__ == "__main__":
test_qwen_image() test_z_image()
test_flux()
test_wan()

View File

@@ -0,0 +1,17 @@
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export CPU_AFFINITY_CONF=1
accelerate launch --config_file examples/flux/model_training/full/accelerate_config_zero2offload.yaml examples/flux/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_kontext.csv \
--data_file_keys "image,kontext_images" \
--max_pixels 1048576 \
--dataset_repeat 400 \
--model_id_with_origin_paths "black-forest-labs/FLUX.1-Kontext-dev:flux1-kontext-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors" \
--learning_rate 1e-5 \
--num_epochs 1 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.1-Kontext-dev_full" \
--trainable_models "dit" \
--extra_inputs "kontext_images" \
--use_gradient_checkpointing

View File

@@ -0,0 +1,15 @@
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export CPU_AFFINITY_CONF=1
accelerate launch --config_file examples/flux/model_training/full/accelerate_config_zero2offload.yaml examples/flux/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 400 \
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors" \
--learning_rate 1e-5 \
--num_epochs 1 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.1-dev_full" \
--trainable_models "dit" \
--use_gradient_checkpointing

View File

@@ -0,0 +1,21 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
)
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4)
image.save("image_FLUX.2-klein-4B.jpg")
prompt = "change the color of the clothes to red"
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=4)
image.save("image_edit_FLUX.2-klein-4B.jpg")

View File

@@ -0,0 +1,21 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
)
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4)
image.save("image_FLUX.2-klein-9B.jpg")
prompt = "change the color of the clothes to red"
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=4)
image.save("image_edit_FLUX.2-klein-9B.jpg")

View File

@@ -0,0 +1,21 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
)
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
image.save("image_FLUX.2-klein-base-4B.jpg")
prompt = "change the color of the clothes to red"
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
image.save("image_edit_FLUX.2-klein-base-4B.jpg")

View File

@@ -0,0 +1,21 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-9B", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
)
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
image.save("image_FLUX.2-klein-base-9B.jpg")
prompt = "change the color of the clothes to red"
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
image.save("image_edit_FLUX.2-klein-base-9B.jpg")

View File

@@ -0,0 +1,32 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4)
image.save("image_FLUX.2-klein-4B.jpg")
prompt = "change the color of the clothes to red"
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=4)
image.save("image_edit_FLUX.2-klein-4B.jpg")

View File

@@ -0,0 +1,32 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4)
image.save("image_FLUX.2-klein-9B.jpg")
prompt = "change the color of the clothes to red"
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=4)
image.save("image_edit_FLUX.2-klein-9B.jpg")

View File

@@ -0,0 +1,32 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
image.save("image_FLUX.2-klein-base-4B.jpg")
prompt = "change the color of the clothes to red"
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
image.save("image_edit_FLUX.2-klein-base-4B.jpg")

View File

@@ -0,0 +1,32 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-9B", origin_file_pattern="transformer/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
image.save("image_FLUX.2-klein-base-9B.jpg")
prompt = "change the color of the clothes to red"
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
image.save("image_edit_FLUX.2-klein-base-9B.jpg")

View File

@@ -0,0 +1,30 @@
accelerate launch examples/flux2/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.2-klein-4B_full" \
--trainable_models "dit" \
--use_gradient_checkpointing
# Edit
# accelerate launch examples/flux2/model_training/train.py \
# --dataset_base_path data/example_image_dataset \
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
# --data_file_keys "image,edit_image" \
# --extra_inputs "edit_image" \
# --max_pixels 1048576 \
# --dataset_repeat 50 \
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
# --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
# --learning_rate 1e-5 \
# --num_epochs 2 \
# --remove_prefix_in_ckpt "pipe.dit." \
# --output_path "./models/train/FLUX.2-klein-4B_full" \
# --trainable_models "dit" \
# --use_gradient_checkpointing

View File

@@ -0,0 +1,31 @@
# This script is tested on 8*A100
accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.2-klein-9B_full" \
--trainable_models "dit" \
--use_gradient_checkpointing
# Edit
# accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \
# --dataset_base_path data/example_image_dataset \
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
# --data_file_keys "image,edit_image" \
# --extra_inputs "edit_image" \
# --max_pixels 1048576 \
# --dataset_repeat 50 \
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
# --learning_rate 1e-5 \
# --num_epochs 2 \
# --remove_prefix_in_ckpt "pipe.dit." \
# --output_path "./models/train/FLUX.2-klein-9B_full" \
# --trainable_models "dit" \
# --use_gradient_checkpointing

View File

@@ -0,0 +1,30 @@
accelerate launch examples/flux2/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.2-klein-base-4B_full" \
--trainable_models "dit" \
--use_gradient_checkpointing
# Edit
# accelerate launch examples/flux2/model_training/train.py \
# --dataset_base_path data/example_image_dataset \
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
# --data_file_keys "image,edit_image" \
# --extra_inputs "edit_image" \
# --max_pixels 1048576 \
# --dataset_repeat 50 \
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
# --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
# --learning_rate 1e-5 \
# --num_epochs 2 \
# --remove_prefix_in_ckpt "pipe.dit." \
# --output_path "./models/train/FLUX.2-klein-base-4B_full" \
# --trainable_models "dit" \
# --use_gradient_checkpointing

View File

@@ -0,0 +1,31 @@
# This script is tested on 8*A100
accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.2-klein-base-9B_full" \
--trainable_models "dit" \
--use_gradient_checkpointing
# Edit
# accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \
# --dataset_base_path data/example_image_dataset \
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
# --data_file_keys "image,edit_image" \
# --extra_inputs "edit_image" \
# --max_pixels 1048576 \
# --dataset_repeat 50 \
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
# --learning_rate 1e-5 \
# --num_epochs 2 \
# --remove_prefix_in_ckpt "pipe.dit." \
# --output_path "./models/train/FLUX.2-klein-base-9B_full" \
# --trainable_models "dit" \
# --use_gradient_checkpointing

View File

@@ -0,0 +1,22 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -0,0 +1,34 @@
accelerate launch examples/flux2/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.2-klein-4B_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out" \
--lora_rank 32 \
--use_gradient_checkpointing
# Edit
# accelerate launch examples/flux2/model_training/train.py \
# --dataset_base_path data/example_image_dataset \
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
# --data_file_keys "image,edit_image" \
# --extra_inputs "edit_image" \
# --max_pixels 1048576 \
# --dataset_repeat 50 \
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
# --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
# --learning_rate 1e-4 \
# --num_epochs 5 \
# --remove_prefix_in_ckpt "pipe.dit." \
# --output_path "./models/train/FLUX.2-klein-4B_lora" \
# --lora_base_model "dit" \
# --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out" \
# --lora_rank 32 \
# --use_gradient_checkpointing

View File

@@ -0,0 +1,34 @@
accelerate launch examples/flux2/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.2-klein-9B_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out" \
--lora_rank 32 \
--use_gradient_checkpointing
# Edit
# accelerate launch examples/flux2/model_training/train.py \
# --dataset_base_path data/example_image_dataset \
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
# --data_file_keys "image,edit_image" \
# --extra_inputs "edit_image" \
# --max_pixels 1048576 \
# --dataset_repeat 50 \
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
# --learning_rate 1e-4 \
# --num_epochs 5 \
# --remove_prefix_in_ckpt "pipe.dit." \
# --output_path "./models/train/FLUX.2-klein-9B_lora" \
# --lora_base_model "dit" \
# --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out" \
# --lora_rank 32 \
# --use_gradient_checkpointing

View File

@@ -0,0 +1,34 @@
accelerate launch examples/flux2/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.2-klein-base-4B_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out" \
--lora_rank 32 \
--use_gradient_checkpointing
# Edit
# accelerate launch examples/flux2/model_training/train.py \
# --dataset_base_path data/example_image_dataset \
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
# --data_file_keys "image,edit_image" \
# --extra_inputs "edit_image" \
# --max_pixels 1048576 \
# --dataset_repeat 50 \
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
# --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
# --learning_rate 1e-4 \
# --num_epochs 5 \
# --remove_prefix_in_ckpt "pipe.dit." \
# --output_path "./models/train/FLUX.2-klein-base-4B_lora" \
# --lora_base_model "dit" \
# --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out" \
# --lora_rank 32 \
# --use_gradient_checkpointing

View File

@@ -0,0 +1,34 @@
accelerate launch examples/flux2/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.2-klein-base-9B_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out" \
--lora_rank 32 \
--use_gradient_checkpointing
# Edit
# accelerate launch examples/flux2/model_training/train.py \
# --dataset_base_path data/example_image_dataset \
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
# --data_file_keys "image,edit_image" \
# --extra_inputs "edit_image" \
# --max_pixels 1048576 \
# --dataset_repeat 50 \
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
# --learning_rate 1e-4 \
# --num_epochs 5 \
# --remove_prefix_in_ckpt "pipe.dit." \
# --output_path "./models/train/FLUX.2-klein-base-9B_lora" \
# --lora_base_model "dit" \
# --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out" \
# --lora_rank 32 \
# --use_gradient_checkpointing

View File

@@ -24,7 +24,7 @@ class Flux2ImageTrainingModule(DiffusionTrainingModule):
super().__init__() super().__init__()
# Load models # Load models
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
tokenizer_config = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"))
self.pipe = Flux2ImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config) self.pipe = Flux2ImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)

View File

@@ -0,0 +1,20 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
from diffsynth.core import load_state_dict
import torch
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
)
state_dict = load_state_dict("./models/train/FLUX.2-klein-4B_full/epoch-1.safetensors", torch_dtype=torch.bfloat16)
pipe.dit.load_state_dict(state_dict)
prompt = "a dog"
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)
image.save("image.jpg")

View File

@@ -0,0 +1,20 @@
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
from diffsynth.core import load_state_dict
import torch
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
)
state_dict = load_state_dict("./models/train/FLUX.2-klein-9B_full/epoch-1.safetensors", torch_dtype=torch.bfloat16)
pipe.dit.load_state_dict(state_dict)
prompt = "a dog"
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)
image.save("image.jpg")

Some files were not shown because too many files have changed in this diff Show More