Compare commits

...

917 Commits

Author SHA1 Message Date
Artiprocher
98290190ec update z-image-i2L demo 2026-01-27 13:42:48 +08:00
Artiprocher
3f4de2cc7f update z-image-i2L examples 2026-01-27 12:16:48 +08:00
Artiprocher
d12bf71bcc support z-image and z-image-i2L 2026-01-27 10:56:15 +08:00
Zhongjie Duan
ffb7a138f7 Merge pull request #1228 from modelscope/klein-bugfix
change klein image resize to crop
2026-01-22 10:34:17 +08:00
Artiprocher
548304667f change klein image resize to crop 2026-01-22 10:33:29 +08:00
Zhongjie Duan
273143136c Merge pull request #1227 from modelscope/modelscope-service-patch
update to 2.0.3
2026-01-21 20:23:13 +08:00
Artiprocher
030ebe649a update to 2.0.3 2026-01-21 20:22:43 +08:00
Zhongjie Duan
90921d2293 Merge pull request #1226 from modelscope/klein-train-fix
improve flux2 training performance
2026-01-21 15:44:52 +08:00
Artiprocher
b61131c693 improve flux2 training performance 2026-01-21 15:44:15 +08:00
Zhongjie Duan
37fbb3248a Merge pull request #1222 from modelscope/trainer-update
support auto detact lora target modules
2026-01-21 11:06:19 +08:00
Artiprocher
d13f533f42 support auto detact lora target modules 2026-01-21 11:05:05 +08:00
Zhongjie Duan
3743b1307c Merge pull request #1219 from modelscope/klein-edit
support klein edit
2026-01-20 12:59:12 +08:00
Artiprocher
a835df984c support klein edit 2026-01-20 12:58:18 +08:00
Zhongjie Duan
3e4b47e424 Merge pull request #1207 from Feng0w0/cuda_replace
[NPU]:Replace 'cuda' in the project with abstract interfaces
2026-01-20 10:13:04 +08:00
Zhongjie Duan
dd8d902624 Merge branch 'main' into cuda_replace 2026-01-20 10:12:31 +08:00
Zhongjie Duan
a8b340c098 Merge pull request #1191 from Feng0w0/wan_rope
[model][NPU]:Wan model rope use torch.complex64 in NPU
2026-01-20 10:05:22 +08:00
Zhongjie Duan
88497b5c13 Merge pull request #1217 from modelscope/klein-update
support klein base models
2026-01-19 21:14:47 +08:00
Artiprocher
1e90c72d94 support klein base models 2026-01-19 21:11:58 +08:00
Zhongjie Duan
3dd82a738e Merge pull request #1215 from lzws/main
updata learning rate in wan-vace training scripts
2026-01-19 17:48:42 +08:00
Artiprocher
8ad2d9884b update lr in wan-vace training scripts 2026-01-19 17:43:07 +08:00
Artiprocher
70f531b724 update wan-vace training scripts 2026-01-19 17:37:30 +08:00
Zhongjie Duan
37c2868b61 Merge pull request #1214 from modelscope/klein
Support FLUX.2-klein
2026-01-19 17:36:39 +08:00
Artiprocher
a18e6233b5 updata wan-vace training scripts 2026-01-19 17:35:08 +08:00
Artiprocher
2336d5f6b3 update doc 2026-01-19 17:27:32 +08:00
Artiprocher
b6ccb362b9 support flux.2 klein 2026-01-19 16:56:14 +08:00
Artiprocher
ae52d93694 support klein 4b models 2026-01-16 13:09:41 +08:00
feng0w0
ad91d41601 [NPU]:Replace 'cuda' in the project with abstract interfaces 2026-01-16 10:28:24 +08:00
feng0w0
dce77ec4d1 [NPU]:Replace 'cuda' in the project with abstract interfaces 2026-01-15 20:35:41 +08:00
feng0w0
5c0b07d939 [NPU]:Replace 'cuda' in the project with abstract interfaces 2026-01-15 20:34:52 +08:00
feng0w0
19e429d889 Merge remote-tracking branch 'origin/cuda_replace' into cuda_replace 2026-01-15 20:33:21 +08:00
feng0w0
209a350c0f [NPU]:Replace 'cuda' in the project with abstract interfaces 2026-01-15 20:33:01 +08:00
feng0w0
a3c2744a43 [NPU]:Replace 'cuda' in the project with abstract interfaces 2026-01-15 20:04:54 +08:00
Zhongjie Duan
55e8346da3 Blog link (#1202)
* update README
2026-01-15 12:31:55 +08:00
Zhongjie Duan
b7979b2633 Merge pull request #1200 from modelscope/flux-compatibility-fix
fix flux compatibility issues
2026-01-14 20:50:18 +08:00
Artiprocher
c90aaa2798 fix flux compatibility issues 2026-01-14 20:49:36 +08:00
Zhongjie Duan
0c617d5d9e Merge pull request #1194 from lzws/main
wan usp bug fix
2026-01-14 16:34:06 +08:00
lzws
fd87b72754 wan usp bug fix 2026-01-14 16:33:02 +08:00
Zhongjie Duan
db75508ba0 Merge pull request #1199 from modelscope/z-image-bugfix
fix RMSNorm precision
2026-01-14 16:32:33 +08:00
Artiprocher
acba342a63 fix RMSNorm precision 2026-01-14 16:29:43 +08:00
feng0w0
d16877e695 [model][NPU]:Wan model rope use torch.complex64 in NPU 2026-01-13 11:17:51 +08:00
lzws
e99cdcf3b8 wan usp bug fix 2026-01-12 22:08:48 +08:00
Zhongjie Duan
a236a17f17 Merge pull request #1193 from modelscope/qwen-image-layered-control
support qwen-image-layered-control
2026-01-12 17:24:06 +08:00
Artiprocher
03e530dc39 support qwen-image-layered-control 2026-01-12 17:20:01 +08:00
feng0w0
6be244233a [model][NPU]:Wan model rope use torch.complex64 in NPU 2026-01-12 11:34:41 +08:00
feng0w0
544c391936 [model][NPU]:Wan model rope use torch.complex64 in NPU 2026-01-12 11:24:11 +08:00
Feng
f4d06ce3fc Merge branch 'modelscope:main' into wan_rope 2026-01-12 11:21:09 +08:00
Zhongjie Duan
ffedb9eb52 Merge pull request #1187 from jiaqixuac/patch-1
Update package inclusion pattern in pyproject.toml
2026-01-12 10:12:20 +08:00
Zhongjie Duan
381067515c Merge pull request #1176 from Feng0w0/z-image-rope
[model][NPU]: Z-image model support NPU
2026-01-12 10:11:22 +08:00
Zhongjie Duan
00f2d1aa5d Merge pull request #1169 from Feng0w0/sample_add
Docs:Supplement NPU training script samples and documentation instruction
2026-01-12 10:08:38 +08:00
Zhongjie Duan
8cc3bece6d Merge pull request #1167 from Feng0w0/install_env
Docs:Supplement NPU environment installation document
2026-01-12 10:07:30 +08:00
Jiaqi Xu
f4bf592064 Update pyproject.toml
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-10 09:32:35 +08:00
Jiaqi Xu
3235393fb5 Update package inclusion pattern in pyproject.toml
Update to install all the sub-packages inside diffsynth. Otherwise, the installed packages only contain __init__.py
2026-01-10 09:28:45 +08:00
feng0w0
3b662da31e [model][NPU]:Wan model rope use torch.complex64 in NPU 2026-01-09 18:11:40 +08:00
feng0w0
19ce3048c1 [model][NPU]:Wan model rope use torch.complex64 in NPU 2026-01-09 18:06:41 +08:00
Zhongjie Duan
de0aa946f7 Merge pull request #1184 from modelscope/z-image-omni-base-dev
update package version
2026-01-08 17:27:33 +08:00
Artiprocher
f376202a49 update package version 2026-01-08 17:26:29 +08:00
Zhongjie Duan
a13ecfc46b Merge pull request #1183 from modelscope/z-image-omni-base-dev
fix unused parameters in z-image-omni-base
2026-01-08 17:03:20 +08:00
Artiprocher
10a1853eda fix unused parameters in z-image-omni-base 2026-01-08 17:02:41 +08:00
Zhongjie Duan
0efab85674 Support Z-Image-Omni-Base and its related models
Support Z-Image-Omni-Base and its related models.
2026-01-08 13:43:59 +08:00
Artiprocher
f45a0ffd02 support z-image-omni-base vram management 2026-01-08 13:41:00 +08:00
Artiprocher
8ba528a8f6 bugfix 2026-01-08 13:21:33 +08:00
Artiprocher
dd479e5bff support z-image-omni-base-i2L 2026-01-07 20:36:53 +08:00
Artiprocher
bac39b1cd2 support z-image controlnet 2026-01-07 15:56:53 +08:00
feng0w0
c1c9a4853b [model][NPU]:Z-image model support NPU 2026-01-07 11:42:19 +08:00
feng0w0
3ee5f53a36 [model][NPU]:Z-image model support NPU 2026-01-07 11:31:22 +08:00
Artiprocher
32449a6aa0 support z-image-omni-base training 2026-01-05 20:04:00 +08:00
Zhongjie Duan
a6884f6b3a Merge pull request #1171 from YZBPXX/main
Fix issue where LoRa loads on a device different from Dit
2026-01-05 16:39:02 +08:00
Zhongjie Duan
b078666640 Merge pull request #1173 from modelscope/flux-compatibility-patch
flux compatibility patch
2026-01-05 16:20:25 +08:00
Artiprocher
7604ca1e52 flux compatibility patch 2026-01-05 16:04:20 +08:00
feng0w0
62c3d406d9 Docs:Supplement NPU training script samples and documentation instruction 2026-01-05 15:42:55 +08:00
Artiprocher
5745c9f200 support z-image-omni-base 2026-01-05 14:45:01 +08:00
feng0w0
86829120c2 Docs:Supplement NPU training script samples and documentation instruction 2026-01-05 09:59:11 +08:00
yaozhengbing
60ac96525b Fix issue where LoRa loads on a device different from Dit 2025-12-31 21:31:01 +08:00
feng0w0
07b1f5702f Docs:Supplement NPU training script samples and documentation instruction 2025-12-31 10:01:21 +08:00
feng0w0
507e7e5d36 Docs:Supplement NPU training script samples and documentation instruction 2025-12-30 19:58:47 +08:00
Zhongjie Duan
ab8580f77e Merge pull request #1166 from modelscope/qwen-image-2512
support qwen-image-2512
2025-12-30 16:47:07 +08:00
Artiprocher
6454259853 support qwen-image-2512 2025-12-30 16:43:41 +08:00
feng0w0
9cc1697d4d Docs:Supplement NPU environment installation document 2025-12-30 15:57:13 +08:00
Zhongjie Duan
8f1d10fb43 Merge pull request #1150 from modelscope/qwen-image-layered
support qwen-image-layered
2025-12-20 14:05:38 +08:00
Artiprocher
20e1aaf908 bugfix 2025-12-20 14:00:22 +08:00
Artiprocher
c6722b3f56 support qwen-image-layered 2025-12-19 19:06:37 +08:00
Zhongjie Duan
11315d7a40 Merge pull request #1147 from modelscope/qwen-image-edit-2511
Qwen image edit 2511
2025-12-18 19:23:44 +08:00
Artiprocher
68d97a9844 update doc 2025-12-18 19:22:22 +08:00
Artiprocher
4629d4cf9e support qwen-image-edit-2511 2025-12-18 19:16:52 +08:00
Zhongjie Duan
3cb5cec906 Merge pull request #1143 from modelscope/readme-update
update README
2025-12-17 16:32:29 +08:00
Artiprocher
b7e16b9034 update README 2025-12-17 16:30:41 +08:00
Zhongjie Duan
83d1e7361f Merge pull request #1136 from modelscope/bugfix-device
bugfix
2025-12-16 16:12:05 +08:00
Artiprocher
1547c3f786 bugfix 2025-12-16 16:09:29 +08:00
Zhongjie Duan
bfaaf12bf4 Merge pull request #1129 from modelscope/ascend
Support Ascend NPU
2025-12-15 19:13:40 +08:00
Zhongjie Duan
47545e1aab Merge pull request #1126 from Leoooo333/main
Fixed: Wan S2V Long video severe quality downgrade
2025-12-15 19:09:39 +08:00
Artiprocher
7c6905a432 support ascend npu 2025-12-15 15:50:12 +08:00
Artiprocher
2883bc1b76 support ascend npu 2025-12-15 15:48:42 +08:00
Zhongjie Duan
78d8842ddf Merge pull request #1128 from modelscope/amd_install
update installation instructions for AMD
2025-12-15 14:35:50 +08:00
Artiprocher
5821a664a0 update AMD GPU support 2025-12-15 14:30:13 +08:00
Zhongjie Duan
ab9aa1a087 Merge pull request #1124 from lzws/main
add wan usp example
2025-12-15 12:57:58 +08:00
Junming Chen
a4d34d9f3d Append: set video compress quality as original version. 2025-12-14 20:53:26 +00:00
Junming Chen
127cc9007a Fixed: S2V Long video severe quality downgrade 2025-12-14 20:30:34 +00:00
lzws
e1f5db5f5c add wan usp example 2025-12-12 20:24:27 +08:00
Zhongjie Duan
e316fb717f Merge pull request #1122 from modelscope/flux-lora-revert
revert FluxLoRAConverter due to dependency issues
2025-12-12 17:19:48 +08:00
Artiprocher
64c5139502 revert FluxLoRAConverter due to dependency issues 2025-12-12 17:19:13 +08:00
Mahdi-CV
5da9611a74 Update README.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-11 09:57:15 -08:00
Mahdi-CV
733750d01b Update README.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-11 09:57:06 -08:00
Mahdi-CV
edc95359d0 Update README_zh.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-11 09:56:48 -08:00
lzws
f2d0241e26 Update Z-Image.md 2025-12-11 16:43:38 +08:00
lzws
7b5d7f4af5 Update Z-Image.md 2025-12-11 16:41:46 +08:00
Mahdi Ghodsi
1fa9a6c60c updated README both Eng and Ch to reflect the AMD installation 2025-12-10 16:14:56 -08:00
Mahdi Ghodsi
51efa128d3 adding amd requirements file 2025-12-10 14:40:38 -08:00
Zhongjie Duan
421c6a5fce Merge pull request #1109 from modelscope/bugfix1
fix typo
2025-12-09 23:30:15 +09:00
Artiprocher
864080d8f2 fix typo 2025-12-09 22:29:50 +08:00
Zhongjie Duan
ba372dd295 Merge pull request #1108 from modelscope/i2L
Qwen-Image-i2L (Image to LoRA)
2025-12-09 23:10:02 +09:00
Artiprocher
1ceb02f673 update README 2025-12-09 22:08:47 +08:00
Artiprocher
30f93161fb support i2L 2025-12-09 22:07:35 +08:00
Zhongjie Duan
3ee3cc3104 Merge pull request #1093 from modelscope/diffsynth-2.0-patch
DiffSynth-Studio 2.0 major update
2025-12-04 16:38:31 +08:00
root
c2218f5c73 DiffSynth-Studio 2.0 major update 2025-12-04 16:34:24 +08:00
root
72af7122b3 DiffSynth-Studio 2.0 major update 2025-12-04 16:33:07 +08:00
Zhongjie Duan
afd101f345 Merge pull request #1058 from modelscope/download
support downloading resource
2025-11-18 10:30:16 +08:00
Artiprocher
1313f4dd63 support downloading resource 2025-11-18 10:29:07 +08:00
Zhongjie Duan
8332ecebb7 Merge pull request #1034 from modelscope/video_as_prompt
Video as prompt
2025-11-04 17:32:50 +08:00
Zhongjie Duan
401d7d74a5 Merge pull request #1025 from krahets/patch-1
Fix sinusoidal_embedding calculation for bf16 precision.
2025-11-04 15:08:11 +08:00
Yudong Jin
b8d7d55568 Fix dtype issue in time embedding calculation 2025-11-01 03:11:03 +08:00
Zhongjie Duan
a30ed9093f Merge pull request #1018 from modelscope/longcat
support LongCat-Video
2025-10-30 13:45:03 +08:00
Artiprocher
b73e713028 support LongCat-Video 2025-10-30 13:38:14 +08:00
yjy415
e0eabaa426 Krea realtime video (#1011)
* krea-realtime-video

* Add Krea real-time video inference and training support

* Delete .gitignore

* update README

* update README

---------

Co-authored-by: Artiprocher <wangye87v5@hotmail.com>
Co-authored-by: Jintao Huang <huangjintao.hjt@alibaba-inc.com>
Co-authored-by: Zhongjie Duan <35051019+Artiprocher@users.noreply.github.com>
2025-10-27 19:09:28 +08:00
Zhongjie Duan
538017177a Merge pull request #1006 from lzws/main
add wan2.2-S2V-14B training
2025-10-22 09:55:21 +08:00
lzws
30292d9411 update wan2.2-S2V training 2025-10-21 19:59:44 +08:00
lzws
b168d7aa8b update wans2v training 2025-10-21 10:39:30 +08:00
lzws
8ea45b0daa update wans2v training 2025-10-21 10:34:48 +08:00
Zhongjie Duan
0a1c172a00 Merge pull request #984 from modelscope/animate-bugfix
bugfix
2025-10-10 15:42:20 +08:00
Artiprocher
77fac2a03f bugfix 2025-10-10 15:41:39 +08:00
Zhongjie Duan
084bc2fc78 Merge pull request #969 from modelscope/bugfix953
fix bug in issue 953
2025-09-30 13:00:15 +08:00
Artiprocher
c63d474b60 fix bug in issue 953 2025-09-30 12:59:44 +08:00
Zhongjie Duan
7540568156 support wan2.2-animate-14b (#968) 2025-09-30 12:57:16 +08:00
Zhongjie Duan
c5d426c254 Merge branch 'main' into wan-animate 2025-09-30 12:56:28 +08:00
Artiprocher
a36f2f6032 support wan2.2-animate-14b 2025-09-30 12:45:56 +08:00
lzws
ed256ef8be fix wan vace bug (#960)
* fix wan vace bug
2025-09-26 13:49:27 +08:00
Zhongjie Duan
15079a6cb8 Merge pull request #944 from baolef/dev
fix: fix the undefined vace typo
2025-09-25 15:58:24 +08:00
Zhongjie Duan
c084d6377b Merge pull request #952 from modelscope/bugfix-vace
Update wan_video_new.py
2025-09-25 15:34:22 +08:00
Zhongjie Duan
e9bc42f233 Update wan_video_new.py 2025-09-25 15:34:09 +08:00
Zhongjie Duan
0d6de58af9 Merge pull request #949 from modelscope/qwen-image-edit-multi
update qwen-image-edit training script
2025-09-25 11:07:38 +08:00
Artiprocher
acbf932974 update qwen-image-edit training script 2025-09-25 11:07:01 +08:00
Baole Fang
9d64ed7042 fix: fix the undefined vace typo 2025-09-24 16:55:47 +08:00
Zhongjie Duan
0b4b337e9a Merge pull request #933 from lzws/main
update wan2.2-VACE-Fun-A14B
2025-09-24 09:56:37 +08:00
Zhongjie Duan
99908d9a1c Merge pull request #940 from mi804/eligen_poster
support eligen-poster
2025-09-23 17:49:37 +08:00
mi804
73ced7a46d support eligen-poster 2025-09-23 17:41:48 +08:00
Zhongjie Duan
32b8b9b51e Merge pull request #910 from ldiex/main
Fix gradient checkpointing in WAN VACE blocks
2025-09-23 12:23:12 +08:00
Zhongjie Duan
f6534a5b63 Merge pull request #909 from huarzone/fix_bug
fix load gif
2025-09-23 12:22:00 +08:00
Zhongjie Duan
034c9b6c60 Qwen-Image-Edit-2509 (#937)
* qwen-image-edit-2509
2025-09-22 20:37:11 +08:00
lzws
76335e0fe5 uodate wan2.2-VACE-Fun 2025-09-22 02:14:20 +08:00
lzws
c0b589d934 add wan2.2-VACE-Fun infereance and trining 2025-09-22 01:57:05 +08:00
Zhongjie Duan
833ba1e1fa update vram management strategy (#929) 2025-09-18 16:53:13 +08:00
Artiprocher
7a5974d964 update vram management strategy 2025-09-18 16:51:53 +08:00
Zhongjie Duan
b0abdaffb4 Qwen image split training Bug Fix (#926)
* bugfix
2025-09-17 20:53:46 +08:00
Zhongjie Duan
e9f29bc402 Merge pull request #921 from modelscope/qwen-image-distill-dmd2-lora
support qwen-image-distill-dmd2-lora
2025-09-16 19:43:59 +08:00
Artiprocher
1a7f482fbd support qwen-image-distill-dmd2-lora 2025-09-16 19:43:07 +08:00
Tianlin Pan
3a0d51d100 Fix gradient checkpointing in WAN VACE blocks 2025-09-14 16:21:46 +08:00
Kared
bffdb901ed fix load gif 2025-09-13 21:01:44 +08:00
Zhongjie Duan
d93e8738cd Merge pull request #902 from xycdx/feature/improve-fastblend
add torch implementation for interpolation
2025-09-11 11:45:55 +08:00
xycdx
7e5ce5d5c9 Update diffsynth/extensions/FastBlend/patch_match.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-09-10 20:48:54 +08:00
xycdx
7aef554d83 add torch implementation for interpolation
- Implement bilinear interpolation kernel using Numba
- Benchmark shows 2x speedup compared to CPU version
- Closes #817
2025-09-10 20:39:35 +08:00
Zhongjie Duan
090074e395 Merge pull request #899 from modelscope/version_update_1.1.8
Update setup.py
2025-09-09 18:43:03 +08:00
Zhongjie Duan
2dcdeefca8 Update setup.py 2025-09-09 18:42:39 +08:00
Zhongjie Duan
452a6ca5cf Merge pull request #898 from modelscope/direct_distill
support direct distill
2025-09-09 16:16:32 +08:00
Artiprocher
d6cf20ef33 support direct distill 2025-09-09 16:12:31 +08:00
Zhongjie Duan
efdd6a59b6 Merge pull request #892 from modelscope/dev2-dzj
refine training framework
2025-09-04 15:53:52 +08:00
Artiprocher
42ec7b08eb bugfix 2025-09-04 15:45:39 +08:00
Artiprocher
d049fb6d1d bugfix 2025-09-04 15:44:37 +08:00
Artiprocher
144365b07d merge data process to training script 2025-09-04 15:18:56 +08:00
Artiprocher
cb8de6be1b move training code to base trainer 2025-09-03 12:03:49 +08:00
Zhongjie Duan
8c13362dcf Merge pull request #884 from modelscope/dev2-dzj
Unified Dataset & Splited Training
2025-09-03 09:50:23 +08:00
Zhongjie Duan
c13fd7e0ee Merge pull request #877 from mi804/wans2v_framepack
support s2v framepack
2025-09-02 16:54:37 +08:00
Artiprocher
958ebf1352 remove testing script 2025-09-02 16:44:36 +08:00
Artiprocher
b6da77e468 qwen-image splited training 2025-09-02 16:44:14 +08:00
Artiprocher
260e32217f unified dataset 2025-09-02 13:14:08 +08:00
mi804
5cee326f92 support s2v framepack 2025-09-01 16:48:46 +08:00
Zhongjie Duan
1d240994e7 Merge pull request #874 from mi804/wans2v_example
Wans2v example
2025-08-29 15:13:28 +08:00
mi804
a0bae07825 add wans2v example 2025-08-29 15:11:30 +08:00
ShunqiangBian
ff71720297 Create Wan2.2-S2V-14B.py
This commit introduces the core inference functionality for the Wan2.2-S2V-14B model.
2025-08-29 14:54:41 +08:00
Zhongjie Duan
dea85643e6 Merge pull request #872 from modelscope/dev2-dzj
remove some requirements & update Qwen-Image Quickstart
2025-08-29 14:22:35 +08:00
Artiprocher
6a46f32afe update Qwen-Image Quickstart 2025-08-29 14:09:49 +08:00
Artiprocher
4641d0f360 remove some requirements 2025-08-29 14:04:58 +08:00
Zhongjie Duan
826bab5962 Merge pull request #859 from krahets/main
Fix batch decoding for Wan-Video-VAE
2025-08-29 12:45:49 +08:00
Zhongjie Duan
5b6d112c15 Merge pull request #843 from wuutiing/main
add read gifs as video support
2025-08-29 12:36:24 +08:00
Zhongjie Duan
febdaf6067 Merge pull request #856 from lzws/main
add wan2.2-fun training scripts
2025-08-29 12:34:55 +08:00
Zhongjie Duan
0a78bb9d38 Merge pull request #864 from modelscope/wans2v
Support Wan-S2V
2025-08-28 10:21:12 +08:00
mi804
9cea10cc69 minor fix 2025-08-28 10:13:52 +08:00
mi804
caa17da5b9 wans2v readme 2025-08-27 20:05:44 +08:00
mi804
fdeb363fa2 wans2v usp 2025-08-27 19:50:33 +08:00
mi804
4147473c81 wans2v refactor 2025-08-27 16:18:22 +08:00
mi804
8a0bd7c377 wans2v lowvram 2025-08-27 13:05:53 +08:00
mi804
b541b9bed2 wans2v inference 2025-08-27 11:51:56 +08:00
Yudong Jin
419d47c195 Remove unnecessary newline in encode method 2025-08-27 02:24:29 +08:00
Yudong Jin
ac2e859960 Fix batch decoding for Wan VAE. 2025-08-27 02:24:00 +08:00
Zhongjie Duan
6663dca015 Merge pull request #857 from modelscope/Artiprocher-patch-1
bugfix
2025-08-26 17:23:32 +08:00
lzws
86e509ad31 update wan2.2-fun training scripts 2025-08-26 17:22:41 +08:00
Zhongjie Duan
8fcfa1dd2d bugfix 2025-08-26 17:22:25 +08:00
lzws
2b7a2548b4 update wan2.2-fun model overview in readme 2025-08-26 17:11:48 +08:00
lzws
f0916e6bae update wan2.2-fun training scripts 2025-08-26 16:37:47 +08:00
lzws
822e80ec2f Merge branch 'modelscope:main' into main 2025-08-26 15:08:43 +08:00
Zhongjie Duan
04e39f7de5 Merge pull request #853 from modelscope/qwen-image-fp8-lora
support qwen-image fp8 lora training
2025-08-25 20:33:36 +08:00
Artiprocher
ce0b948655 support qwen-image fp8 lora training 2025-08-25 20:32:36 +08:00
lzws
c795e35142 add wan2.2-fun-A14B inp, control and control-camera (#839)
* update wan2.2-fun

* update wan2.2-fun

* update wan2.2-fun

* add examples

* update wan2.2-fun

* update wan2.2-fun

* Rename Wan2.2-Fun-A14B-Inp.py to Wan2.2-Fun-A14B-InP.py

---------

Co-authored-by: lzw478614@alibaba-inc.com <lzw478614@alibaba-inc.com>
2025-08-22 14:20:31 +08:00
lzws
f7c01f1367 Merge branch 'modelscope:main' into main 2025-08-22 14:18:36 +08:00
lzws
cb49f0283f Rename Wan2.2-Fun-A14B-Inp.py to Wan2.2-Fun-A14B-InP.py 2025-08-22 14:18:16 +08:00
Zhongjie Duan
6a45815b23 Merge pull request #844 from mi804/blockwisecontrolnet_fix
fix blockwise controlnet training by avoid inplace
2025-08-22 11:47:21 +08:00
mi804
8dae8d7bc8 fix blockwise controlnet training by avoid inplace 2025-08-22 11:28:57 +08:00
twu
f6418004bb as numframe limit is impled in reader, add that 2025-08-22 03:00:35 +00:00
lzw478614@alibaba-inc.com
c4b97cd591 update wan2.2-fun 2025-08-22 09:38:19 +08:00
lzws
b6d1ff01e0 Merge branch 'modelscope:main' into main 2025-08-21 20:53:19 +08:00
lzw478614@alibaba-inc.com
0d81626fe7 update wan2.2-fun 2025-08-21 20:08:49 +08:00
twu
e3f47a799b make it more efficient to locate where to sample the frame 2025-08-21 09:13:45 +00:00
twu
e014cad820 add read gifs as video support 2025-08-21 09:01:48 +00:00
Zhongjie Duan
89bf3ce5cf Merge pull request #841 from modelscope/qwen-image-lora-hotload
support qwen-image lora hotload
2025-08-21 15:14:46 +08:00
Zhongjie Duan
3ebe118f23 Merge pull request #840 from modelscope/qwen-image-incontext
Qwen image incontext
2025-08-21 15:11:42 +08:00
Artiprocher
7f719cefe6 refine code 2025-08-21 14:25:17 +08:00
lzw478614@alibaba-inc.com
46bd05b54d add examples 2025-08-21 13:41:07 +08:00
Artiprocher
613dafbd09 rename model 2025-08-21 13:35:47 +08:00
lzw478614@alibaba-inc.com
952933eeb1 update wan2.2-fun 2025-08-21 13:34:09 +08:00
lzw478614@alibaba-inc.com
c0172e70b1 update wan2.2-fun 2025-08-21 12:59:41 +08:00
Artiprocher
6ab426e641 support qwen-image lora hotload 2025-08-21 10:12:52 +08:00
mi804
d0467a7e8d fix controlnet annotator 2025-08-20 23:28:40 +08:00
mi804
36838a05ee minor fix 2025-08-20 22:50:18 +08:00
mi804
5e6f9f89f1 support eligenv2 and context_control 2025-08-20 22:48:34 +08:00
lzw478614@alibaba-inc.com
2dad9a319c update wan2.2-fun 2025-08-20 20:17:41 +08:00
Zhongjie Duan
9ec0652339 Merge pull request #829 from mi804/qwen-image-edit-autoresize
support edit_image_auto_resize
2025-08-20 13:40:02 +08:00
mi804
7e348083ae minor fix 2025-08-20 12:42:11 +08:00
mi804
29b12b2f4e support edit_image_auto_resize 2025-08-20 12:36:26 +08:00
Zhongjie Duan
b3f57ed920 Merge pull request #826 from mi804/qwen-image-edit-lowvram
fix qwen-image-edit-lowvram
2025-08-20 11:39:56 +08:00
mi804
c9fea729d8 fix qwen-image-edit-lowvram 2025-08-20 10:31:43 +08:00
Hong Zhang
9d0683df25 Merge pull request #824 from mi804/low_res_fix
support qwen-image-edit lowres fix
2025-08-20 10:24:11 +08:00
mi804
838b8109b1 support qwen-image-edit lowres fix 2025-08-19 20:15:36 +08:00
Zhongjie Duan
3a9621f6da Merge pull request #815 from mi804/lora_checkpoint
fix bug
2025-08-19 12:43:04 +08:00
mi804
fff2c89360 fix bug 2025-08-19 12:38:33 +08:00
Zhongjie Duan
ce61bef2b0 Merge pull request #814 from mi804/qwen-image-edit
Qwen image edit
2025-08-19 09:33:39 +08:00
mi804
123f6dbadb update lora and full train 2025-08-18 19:09:19 +08:00
Hong Zhang
f9ce261a0e Merge branch 'main' into qwen-image-edit 2025-08-18 18:56:26 +08:00
mi804
d93de98a21 fix qwen_rope 2025-08-18 17:31:18 +08:00
mi804
ad1da43476 fix validate full 2025-08-18 16:17:40 +08:00
mi804
398b1dbd7a fix inference 2025-08-18 16:10:01 +08:00
mi804
9f6922bba9 support qwen-image-edit 2025-08-18 16:07:45 +08:00
Zhongjie Duan
f11a91e610 Merge pull request #813 from modelscope/qwen-image-inpaint
Qwen image inpaint
2025-08-18 15:26:06 +08:00
Artiprocher
7ed09bb78d add inpaint mask in qwen-image 2025-08-18 15:16:38 +08:00
mi804
ac931856d5 minor fix 2025-08-16 17:24:37 +08:00
mi804
2d09318236 support qwen-image inpaint controlnet 2025-08-16 17:12:29 +08:00
Zhongjie Duan
7dc49bd036 Merge pull request #806 from mi804/wan2.2_boundary
fix training boundary for wan2.2 A14B
2025-08-15 18:43:37 +08:00
Zhongjie Duan
4d16bdf853 Merge pull request #807 from modelscope/qwen-image-blockwise-controlnet-train
support qwen-image blockwise controlnet training
2025-08-15 18:42:29 +08:00
Artiprocher
01a1f48f70 support qwen-image blockwise controlnet training 2025-08-15 18:41:01 +08:00
mi804
6a9d875d65 fix training boundary for wan2.2 A14B 2025-08-15 17:54:52 +08:00
Zhongjie Duan
f1c96d31b4 Merge pull request #804 from mi804/qwen-image-dataset
qwen-image-dataset
2025-08-15 14:39:44 +08:00
mi804
aafcca8d77 add announcements 2025-08-15 14:38:03 +08:00
mi804
bf369cad4d qwen-image-dataset 2025-08-15 14:28:55 +08:00
Zhongjie Duan
024fdad76d Merge pull request #801 from modelscope/qwen-image-lowvram
add low vram examples
2025-08-15 11:34:24 +08:00
Artiprocher
e1c2eda5f5 add low vram examples 2025-08-15 11:31:57 +08:00
Zhongjie Duan
0b574cc0c2 Merge pull request #794 from mi804/training_optimize
lora_checkpoint & weight_decay
2025-08-14 14:20:03 +08:00
mi804
3212c83398 minor fix 2025-08-14 13:59:04 +08:00
mi804
49f9a11eb3 lora_checkpoint & weight_decay & qwen_image_controlnet_train 2025-08-14 13:50:04 +08:00
Zhongjie Duan
fa36739f01 Merge pull request #791 from mi804/qwen-image-longprompt
fix long prompt for qwen-image
2025-08-14 09:59:42 +08:00
Zhongjie Duan
42e9764b60 Merge pull request #790 from mi804/qwen-image-blockwise-controlnet
support qwen-image blockwise-controlnet depth
2025-08-13 20:35:10 +08:00
mi804
f7f5c07570 fix long prompt for qwen-image 2025-08-13 17:23:00 +08:00
mi804
ec1a936624 update date 2025-08-13 13:38:19 +08:00
mi804
6e6136586c support controlnet depth 2025-08-13 13:36:26 +08:00
Zhongjie Duan
34766863f8 Merge pull request #787 from modelscope/qwen-image-controlnet-update-1
support qwen-image controlnet
2025-08-12 20:37:05 +08:00
Artiprocher
1d76d5e828 support qwen-image controlnet 2025-08-12 17:17:08 +08:00
Zhongjie Duan
250540a398 Merge pull request #780 from modelscope/qwen-image-distill-lora
Qwen image distill lora
2025-08-11 15:05:19 +08:00
Artiprocher
46f3c38c37 Qwen-Image-Distill-LoRA 2025-08-11 15:04:21 +08:00
Artiprocher
9a8982efb1 Qwen-Image-Distill-LoRA 2025-08-11 15:01:21 +08:00
Zhongjie Duan
3c815cce4b Merge pull request #779 from modelscope/qwen-image-forward-fix
qwen-image dit original forward fix
2025-08-11 14:42:02 +08:00
Artiprocher
39d199c8bb qwen-image dit original forward fix 2025-08-11 14:41:32 +08:00
Zhongjie Duan
f5506d1e13 Merge pull request #769 from modelscope/qwen-image-lora-format
remove lora format alignment
2025-08-08 19:06:03 +08:00
Artiprocher
166a8734fe remove lora format alignment 2025-08-08 19:05:06 +08:00
Zhongjie Duan
b2273ec568 Merge pull request #768 from modelscope/lora-fix
lora-fix
2025-08-08 18:55:57 +08:00
Artiprocher
89c4e3bdb6 lora-fix 2025-08-08 18:55:13 +08:00
Zhongjie Duan
051ebf3439 fix wan2.2 5B usp (#763) 2025-08-08 16:26:04 +08:00
mi804
7cfadc2ca8 fix wan2.2 5B usp 2025-08-07 23:06:52 +08:00
Zhongjie Duan
32cf5d32ce Qwen-Image FP8 (#761)
* support qwen-image-fp8

* refine README

* bugfix

* bugfix
2025-08-07 16:56:02 +08:00
Zhongjie Duan
4f7c3b6a1e Merge pull request #755 from mi804/qwen-image-eligen
Qwen-Image-EliGen
2025-08-07 14:04:44 +08:00
mi804
57128dc89f update readme for qwen-image-eligen 2025-08-07 13:42:47 +08:00
Zhongjie Duan
d20680baae Merge pull request #756 from mi804/flux-eligen
fix flux-eligen bug
2025-08-06 20:09:00 +08:00
mi804
970403f78e fix flux-eligen bug 2025-08-06 20:07:21 +08:00
mi804
bee2a969e5 minor fix readme and path 2025-08-06 17:48:44 +08:00
mi804
2803ffcb38 minor fix 2025-08-06 17:39:00 +08:00
mi804
d3224e1fdc update qwen-image-eligen readme 2025-08-06 17:36:28 +08:00
mi804
3c2f85606f update model 2025-08-06 17:23:05 +08:00
mi804
1f25ad416b Merge branch 'main' into qwen-image-eligen 2025-08-06 15:57:13 +08:00
Zhongjie Duan
d0b9b25db7 Merge pull request #749 from mi804/training_args
support num_workers,save_steps,find_unused_parameters
2025-08-06 15:54:04 +08:00
mi804
ef09db69cd refactor model_logger 2025-08-06 15:47:35 +08:00
Zhongjie Duan
84ede171fd Merge pull request #752 from modelscope/qwen-image-lora-fromat
remove default in qwen-image lora
2025-08-06 15:42:03 +08:00
Artiprocher
6f4e38276e remove default in qwen-image lora 2025-08-06 15:41:22 +08:00
mi804
a3b67436a6 eligen ui 2025-08-06 15:04:38 +08:00
Zhongjie Duan
829ca3414b fmt fixes in wan_video_dit.py
fmt fixes in wan_video_dit.py
2025-08-06 14:39:25 +08:00
mi804
3915bc3ee6 minor fix 2025-08-06 10:58:53 +08:00
mi804
4299c999b5 restore readme 2025-08-06 10:56:46 +08:00
mi804
6bae70eee0 support num_workers,save_steps,find_unused_parameters 2025-08-06 10:52:59 +08:00
mi804
6452edb738 qwen_image eligen 2025-08-05 20:41:03 +08:00
Zhongjie Duan
bc739c78cd Merge pull request #746 from modelscope/qwen-image-distill
Qwen image distill
2025-08-05 19:21:37 +08:00
Artiprocher
2feaeb1a64 update readme 2025-08-05 19:20:37 +08:00
Artiprocher
09360cf4f5 qwen-image-distill 2025-08-05 19:18:43 +08:00
Yudong Jin
26461c1963 Update diffsynth/models/wan_video_dit.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-08-04 23:52:48 +08:00
Yudong Jin
0412fc7232 fmt fixes in wan_video_dit.py 2025-08-04 23:40:18 +08:00
Zhongjie Duan
8d2f6ad32e Merge pull request #735 from modelscope/qwen-image
qwen-image
2025-08-04 20:40:32 +08:00
Artiprocher
1625894694 bugfix 2025-08-04 20:35:44 +08:00
Artiprocher
c35f2d8bda qwen-image 2025-08-04 20:24:13 +08:00
Zhongjie Duan
a8ee7ec9ef Merge pull request #725 from mi804/imagedataset_jsonl
support jsonl dataset
2025-08-04 14:39:01 +08:00
Zhongjie Duan
46d390cf8a Merge pull request #727 from mi804/flux.1_kera_dev
support flux.1-kera-dev
2025-08-01 17:26:32 +08:00
mi804
6b8e3880ff fix lowvram inference 2025-08-01 17:25:50 +08:00
mi804
c1c3be2420 fix readmezh 2025-08-01 17:21:48 +08:00
mi804
b2554db100 fix krea typo 2025-08-01 17:13:45 +08:00
mi804
b63f81c6e3 support flux.1-kera-dev 2025-08-01 11:26:39 +08:00
mi804
cb2caa3a36 support jsonl 2025-07-31 16:24:58 +08:00
Zhongjie Duan
f0ea049faa Merge pull request #720 from mi804/wanvideo_seq_usp
Wanvideo seq usp
2025-07-31 10:04:57 +08:00
mi804
0954e8a017 fix vace usp 2025-07-30 19:40:08 +08:00
mi804
e4178e2501 fix usp dit_forward 2025-07-30 19:21:21 +08:00
mi804
0b860abf1b support arbitrary seq len 2025-07-30 19:07:16 +08:00
mi804
8c558b3526 fix modelconfig 2025-07-30 18:44:17 +08:00
mi804
aef982a53c Merge branch 'main' into wanvideo_seq_usp 2025-07-30 16:44:44 +08:00
Zhongjie Duan
db124fa6bc Merge pull request #715 from modelscope/nexusgen-eligen
NexusGen and EliGen
2025-07-29 20:28:07 +08:00
Artiprocher
2ed3860085 refine code 2025-07-29 20:10:08 +08:00
Artiprocher
87ab7d020b refine code 2025-07-29 20:02:34 +08:00
Artiprocher
03c8fd5e61 refine code 2025-07-29 18:49:18 +08:00
Artiprocher
9c51623fc2 refine code 2025-07-29 18:47:16 +08:00
Zhongjie Duan
8ec545d70c Merge pull request #713 from modelscope/bugfix3
update README
2025-07-29 17:06:28 +08:00
Artiprocher
79fa8607dc update README 2025-07-29 17:05:41 +08:00
mi804
7df48fc2b5 remove debug out 2025-07-29 13:33:14 +08:00
mi804
8ef91b3672 support training for eligen and nexusgen 2025-07-29 13:28:42 +08:00
Zhongjie Duan
2860470b4e Merge pull request #709 from modelscope/bugfix2
Bugfix2
2025-07-29 11:17:18 +08:00
Artiprocher
c125728ce0 bug fix 2025-07-29 11:16:50 +08:00
Zhongjie Duan
63eaa9e7ea Merge pull request #708 from modelscope/bug-fix
bug fix
2025-07-29 10:17:33 +08:00
Artiprocher
158567ca20 bug fix 2025-07-29 10:16:40 +08:00
Zhongjie Duan
de4e2703ca Merge pull request #706 from modelscope/wan2.2-patch
Wan2.2
2025-07-28 19:52:30 +08:00
Artiprocher
9e683bfe25 fix typo 2025-07-28 18:30:04 +08:00
Artiprocher
0befa05014 Merge branch 'wan2.2-patch' of https://github.com/modelscope/DiffSynth-Studio into wan2.2-patch 2025-07-28 18:27:20 +08:00
Artiprocher
283f35447a refine readme 2025-07-28 18:25:43 +08:00
Zhongjie Duan
c35414a652 Merge pull request #705 from modelscope/wan2.2
fix wan2.2 vae
2025-07-28 17:04:40 +08:00
Artiprocher
68aafab09e update readme 2025-07-28 17:02:30 +08:00
mi804
29663b25a6 fix wan2.2 vae 2025-07-28 16:49:28 +08:00
mi804
2861ec4d9f tmp commit for nexus-gen edit 2025-07-28 16:18:38 +08:00
Artiprocher
729c512c66 bugfix 2025-07-28 15:18:47 +08:00
Zhongjie Duan
2af3a6f6a2 Merge pull request #704 from modelscope/wan2.2
Wan2.2
2025-07-28 15:06:01 +08:00
mi804
05dba91f79 fix wan2.2 5B 2025-07-28 13:38:01 +08:00
mi804
b8f05bb342 tmp commit 2025-07-28 11:09:33 +08:00
Artiprocher
5f68727ad3 refine code 2025-07-28 11:00:54 +08:00
mi804
bba44173d2 minor fix 2025-07-25 17:24:42 +08:00
mi804
9015d08927 support wan2.2 A14B I2V&T2V 2025-07-25 17:09:53 +08:00
Zhongjie Duan
1dfa32f0ae Merge pull request #702 from modelscope/lora-rearrange
Lora rearrange
2025-07-24 19:12:09 +08:00
Artiprocher
c98e31fee3 update README 2025-07-24 19:10:06 +08:00
Artiprocher
f3d2470e84 update README 2025-07-24 19:02:08 +08:00
Artiprocher
4ad6bd4e23 rearrange lora loading modules 2025-07-24 18:56:25 +08:00
mi804
3aed244c6f update variable 2025-07-23 11:20:06 +08:00
Zhongjie Duan
783c435d88 Merge pull request #701 from modelscope/readme-refine
update readme
2025-07-23 11:14:25 +08:00
Artiprocher
cd1ba7281b update readme 2025-07-23 11:13:38 +08:00
Zhongjie Duan
970ff12ff5 Merge pull request #700 from modelscope/readme-refine
Readme refine
2025-07-22 20:48:47 +08:00
Artiprocher
2827b60330 update readme 2025-07-22 20:48:19 +08:00
Artiprocher
b3df7e5e21 update readme 2025-07-22 20:43:58 +08:00
Artiprocher
c18b5a0c71 update readme 2025-07-22 20:31:44 +08:00
Artiprocher
b9f7d08219 update readme 2025-07-22 20:30:34 +08:00
Artiprocher
11ea986e67 update readme 2025-07-22 20:28:29 +08:00
Artiprocher
b06066f25b update readme 2025-07-22 20:26:41 +08:00
Artiprocher
0b3400bca3 update readme 2025-07-22 20:22:48 +08:00
Artiprocher
0d509241c0 update readme 2025-07-22 20:20:56 +08:00
Artiprocher
ebeda32215 update readme 2025-07-22 20:02:21 +08:00
Artiprocher
ff95c56884 refine readme 2025-07-22 13:22:47 +08:00
Zhongjie Duan
2871535f3b Merge pull request #699 from modelscope/AttrCtrl
Support AttriCtrl
2025-07-21 19:18:18 +08:00
Artiprocher
e3c5d2540b support value controller training 2025-07-21 19:16:30 +08:00
Artiprocher
22705a44b4 update value controller 2025-07-21 16:30:06 +08:00
Zhongjie Duan
43a8d9768c Merge pull request #697 from mi804/nexus-genv2
add nexus-gen news
2025-07-21 15:09:05 +08:00
mi804
dbee3a1ae0 add nexus-gen news 2025-07-21 15:07:13 +08:00
mi804
f1f00c4255 support wan2.2 5B I2V 2025-07-21 14:47:58 +08:00
ziyannchen
c05b1a2fd0 fix a bug in sliding window inference 2025-07-20 11:13:20 +00:00
mi804
55951590f5 support wan2.2 5B T2V 2025-07-20 18:13:50 +08:00
Zhongjie Duan
1384de0353 Support LoRA encoder (#695)
* lora_encoder
2025-07-19 20:44:03 +08:00
ziyannchen
05c6b49b90 fix a bug in sliding_window inference 2025-07-16 10:30:33 +00:00
Zhongjie Duan
d19fcc8c04 Merge pull request #688 from modelscope/flux_vram_management
flux series vram management
2025-07-15 20:12:08 +08:00
Artiprocher
af6b1d4246 flux series vram management 2025-07-15 20:11:02 +08:00
Zhongjie Duan
cbd10fb27d Merge pull request #684 from modelscope/value_controller
support flux value controller
2025-07-15 10:11:08 +08:00
Zhongjie Duan
836fa5c957 Merge pull request #685 from lzws/main
update flux lora convert state dict
2025-07-14 14:58:07 +08:00
lzw478614@alibaba-inc.com
dc066aca2d update flux lora convert state dict 2025-07-14 14:08:22 +08:00
Zhongjie Duan
44f6ffbf56 Merge pull request #673 from lzws/main
support other lora format
2025-07-14 13:51:47 +08:00
Artiprocher
0a24d0819f support flux value controller 2025-07-14 13:37:55 +08:00
lzw478614@alibaba-inc.com
f0106cd48c support other lora forma 2025-07-09 14:01:49 +08:00
lzw478614@alibaba-inc.com
dee4075380 support other lora format 2025-07-09 13:59:43 +08:00
Zhongjie Duan
a692389df0 Merge pull request #670 from modelscope/flux-any-training
support flux any training
2025-07-08 21:45:02 +08:00
Artiprocher
629e9be4ce support flux any training 2025-07-08 19:55:27 +08:00
Yingda Chen
3a3d9010b8 Update README.md 2025-07-08 17:24:39 +08:00
Yingda Chen
a25334b352 Add files via upload 2025-07-08 17:15:21 +08:00
handoku
00279a8375 fea : enable wan video usp for arbitrary seq len 2025-07-08 16:43:43 +08:00
Zhongjie Duan
89397c755a Merge pull request #667 from modelscope/lora_merge
Lora merge
2025-07-07 13:30:34 +08:00
lzws
77676b5cea Update FLUX.1-dev-LoRAFusion.py 2025-07-07 10:54:49 +08:00
Zhongjie Duan
0f4b08daa3 Merge pull request #661 from longredzhong/main
fix wan vace load mask video
2025-07-04 11:14:38 +08:00
longredzhong
63b2c51e11 fix wan vace load mask video 2025-07-04 10:22:34 +08:00
Artiprocher
8a9dbbd3ba support lora fusion 2025-07-03 18:49:46 +08:00
Zhongjie Duan
22d28665fe Merge pull request #657 from modelscope/dev-dzj
support json dataset
2025-07-02 20:08:13 +08:00
Artiprocher
1363a0559f support json dataset 2025-07-02 20:07:16 +08:00
lzw478614@alibaba-inc.com
9cb887015b lora hotload and merge 2025-07-02 13:32:24 +08:00
Zhongjie Duan
789dade026 Merge pull request #655 from modelscope/dev-dzj
refine wan readme
2025-07-02 11:37:18 +08:00
Artiprocher
9bb51fe879 refine wan readme 2025-07-02 11:36:41 +08:00
Zhongjie Duan
d9c812818d Merge pull request #653 from mi804/main
fix step1xedit
2025-07-01 17:16:41 +08:00
mi804
c8e9a96196 fix step1xedit 2025-07-01 17:12:53 +08:00
Zhongjie Duan
6143af4654 Merge pull request #651 from mi804/infiniteyou_controlnet_replace
infiniteyou_controlnet outof pipeline
2025-07-01 13:39:47 +08:00
Zhongjie Duan
9458e382b0 Merge pull request #652 from modelscope/flux-refactor
refine readme
2025-07-01 11:34:00 +08:00
Artiprocher
4f2d9226cf refine readme 2025-07-01 11:33:04 +08:00
mi804
f688a469b1 infiniteyou_controlnet outof pipeline 2025-07-01 11:10:46 +08:00
Zhongjie Duan
c8ea3b3356 Merge pull request #649 from modelscope/flux-refactor
refine readme
2025-06-30 11:46:16 +08:00
Artiprocher
6e9472b470 refine readme 2025-06-30 11:45:40 +08:00
Zhongjie Duan
a5c03c5272 Merge pull request #648 from modelscope/flux-refactor
refine readme
2025-06-30 11:44:47 +08:00
Artiprocher
8068ac2592 refine readme 2025-06-30 11:43:59 +08:00
Zhongjie Duan
5f80e7ac5e Merge pull request #647 from modelscope/flux-refactor
kontext training
2025-06-30 11:09:22 +08:00
Artiprocher
157e0be49d kontext training 2025-06-30 11:00:10 +08:00
Zhongjie Duan
3dbe271aab Merge pull request #646 from modelscope/flux-refactor
Flux refactor
2025-06-29 18:04:05 +08:00
Artiprocher
44e2eecdf1 flux-kontext 2025-06-29 15:59:04 +08:00
Artiprocher
8c226e83a6 flux-kontext 2025-06-29 15:51:45 +08:00
Artiprocher
009f26bb40 kontext 2025-06-27 18:38:40 +08:00
Artiprocher
fcf2fbc07f flux-refactor 2025-06-27 10:20:11 +08:00
Artiprocher
b603acd36a refine examples 2025-06-25 13:38:21 +08:00
Artiprocher
6c8bb6438b infiniteyou 2025-06-25 10:33:11 +08:00
Artiprocher
8072d3839d refine examples 2025-06-24 19:17:54 +08:00
Artiprocher
c8ad643374 refine examples 2025-06-24 19:17:43 +08:00
Zhongjie Duan
31f9df5e62 Merge pull request #567 from emmanuel-ferdman/main
Migrate to modern Python Logger API
2025-06-24 15:32:14 +08:00
Zhongjie Duan
e2f415524a Merge pull request #587 from ernestchu/patch-1
Fix typo
2025-06-24 15:23:19 +08:00
Zhongjie Duan
3eb7e7530e Merge pull request #632 from lzws/flux-refactor
step1x, teacache, flex refactor
2025-06-24 15:19:54 +08:00
Zhongjie Duan
916aa54595 Merge branch 'flux-refactor' into flux-refactor 2025-06-24 15:19:42 +08:00
Zhongjie Duan
6ddbd43f7b Merge pull request #634 from modelscope/bugfix
fix videodataset to load images
2025-06-24 11:42:14 +08:00
Artiprocher
a37a83ecc3 fix videodataset to load images 2025-06-24 11:38:43 +08:00
Zhongjie Duan
f2a0d0c85f Merge pull request #633 from modelscope/bugfix
fix i2v resolution
2025-06-24 10:59:31 +08:00
Artiprocher
93194f44e8 fix i2v resolution 2025-06-24 10:56:52 +08:00
Artiprocher
c4e5033532 flux controlnet 2025-06-23 21:01:53 +08:00
lzw478614@alibaba-inc.com
cc6cd26733 step1x, teacache, flex refactor 2025-06-23 17:06:00 +08:00
Zhongjie Duan
1113d305d1 Merge pull request #626 from mi804/flux-refactor
Flux refactor
2025-06-23 10:20:40 +08:00
mi804
6d5f8b7423 flux_eligen_refactor 2025-06-20 16:53:41 +08:00
mi804
1b3c204d20 flux_ipadapter_refactor 2025-06-20 14:49:09 +08:00
Artiprocher
1788d50f0a flux-refactor 2025-06-19 15:04:30 +08:00
Artiprocher
e7a21dbf0b flux-refactor 2025-06-19 14:53:11 +08:00
Zhongjie Duan
3b3e1e4d44 Merge pull request #623 from modelscope/usp
Usp
2025-06-19 10:15:39 +08:00
Artiprocher
24426e3a32 update README_zh 2025-06-19 10:06:55 +08:00
Artiprocher
31369bab15 update import 2025-06-19 10:04:24 +08:00
mi804
551721658b fix bug for usp with refimage 2025-06-16 19:38:45 +08:00
mi804
46f052375f fix vace usp 2025-06-16 18:54:29 +08:00
Zhongjie Duan
c2d35a2157 update wan training (#614)
update wan training
2025-06-16 15:48:35 +08:00
mi804
4c052e42bc fix usp download 2025-06-16 15:43:39 +08:00
Zhongjie Duan
a88613555d Merge pull request #612 from Yunnglin/update/eval_news
update readme for eval
2025-06-16 14:06:52 +08:00
Zhongjie Duan
c164519ef1 vram management support torch<2.6.0 (#613)
support torch<2.6.0
2025-06-16 13:08:29 +08:00
Yunnglin
afff5ffb21 update readme 2025-06-16 11:08:53 +08:00
Yunnglin
a8481fd5e1 update readme 2025-06-16 11:00:53 +08:00
Zhongjie Duan
8584e50309 Merge pull request #611 from modelscope/refactor
fix model id
2025-06-16 10:58:14 +08:00
Artiprocher
9f3e02f167 fix model id 2025-06-16 10:57:33 +08:00
Zhongjie Duan
7ad9b9aecc Merge pull request #609 from modelscope/refactor
refine readme
2025-06-13 14:14:14 +08:00
Artiprocher
b6a111d3a2 refine readme 2025-06-13 14:13:38 +08:00
Zhongjie Duan
bd6f2695a9 Merge pull request #608 from modelscope/refactor
Refactor
2025-06-13 14:02:49 +08:00
Artiprocher
6eecc9d442 refine readme 2025-06-13 14:02:20 +08:00
Artiprocher
35269783d7 refine readme 2025-06-13 14:00:58 +08:00
Zhongjie Duan
9534a78167 Merge pull request #607 from modelscope/refactor
wan-refactor
2025-06-13 13:49:00 +08:00
Artiprocher
830b1b7202 wan-refactor 2025-06-13 13:46:17 +08:00
Zhongjie Duan
436a91e0c9 Merge pull request #602 from modelscope/revert-601-wan-refactor
Revert "Wan refactor"
2025-06-11 17:30:06 +08:00
Zhongjie Duan
40760ab88b Revert "Wan refactor" 2025-06-11 17:29:27 +08:00
CD22104
8badd63a2d Merge pull request #601 from CD22104/wan-refactor
Wan refactor
2025-06-11 17:26:58 +08:00
CD22104
b1afff1728 camera 2025-06-11 17:24:09 +08:00
Artiprocher
6e977e1181 refine wan doc 2025-06-06 15:19:09 +08:00
Artiprocher
62f6ca2b8a new wan trainer 2025-06-06 14:58:41 +08:00
Ernie Chu
4e00c109e3 Fix typo
Change
Only `num_frames % 4 != 1` is acceptable
to
Only `num_frames % 4 == 1` is acceptable
2025-05-27 21:20:38 -04:00
Artiprocher
8f10a9c353 training script 2025-05-19 19:02:52 +08:00
Emmanuel Ferdman
a3a35acc7e Migrate to modern Python Logger API
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2025-05-12 14:09:26 -07:00
Artiprocher
675eefa07e training framework 2025-05-12 17:48:28 +08:00
Artiprocher
dbef6122e9 ... 2025-05-05 23:23:06 +08:00
Artiprocher
d150bcf622 ... 2025-05-05 13:01:45 +08:00
Artiprocher
451aab0116 refactor 2025-05-04 15:42:11 +08:00
Artiprocher
3edf3583b1 wan-fun-v1.1 reference control 2025-04-30 11:38:17 +08:00
Zhongjie Duan
ef2a7abad4 Step1x vram (#556)
* support step1x vram management
2025-04-28 10:13:20 +08:00
Zhongjie Duan
32f630ff5f Merge pull request #555 from modelscope/step1x
support step1x
2025-04-27 20:40:43 +08:00
Artiprocher
109a0a0d49 support step1x 2025-04-27 20:37:43 +08:00
Zhongjie Duan
4f01b37a2a Merge pull request #553 from modelscope/flex
Flex
2025-04-25 12:24:18 +08:00
Artiprocher
cc6306136c flex full support 2025-04-25 12:23:29 +08:00
Artiprocher
419ace37f3 flex full support 2025-04-25 11:32:13 +08:00
Artiprocher
ccf24c363f flex control 2025-04-24 19:18:54 +08:00
Artiprocher
b7a1ac6671 flex t2i 2025-04-24 14:51:40 +08:00
Zhongjie Duan
e54c0a8468 Merge pull request #548 from CD22104/main
liblib-controlnet
2025-04-22 14:54:16 +08:00
xuyixuan.xyx
5f4cb32255 liblib-controlnet 2025-04-22 13:45:49 +08:00
Zhongjie Duan
7b6cf39618 Merge pull request #544 from modelscope/Artiprocher-patch-1
Update train_wan_t2v.py
2025-04-17 15:39:44 +08:00
Zhongjie Duan
bf81de0c88 Update train_wan_t2v.py 2025-04-17 15:37:30 +08:00
Zhongjie Duan
b36cad6929 Merge pull request #543 from modelscope/wan-flf2v
bugfix
2025-04-17 15:24:36 +08:00
Zhongjie Duan
b161bd6dfd bugfix 2025-04-17 15:23:46 +08:00
Zhongjie Duan
538cfcbb77 Merge pull request #541 from modelscope/wan-flf2v
Wan flf2v
2025-04-17 14:51:08 +08:00
Artiprocher
a4105d2c0e support wan-flf2v 2025-04-17 14:48:55 +08:00
Artiprocher
553b341f5f support wan-flf2v 2025-04-17 14:47:55 +08:00
Zhongjie Duan
e9e24b8cf1 Merge pull request #537 from CD22104/main
issue523
2025-04-16 15:53:39 +08:00
CD22104
1b693d0028 issue523 2025-04-16 15:49:52 +08:00
Zhongjie Duan
a4c3c07229 Merge pull request #536 from modelscope/wan-vace-quant
support vace quant
2025-04-16 10:43:14 +08:00
Artiprocher
6b24748c80 support vace quant 2025-04-16 10:29:21 +08:00
Zhongjie Duan
8f2f8646eb Merge pull request #526 from mohui37/main
Update train_wan_t2v.py
2025-04-16 09:55:19 +08:00
Zhongjie Duan
e3ac438f5a Merge pull request #533 from modelscope/wan-vace
vace
2025-04-15 18:47:36 +08:00
Artiprocher
b731628112 vace 2025-04-15 17:52:25 +08:00
mohui37
0dc56d9dcc Update train_wan_t2v.py
在应用itv的管道处理数据时有bug,提交修复
2025-04-11 17:05:40 +08:00
Zhongjie Duan
b925b402e2 Merge pull request #522 from modelscope/Artiprocher-patch-1
Update README.md
2025-04-10 11:42:32 +08:00
Zhongjie Duan
61d9653536 Update README.md 2025-04-10 11:42:18 +08:00
Zhongjie Duan
53f01e72e6 Update setup.py 2025-04-09 15:38:17 +08:00
Zhongjie Duan
55e5e373dd Update publish.yaml 2025-04-09 15:37:46 +08:00
Zhongjie Duan
4a0921ada1 Update requirements.txt 2025-04-09 15:37:16 +08:00
Zhongjie Duan
5129d3dc52 Update setup.py 2025-04-09 15:34:02 +08:00
Zhongjie Duan
ee9bab80f2 Update requirements.txt 2025-04-09 15:33:21 +08:00
Zhongjie Duan
cd8884c9ef Update setup.py 2025-04-09 15:27:36 +08:00
Zhongjie Duan
46744362de Update requirements.txt 2025-04-09 15:26:13 +08:00
Zhongjie Duan
0f0cdc3afc Update setup.py 2025-04-09 15:15:18 +08:00
Zhongjie Duan
a33c63af87 Merge pull request #518 from modelscope/wan-fun
Wan fun
2025-04-08 19:25:12 +08:00
Artiprocher
3cc9764bc9 support more wan models 2025-04-08 19:22:53 +08:00
Artiprocher
f6c6e3c640 support more wan models 2025-04-08 17:19:54 +08:00
Artiprocher
60a9db706e support more wan models 2025-04-08 17:07:10 +08:00
lzw478614@alibaba-inc.com
a98700feb2 support wan-fun-inp generating 2025-04-06 22:55:42 +08:00
lzw478614@alibaba-inc.com
5418ca781e support load wan2.1-fun-inp-1.3B and 14B model 2025-04-03 16:37:59 +08:00
Zhongjie Duan
71eee780fb Merge pull request #511 from modelscope/version-update
Update setup.py
2025-04-02 16:35:01 +08:00
Zhongjie Duan
4864453e0a Update setup.py 2025-04-02 16:34:50 +08:00
Zhongjie Duan
c5a32f76c2 Merge pull request #509 from modelscope/wan-lora-converter
Update lora.py
2025-04-02 13:08:48 +08:00
Zhongjie Duan
c4ed3d3e4b Update lora.py 2025-04-02 13:08:16 +08:00
Zhongjie Duan
803ddcccc7 Merge pull request #505 from modelscope/infinityou
Infinityou
2025-03-31 20:21:10 +08:00
Artiprocher
4cd51fecf2 refine infinityou 2025-03-31 20:19:32 +08:00
Zhongjie Duan
3b0211a547 Merge pull request #499 from calmhawk/hotfix/tc_bug_with_usp
Fix TeaCache bug and optimize memory usage of WAN with USP feature
2025-03-31 16:24:03 +08:00
mi804
e88328d152 support infiniteyou 2025-03-31 14:29:15 +08:00
calmhawk
52896fa8dd Fix TeaCache bug with usp support integration and optimize memory usage by clearing attn cache 2025-03-30 01:13:34 +08:00
Zhongjie Duan
c7035ad911 Merge pull request #493 from modelscope/lzws-patch-1
Update wan_video.py
2025-03-26 19:48:33 +08:00
lzws
070811e517 Update wan_video.py
prompter.encode_prompt use pipe's deivce
2025-03-26 17:51:13 +08:00
Zhongjie Duan
7e010d88a5 Merge pull request #485 from modelscope/usp
support Unified Sequence Parallel
2025-03-25 19:28:42 +08:00
Artiprocher
4e43d4d461 fix usp dependency 2025-03-25 19:26:24 +08:00
Zhongjie Duan
d7efe7e539 Merge pull request #482 from modelscope/Artiprocher-patch-1
Update README.md
2025-03-25 16:44:48 +08:00
Zhongjie Duan
633f789c47 Update README.md 2025-03-25 16:44:05 +08:00
Zhongjie Duan
88607f404e Merge pull request #480 from mi804/wanx_tensor_parallel
update tensor parallel
2025-03-25 15:33:15 +08:00
mi804
6d405b669c update tensor parallel 2025-03-25 12:38:17 +08:00
ByteDance
d0fed6ba72 add usp for wanx 2025-03-25 11:51:37 +08:00
ByteDance
64eaa0d76a Merge branch 'usp' into xdit 2025-03-25 11:45:49 +08:00
Zhongjie Duan
3dc28f428f Merge pull request #465 from CD22104/main
cd0319-ImportError-libX11.so.6
2025-03-19 14:14:01 +08:00
xuyixuan.xyx
3c8a3fe2e1 cd0319 2025-03-19 14:00:42 +08:00
Zhongjie Duan
e28c246bcc Merge pull request #457 from modelscope/wan-tp
support wan tensor parallel (preview)
2025-03-17 19:53:17 +08:00
Artiprocher
04d03500ff support wan tensor parallel (preview) 2025-03-17 19:39:45 +08:00
Jinzhe Pan
54081bdcbb Merge pull request #1 from Eigensystem/fjr
fix some bugs
2025-03-17 17:07:07 +08:00
feifeibear
d8b250607a polish code 2025-03-17 09:04:51 +00:00
feifeibear
1e58e6ef82 fix some bugs 2025-03-17 09:00:52 +00:00
Jinzhe Pan
42cb7d96bb feat: sp for wan 2025-03-17 08:31:45 +00:00
Zhongjie Duan
39890f023f Merge pull request #448 from modelscope/wan-teacache
support teacache in wan
2025-03-14 18:21:20 +08:00
Artiprocher
e425753f79 support teacache in wan 2025-03-14 17:45:52 +08:00
Zhongjie Duan
ca40074d72 Merge pull request #447 from modelscope/lora
Lora
2025-03-14 15:34:22 +08:00
Artiprocher
1fd3d67379 improve lora loading efficiency 2025-03-14 15:15:37 +08:00
Artiprocher
3acd9c73be improve lora loading efficiency 2025-03-14 15:05:54 +08:00
Zhongjie Duan
32422b49ee Merge pull request #436 from mi804/hunyuanvideo_i2v
support hunyuanvideo-i2v
2025-03-13 19:38:11 +08:00
Furkan Gözükara
5c4d3185fb Merge branch 'modelscope:main' into main 2025-03-13 14:22:34 +03:00
Zhongjie Duan
762bcbee58 Merge pull request #444 from modelscope/wan-itv-train
Wan itv train
2025-03-13 15:40:51 +08:00
Zhongjie Duan
6b411ada16 Merge branch 'main' into wan-itv-train 2025-03-13 15:24:59 +08:00
Artiprocher
a25bd74d8b support wan i2v training 2025-03-13 15:14:10 +08:00
Furkan Gözükara
fb5fc09bad Made much much faster than before
enable debug to see every message
2025-03-13 02:30:42 +03:00
Furkan Gözükara
3fdba19e02 Fixes high RAM usage Wan 2.1
Fixes high RAM usage Wan 2.1
2025-03-12 15:49:57 +03:00
mi804
4bec2983a9 support hunyuanvideo_i2v 2025-03-11 16:20:09 +08:00
Zhongjie Duan
03ea27893f Merge pull request #431 from modelscope/wan-update
Wan update
2025-03-10 18:26:32 +08:00
Artiprocher
718b45f2af bugfix 2025-03-10 18:25:23 +08:00
Zhongjie Duan
63a79eeb2a Merge pull request #426 from Zeyi-Lin/main
Modify the swanlab `logdir` location
2025-03-10 17:59:17 +08:00
Artiprocher
e757013a14 vram optimization 2025-03-10 17:47:14 +08:00
Artiprocher
a05f647633 vram optimization 2025-03-10 17:11:11 +08:00
ZeYi Lin
7604be0301 output_path join swanlog 2025-03-08 13:57:08 +08:00
mi804
945b43492e load hunyuani2v model 2025-03-07 17:43:30 +08:00
Artiprocher
b548d7caf2 refactor wan dit 2025-03-07 16:35:26 +08:00
Zhongjie Duan
6e316fd825 Merge pull request #421 from modelscope/wan-update
support diffusers format wan and other lora
2025-03-06 17:41:36 +08:00
Artiprocher
84fb61aaaf support diffusers format wan and other lora 2025-03-06 17:40:21 +08:00
Zhongjie Duan
50a9946b57 Merge pull request #419 from modelscope/wan-update
wan image encoder to fp16
2025-03-06 16:28:55 +08:00
Artiprocher
384d1a8198 wan image encoder to fp16 2025-03-06 16:28:23 +08:00
Zhongjie Duan
a58c193d0c Merge pull request #412 from boopage/patch-1
Update train_wan_t2v.py - include .jpeg for image detection
2025-03-06 12:46:43 +08:00
boopage
34a5ef8c15 Update train_wan_t2v.py
Included .jpeg extension for image type detection, preventing an error trying to the read image as a video format
2025-03-05 11:13:11 +01:00
Zhongjie Duan
41e3e4e157 Merge pull request #410 from mi804/dreambooth_lora
support dreambooth lora
2025-03-05 11:48:00 +08:00
mi804
e576d71908 support dreambooth lora 2025-03-05 11:20:10 +08:00
Zhongjie Duan
906aadbf1b Merge pull request #404 from modelscope/wan-examples-update
update wan examples
2025-03-04 21:54:33 +08:00
Artiprocher
bf0bf2d5ba update wan examples 2025-03-04 21:54:04 +08:00
Zhongjie Duan
fe0fff1399 Merge pull request #401 from modelscope/flux-diffusers
support load flux from diffusers
2025-03-04 20:52:07 +08:00
Artiprocher
50fceb84d2 support load flux from diffusers 2025-03-04 20:38:25 +08:00
Zhongjie Duan
100da41034 Merge pull request #400 from mi804/eligen
update eligen model from huggingface
2025-03-04 20:11:18 +08:00
mi804
c382237833 update eligen from huggingface 2025-03-04 20:04:24 +08:00
Zhongjie Duan
98ac191750 Merge pull request #398 from modelscope/reduce_dependency
reduce dependency
2025-03-04 16:22:29 +08:00
Artiprocher
2f73dbe7a3 reduce dependency 2025-03-04 15:21:00 +08:00
Zhongjie Duan
a66203a391 Update setup.py 2025-03-04 10:08:16 +08:00
Zhongjie Duan
fab61f614b Merge pull request #394 from modelscope/wan-train-update
fix swanlab after test
2025-03-03 19:00:48 +08:00
Artiprocher
6b67a11ad6 fix swanlab after test 2025-03-03 18:59:34 +08:00
Zhongjie Duan
91f77d268c Merge pull request #393 from modelscope/wan-train-update
support resume training
2025-03-03 18:45:17 +08:00
Artiprocher
eb4d5187d8 support resume training 2025-03-03 18:31:31 +08:00
Zhongjie Duan
ee4b02247c Merge pull request #392 from modelscope/sage_attention
Sage attention
2025-03-03 14:28:36 +08:00
Artiprocher
da8e1fe7e4 support sage attention 2025-03-03 14:19:16 +08:00
Zhongjie Duan
3db824c281 Merge pull request #390 from YunhongLu-ZJU/main
revised image quality metric
2025-03-03 13:36:34 +08:00
YunhongLu-ZJU
df2ecafd3f revised 2025-03-03 12:30:26 +08:00
Zhongjie Duan
217652d28e Merge pull request #389 from modelscope/requirements
Requirements
2025-03-03 11:25:31 +08:00
Artiprocher
f64c766dcd update install guide in README 2025-03-03 11:24:48 +08:00
Artiprocher
076fd85556 update install guide in README 2025-03-03 11:10:51 +08:00
Zhongjie Duan
c7912ed827 Merge pull request #388 from modelscope/preference_model
Preference model
2025-03-02 19:56:00 +08:00
Artiprocher
e63f9d6993 update preference models 2025-03-02 19:52:27 +08:00
Raffaele Mancuso
d80ef3a677 Sentencepiece requires cmake 2025-03-02 10:58:42 +01:00
philipy1219
852c3d831f support sageattn 2025-03-02 15:09:21 +08:00
Zhongjie Duan
ceb92ee7aa Merge pull request #378 from modelscope/wan-video-params
update wan input params
2025-02-28 19:52:20 +08:00
Artiprocher
3a75026176 update wan input params 2025-02-28 19:43:18 +08:00
Zhongjie Duan
6a92b08244 Merge pull request #375 from modelscope/swanlab-dev
del swanlab because of bad cases
2025-02-28 16:16:56 +08:00
Zhongjie Duan
38bc785ea9 Merge branch 'main' into swanlab-dev 2025-02-28 16:16:15 +08:00
Artiprocher
a466fdca8f del swanlab 2025-02-28 16:13:06 +08:00
Zhongjie Duan
f9f49e3c78 Merge pull request #374 from modelscope/wan-tokenizer-bugfix
align wan tokenizer to official
2025-02-28 16:05:36 +08:00
Artiprocher
61a30673c2 align wan tokenizer to official 2025-02-28 15:50:07 +08:00
Yingda Chen
a48822ec00 Merge pull request #372 from Zeyi-Lin/main
fix: text-to-image swanlab_logger
2025-02-28 14:38:36 +08:00
ZeYi Lin
b6c3d2b74a fix: logger 2025-02-28 12:51:58 +08:00
Zhongjie Duan
5006c2176c Merge pull request #371 from modelscope/wan-video-readme
Update README.md
2025-02-28 10:10:03 +08:00
Zhongjie Duan
d3d3556ff6 Update README.md 2025-02-28 10:09:48 +08:00
Zhongjie Duan
6fa8dbe077 Merge pull request #366 from modelscope/swanlab
Swanlab
2025-02-27 19:32:23 +08:00
Artiprocher
a57749ef60 update swanlab log 2025-02-27 19:30:53 +08:00
Artiprocher
b5c1d33e58 update swanlab log 2025-02-27 19:21:51 +08:00
Zhongjie Duan
34a9f82865 Merge pull request #365 from modelscope/wan-train-dev
update wanx lora examples
2025-02-27 19:07:10 +08:00
Artiprocher
18dc6cb962 update wanx lora examples 2025-02-27 19:06:24 +08:00
wang96
490d420d82 fix bugs 2025-02-27 15:26:39 +08:00
wang96
0aca943a39 Merge remote-tracking branch 'upstream/main' 2025-02-27 15:23:55 +08:00
Zhongjie Duan
c760208614 Merge pull request #360 from modelscope/wan-train-dev
support wan image training
2025-02-27 12:58:32 +08:00
Artiprocher
fad7aea58a support wan image training 2025-02-27 12:56:55 +08:00
Zhongjie Duan
b42eb1444c Merge pull request #357 from modelscope/bugfix
bugfix
2025-02-27 11:06:24 +08:00
Zhongjie Duan
25a247dd3f bugfix 2025-02-27 11:06:10 +08:00
Zhongjie Duan
7792017a02 Update README.md 2025-02-27 10:52:47 +08:00
Zhongjie Duan
0219e8d2f3 Update README.md 2025-02-26 22:53:07 +08:00
Zhongjie Duan
1d309a14a3 Merge pull request #352 from modelscope/bugfix
Fix Wan VAE device
2025-02-26 20:03:53 +08:00
Zhongjie Duan
7df73ceaaf Fix Wan VAE device 2025-02-26 20:03:26 +08:00
wang96
0dbb3d333f feat: support I2V training 2025-02-26 19:50:59 +08:00
ZeYi Lin
1419bec53d feat: add swanlab logger 2025-02-26 17:12:54 +08:00
Zhongjie Duan
cf12723c89 Merge pull request #347 from co63oc/fix1
Fix typos
2025-02-26 15:50:36 +08:00
co63oc
4268f5466b Fix 2025-02-26 14:18:36 +08:00
Zhongjie Duan
b9f5a00d98 Merge pull request #345 from ghunkins/dev/ghunkins/allow-for-py39
🐍 Remove Python 3.10 Type Hint
2025-02-26 11:42:19 +08:00
Zhongjie Duan
7d44dc99fb support wan full training
support wan full train
2025-02-26 11:38:51 +08:00
Artiprocher
b20de1b44d support wan full train 2025-02-26 11:34:04 +08:00
Gregory D. Hunkins
366ee0f542 remove py310 type hint 2025-02-25 22:29:53 -05:00
Artiprocher
bed770248b update examples 2025-02-26 10:25:36 +08:00
Kohaku-Blueleaf
020560d2b5 Fix num_frames in i2v (#339)
* Fix num_frames in i2v

* Remove print in flash_attention
2025-02-26 10:05:51 +08:00
Zhongjie Duan
af7d305f00 Wan video (#338) 2025-02-25 19:00:43 +08:00
Zhongjie Duan
427232cbc0 Merge pull request #328 from modelscope/stepvideo
Stepvideo low VRAM support!
2025-02-18 18:01:40 +08:00
Zhongjie Duan
2899283c01 Update stepvideo examples 2025-02-18 18:00:08 +08:00
Artiprocher
9cff769fbd optimize stepvideo vae 2025-02-18 17:28:05 +08:00
Zhongjie Duan
23e33273f1 Merge pull request #327 from modelscope/stepvideo
support stepvideo quantized
2025-02-17 19:44:41 +08:00
Artiprocher
f191353cf4 support stepvideo quantized 2025-02-17 19:43:47 +08:00
Zhongjie Duan
66a094fc84 Merge pull request #326 from modelscope/stepvideo
support stepvideo
2025-02-17 17:35:26 +08:00
Artiprocher
3681adc5ac support stepvideo 2025-02-17 17:32:25 +08:00
YunhongLu-ZJU
4449faaa01 Merge branch 'modelscope:main' into main 2025-02-17 14:45:13 +08:00
YunhongLu-ZJU
991ba162bd add new quality metric 2025-02-17 14:42:20 +08:00
YunhongLu-ZJU
77d0f4d297 add image quality metric 2025-02-14 14:02:17 +08:00
YunhongLu-ZJU
a834371d50 add quality metric 2025-02-14 13:59:56 +08:00
hongzhang.hz
acda7d891a image quality metric 2025-02-14 12:39:06 +08:00
Zhongjie Duan
7434ec8fcd Merge pull request #324 from modelscope/vram_management
support vram management in flux
2025-02-14 10:54:55 +08:00
Artiprocher
0699212665 support vram management in flux 2025-02-13 15:11:39 +08:00
Zhongjie Duan
f47de78b59 Merge pull request #323 from mi804/eligen
update eligen dataset
2025-02-12 19:14:02 +08:00
mi804
5fdc8039ec update eligen dataset 2025-02-11 13:53:51 +08:00
Zhongjie Duan
46d4616e23 Update setup.py 2025-02-06 20:12:01 +08:00
Zhongjie Duan
2e597335be Merge pull request #320 from mi804/eligen
update eligen ui and readme
2025-01-24 16:40:45 +08:00
mi804
d346300162 update eligen ui and readme 2025-01-24 11:26:48 +08:00
Zhongjie Duan
1df7387f1b Merge pull request #318 from modelscope/hunyuanvideo-seed
fix rand device
2025-01-15 20:07:51 +08:00
Artiprocher
75d62a02d1 fix rand device 2025-01-15 19:30:38 +08:00
Zhongjie Duan
9db26879df Merge pull request #317 from mi804/eligen
update eligen logo_transfer
2025-01-14 17:49:03 +08:00
mi804
7beac7972e update eligen logo_transfer 2025-01-14 17:47:39 +08:00
Zhongjie Duan
72cac18d3e Merge pull request #316 from modelscope/teacache-hunyuanvideo
support teacache-hunyuanvideo
2025-01-14 14:48:04 +08:00
Artiprocher
9f8112ec34 support teacache-hunyuanvideo 2025-01-14 14:46:35 +08:00
Zhongjie Duan
d9fad821b2 Merge pull request #314 from modelscope/teacache
support teacache
2025-01-13 15:59:01 +08:00
Artiprocher
c0889c2564 support teacache 2025-01-13 15:56:33 +08:00
Zhongjie Duan
913591c13e Merge pull request #313 from modelscope/Artiprocher-patch-2
Update model_config.py
2025-01-12 11:15:18 +08:00
Zhongjie Duan
aaf13d6e4a Update model_config.py 2025-01-12 11:14:57 +08:00
Zhongjie Duan
90c07fec61 Merge pull request #312 from modelscope/HunyuanVideo-fp8
Update model_config.py
2025-01-11 20:41:20 +08:00
Zhongjie Duan
cc6c3c0807 Update model_config.py 2025-01-11 20:40:53 +08:00
Zhongjie Duan
ce2476ab9b Merge pull request #311 from mi804/eligen
update eligen readme's visualization
2025-01-09 16:26:54 +08:00
mi804
9e70c49317 update eligen readme 2025-01-09 16:22:39 +08:00
Zhongjie Duan
bf1c99645b Merge pull request #308 from mi804/eligen
fix bug for enable_eligen_on_negative
2025-01-09 15:52:03 +08:00
mi804
c2478ff284 update eligen examples and readme 2025-01-09 15:47:23 +08:00
mi804
a60bf3cd5f fix bug for enable_eligen_on_negative 2025-01-08 19:04:33 +08:00
Hong Zhang
34231907d0 Merge pull request #304 from modelscope/eligen-entity-transfer
add entity transfer example
2025-01-03 15:10:59 +08:00
Artiprocher
840dab58cd add entity transfer example 2025-01-03 14:40:37 +08:00
Zhongjie Duan
d5ceca0663 Merge pull request #303 from modelscope/eligen
Eligen
2025-01-03 10:47:26 +08:00
mi804
8cf3422688 update eligen ui 2025-01-03 10:37:34 +08:00
Artiprocher
6f743fc4b6 refine code 2025-01-02 19:54:09 +08:00
Zhongjie Duan
991b133bff Merge pull request #302 from modelscope/cache_latents
Update text_to_image.py
2025-01-02 14:23:33 +08:00
Zhongjie Duan
3b010043de Update text_to_image.py 2025-01-02 14:23:02 +08:00
Zhongjie Duan
088ea29e6e Merge pull request #301 from modelscope/Artiprocher-patch-1
Update model_config.py
2025-01-02 10:54:46 +08:00
Zhongjie Duan
b8b135ff73 Update model_config.py 2025-01-02 10:54:22 +08:00
mi804
2872fdaf48 update video of entity control 2024-12-31 18:09:29 +08:00
mi804
9853f83454 update readme video 2024-12-31 18:02:49 +08:00
mi804
fd6e661203 update readme 2024-12-31 17:50:20 +08:00
mi804
c087f68d74 update readme 2024-12-31 17:08:44 +08:00
mi804
b6620f3dde update_example entity control 2024-12-31 14:04:28 +08:00
Zhongjie Duan
3228c3e085 Support MERJIC's new model (#298)
* Update flux_dit.py
* Update model_config.py
2024-12-28 21:21:25 +08:00
Zhongjie Duan
6cc5fd6d1e Merge pull request #297 from modelscope/dev
Dev
2024-12-26 10:21:50 +08:00
Artiprocher
4f6d5e7074 hunyuanvideo step_processor 2024-12-26 10:20:59 +08:00
Artiprocher
6a999e1127 hunyuanvideo step_processor 2024-12-26 10:13:46 +08:00
mi804
e3d89cec0c temp commit for entity control 2024-12-25 17:19:31 +08:00
Zhongjie Duan
1b6e96a820 Merge pull request #296 from modelscope/dev
update hunyuanvideo examples
2024-12-24 10:48:11 +08:00
Artiprocher
e38ccf4c2f update hunyuanvideo examples 2024-12-24 10:47:26 +08:00
Zhongjie Duan
010c801081 Update hunyuanvideo_v2v_6G.py 2024-12-23 20:57:58 +08:00
Zhongjie Duan
edc9272e55 Merge pull request #295 from modelscope/dev
support hunyuanvideo v2v
2024-12-23 20:56:04 +08:00
Artiprocher
405ca6be33 support hunyuanvideo v2v 2024-12-23 20:43:47 +08:00
Zhongjie Duan
c06ea2271a Merge pull request #293 from modelscope/dev
hunyuanvideo quantization
2024-12-19 16:20:35 +08:00
Artiprocher
0692e8b1e1 hunyuanvideo quantization 2024-12-19 16:20:11 +08:00
Zhongjie Duan
aa23356420 Merge pull request #292 from modelscope/dev
hunyuanvideo examples
2024-12-19 13:29:51 +08:00
Zhongjie Duan
00a610e5ad Merge branch 'main' into dev 2024-12-19 13:29:40 +08:00
Artiprocher
2e39dcc0d3 hunyuanvideo examples 2024-12-19 13:28:44 +08:00
Zhongjie Duan
03d3a26f6f Merge pull request #291 from modelscope/dev
hunyuanvideo examples
2024-12-19 13:20:18 +08:00
Artiprocher
309fa9cf51 hunyuanvideo examples 2024-12-19 13:19:39 +08:00
Zhongjie Duan
65aab8adea Merge pull request #290 from modelscope/dev
Dev
2024-12-19 13:16:55 +08:00
Artiprocher
3d48b287a3 hunyuanvideo examples 2024-12-19 13:15:06 +08:00
Zhongjie Duan
29cebf0bec Update artaug_flux.py 2024-12-18 20:43:53 +08:00
Zhongjie Duan
95a0f0bedc Update README.md 2024-12-18 20:42:50 +08:00
Zhongjie Duan
77e0617861 Merge pull request #289 from modelscope/artaug
Artaug
2024-12-18 20:40:13 +08:00
Artiprocher
469a0405a1 ArtAug 2024-12-18 20:32:23 +08:00
Zhongjie Duan
46f191ffe7 Merge pull request #288 from mi804/hunyuanvideo
Hunyuanvideo
2024-12-18 19:40:23 +08:00
Artiprocher
ec7ac20def hunyuanvideo text encoder offload 2024-12-18 19:35:04 +08:00
mi804
3f410b0b77 hunyuanvideo_vae_encoder 2024-12-18 19:03:04 +08:00
mi804
8e06cac0df vae_encoder_weightsloading 2024-12-18 17:37:46 +08:00
Artiprocher
e5099f4e74 hunyuanvideo 2024-12-18 16:43:06 +08:00
Zhongjie Duan
447adef472 Merge pull request #287 from modelscope/dev-dzj
hunyuanvideo pipeline
2024-12-18 11:47:44 +08:00
Zhongjie Duan
a849b05e5a Merge branch 'dev' into dev-dzj 2024-12-18 11:47:34 +08:00
Artiprocher
b048f1b1de hunyuanvideo pipeline 2024-12-18 11:42:43 +08:00
Zhongjie Duan
f7848f9560 Merge pull request #286 from mi804/hunyuanvideo
hunyuanvideo_vae_decoder
2024-12-18 11:35:06 +08:00
mi804
236b56d285 hunyuanvideo_vae_decoder_model 2024-12-18 11:31:33 +08:00
Zhongjie Duan
42a717054a Merge branch 'dev' into hunyuanvideo 2024-12-18 11:21:33 +08:00
mi804
263166768e hunyuanvideo_vae_decoder 2024-12-18 11:14:57 +08:00
Zhongjie Duan
7a45b7efa7 Merge pull request #284 from modelscope/dev-dzj
hunyuanvideo dit
2024-12-17 14:50:21 +08:00
Zhongjie Duan
54ed532e3e Merge branch 'dev' into dev-dzj 2024-12-17 14:49:46 +08:00
Artiprocher
05e2028c5d hunyuanvideo dit 2024-12-17 14:45:23 +08:00
Zhongjie Duan
79249063b8 Merge pull request #283 from mi804/hunyuanvideo
hunyuanvideo text encoder
2024-12-17 14:42:46 +08:00
Zhongjie Duan
31ebec7a72 Merge pull request #282 from modelscope/lora-patch-2
support resume from opensource format
2024-12-16 12:26:37 +08:00
Artiprocher
919d399fdb support resume from opensource format 2024-12-16 12:25:05 +08:00
Zhongjie Duan
32a7a1487d Merge pull request #281 from modelscope/lora-patch
support resume training
2024-12-16 11:10:32 +08:00
Artiprocher
8c2671ce40 support resume training 2024-12-16 11:08:14 +08:00
root
5d1005a7c8 hunyuanvideo text encoder 2024-12-11 18:52:42 +08:00
Artiprocher
b84f906964 support artaug 2024-12-03 15:30:01 +08:00
Zhongjie Duan
7c0520d029 Merge pull request #277 from modelscope/sd35-lora
support sd35-lora
2024-11-29 12:35:32 +08:00
Artiprocher
9d09121fbc support sd35-lora 2024-11-29 11:45:40 +08:00
Zhongjie Duan
7f2a5424d4 Merge pull request #276 from modelscope/Artiprocher-patch-2
Update flux_ipadapter example
2024-11-28 10:44:29 +08:00
Zhongjie Duan
00830f0ecd Update flux_ipadapter.py 2024-11-28 10:44:07 +08:00
Zhongjie Duan
fd7737af7d Merge pull request #275 from mi804/flux_ipadapter
Flux ipadapter
2024-11-28 10:43:06 +08:00
root
f2130c4c25 minor 2024-11-26 19:08:41 +08:00
root
4f40683fd8 support flux ipadapter 2024-11-26 18:08:50 +08:00
Zhongjie Duan
5fc9e53eec Merge pull request #272 from modelscope/fix_kolors_pad
fix_kolors_pad
2024-11-21 14:50:21 +08:00
tc2000731
27e3cea285 fix_kolors_pad 2024-11-21 11:39:28 +08:00
Zhongjie Duan
ee770fa68f Merge pull request #271 from modelscope/sd35-series
Sd35 series
2024-11-20 09:54:41 +08:00
Artiprocher
9cb4aa16eb fix cogvideo height width checker 2024-11-20 09:51:31 +08:00
Zhongjie Duan
92d990629f Merge pull request #269 from modelscope/fix_image_resize
fix_image_resize
2024-11-18 19:24:57 +08:00
tc2000731
ba58f1bc0b fix_image_resize 2024-11-18 18:34:21 +08:00
Artiprocher
02fcfd530f support sd3.5 medium and large-turbo 2024-11-15 14:20:39 +08:00
Zhongjie Duan
095e8a3de8 Merge pull request #265 from modelscope/dev
support height width checker
2024-11-13 12:39:56 +08:00
Artiprocher
e17ad83fb5 support height width checker 2024-11-13 12:39:09 +08:00
Zhongjie Duan
e7c41151ec Merge pull request #264 from modelscope/dev
Dev
2024-11-13 09:53:49 +08:00
Artiprocher
7f4ba62d4f support size checker 2024-11-12 19:41:09 +08:00
Artiprocher
71b17a3a53 update mask blur 2024-11-12 19:20:17 +08:00
Artiprocher
d46b8b8fd7 bux fix 2024-11-12 10:17:01 +08:00
Artiprocher
a671070a28 bug fix 2024-11-11 21:01:38 +08:00
Zhongjie Duan
4600d5351b Update model_config.py 2024-11-11 19:26:30 +08:00
Zhongjie Duan
75bba5b8e5 Merge pull request #263 from modelscope/super-alignment
support mask blur
2024-11-11 19:24:30 +08:00
Artiprocher
8d1d1536d3 support mask blur 2024-11-11 18:59:55 +08:00
Zhongjie Duan
a7050a185b Merge pull request #262 from modelscope/sd3.5
Sd3.5
2024-11-11 18:47:49 +08:00
Zhongjie Duan
d345541c2d Merge pull request #261 from modelscope/omnigen
support omnigen
2024-11-11 18:47:09 +08:00
Artiprocher
bd028e4c66 support omnigen 2024-11-11 18:39:40 +08:00
Zhongjie Duan
d6f4fb67cc Merge pull request #260 from mi804/sd3.5
update default t5_sequence_length to 77
2024-11-11 16:39:31 +08:00
mi804
4378b540cf update t5_sequence_length 2024-11-11 16:28:17 +08:00
Artiprocher
39ddb7c3e3 support sd3.5 2024-11-06 19:57:01 +08:00
Zhongjie Duan
344cbd3286 Merge pull request #258 from modelscope/Artiprocher-patch-2
Update README.md
2024-11-05 19:09:04 +08:00
Zhongjie Duan
d4ba173b53 Update README.md 2024-11-05 19:08:52 +08:00
Zhongjie Duan
c56ce656b2 Merge pull request #252 from modelscope/Flux_ControlNet_Quantization
add Flux_ControlNet_Quantization
2024-11-01 14:51:10 +08:00
tc2000731
9377214518 update controlnet_frames, downloads 2024-10-31 17:38:57 +08:00
tc2000731
900a1c095f add Flux_ControlNet_Quantization 2024-10-29 17:29:24 +08:00
Zhongjie Duan
7e97a96840 Merge pull request #249 from modelscope/newpush
update noise generate
2024-10-25 16:43:37 +08:00
Zhongjie Duan
69f272d7ba Merge pull request #251 from modelscope/flux-examples
Flux examples
2024-10-25 16:35:47 +08:00
Artiprocher
a653554bd9 update examples 2024-10-25 16:30:35 +08:00
Artiprocher
6a25006544 update examples 2024-10-25 16:27:19 +08:00
Qianyi Zhao
8cfe4820f6 Update sd_video.py 2024-10-25 03:23:01 -05:00
Qianyi Zhao
c8021d4224 Update svd_video.py 2024-10-25 01:44:09 -05:00
Zhongjie Duan
3a64cc27b5 Merge pull request #250 from modelscope/flux-controlnet
Flux controlnet
2024-10-25 10:58:37 +08:00
Zhongjie Duan
2edc485ec1 Update requirements.txt 2024-10-25 00:16:11 +08:00
Artiprocher
a6d6553cee bug fix 2024-10-24 17:36:22 +08:00
Artiprocher
45feef9413 update model config 2024-10-24 16:10:15 +08:00
Artiprocher
105fe3961c update examples 2024-10-24 15:42:46 +08:00
Qianyi Zhao
d381c7b186 Update svd_video.py 2024-10-23 03:27:59 -05:00
Zhongjie Duan
5e8334c0bf Merge pull request #248 from modelscope/Artiprocher-patch-1
Update requirements.txt
2024-10-23 16:03:35 +08:00
Zhongjie Duan
2ea8a16afb Update requirements.txt 2024-10-23 16:03:21 +08:00
Artiprocher
aa054db1c7 bug fix 2024-10-23 14:24:41 +08:00
Artiprocher
07d70a6a56 support flux-controlnet 2024-10-22 18:52:24 +08:00
Qing112
747572e62c update noise generate 2024-10-21 15:09:21 +08:00
Zhongjie Duan
72ed76e89e Merge pull request #243 from modelscope/flux-lora
support preset lora
2024-10-21 14:04:44 +08:00
Artiprocher
a403cb04f3 support preset lora 2024-10-21 14:03:58 +08:00
Zhongjie Duan
ed71184854 Merge pull request #242 from modelscope/accelerate_load_model
accelerate load model
2024-10-21 10:00:09 +08:00
tc2000731
dfbf43e463 accelerate load model 2024-10-18 15:29:50 +08:00
Zhongjie Duan
7d7d72dcfe Merge pull request #239 from modelscope/flux-lora-update
Flux lora update
2024-10-14 19:12:33 +08:00
Artiprocher
540c036988 add alpha to lora converter 2024-10-14 18:57:54 +08:00
Artiprocher
58f89ceec9 update examples 2024-10-14 17:51:12 +08:00
Artiprocher
4e3a184199 update flux training 2024-10-14 10:00:32 +08:00
Zhongjie Duan
22e4ae99e8 Flux lora update (#237)
* update flux lora

---------

Co-authored-by: tc2000731 <tc2000731@163.com>
2024-10-11 18:41:24 +08:00
Zhongjie Duan
75ab786afc Merge pull request #234 from modelscope/doc-patch
Patch
2024-10-10 19:17:00 +08:00
Artiprocher
e5c72ba1f2 update examples 2024-10-10 18:26:37 +08:00
Artiprocher
66873d7d64 update examples 2024-10-10 18:23:43 +08:00
Artiprocher
a0d1d5bcea update examples 2024-10-10 17:25:55 +08:00
Artiprocher
fa0fa95bb6 update flux pipeline 2024-10-10 17:05:04 +08:00
Artiprocher
41ea2f811a update ESRGAN 2024-10-08 18:23:39 +08:00
Artiprocher
ec352cfce2 update model loader 2024-10-08 16:46:44 +08:00
Zhongjie Duan
aade874241 Merge pull request #232 from modelscope/Artiprocher-patch-1
Update README.md
2024-10-08 13:37:12 +08:00
Zhongjie Duan
c01eb653d7 Update README.md 2024-10-08 13:36:56 +08:00
Zhongjie Duan
892f80c265 Merge pull request #230 from modelscope/Artiprocher-dev
support ExVideo-CogVideoX-LoRA-129f-v1
2024-09-30 17:42:49 +08:00
Artiprocher
2e487a2c55 support ExVideo-CogVideoX-LoRA-129f-v1 2024-09-30 17:33:15 +08:00
Zhongjie Duan
a34e3ba338 Merge pull request #229 from modelscope/flux-enhance
support t5 sequence length
2024-09-30 15:33:51 +08:00
Artiprocher
c414f4cb12 support t5 sequence length 2024-09-30 14:45:30 +08:00
Zhongjie Duan
d91c603875 Flux fp8 lora training (#221)
* flux fp8 lora training

---------

Co-authored-by: tc2000731 <tc2000731@163.com>
2024-09-24 11:12:32 +08:00
Zhongjie Duan
7f899dcfca Merge pull request #216 from modelscope/Artiprocher-bugfix
bug fix
2024-09-19 12:27:22 +08:00
Artiprocher
5f12fd4346 bug fix 2024-09-19 12:26:46 +08:00
Zhongjie Duan
a7197f846b Merge pull request #215 from modelscope/flux-fp8
Support FLUX fp8
2024-09-19 10:36:16 +08:00
Artiprocher
ac81fa7a9f update examples 2024-09-19 10:33:30 +08:00
Artiprocher
091df1f1e7 support flux-fp8 2024-09-19 10:32:16 +08:00
tc2000731
a9fbfa108f float8_flux 2024-09-18 16:10:59 +08:00
Zhongjie Duan
44a8bf4143 Merge pull request #210 from modelscope/opensource-alignment
staticmethod
2024-09-14 17:18:19 +08:00
Artiprocher
3da8aa257b staticmethod 2024-09-14 17:16:59 +08:00
Zhongjie Duan
884dd749a0 Merge pull request #209 from modelscope/Artiprocher-patch-1
Update model_config.py
2024-09-14 11:42:30 +08:00
Zhongjie Duan
c697591d6e Update model_config.py 2024-09-14 11:41:47 +08:00
Zhongjie Duan
0b706e03e7 Merge pull request #208 from Qing112/main
update model_config and downloader
2024-09-14 11:40:42 +08:00
Qing112
447e75cd06 update model_config and downloader 2024-09-14 11:35:01 +08:00
Zhongjie Duan
7f76c8809c Merge pull request #207 from modelscope/flux-schnell
support flux-schnell
2024-09-14 11:17:59 +08:00
Artiprocher
cde1f81df6 support flux-schnell 2024-09-14 11:16:03 +08:00
Zhongjie Duan
c21ed1e478 Flux lora (#205) 2024-09-12 16:49:30 +08:00
Zhongjie Duan
a8cb4a21d1 align flux lora format (#204) 2024-09-12 16:01:27 +08:00
Zhongjie Duan
0b9e673fa2 Merge pull request #199 from modelscope/examples
update examples
2024-09-10 17:45:44 +08:00
Artiprocher
d242af8e22 update examples 2024-09-10 17:36:35 +08:00
Hong Zhang
76bd931d79 refine system_prompt for QwenPrompt (#198) 2024-09-10 15:15:23 +08:00
ZhouTianchen
995f3374f1 update omost (#190)
* update omost
2024-09-09 17:39:46 +08:00
Zhongjie Duan
1887885274 Merge pull request #197 from mi804/cpuoffload
add cpuoffload support for image pipelines
2024-09-09 14:48:26 +08:00
mi804
ce43cf412d add cpuoffload support for image pipelines 2024-09-09 13:50:52 +08:00
Zhongjie Duan
d1712f0594 Merge pull request #194 from modelscope/flux-lora
support flux training
2024-09-06 19:15:42 +08:00
Artiprocher
416b73b8c0 support flux training 2024-09-06 10:37:28 +08:00
Zhongjie Duan
4654aa0cab Merge pull request #188 from modelscope/qwen
support Qwen prompt refine
2024-09-04 17:22:56 +08:00
Zhongjie Duan
6f9d8f465a Merge branch 'main' into qwen 2024-09-04 17:22:38 +08:00
Artiprocher
e5e55345dc support qwen prompt refiner 2024-09-04 17:12:01 +08:00
Zhongjie Duan
8d6eb6d41a Merge pull request #187 from modelscope/omost
support Omost LLM
2024-09-04 12:52:23 +08:00
Zhongjie Duan
1118e67cec Merge branch 'main' into omost 2024-09-04 12:52:03 +08:00
Artiprocher
d70cd04b15 fix bugs 2024-09-04 12:48:32 +08:00
Zhongjie Duan
3d1db23224 Merge pull request #186 from modelscope/flux-lora
support flux lora inference
2024-09-04 09:47:08 +08:00
Artiprocher
a488810693 support flux lora inference 2024-09-04 09:39:39 +08:00
tc2000731
0b066d3cb4 add omost.py + omost_flux_example 2024-09-03 19:40:40 +08:00
Zhongjie Duan
d154bee18a support CogVideoX-5B (#184)
* support cogvideo

* update examples
2024-09-03 11:37:54 +08:00
Yudi
3a8694b642 add qwen prompt refiner 2024-08-27 17:28:32 +08:00
Zhongjie Duan
fe485b3fa1 Merge pull request #176 from modelscope/Artiprocher-dev
remove packages from requirements.txt
2024-08-26 15:02:59 +08:00
Artiprocher
e70eaa6a31 remove packages from requirements.txt 2024-08-26 15:01:35 +08:00
Zhongjie Duan
27ef67306d Merge pull request #175 from modelscope/Artiprocher-dev
model cache
2024-08-26 13:57:48 +08:00
Artiprocher
547aca3db2 model cache 2024-08-26 13:57:03 +08:00
Zhongjie Duan
5f7360e2ce Merge pull request #171 from modelscope/Artiprocher-dev
update README
2024-08-23 16:47:13 +08:00
Artiprocher
23f9675218 update README 2024-08-23 16:46:26 +08:00
Zhongjie Duan
ef1e82076c Merge pull request #170 from modelscope/Artiprocher-dev
update model config
2024-08-23 14:18:15 +08:00
Artiprocher
65d4588cc7 update model config 2024-08-23 14:17:10 +08:00
Zhongjie Duan
0488f90c8f Merge pull request #169 from modelscope/Artiprocher-dev
fix bug
2024-08-23 09:28:46 +08:00
Artiprocher
03d91f6618 fix bug 2024-08-23 09:28:10 +08:00
Zhongjie Duan
ae5e4b67dc Merge pull request #166 from modelscope/Artiprocher-dev
update examples
2024-08-22 11:48:50 +08:00
Artiprocher
a6c6e33d88 update examples 2024-08-22 11:41:48 +08:00
Zhongjie Duan
79d9bf7109 Merge pull request #165 from modelscope/Artiprocher-dev
update UI
2024-08-22 10:45:23 +08:00
Artiprocher
66e1b382cd update examples 2024-08-22 10:37:30 +08:00
Artiprocher
66f1ff43e9 update examples 2024-08-22 10:35:58 +08:00
Artiprocher
d6d14859e3 update UI 2024-08-21 16:57:56 +08:00
Zhongjie Duan
4478bb9bbe Merge pull request #164 from modelscope/Artiprocher-dev
FLUX highres-fix
2024-08-20 13:40:23 +08:00
Artiprocher
a6aaf9da2a support flux UI 2024-08-19 14:24:23 +08:00
Artiprocher
aa908ae0c2 support flux highresfix 2024-08-19 13:35:40 +08:00
Artiprocher
778a2d8f84 support flux highresfix 2024-08-19 13:35:27 +08:00
Zhongjie Duan
508baabf9a Merge pull request #160 from modelscope/Artiprocher-dev
support FLUX
2024-08-17 17:52:59 +08:00
Artiprocher
80aa4d8e19 update examples 2024-08-17 17:51:31 +08:00
Artiprocher
99e11112a7 support FLUX 2024-08-16 20:04:10 +08:00
Zhongjie Duan
1116e6dbc7 Merge pull request #155 from Qing112/main
add Flux text encoder
2024-08-14 11:28:14 +08:00
Qianyi Zhao
d1ac96c1ab add flux_text_encoder.py 2024-08-13 22:26:10 -05:00
Qianyi Zhao
abe88c899e add Flux text encoder 2024-08-14 10:46:52 +08:00
Zhongjie Duan
b1709fcbdb Merge pull request #145 from modelscope/Artiprocher-dev
chatglm quantize
2024-08-02 15:09:41 +08:00
Artiprocher
ec877bf490 chatglm quantize 2024-08-02 14:46:29 +08:00
Zhongjie Duan
a8f1812acf Merge pull request #144 from modelscope/Artiprocher-dev
UI update
2024-08-02 13:49:48 +08:00
Artiprocher
6877b460c4 fix bugs 2024-08-02 13:47:07 +08:00
Artiprocher
f189f9f1be update UI 2024-08-02 10:31:25 +08:00
Artiprocher
6f79fd6d77 support sdxl controlnet union 2024-08-01 10:01:39 +08:00
Zhongjie Duan
60d7bb52d6 Update README.md 2024-07-30 10:42:43 +08:00
Yingda Chen
65a2a0643a add badges 2024-07-30 10:32:03 +08:00
Zhongjie Duan
bc5f151dfa Update setup.py 2024-07-29 20:22:01 +08:00
Zhongjie Duan
5cd6ed0096 Update publish.yaml 2024-07-29 20:12:37 +08:00
Zhongjie Duan
be84b35bfd Update publish.yaml 2024-07-29 19:36:28 +08:00
Zhongjie Duan
d9fc30ffd0 Create publish.yaml 2024-07-29 19:27:14 +08:00
Zhongjie Duan
8f59d00d9e Merge pull request #135 from modelscope/Artiprocher-setup
update setup.py
2024-07-29 19:16:59 +08:00
Artiprocher
3d8ff39aed update setup.py 2024-07-29 19:10:03 +08:00
Zhongjie Duan
b5c194df43 Merge pull request #134 from modelscope/Artiprocher-webui
support kolors in webui
2024-07-29 16:25:25 +08:00
Artiprocher
8680f92b60 support kolors in webui 2024-07-29 16:24:13 +08:00
Zhongjie Duan
05c97bc755 Merge pull request #133 from modelscope/Artiprocher-doc
add general options to lora readme
2024-07-29 14:45:25 +08:00
Artiprocher
db88d60750 add general options to lora readme 2024-07-29 14:44:29 +08:00
Zhongjie Duan
40c6da8075 Merge pull request #132 from modelscope/Artiprocher-rebuild
rebuild base modules
2024-07-29 12:14:26 +08:00
Artiprocher
3981b8084f redirect Kolors 2024-07-29 10:22:47 +08:00
Zhongjie Duan
9dfb7c1c37 Merge pull request #128 from Yuan-ManX/Kolors-1
support Kolors
2024-07-28 17:09:52 +08:00
Artiprocher
9ed54c188e fix bugs 2024-07-26 17:51:03 +08:00
Yuan-Man
6a47a346b1 support Kolors 2024-07-26 16:43:52 +08:00
Artiprocher
e3f8a576cf rebuild base modules 2024-07-26 12:15:40 +08:00
Yingda Chen
0aff733a92 add github trending badge 2024-07-26 11:32:23 +08:00
Zhongjie Duan
9471bff8a4 Merge pull request #107 from modelscope/Artiprocher-dev
reduce VRAM requirements in Kolors LoRA
2024-07-12 17:42:22 +08:00
Artiprocher
3f8eea4687 update downloader 2024-07-12 17:39:26 +08:00
Artiprocher
b1b2d50c0d reduce VRAM requirements in Kolors LoRA 2024-07-12 17:30:19 +08:00
Zhongjie Duan
9c6607f78d support kolors! (#106) 2024-07-11 21:43:45 +08:00
Zhongjie Duan
2a4709e572 Merge pull request #102 from modelscope/Artiprocher-ExVideo
Add ExVideo Demo link
2024-07-10 16:49:03 +08:00
Artiprocher
04f3fce3b0 add ExVideo demo link 2024-07-10 16:45:18 +08:00
Artiprocher
be9c3524a5 add ExVideo demo link 2024-07-10 16:44:32 +08:00
Zhongjie Duan
c3d899dd48 Merge pull request #101 from modelscope/Artiprocher-sd3-lora
Support SD3 LoRA
2024-07-10 13:42:54 +08:00
Artiprocher
6e03ee2a75 update examples 2024-07-10 13:41:11 +08:00
Artiprocher
979a8814f1 support SD3 LoRA 2024-07-10 10:07:02 +08:00
Zhongjie Duan
8be4fad330 Merge pull request #94 from modelscope/Artiprocher-sd3
support SD3
2024-07-05 16:39:59 +08:00
Artiprocher
8113f95278 update README 2024-07-05 16:38:10 +08:00
Artiprocher
9ca6c646df update SD3 examples 2024-07-05 16:35:41 +08:00
Artiprocher
466b37994e SD3 UI 2024-07-05 14:28:24 +08:00
Artiprocher
518c6d6ac3 support SD3 textual inversion 2024-07-05 13:36:54 +08:00
Artiprocher
9920b8d975 support SD3 2024-07-04 16:08:39 +08:00
Artiprocher
237daa2048 Merge pull request #87 from Lupino/main
pass device to processors Annotator
2024-07-04 10:34:40 +08:00
Lupino
e9af28e6a3 pass device to processors Annotator 2024-07-01 17:37:25 +08:00
Artiprocher
996515c7ca Merge pull request #73 from modelscope/tamannaaaaa-my-branch
Improve the script file
2024-06-28 11:21:13 +08:00
Artiprocher
c2ccc39e3c update script file based on tamannaaaaa 2024-06-28 11:16:42 +08:00
Artiprocher
ad24b93431 Merge branch 'my-branch' of https://github.com/tamannaaaaa/DiffSynth-Studio into tamannaaaaa-my-branch 2024-06-28 11:00:53 +08:00
Artiprocher
bd5fc32d79 Merge pull request #72 from modelscope/dev
add downloaders and update examples
2024-06-28 10:04:21 +08:00
Artiprocher
03cefe8f58 update examples 2024-06-28 09:49:52 +08:00
tamannaaaaa
64339f7089 Improved the script file 2024-06-27 18:23:44 +05:30
Artiprocher
0b1704976a update examples and downloaders 2024-06-27 19:43:50 +08:00
wenmeng zhou
0af60b9c73 Update README.md 2024-06-27 16:50:05 +08:00
Artiprocher
280f0eacc0 Merge pull request #65 from modelscope/wenmengzhou-patch-1
Update README.md
2024-06-27 16:32:54 +08:00
wenmeng zhou
03cba5e59e Update README.md 2024-06-27 15:56:51 +08:00
Artiprocher
fa0ea0e1a4 Update README.md 2024-06-25 17:03:52 +08:00
Artiprocher
40d24b8907 Merge pull request #48 from modelscope/package
simplify installation
2024-06-25 15:56:49 +08:00
Artiprocher
1bf02f439f update setup.py 2024-06-25 15:53:35 +08:00
Artiprocher
0489c62550 update setup.py 2024-06-25 15:43:27 +08:00
Artiprocher
ad98602da3 Merge pull request #47 from eltociear/patch-1
docs: update README.md
2024-06-25 14:57:34 +08:00
Ikko Eltociear Ashimine
fb12ac316a docs: update README.md
transfered -> transferred
2024-06-25 11:39:59 +09:00
Artiprocher
e9ec2f2706 add downloader 2024-06-24 16:45:35 +08:00
Artiprocher
00f294454b Merge pull request #43 from modelscope/ExVideo
fix compatibility issues in sd_video_pipeline
2024-06-21 16:25:48 +08:00
Artiprocher
0465d940c7 Merge pull request #42 from modelscope/ExVideo
update ExVideo doc
2024-06-21 12:59:29 +08:00
Artiprocher
2c549598d0 Merge pull request #41 from modelscope/ExVideo
update ExVideo doc
2024-06-21 12:48:50 +08:00
Artiprocher
7d33082d70 Merge pull request #40 from modelscope/ExVideo
ExVideo training
2024-06-21 11:43:48 +08:00
766 changed files with 53809 additions and 273566 deletions

BIN
.github/workflows/logo.gif vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 146 KiB

29
.github/workflows/publish.yaml vendored Normal file
View File

@@ -0,0 +1,29 @@
name: release
on:
push:
tags:
- 'v**'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-publish
cancel-in-progress: true
jobs:
build-n-publish:
runs-on: ubuntu-20.04
#if: startsWith(github.event.ref, 'refs/tags')
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: '3.10'
- name: Install wheel
run: pip install wheel==0.44.0 && pip install -r requirements.txt
- name: Build DiffSynth
run: python -m build
- name: Publish package to PyPI
run: |
pip install twine
twine upload dist/* --skip-existing -u __token__ -p ${{ secrets.PYPI_API_TOKEN }}

175
.gitignore vendored Normal file
View File

@@ -0,0 +1,175 @@
/data
/models
/scripts
/diffusers
*.pkl
*.safetensors
*.pth
*.ckpt
*.pt
*.bin
*.DS_Store
*.msc
*.mv
log*.txt
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

View File

@@ -1,15 +0,0 @@
# Set web page format
import streamlit as st
st.set_page_config(layout="wide")
# Diasble virtual VRAM on windows system
import torch
torch.cuda.set_per_process_memory_fraction(0.999, 0)
st.markdown("""
# DiffSynth Studio
[Source Code](https://github.com/Artiprocher/DiffSynth-Studio)
Welcome to DiffSynth Studio.
""")

846
README.md
View File

@@ -1,125 +1,789 @@
# DiffSynth Studio
# DiffSynth-Studio
<a href="https://github.com/modelscope/DiffSynth-Studio"><img src=".github/workflows/logo.gif" title="Logo" style="max-width:100%;" width="55" /></a> <a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></p>
[![PyPI](https://img.shields.io/pypi/v/DiffSynth)](https://pypi.org/project/DiffSynth/)
[![license](https://img.shields.io/github/license/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)
[![open issues](https://isitmaintained.com/badge/open/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/issues)
[![GitHub pull-requests](https://img.shields.io/github/issues-pr/modelscope/DiffSynth-Studio.svg)](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
[![GitHub latest commit](https://badgen.net/github/last-commit/modelscope/DiffSynth-Studio)](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
[切换到中文版](./README_zh.md)
## Introduction
DiffSynth Studio is a Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. We provide many interesting features. Enjoy the magic of Diffusion models!
Welcome to the magical world of Diffusion models! DiffSynth-Studio is an open-source Diffusion model engine developed and maintained by the [ModelScope Community](https://www.modelscope.cn/). We hope to foster technological innovation through framework construction, aggregate the power of the open-source community, and explore the boundaries of generative model technology!
## Roadmap
DiffSynth currently includes two open-source projects:
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): Focused on aggressive technical exploration, targeting academia, and providing cutting-edge model capability support.
* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): Focused on stable model deployment, targeting industry, and providing higher computational performance and more stable features.
* Aug 29, 2023. We propose DiffSynth, a video synthesis framework.
* [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/).
* The source codes are released in [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
* The technical report (ECML PKDD 2024) is released on [arXiv](https://arxiv.org/abs/2308.03463).
* Oct 1, 2023. We release an early version of this project, namely FastSDXL. A try for building a diffusion engine.
* The source codes are released on [GitHub](https://github.com/Artiprocher/FastSDXL).
* FastSDXL includes a trainable OLSS scheduler for efficiency improvement.
* The original repo of OLSS is [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler).
* The technical report (CIKM 2023) is released on [arXiv](https://arxiv.org/abs/2305.14677).
* A demo video is shown on [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj).
* Since OLSS requires additional training, we don't implement it in this project.
* Nov 15, 2023. We propose FastBlend, a powerful video deflickering algorithm.
* The sd-webui extension is released on [GitHub](https://github.com/Artiprocher/sd-webui-fastblend).
* Demo videos are shown on Bilibili, including three tasks.
* [Video deflickering](https://www.bilibili.com/video/BV1d94y1W7PE)
* [Video interpolation](https://www.bilibili.com/video/BV1Lw411m71p)
* [Image-driven video rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF)
* The technical report is released on [arXiv](https://arxiv.org/abs/2311.09265).
* An unofficial ComfyUI extension developed by other users is released on [GitHub](https://github.com/AInseven/ComfyUI-fastblend).
* Dec 8, 2023. We decide to develop a new Project, aiming to release the potential of diffusion models, especially in video synthesis. The development of this project is started.
* Jan 29, 2024. We propose Diffutoon, a fantastic solution for toon shading.
* [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/).
* The source codes are released in this project.
* The technical report (IJCAI 2024) is released on [arXiv](https://arxiv.org/abs/2401.16224).
* June 13, 2024. DiffSynth Studio is transfered to ModelScope. The developers have transitioned from "I" to "we". Of course, I will still participate in development and maintenance.
* June 21, 2024. We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
* [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/).
* Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
* Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
* Technical report is released on [arXiv](https://arxiv.org/abs/2406.14130).
* Until now, DiffSynth Studio has supported the following models:
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
* [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
* [ControlNet](https://github.com/lllyasviel/ControlNet)
* [AnimateDiff](https://github.com/guoyww/animatediff/)
* [Ip-Adapter](https://github.com/tencent-ailab/IP-Adapter)
* [ESRGAN](https://github.com/xinntao/ESRGAN)
* [RIFE](https://github.com/hzwer/ECCV2022-RIFE)
* [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT)
* [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt)
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
[DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) and [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) are the core engines of the ModelScope AIGC zone. Welcome to experience our carefully crafted productized features:
* ModelScope AIGC Zone (for Chinese users): https://modelscope.cn/aigc/home
* ModelScope Civision (for global users): https://modelscope.ai/civision/home
> DiffSynth-Studio Documentation: [中文版](/docs/zh/README.md)、[English version](/docs/en/README.md)
We believe that a well-developed open-source code framework can lower the threshold for technical exploration. We have achieved many [interesting technologies](#innovative-achievements) based on this codebase. Perhaps you also have many wild ideas, and with DiffSynth-Studio, you can quickly realize these ideas. For this reason, we have prepared detailed documentation for developers. We hope that through these documents, developers can understand the principles of Diffusion models, and we look forward to expanding the boundaries of technology together with you.
## Update History
> DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update.
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
- **January 27, 2026**: [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) is released, and our [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) model is released concurrently. You can use it in [ModelScope Studios](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L). For details, see the [documentation](/docs/zh/Model_Details/Z-Image.md).
- **January 19, 2026**: Added support for [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) and [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/FLUX2.md) and [example code](/examples/flux2/) are now available.
- **January 12, 2026**: We trained and open-sourced a text-guided image layer separation model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)). Given an input image and a textual description, the model isolates the image layer corresponding to the described content. For more details, please refer to our blog post ([Chinese version](https://modelscope.cn/learn/4938), [English version](https://huggingface.co/blog/kelseye/qwen-image-layered-control)).
- **December 24, 2025**: Based on Qwen-Image-Edit-2511, we trained an In-Context Editing LoRA model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)). This model takes three images as input (Image A, Image B, and Image C), and automatically analyzes the transformation from Image A to Image B, then applies the same transformation to Image C to generate Image D. For more details, please refer to our blog post ([Chinese version](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g), [English version](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora)).
- **December 9, 2025** We release a wild model based on DiffSynth-Studio 2.0: [Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L) (Image-to-LoRA). This model takes an image as input and outputs a LoRA. Although this version still has significant room for improvement in terms of generalization, detail preservation, and other aspects, we are open-sourcing these models to inspire more innovative research. For more details, please refer to our [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l).
- **December 4, 2025** DiffSynth-Studio 2.0 released! Many new features online
- [Documentation](/docs/en/README.md) online: Our documentation is still continuously being optimized and updated
- [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md) module upgraded, supporting layer-level disk offload, releasing both memory and VRAM simultaneously
- New model support
- Z-Image Turbo: [Model](https://www.modelscope.ai/models/Tongyi-MAI/Z-Image-Turbo), [Documentation](/docs/en/Model_Details/Z-Image.md), [Code](/examples/z_image/)
- FLUX.2-dev: [Model](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev), [Documentation](/docs/en/Model_Details/FLUX2.md), [Code](/examples/flux2/)
- Training framework upgrade
- [Split Training](/docs/zh/Training/Split_Training.md): Supports automatically splitting the training process into two stages: data processing and training (even for training ControlNet or any other model). Computations that do not require gradient backpropagation, such as text encoding and VAE encoding, are performed during the data processing stage, while other computations are handled during the training stage. Faster speed, less VRAM requirement.
- [Differential LoRA Training](/docs/zh/Training/Differential_LoRA.md): This is a training technique we used in [ArtAug](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), now available for LoRA training of any model.
- [FP8 Training](/docs/zh/Training/FP8_Precision.md): FP8 can be applied to any non-training model during training, i.e., models with gradients turned off or gradients that only affect LoRA weights.
<details>
<summary>More</summary>
- **November 4, 2025** Supported the [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) model, which is trained based on Wan 2.1 and supports generating corresponding actions based on reference videos.
- **October 30, 2025** Supported the [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) model, which supports text-to-video, image-to-video, and video continuation. This model uses the Wan framework for inference and training in this project.
- **October 27, 2025** Supported the [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) model, adding another member to the Wan model ecosystem.
- **September 23, 2025** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) released! This model was jointly developed and open-sourced by us and Taobao Experience Design Team. Built upon Qwen-Image, the model is specifically designed for e-commerce poster scenarios, supporting precise partition layout control. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py).
- **September 9, 2025** Our training framework supports various training modes. Currently adapted for Qwen-Image, in addition to the standard SFT training mode, Direct Distill is now supported. Please refer to [our sample code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh). This feature is experimental, and we will continue to improve it to support more comprehensive model training functions.
- **August 28, 2025** We support Wan2.2-S2V, an audio-driven cinematic video generation model. See [./examples/wanvideo/](./examples/wanvideo/).
- **August 21, 2025** [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) released! Compared to the V1 version, the training dataset has been changed to [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset), so the generated images better conform to Qwen-Image's own image distribution and style. Please refer to [our sample code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py).
- **August 21, 2025** We open-sourced the [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) structural control LoRA model, adopting the In Context technical route, supporting multiple categories of structural control conditions, including canny, depth, lineart, softedge, normal, and openpose. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py).
- **August 20, 2025** We open-sourced the [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) model, improving the editing effect of Qwen-Image-Edit on low-resolution image inputs. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)
- **August 19, 2025** 🔥 Qwen-Image-Edit open-sourced, welcome a new member to the image editing model family!
- **August 18, 2025** We trained and open-sourced the Qwen-Image inpainting ControlNet model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint). The model structure adopts a lightweight design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py).
- **August 15, 2025** We open-sourced the [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset) dataset. This is an image dataset generated using the Qwen-Image model, containing 160,000 `1024 x 1024` images. It includes general, English text rendering, and Chinese text rendering subsets. We provide annotations for image descriptions, entities, and structural control images for each image. Developers can use this dataset to train Qwen-Image models' ControlNet and EliGen models. We aim to promote technological development through open-sourcing!
- **August 13, 2025** We trained and open-sourced the Qwen-Image ControlNet model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth). The model structure adopts a lightweight design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py).
- **August 12, 2025** We trained and open-sourced the Qwen-Image ControlNet model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny). The model structure adopts a lightweight design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py).
- **August 11, 2025** We open-sourced the distilled acceleration model [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) for Qwen-Image, following the same training process as [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full), but the model structure has been modified to LoRA, thus being better compatible with other open-source ecosystem models.
- **August 7, 2025** We open-sourced the entity control LoRA model [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) for Qwen-Image. Qwen-Image-EliGen can achieve entity-level controlled text-to-image generation. Technical details can be found in [the paper](https://arxiv.org/abs/2501.01097). Training dataset: [EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet).
- **August 5, 2025** We open-sourced the distilled acceleration model [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) for Qwen-Image, achieving approximately 5x acceleration.
- **August 4, 2025** 🔥 Qwen-Image open-sourced, welcome a new member to the image generation model family!
- **August 1, 2025** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) open-sourced, a text-to-image model focused on aesthetic photography. We provided comprehensive support in a timely manner, including low VRAM layer-by-layer offload, LoRA training, and full training. For more details, please refer to [./examples/flux/](./examples/flux/).
- **July 28, 2025** Wan 2.2 open-sourced. We provided comprehensive support in a timely manner, including low VRAM layer-by-layer offload, FP8 quantization, sequence parallelism, LoRA training, and full training. For more details, please refer to [./examples/wanvideo/](./examples/wanvideo/).
- **July 11, 2025** We propose Nexus-Gen, a unified framework that combines the language reasoning capabilities of Large Language Models (LLMs) with the image generation capabilities of diffusion models. This framework supports seamless image understanding, generation, and editing tasks.
- Paper: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
- GitHub Repository: https://github.com/modelscope/Nexus-Gen
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
- Training Dataset: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
- Online Experience: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
- **June 15, 2025** ModelScope's official evaluation framework [EvalScope](https://github.com/modelscope/evalscope) now supports text-to-image generation evaluation. Please refer to the [best practices](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html) guide to try it out.
- **March 25, 2025** Our new open-source project [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) is now open-sourced! Focused on stable model deployment, targeting industry, providing better engineering support, higher computational performance, and more stable features.
- **March 31, 2025** We support InfiniteYou, a face feature preservation method for FLUX. More details can be found in [./examples/InfiniteYou/](./examples/InfiniteYou/).
- **March 13, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of Tencent's open-source HunyuanVideo. More details can be found in [./examples/HunyuanVideo/](./examples/HunyuanVideo/).
- **February 25, 2025** We support Wan-Video, a series of state-of-the-art video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
- **February 17, 2025** We support [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)! Advanced video synthesis model! See [./examples/stepvideo](./examples/stepvideo/).
- **December 31, 2024** We propose EliGen, a new framework for entity-level controlled text-to-image generation, supplemented with an inpainting fusion pipeline, extending its capabilities to image inpainting tasks. EliGen can seamlessly integrate existing community models such as IP-Adapter and In-Context LoRA, enhancing their versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/).
- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
- Online Experience: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
- Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
- **December 19, 2024** We implemented advanced VRAM management for HunyuanVideo, enabling video generation with resolutions of 129x720x1280 on 24GB VRAM or 129x512x384 on just 6GB VRAM. More details can be found in [./examples/HunyuanVideo/](./examples/HunyuanVideo/).
- **December 18, 2024** We propose ArtAug, a method to improve text-to-image models through synthesis-understanding interaction. We trained an ArtAug enhancement module for FLUX.1-dev in LoRA format. This model incorporates the aesthetic understanding of Qwen2-VL-72B into FLUX.1-dev, thereby improving the quality of generated images.
- Paper: https://arxiv.org/abs/2412.12888
- Example: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
- Demo: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (coming soon)
- **October 25, 2024** We provide extensive FLUX ControlNet support. This project supports many different ControlNet models and can be freely combined, even if their structures are different. Additionally, ControlNet models are compatible with high-resolution optimization and partition control technologies, enabling very powerful controllable image generation. See [`./examples/ControlNet/`](./examples/ControlNet/).
- **October 8, 2024** We released extended LoRAs based on CogVideoX-5B and ExVideo. You can download this model from [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) or [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1).
- **August 22, 2024** This project now supports CogVideoX-5B. See [here](/examples/video_synthesis/). We provide several interesting features for this text-to-video model, including:
- Text-to-video
- Video editing
- Self super-resolution
- Video interpolation
- **August 22, 2024** We implemented an interesting brush feature that supports all text-to-image models. Now you can create stunning images with the assistance of AI using the brush!
- Use it in our [WebUI](#usage-in-webui).
- **August 21, 2024** DiffSynth-Studio now supports FLUX.
- Enable CFG and high-resolution inpainting to improve visual quality. See [here](/examples/image_synthesis/README.md)
- LoRA, ControlNet, and other addon models will be released soon.
- **June 21, 2024** We propose ExVideo, a post-training fine-tuning technique aimed at enhancing the capabilities of video generation models. We extended Stable Video Diffusion to achieve long video generation of up to 128 frames.
- [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
- Source code has been released in this repository. See [`examples/ExVideo`](./examples/ExVideo/).
- Model has been released at [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
- Technical report has been released at [arXiv](https://arxiv.org/abs/2406.14130).
- You can try ExVideo in this [demo](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1)!
- **June 13, 2024** DiffSynth Studio has migrated to ModelScope. The development team has also transitioned from "me" to "us". Of course, I will still participate in subsequent development and maintenance work.
- **January 29, 2024** We propose Diffutoon, an excellent cartoon coloring solution.
- [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
- Source code has been released in this project.
- Technical report (IJCAI 2024) has been released at [arXiv](https://arxiv.org/abs/2401.16224).
- **December 8, 2023** We decided to initiate a new project aimed at unleashing the potential of diffusion models, especially in video synthesis. The development work of this project officially began.
- **November 15, 2023** We propose FastBlend, a powerful video deflickering algorithm.
- sd-webui extension has been released at [GitHub](https://github.com/Artiprocher/sd-webui-fastblend).
- Demonstration videos have been showcased on Bilibili, including three tasks:
- [Video Deflickering](https://www.bilibili.com/video/BV1d94y1W7PE)
- [Video Interpolation](https://www.bilibili.com/video/BV1Lw411m71p)
- [Image-Driven Video Rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF)
- Technical report has been released at [arXiv](https://arxiv.org/abs/2311.09265).
- Unofficial ComfyUI extensions developed by other users have been released at [GitHub](https://github.com/AInseven/ComfyUI-fastblend).
- **October 1, 2023** We released an early version of the project named FastSDXL. This was an initial attempt to build a diffusion engine.
- Source code has been released at [GitHub](https://github.com/Artiprocher/FastSDXL).
- FastSDXL includes a trainable OLSS scheduler to improve efficiency.
- The original repository of OLSS is located [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler).
- Technical report (CIKM 2023) has been released at [arXiv](https://arxiv.org/abs/2305.14677).
- Demonstration video has been released at [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj).
- Since OLSS requires additional training, we did not implement it in this project.
- **August 29, 2023** We propose DiffSynth, a video synthesis framework.
- [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/).
- Source code has been released at [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
- Technical report (ECML PKDD 2024) has been released at [arXiv](https://arxiv.org/abs/2308.03463).
</details>
## Installation
Create Python environment:
Install from source (recommended):
```
conda env create -f environment.yml
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
We find that sometimes `conda` cannot install `cupy` correctly, please install it manually. See [this document](https://docs.cupy.dev/en/stable/install.html) for more details.
For more installation methods and instructions for non-NVIDIA GPUs, please refer to the [Installation Guide](/docs/en/Pipeline_Usage/Setup.md).
Enter the Python environment:
</details>
```
conda activate DiffSynthStudio
```
## Basic Framework
## Usage (in Python code)
DiffSynth-Studio redesigns the inference and training pipelines for mainstream Diffusion models (including FLUX, Wan, etc.), enabling efficient memory management and flexible model training.
The Python examples are in [`examples`](./examples/). We provide an overview here.
<details>
<summary>Environment Variable Configuration</summary>
### Long Video Synthesis
> Before running model inference or training, you can configure settings such as the model download source via [environment variables](/docs/en/Pipeline_Usage/Environment_Variables.md).
>
> By default, this project downloads models from ModelScope. For users outside China, you can configure the system to download models from the ModelScope international site as follows:
>
> ```python
> import os
> os.environ["MODELSCOPE_DOMAIN"] = "www.modelscope.ai"
> ```
>
> To download models from other sources, please modify the environment variable [DIFFSYNTH_DOWNLOAD_SOURCE](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_download_source).
We trained an extended video synthesis model, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/)
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
</details>
### Image Synthesis
Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/)
![Image](https://github.com/user-attachments/assets/c01258e2-f251-441a-aa1e-ebb22f02594d)
|512*512|1024*1024|2048*2048|4096*4096|
|-|-|-|-|
|![512](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/55f679e9-7445-4605-9315-302e93d11370)|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/9087a73c-9164-4c58-b2a0-effc694143fb)|![4096](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/edee9e71-fc39-4d1c-9ca9-fa52002c67ac)|
#### Z-Image: [/docs/en/Model_Details/Z-Image.md](/docs/en/Model_Details/Z-Image.md)
|1024*1024|2048*2048|
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) model for inference. FP8 quantization significantly degrades image quality, so we do not recommend enabling any quantization for the Z-Image Turbo model. CPU offloading is recommended, and the model can run with as little as 8 GB of GPU memory.
```python
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = ZImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
image = pipe(prompt=prompt, seed=42, rand_device="cuda")
image.save("image.jpg")
```
</details>
<details>
<summary>Examples</summary>
Example code for Z-Image is available at: [/examples/z_image/](/examples/z_image/)
|Model ID|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
|-|-|-|-|-|-|-|
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
</details>
#### FLUX.2: [/docs/en/Model_Details/FLUX2.md](/docs/en/Model_Details/FLUX2.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) model for inference. VRAM management is enabled, and the framework automatically loads model parameters based on available GPU memory. The model can run with as little as 10 GB of VRAM.
```python
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene."
image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50)
image.save("image.jpg")
```
</details>
<details>
<summary>Examples</summary>
Example code for FLUX.2 is available at: [/examples/flux2/](/examples/flux2/)
| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
</details>
#### Qwen-Image: [/docs/en/Model_Details/Qwen-Image.md](/docs/en/Model_Details/Qwen-Image.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8 GB of VRAM.
```python
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=40)
image.save("image.jpg")
```
</details>
<details>
<summary>Model Lineage</summary>
```mermaid
graph LR;
Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit;
Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509;
Qwen/Qwen-Image-->EliGen-Series;
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen;
DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2;
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster;
Qwen/Qwen-Image-->Distill-Series;
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full;
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA;
Qwen/Qwen-Image-->ControlNet-Series;
ControlNet-Series-->Blockwise-ControlNet-Series;
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny;
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth;
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint;
ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union;
Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix;
```
</details>
<details>
<summary>Examples</summary>
Example code for Qwen-Image is available at: [/examples/qwen_image/](/examples/qwen_image/)
| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
</details>
#### FLUX.1: [/docs/en/Model_Details/FLUX.md](/docs/en/Model_Details/FLUX.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8 GB of VRAM.
```python
import torch
from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig
vram_config = {
"offload_dtype": torch.float8_e4m3fn,
"offload_device": "cpu",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config),
],
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 1,
)
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
image = pipe(prompt=prompt, seed=0)
image.save("image.jpg")
```
</details>
<details>
<summary>Model Lineage</summary>
```mermaid
graph LR;
FLUX.1-Series-->black-forest-labs/FLUX.1-dev;
FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev;
FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev;
black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series;
FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta;
FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha;
FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler;
black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter;
black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou;
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen;
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev;
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev;
black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview;
black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit;
Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit;
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2;
Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2;
```
</details>
<details>
<summary>Examples</summary>
Example code for FLUX.1 is available at: [/examples/flux/](/examples/flux/)
| Model ID | Extra Args | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|-|
|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py)|
|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|
|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|
|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](/examples/flux/model_inference/Step1X-Edit.py)|[code](/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](/examples/flux/model_training/full/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_lora/Step1X-Edit.py)|
|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](/examples/flux/model_inference/FLEX.2-preview.py)|[code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py)|
|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](/examples/flux/model_training/full/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_lora/Nexus-Gen.py)|
</details>
### Video Synthesis
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
#### Wan: [/docs/en/Model_Details/Wan.md](/docs/en/Model_Details/Wan.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8 GB of VRAM.
```python
import torch
from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
)
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=True,
)
save_video(video, "video.mp4", fps=15, quality=5)
```
</details>
<details>
<summary>Model Lineage</summary>
```mermaid
graph LR;
Wan-Series-->Wan2.1-Series;
Wan-Series-->Wan2.2-Series;
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B;
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B;
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P;
Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P;
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P;
Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview;
iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B;
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B;
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series;
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP;
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control;
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series;
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP;
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control;
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series;
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control;
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP;
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera;
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series;
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control;
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP;
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera;
Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1;
Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video;
Wan-AI/Wan2.1-T2V-14B-->meituan-longcat/LongCat-Video;
Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B;
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B;
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B;
Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B;
Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B;
Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B;
Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series;
Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B;
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP;
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control;
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera;
```
</details>
<details>
<summary>Examples</summary>
Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
| Model ID | Extra Args | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](/examples/wanvideo/model_inference/LongCat-Video.py)|[code](/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
</details>
## Innovative Achievements
DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements.
<details>
<summary>AttriCtrl: Attribute Intensity Control for Image Generation Models</summary>
- Paper: [AttriCtrl: Fine-Grained Control of Aesthetic Attribute Intensity in Diffusion Models](https://arxiv.org/abs/2508.02151)
- Sample Code: [/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py](/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py)
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev)
|brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9|
|-|-|-|-|-|
|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.5.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.7.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.9.jpg)|
</details>
<details>
<summary>AutoLoRA: Automated LoRA Retrieval and Fusion</summary>
- Paper: [AutoLoRA: Automatic LoRA Retrieval and Fine-Grained Gated Fusion for Text-to-Image Generation](https://arxiv.org/abs/2508.02107)
- Sample Code: [/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)
||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)|
|-|-|-|-|-|
|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_0.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|
|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|
|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|
|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_3_3.jpg)|
</details>
<details>
<summary>Nexus-Gen: Unified Architecture for Image Understanding, Generation, and Editing</summary>
- Detailed Page: https://github.com/modelscope/Nexus-Gen
- Paper: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
- Dataset: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
- Online Experience: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
![](https://github.com/modelscope/Nexus-Gen/raw/main/assets/illustrations/gen_edit.jpg)
</details>
<details>
<summary>ArtAug: Aesthetic Enhancement for Image Generation Models</summary>
- Detailed Page: [./examples/ArtAug/](./examples/ArtAug/)
- Paper: [ArtAug: Enhancing Text-to-Image Generation through Synthesis-Understanding Interaction](https://arxiv.org/abs/2412.12888)
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
- Online Experience: [ModelScope AIGC Tab](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0)
|FLUX.1-dev|FLUX.1-dev + ArtAug LoRA|
|-|-|
|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/584186bc-9855-4140-878e-99541f9a757f)|
|![image_1_base](https://github.com/user-attachments/assets/e1d5c505-b423-45fe-be01-25c2758f5417)|![image_1_enhance](https://github.com/user-attachments/assets/335908e3-d0bd-41c2-9d99-d10528a2d719)|
### Toon Shading
</details>
Render realistic videos in a flatten style and enable video editing features. [`examples/Diffutoon`](./examples/Diffutoon/)
<details>
<summary>EliGen: Precise Image Partition Control</summary>
- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
- Sample Code: [/examples/flux/model_inference/FLUX.1-dev-EliGen.py](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
- Online Experience: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
- Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|Entity Control Region|Generated Image|
|-|-|
|![eligen_example_2_mask_0](https://github.com/user-attachments/assets/1c6d9445-5022-4d91-ad2e-dc05321883d1)|![eligen_example_2_0](https://github.com/user-attachments/assets/86739945-cb07-4a49-b3b3-3bb65c90d14f)|
</details>
<details>
<summary>ExVideo: Extended Training for Video Generation Models</summary>
- Project Page: [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
- Paper: [ExVideo: Extending Video Diffusion Models via Parameter-Efficient Post-Tuning](https://arxiv.org/abs/2406.14130)
- Sample Code: Please refer to the [older version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/ExVideo)
- Model: [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
</details>
<details>
<summary>Diffutoon: High-Resolution Anime-Style Video Rendering</summary>
- Project Page: [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
- Paper: [Diffutoon: High-Resolution Editable Toon Shading via Diffusion Models](https://arxiv.org/abs/2401.16224)
- Sample Code: Please refer to the [older version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/Diffutoon)
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/20528af5-5100-474a-8cdc-440b9efdd86c
</details>
### Video Stylization
Video stylization without video models. [`examples/diffsynth`](./examples/diffsynth/)
<details>
<summary>DiffSynth: The Original Version of This Project</summary>
- Project Page: [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/)
- Paper: [DiffSynth: Latent In-Iteration Deflickering for Realistic Video Synthesis](https://arxiv.org/abs/2308.03463)
- Sample Code: Please refer to the [older version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/diffsynth)
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
### Chinese Models
Use Hunyuan-DiT to generate images with Chinese prompts. We also support LoRA fine-tuning of this model. [`examples/hunyuan_dit`](./examples/hunyuan_dit/)
Prompt: 少女手捧鲜花,坐在公园的长椅上,夕阳的余晖洒在少女的脸庞,整个画面充满诗意的美感
|1024x1024|2048x2048 (highres-fix)|
|-|-|
|![image_1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/2b6528cf-a229-46e9-b7dd-4a9475b07308)|![image_2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/11d264ec-966b-45c9-9804-74b60428b866)|
Prompt: 一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉
|Without LoRA|With LoRA|
|-|-|
|![image_without_lora](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/1aa21de5-a992-4b66-b14f-caa44e08876e)|![image_with_lora](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/83a0a41a-691f-4610-8e7b-d8e17c50a282)|
## Usage (in WebUI)
```
python -m streamlit run DiffSynth_Studio.py
```
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/93085557-73f3-4eee-a205-9829591ef954
</details>

791
README_zh.md Normal file
View File

@@ -0,0 +1,791 @@
# DiffSynth-Studio
<a href="https://github.com/modelscope/DiffSynth-Studio"><img src=".github/workflows/logo.gif" title="Logo" style="max-width:100%;" width="55" /></a> <a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></p>
[![PyPI](https://img.shields.io/pypi/v/DiffSynth)](https://pypi.org/project/DiffSynth/)
[![license](https://img.shields.io/github/license/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)
[![open issues](https://isitmaintained.com/badge/open/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/issues)
[![GitHub pull-requests](https://img.shields.io/github/issues-pr/modelscope/DiffSynth-Studio.svg)](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
[![GitHub latest commit](https://badgen.net/github/last-commit/modelscope/DiffSynth-Studio)](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
[Switch to English](./README.md)
## 简介
欢迎来到 Diffusion 模型的魔法世界DiffSynth-Studio 是由[魔搭社区](https://www.modelscope.cn/)团队开发和维护的开源 Diffusion 模型引擎。我们期望以框架建设孵化技术创新,凝聚开源社区的力量,探索生成式模型技术的边界!
DiffSynth 目前包括两个开源项目:
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): 聚焦于激进的技术探索,面向学术界,提供更前沿的模型能力支持。
* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): 聚焦于稳定的模型部署,面向工业界,提供更高的计算性能与更稳定的功能。
[DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) 与 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 是魔搭社区 AIGC 专区的核心引擎,欢迎体验我们精心打造的产品化功能:
* 魔搭社区 AIGC 专区 (面向中国用户): https://modelscope.cn/aigc/home
* ModelScope Civision (for global users): https://modelscope.ai/civision/home
> DiffSynth-Studio 文档:[中文版](/docs/zh/README.md)、[English version](/docs/en/README.md)
我们相信,一个完善的开源代码框架能够降低技术探索的门槛,我们基于这个代码库搞出了不少[有意思的技术](#创新成果)。或许你也有许多天马行空的构想,借助 DiffSynth-Studio你可以快速实现这些想法。为此我们为开发者准备了详细的文档我们希望通过这些文档帮助开发者理解 Diffusion 模型的原理,更期待与你一同拓展技术的边界。
## 更新历史
> DiffSynth-Studio 经历了大版本更新,部分旧功能已停止维护,如需使用旧版功能,请切换到大版本更新前的[最后一个历史版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3)。
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责因此新功能的开发进展会比较缓慢issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
- **2026年1月27日** [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) 发布,我们的 [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) 模型同步发布,在[魔搭创空间](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L)可直接体验,详见[文档](/docs/zh/Model_Details/Z-Image.md)。
- **2026年1月19日** 新增对 [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 和 [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/FLUX2.md)和[示例代码](/examples/flux2/)现已可用。
- **2026年1月12日** 我们训练并开源了一个文本引导的图层拆分模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)),这一模型输入一张图与一段文本描述,模型会将图像中与文本描述相关的图层拆分出来。更多细节请阅读我们的 blog[中文版](https://modelscope.cn/learn/4938)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-layered-control))。
- **2025年12月24日** 我们基于 Qwen-Image-Edit-2511 训练了一个 In-Context Editing LoRA 模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)这个模型可以输入三张图图A、图B、图C模型会自行分析图A到图B的变化并将这样的变化应用到图C生成图D。更多细节请阅读我们的 blog[中文版](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora))。
- **2025年12月9日** 我们基于 DiffSynth-Studio 2.0 训练了一个疯狂的模型:[Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)Image to LoRA。这一模型以图像为输入以 LoRA 为输出。尽管这个版本的模型在泛化能力、细节保持能力等方面还有很大改进空间,我们将这些模型开源,以启发更多创新性的研究工作。更多细节,请参考我们的 [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l)。
- **2025年12月4日** DiffSynth-Studio 2.0 发布!众多新功能上线
- [文档](/docs/zh/README.md)上线:我们的文档还在持续优化更新中
- [显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)模块升级,支持 Layer 级别的 Disk Offload同时释放内存与显存
- 新模型支持
- Z-Image Turbo: [模型](https://www.modelscope.ai/models/Tongyi-MAI/Z-Image-Turbo)、[文档](/docs/zh/Model_Details/Z-Image.md)、[代码](/examples/z_image/)
- FLUX.2-dev: [模型](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)、[文档](/docs/zh/Model_Details/FLUX2.md)、[代码](/examples/flux2/)
- 训练框架升级
- [拆分训练](/docs/zh/Training/Split_Training.md):支持自动化地将训练过程拆分为数据处理和训练两阶段(即使训练的是 ControlNet 或其他任意模型在数据处理阶段进行文本编码、VAE 编码等不需要梯度回传的计算,在训练阶段处理其他计算。速度更快,显存需求更少。
- [差分 LoRA 训练](/docs/zh/Training/Differential_LoRA.md):这是我们曾在 [ArtAug](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) 中使用的训练技术,目前已可用于任意模型的 LoRA 训练。
- [FP8 训练](/docs/zh/Training/FP8_Precision.md)FP8 在训练中支持应用到任意非训练模型,即梯度关闭或者梯度仅影响 LoRA 权重的模型。
<details>
<summary>更多</summary>
- **2025年11月4日** 支持了 [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) 模型,该模型基于 Wan 2.1 训练,支持根据参考视频生成相应的动作。
- **2025年10月30日** 支持了 [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) 模型,该模型支持文生视频、图生视频、视频续写。这个模型在本项目中沿用 Wan 的框架进行推理和训练。
- **2025年10月27日** 支持了 [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) 模型Wan 模型生态再添一员。
- **2025年9月23日** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) 发布!本模型由我们与淘天体验设计团队联合研发并开源。模型基于 Qwen-Image 构建,专为电商海报场景设计,支持精确的分区布局控制。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)。
- **2025年9月9日** 我们的训练框架支持了多种训练模式,目前已适配 Qwen-Image除标准 SFT 训练模式外,已支持 Direct Distill请参考[我们的示例代码](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)。这项功能是实验性的,我们将会继续完善已支持更全面的模型训练功能。
- **2025年8月28日** 我们支持了Wan2.2-S2V一个音频驱动的电影级视频生成模型。请参见[./examples/wanvideo/](./examples/wanvideo/)。
- **2025年8月21日** [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) 发布!相比于 V1 版本,训练数据集变为 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset),因此,生成的图像更符合 Qwen-Image 本身的图像分布和风格。 请参考[我们的示例代码](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)。
- **2025年8月21日** 我们开源了 [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) 结构控制 LoRA 模型,采用 In Context 的技术路线,支持多种类别的结构控制条件,包括 canny, depth, lineart, softedge, normal, openpose。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)。
- **2025年8月20日** 我们开源了 [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) 模型,提升了 Qwen-Image-Edit 对低分辨率图像输入的编辑效果。请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)
- **2025年8月19日** 🔥 Qwen-Image-Edit 开源,欢迎图像编辑模型新成员!
- **2025年8月18日** 我们训练并开源了 Qwen-Image 的图像重绘 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)。
- **2025年8月15日** 我们开源了 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset) 数据集。这是一个使用 Qwen-Image 模型生成的图像数据集,共包含 160,000 张`1024 x 1024`图像。它包括通用、英文文本渲染和中文文本渲染子集。我们为每张图像提供了图像描述、实体和结构控制图像的标注。开发者可以使用这个数据集来训练 Qwen-Image 模型的 ControlNet 和 EliGen 等模型,我们旨在通过开源推动技术发展!
- **2025年8月13日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)。
- **2025年8月12日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)。
- **2025年8月11日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA),沿用了与 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) 相同的训练流程,但模型结构修改为了 LoRA因此能够更好地与其他开源生态模型兼容。
- **2025年8月7日** 我们开源了 Qwen-Image 的实体控制 LoRA 模型 [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)。Qwen-Image-EliGen 能够实现实体级可控的文生图。技术细节请参见[论文](https://arxiv.org/abs/2501.01097)。训练数据集:[EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)。
- **2025年8月5日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full),实现了约 5 倍加速。
- **2025年8月4日** 🔥 Qwen-Image 开源,欢迎图像生成模型家族新成员!
- **2025年8月1日** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) 开源,这是一个专注于美学摄影的文生图模型。我们第一时间提供了全方位支持,包括低显存逐层 offload、LoRA 训练、全量训练。详细信息请参考 [./examples/flux/](./examples/flux/)。
- **2025年7月28日** Wan 2.2 开源,我们第一时间提供了全方位支持,包括低显存逐层 offload、FP8 量化、序列并行、LoRA 训练、全量训练。详细信息请参考 [./examples/wanvideo/](./examples/wanvideo/)。
- **2025年7月11日** 我们提出 Nexus-Gen一个将大语言模型LLM的语言推理能力与扩散模型的图像生成能力相结合的统一框架。该框架支持无缝的图像理解、生成和编辑任务。
- 论文: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
- Github 仓库: https://github.com/modelscope/Nexus-Gen
- 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
- 训练数据集: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
- 在线体验: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
- **2025年6月15日** ModelScope 官方评测框架 [EvalScope](https://github.com/modelscope/evalscope) 现已支持文生图生成评测。请参考[最佳实践](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html)指南进行尝试。
- **2025年3月25日** 我们的新开源项目 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 现已开源!专注于稳定的模型部署,面向工业界,提供更好的工程支持、更高的计算性能和更稳定的功能。
- **2025年3月31日** 我们支持 InfiniteYou一种用于 FLUX 的人脸特征保留方法。更多细节请参考 [./examples/InfiniteYou/](./examples/InfiniteYou/)。
- **2025年3月13日** 我们支持 HunyuanVideo-I2V即腾讯开源的 HunyuanVideo 的图像到视频生成版本。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。
- **2025年2月25日** 我们支持 Wan-Video这是阿里巴巴开源的一系列最先进的视频合成模型。详见 [./examples/wanvideo/](./examples/wanvideo/)。
- **2025年2月17日** 我们支持 [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)!先进的视频合成模型!详见 [./examples/stepvideo](./examples/stepvideo/)。
- **2024年12月31日** 我们提出 EliGen一种用于精确实体级别控制的文本到图像生成的新框架并辅以修复融合管道将其能力扩展到图像修复任务。EliGen 可以无缝集成现有的社区模型,如 IP-Adapter 和 In-Context LoRA提升其通用性。更多详情请见 [./examples/EntityControl](./examples/EntityControl/)。
- 论文: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
- 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
- 在线体验: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
- 训练数据集: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
- **2024年12月19日** 我们为 HunyuanVideo 实现了高级显存管理,使得在 24GB 显存下可以生成分辨率为 129x720x1280 的视频,或在仅 6GB 显存下生成分辨率为 129x512x384 的视频。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。
- **2024年12月18日** 我们提出 ArtAug一种通过合成-理解交互来改进文生图模型的方法。我们以 LoRA 格式为 FLUX.1-dev 训练了一个 ArtAug 增强模块。该模型将 Qwen2-VL-72B 的美学理解融入 FLUX.1-dev从而提升了生成图像的质量。
- 论文: https://arxiv.org/abs/2412.12888
- 示例: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug
- 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
- 演示: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (即将上线)
- **2024年10月25日** 我们提供了广泛的 FLUX ControlNet 支持。该项目支持许多不同的 ControlNet 模型并且可以自由组合即使它们的结构不同。此外ControlNet 模型兼容高分辨率优化和分区控制技术,能够实现非常强大的可控图像生成。详见 [`./examples/ControlNet/`](./examples/ControlNet/)。
- **2024年10月8日** 我们发布了基于 CogVideoX-5B 和 ExVideo 的扩展 LoRA。您可以从 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 或 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 下载此模型。
- **2024年8月22日** 本项目现已支持 CogVideoX-5B。详见 [此处](/examples/video_synthesis/)。我们为这个文生视频模型提供了几个有趣的功能,包括:
- 文本到视频
- 视频编辑
- 自我超分
- 视频插帧
- **2024年8月22日** 我们实现了一个有趣的画笔功能,支持所有文生图模型。现在,您可以在 AI 的辅助下使用画笔创作惊艳的图像了!
- 在我们的 [WebUI](#usage-in-webui) 中使用它。
- **2024年8月21日** DiffSynth-Studio 现已支持 FLUX。
- 启用 CFG 和高分辨率修复以提升视觉质量。详见 [此处](/examples/image_synthesis/README.md)
- LoRA、ControlNet 和其他附加模型将很快推出。
- **2024年6月21日** 我们提出 ExVideo一种旨在增强视频生成模型能力的后训练微调技术。我们将 Stable Video Diffusion 进行了扩展,实现了长达 128 帧的长视频生成。
- [项目页面](https://ecnu-cilab.github.io/ExVideoProjectPage/)
- 源代码已在此仓库中发布。详见 [`examples/ExVideo`](./examples/ExVideo/)。
- 模型已发布于 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) 和 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1)。
- 技术报告已发布于 [arXiv](https://arxiv.org/abs/2406.14130)。
- 您可以在此 [演示](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1) 中试用 ExVideo
- **2024年6月13日** DiffSynth Studio 已迁移至 ModelScope。开发团队也从“我”转变为“我们”。当然我仍会参与后续的开发和维护工作。
- **2024年1月29日** 我们提出 Diffutoon这是一个出色的卡通着色解决方案。
- [项目页面](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
- 源代码已在此项目中发布。
- 技术报告IJCAI 2024已发布于 [arXiv](https://arxiv.org/abs/2401.16224)。
- **2023年12月8日** 我们决定启动一个新项目,旨在释放扩散模型的潜力,尤其是在视频合成方面。该项目的开发工作正式开始。
- **2023年11月15日** 我们提出 FastBlend一种强大的视频去闪烁算法。
- sd-webui 扩展已发布于 [GitHub](https://github.com/Artiprocher/sd-webui-fastblend)。
- 演示视频已在 Bilibili 上展示,包含三个任务:
- [视频去闪烁](https://www.bilibili.com/video/BV1d94y1W7PE)
- [视频插帧](https://www.bilibili.com/video/BV1Lw411m71p)
- [图像驱动的视频渲染](https://www.bilibili.com/video/BV1RB4y1Z7LF)
- 技术报告已发布于 [arXiv](https://arxiv.org/abs/2311.09265)。
- 其他用户开发的非官方 ComfyUI 扩展已发布于 [GitHub](https://github.com/AInseven/ComfyUI-fastblend)。
- **2023年10月1日** 我们发布了该项目的早期版本,名为 FastSDXL。这是构建一个扩散引擎的初步尝试。
- 源代码已发布于 [GitHub](https://github.com/Artiprocher/FastSDXL)。
- FastSDXL 包含一个可训练的 OLSS 调度器,以提高效率。
- OLSS 的原始仓库位于 [此处](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler)。
- 技术报告CIKM 2023已发布于 [arXiv](https://arxiv.org/abs/2305.14677)。
- 演示视频已发布于 [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj)。
- 由于 OLSS 需要额外训练,我们未在本项目中实现它。
- **2023年8月29日** 我们提出 DiffSynth一个视频合成框架。
- [项目页面](https://ecnu-cilab.github.io/DiffSynth.github.io/)。
- 源代码已发布在 [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth)。
- 技术报告ECML PKDD 2024已发布于 [arXiv](https://arxiv.org/abs/2308.03463)。
</details>
## 安装
从源码安装(推荐):
```
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
更多安装方式,以及非 NVIDIA GPU 的安装,请参考[安装文档](/docs/zh/Pipeline_Usage/Setup.md)。
</details>
## 基础框架
DiffSynth-Studio 为主流 Diffusion 模型(包括 FLUX、Wan 等)重新设计了推理和训练流水线,能够实现高效的显存管理、灵活的模型训练。
<details>
<summary>环境变量配置</summary>
> 在进行模型推理和训练前,可通过[环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md)配置模型下载源等。
>
> 本项目默认从魔搭社区下载模型。对于非中国区域的用户,可以通过以下配置从魔搭社区的国际站下载模型:
>
> ```python
> import os
> os.environ["MODELSCOPE_DOMAIN"] = "www.modelscope.ai"
> ```
>
> 如需从其他站点下载,请修改[环境变量 DIFFSYNTH_DOWNLOAD_SOURCE](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_download_source)。
</details>
### 图像生成模型
![Image](https://github.com/user-attachments/assets/c01258e2-f251-441a-aa1e-ebb22f02594d)
#### Z-Image[/docs/zh/Model_Details/Z-Image.md](/docs/zh/Model_Details/Z-Image.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) 模型并进行推理。FP8 精度量化会导致明显的图像质量劣化,因此不建议在 Z-Image Turbo 模型上开启任何量化,仅建议开启 CPU Offload最低 8G 显存即可运行。
```python
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = ZImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
image = pipe(prompt=prompt, seed=42, rand_device="cuda")
image.save("image.jpg")
```
</details>
<details>
<summary>示例代码</summary>
Z-Image 的示例代码位于:[/examples/z_image/](/examples/z_image/)
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
</details>
#### FLUX.2: [/docs/zh/Model_Details/FLUX2.md](/docs/zh/Model_Details/FLUX2.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 10G 显存即可运行。
```python
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene."
image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50)
image.save("image.jpg")
```
</details>
<details>
<summary>示例代码</summary>
FLUX.2 的示例代码位于:[/examples/flux2/](/examples/flux2/)
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
</details>
#### Qwen-Image: [/docs/zh/Model_Details/Qwen-Image.md](/docs/zh/Model_Details/Qwen-Image.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
```python
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=40)
image.save("image.jpg")
```
</details>
<details>
<summary>模型血缘</summary>
```mermaid
graph LR;
Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit;
Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509;
Qwen/Qwen-Image-->EliGen-Series;
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen;
DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2;
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster;
Qwen/Qwen-Image-->Distill-Series;
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full;
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA;
Qwen/Qwen-Image-->ControlNet-Series;
ControlNet-Series-->Blockwise-ControlNet-Series;
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny;
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth;
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint;
ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union;
Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix;
```
</details>
<details>
<summary>示例代码</summary>
Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/)
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
</details>
#### FLUX.1: [/docs/zh/Model_Details/FLUX.md](/docs/zh/Model_Details/FLUX.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
```python
import torch
from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig
vram_config = {
"offload_dtype": torch.float8_e4m3fn,
"offload_device": "cpu",
"onload_dtype": torch.float8_e4m3fn,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config),
],
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 1,
)
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
image = pipe(prompt=prompt, seed=0)
image.save("image.jpg")
```
</details>
<details>
<summary>模型血缘</summary>
```mermaid
graph LR;
FLUX.1-Series-->black-forest-labs/FLUX.1-dev;
FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev;
FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev;
black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series;
FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta;
FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha;
FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler;
black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter;
black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou;
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen;
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev;
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev;
black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview;
black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit;
Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit;
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2;
Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2;
```
</details>
<details>
<summary>示例代码</summary>
FLUX.1 的示例代码位于:[/examples/flux/](/examples/flux/)
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|-|
|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py)|
|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|
|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|
|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](/examples/flux/model_inference/Step1X-Edit.py)|[code](/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](/examples/flux/model_training/full/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_lora/Step1X-Edit.py)|
|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](/examples/flux/model_inference/FLEX.2-preview.py)|[code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py)|
|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](/examples/flux/model_training/full/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_lora/Nexus-Gen.py)|
</details>
### 视频生成模型
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
#### Wan: [/docs/zh/Model_Details/Wan.md](/docs/zh/Model_Details/Wan.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
```python
import torch
from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
)
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=True,
)
save_video(video, "video.mp4", fps=15, quality=5)
```
</details>
<details>
<summary>模型血缘</summary>
```mermaid
graph LR;
Wan-Series-->Wan2.1-Series;
Wan-Series-->Wan2.2-Series;
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B;
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B;
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P;
Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P;
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P;
Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview;
iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B;
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B;
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series;
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP;
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control;
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series;
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP;
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control;
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series;
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control;
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP;
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera;
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series;
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control;
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP;
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera;
Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1;
Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video;
Wan-AI/Wan2.1-T2V-14B-->meituan-longcat/LongCat-Video;
Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B;
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B;
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B;
Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B;
Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B;
Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B;
Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series;
Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B;
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP;
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control;
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera;
```
</details>
<details>
<summary>示例代码</summary>
Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/)
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](/examples/wanvideo/model_inference/LongCat-Video.py)|[code](/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
</details>
## 创新成果
DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。
<details>
<summary>AttriCtrl: 图像生成模型的属性强度控制</summary>
- 论文:[AttriCtrl: Fine-Grained Control of Aesthetic Attribute Intensity in Diffusion Models
](https://arxiv.org/abs/2508.02151)
- 代码样例:[/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py](/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py)
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev)
|brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9|
|-|-|-|-|-|
|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.5.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.7.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.9.jpg)|
</details>
<details>
<summary>AutoLoRA: 自动化的 LoRA 检索和融合</summary>
- 论文:[AutoLoRA: Automatic LoRA Retrieval and Fine-Grained Gated Fusion for Text-to-Image Generation
](https://arxiv.org/abs/2508.02107)
- 代码样例:[/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)
||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)|
|-|-|-|-|-|
|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_0.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|
|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|
|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|
|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_3_3.jpg)|
</details>
<details>
<summary>Nexus-Gen: 统一架构的图像理解、生成、编辑</summary>
- 详细页面https://github.com/modelscope/Nexus-Gen
- 论文:[Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
- 数据集:[ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
- 在线体验:[ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
![](https://github.com/modelscope/Nexus-Gen/raw/main/assets/illustrations/gen_edit.jpg)
</details>
<details>
<summary>ArtAug: 图像生成模型的美学提升</summary>
- 详细页面:[./examples/ArtAug/](./examples/ArtAug/)
- 论文:[ArtAug: Enhancing Text-to-Image Generation through Synthesis-Understanding Interaction](https://arxiv.org/abs/2412.12888)
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
- 在线体验:[ModelScope AIGC Tab](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0)
|FLUX.1-dev|FLUX.1-dev + ArtAug LoRA|
|-|-|
|![image_1_base](https://github.com/user-attachments/assets/e1d5c505-b423-45fe-be01-25c2758f5417)|![image_1_enhance](https://github.com/user-attachments/assets/335908e3-d0bd-41c2-9d99-d10528a2d719)|
</details>
<details>
<summary>EliGen: 精准的图像分区控制</summary>
- 论文:[EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
- 代码样例:[/examples/flux/model_inference/FLUX.1-dev-EliGen.py](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
- 在线体验:[ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
- 数据集:[EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|实体控制区域|生成图像|
|-|-|
|![eligen_example_2_mask_0](https://github.com/user-attachments/assets/1c6d9445-5022-4d91-ad2e-dc05321883d1)|![eligen_example_2_0](https://github.com/user-attachments/assets/86739945-cb07-4a49-b3b3-3bb65c90d14f)|
</details>
<details>
<summary>ExVideo: 视频生成模型的扩展训练</summary>
- 项目页面:[Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
- 论文:[ExVideo: Extending Video Diffusion Models via Parameter-Efficient Post-Tuning](https://arxiv.org/abs/2406.14130)
- 代码样例:请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/ExVideo)查看
- 模型:[ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
</details>
<details>
<summary>Diffutoon: 高分辨率动漫风格视频渲染</summary>
- 项目页面:[Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
- 论文:[Diffutoon: High-Resolution Editable Toon Shading via Diffusion Models](https://arxiv.org/abs/2401.16224)
- 代码样例:请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/Diffutoon)查看
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
</details>
<details>
<summary>DiffSynth: 本项目的初代版本</summary>
- 项目页面:[Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/)
- 论文:[DiffSynth: Latent In-Iteration Deflickering for Realistic Video Synthesis](https://arxiv.org/abs/2308.03463)
- 代码样例:请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/diffsynth)查看
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
</details>

View File

@@ -1,7 +0,0 @@
{
"cls_token": "[CLS]",
"mask_token": "[MASK]",
"pad_token": "[PAD]",
"sep_token": "[SEP]",
"unk_token": "[UNK]"
}

View File

@@ -1,16 +0,0 @@
{
"cls_token": "[CLS]",
"do_basic_tokenize": true,
"do_lower_case": true,
"mask_token": "[MASK]",
"name_or_path": "hfl/chinese-roberta-wwm-ext",
"never_split": null,
"pad_token": "[PAD]",
"sep_token": "[SEP]",
"special_tokens_map_file": "/home/chenweifeng/.cache/huggingface/hub/models--hfl--chinese-roberta-wwm-ext/snapshots/5c58d0b8ec1d9014354d691c538661bf00bfdb44/special_tokens_map.json",
"strip_accents": null,
"tokenize_chinese_chars": true,
"tokenizer_class": "BertTokenizer",
"unk_token": "[UNK]",
"model_max_length": 77
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,28 +0,0 @@
{
"_name_or_path": "/home/patrick/t5/mt5-xl",
"architectures": [
"MT5ForConditionalGeneration"
],
"d_ff": 5120,
"d_kv": 64,
"d_model": 2048,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 1,
"feed_forward_proj": "gated-gelu",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"layer_norm_epsilon": 1e-06,
"model_type": "mt5",
"num_decoder_layers": 24,
"num_heads": 32,
"num_layers": 24,
"output_past": true,
"pad_token_id": 0,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"tokenizer_class": "T5Tokenizer",
"transformers_version": "4.10.0.dev0",
"use_cache": true,
"vocab_size": 250112
}

View File

@@ -1 +0,0 @@
{"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>"}

View File

@@ -1 +0,0 @@
{"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "extra_ids": 0, "additional_special_tokens": null, "special_tokens_map_file": "", "tokenizer_file": null, "name_or_path": "google/mt5-small", "model_max_length": 256, "legacy": true}

File diff suppressed because it is too large Load Diff

View File

@@ -1,24 +0,0 @@
{
"bos_token": {
"content": "<|startoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"pad_token": "<|endoftext|>",
"unk_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
}
}

View File

@@ -1,34 +0,0 @@
{
"add_prefix_space": false,
"bos_token": {
"__type": "AddedToken",
"content": "<|startoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"do_lower_case": true,
"eos_token": {
"__type": "AddedToken",
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"errors": "replace",
"model_max_length": 77,
"name_or_path": "openai/clip-vit-large-patch14",
"pad_token": "<|endoftext|>",
"special_tokens_map_file": "./special_tokens_map.json",
"tokenizer_class": "CLIPTokenizer",
"unk_token": {
"__type": "AddedToken",
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,24 +0,0 @@
{
"bos_token": {
"content": "<|startoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"pad_token": "!",
"unk_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
}
}

View File

@@ -1,38 +0,0 @@
{
"add_prefix_space": false,
"added_tokens_decoder": {
"0": {
"content": "!",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"49406": {
"content": "<|startoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": true
},
"49407": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": true
}
},
"bos_token": "<|startoftext|>",
"clean_up_tokenization_spaces": true,
"do_lower_case": true,
"eos_token": "<|endoftext|>",
"errors": "replace",
"model_max_length": 77,
"pad_token": "!",
"tokenizer_class": "CLIPTokenizer",
"unk_token": "<|endoftext|>"
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1 @@
from .data import *
from .models import *
from .prompts import *
from .schedulers import *
from .pipelines import *
from .controlnets import *
from .core import *

View File

@@ -0,0 +1,2 @@
from .model_configs import MODEL_CONFIGS
from .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS

View File

@@ -0,0 +1,594 @@
qwen_image_series = [
{
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors")
"model_hash": "0319a1cb19835fb510907dd3367c95ff",
"model_name": "qwen_image_dit",
"model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
},
{
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
"model_hash": "8004730443f55db63092006dd9f7110e",
"model_name": "qwen_image_text_encoder",
"model_class": "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.qwen_image_text_encoder.QwenImageTextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "ed4ea5824d55ec3107b09815e318123a",
"model_name": "qwen_image_vae",
"model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors")
"model_hash": "073bce9cf969e317e5662cd570c3e79c",
"model_name": "qwen_image_blockwise_controlnet",
"model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors")
"model_hash": "a9e54e480a628f0b956a688a81c33bab",
"model_name": "qwen_image_blockwise_controlnet",
"model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
"extra_kwargs": {"additional_in_dim": 4},
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors")
"model_hash": "469c78b61e3e31bc9eec0d0af3d3f2f8",
"model_name": "siglip2_image_encoder",
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors")
"model_hash": "5722b5c873720009de96422993b15682",
"model_name": "dinov3_image_encoder",
"model_class": "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder",
},
{
# Example:
"model_hash": "a166c33455cdbd89c0888a3645ca5c0f",
"model_name": "qwen_image_image2lora_coarse",
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
},
{
# Example:
"model_hash": "a5476e691767a4da6d3a6634a10f7408",
"model_name": "qwen_image_image2lora_fine",
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
"extra_kwargs": {"residual_length": 37*37+7, "residual_mid_dim": 64}
},
{
# Example:
"model_hash": "0aad514690602ecaff932c701cb4b0bb",
"model_name": "qwen_image_image2lora_style",
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
"extra_kwargs": {"compress_dim": 64, "use_residual": False}
},
{
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
"model_hash": "8dc8cda05de16c73afa755e2c1ce2839",
"model_name": "qwen_image_dit",
"model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
"extra_kwargs": {"use_layer3d_rope": True, "use_additional_t_cond": True}
},
{
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "44b39ddc499e027cfb24f7878d7416b9",
"model_name": "qwen_image_vae",
"model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
"extra_kwargs": {"image_channels": 4}
},
]
wan_series = [
{
# Example: ModelConfig(model_id="krea/krea-realtime-video", origin_file_pattern="krea-realtime-video-14b.safetensors")
"model_hash": "5ec04e02b42d2580483ad69f4e76346a",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth")
"model_hash": "9c8818c2cbea55eca56c7b447df170da",
"model_name": "wan_video_text_encoder",
"model_class": "diffsynth.models.wan_video_text_encoder.WanTextEncoder",
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth")
"model_hash": "ccc42284ea13e1ad04693284c7a09be6",
"model_name": "wan_video_vae",
"model_class": "diffsynth.models.wan_video_vae.WanVideoVAE",
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter",
},
{
# Example: ModelConfig(model_id="meituan-longcat/LongCat-Video", origin_file_pattern="dit/diffusion_pytorch_model*.safetensors")
"model_hash": "8b27900f680d7251ce44e2dc8ae1ffef",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel",
},
{
# Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
"model_hash": "5f90e66a0672219f12d9a626c8c21f61",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTFromDiffusers"
},
{
# Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
"model_hash": "5f90e66a0672219f12d9a626c8c21f61",
"model_name": "wan_video_vap",
"model_class": "diffsynth.models.wan_video_mot.MotWanModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_mot.WanVideoMotStateDictConverter"
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth")
"model_hash": "5941c53e207d62f20f9025686193c40b",
"model_name": "wan_video_image_encoder",
"model_class": "diffsynth.models.wan_video_image_encoder.WanImageEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_image_encoder.WanImageEncoderStateDictConverter"
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors")
"model_hash": "dbd5ec76bbf977983f972c151d545389",
"model_name": "wan_video_motion_controller",
"model_class": "diffsynth.models.wan_video_motion_controller.WanMotionControllerModel",
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "9269f8db9040a9d860eaca435be61814",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "3ef3b1f8e1dab83d5b71fd7b617f859f",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_image_pos_emb': True}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "349723183fc063b2bfc10bb2835cf677",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "6d6ccde6845b95ad9114ab993d917893",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "efa44cddf936c70abd0ea28b6cbe946c",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "6bfcfb3b342cb286ce886889d519a77e",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "ac6a5aa74f4a0aab6f64eb9a72f19901",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "70ddad9d3a133785da5ea371aae09504",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': True}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "b61c605c2adbd23124d152ed28e049ae",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "26bde73488a92e64cc20b0a7485b9e5b",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True}
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "aafcfd9672c3a2456dc46e1cb6e52c70",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
},
{
# Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "a61453409b67cd3246cf0c3bebad47ba",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06},
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
},
{
# Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "a61453409b67cd3246cf0c3bebad47ba",
"model_name": "wan_video_vace",
"model_class": "diffsynth.models.wan_video_vace.VaceWanModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter"
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "7a513e1f257a861512b1afd387a8ecd9",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "7a513e1f257a861512b1afd387a8ecd9",
"model_name": "wan_video_vace",
"model_class": "diffsynth.models.wan_video_vace.VaceWanModel",
"extra_kwargs": {'vace_layers': (0, 5, 10, 15, 20, 25, 30, 35), 'vace_in_dim': 96, 'patch_size': (1, 2, 2), 'has_image_input': False, 'dim': 5120, 'num_heads': 40, 'ffn_dim': 13824, 'eps': 1e-06},
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter"
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "31fa352acb8a1b1d33cd8764273d80a2",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter"
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "31fa352acb8a1b1d33cd8764273d80a2",
"model_name": "wan_video_animate_adapter",
"model_class": "diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter",
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_animate_adapter.WanAnimateAdapterStateDictConverter"
},
{
# Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
"model_hash": "47dbeab5e560db3180adf51dc0232fb1",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24, 'require_clip_embedding': False}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
"model_hash": "2267d489f0ceb9f21836532952852ee5",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 52, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True, 'require_clip_embedding': False},
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
"model_hash": "5b013604280dd715f8457c6ed6d6a626",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'require_clip_embedding': False}
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "966cffdcc52f9c46c391768b27637614",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit_s2v.WanS2VModel",
"extra_kwargs": {'dim': 5120, 'in_dim': 16, 'ffn_dim': 13824, 'out_dim': 16, 'text_dim': 4096, 'freq_dim': 256, 'eps': 1e-06, 'patch_size': (1, 2, 2), 'num_heads': 40, 'num_layers': 40, 'cond_dim': 16, 'audio_dim': 1024, 'num_audio_token': 4}
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "1f5ab7703c6fc803fdded85ff040c316",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 3072, 'ffn_dim': 14336, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 48, 'num_heads': 24, 'num_layers': 30, 'eps': 1e-06, 'seperated_timestep': True, 'require_clip_embedding': False, 'require_vae_embedding': False, 'fuse_vae_embedding_in_latents': True}
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth")
"model_hash": "e1de6c02cdac79f8b739f4d3698cd216",
"model_name": "wan_video_vae",
"model_class": "diffsynth.models.wan_video_vae.WanVideoVAE38",
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter",
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors")
"model_hash": "06be60f3a4526586d8431cd038a71486",
"model_name": "wans2v_audio_encoder",
"model_class": "diffsynth.models.wav2vec.WanS2VAudioEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter",
},
]
flux_series = [
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors")
"model_hash": "a29710fea6dddb0314663ee823598e50",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
},
{
# Supported due to historical reasons.
"model_hash": "605c56eab23e9e2af863ad8f0813a25d",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverterFromDiffusers",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors")
"model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
"model_name": "flux_text_encoder_clip",
"model_class": "diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_clip.FluxTextEncoderClipStateDictConverter",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors")
"model_hash": "22540b49eaedbc2f2784b2091a234c7c",
"model_name": "flux_text_encoder_t5",
"model_class": "diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_t5.FluxTextEncoderT5StateDictConverter",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors")
"model_hash": "21ea55f476dfc4fd135587abb59dfe5d",
"model_name": "flux_vae_encoder",
"model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors")
"model_hash": "21ea55f476dfc4fd135587abb59dfe5d",
"model_name": "flux_vae_decoder",
"model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors")
"model_hash": "d02f41c13549fa5093d3521f62a5570a",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"extra_kwargs": {'input_dim': 196, 'num_blocks': 8},
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors")
"model_hash": "0629116fce1472503a66992f96f3eb1a",
"model_name": "flux_value_controller",
"model_class": "diffsynth.models.flux_value_control.SingleValueEncoder",
},
{
# Example: ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors")
"model_hash": "52357cb26250681367488a8954c271e8",
"model_name": "flux_controlnet",
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
"extra_kwargs": {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4},
},
{
# Example: ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors")
"model_hash": "78d18b9101345ff695f312e7e62538c0",
"model_name": "flux_controlnet",
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
"extra_kwargs": {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}},
},
{
# Example: ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors")
"model_hash": "b001c89139b5f053c715fe772362dd2a",
"model_name": "flux_controlnet",
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
"extra_kwargs": {"num_single_blocks": 0},
},
{
# Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin")
"model_hash": "c07c0f04f5ff55e86b4e937c7a40d481",
"model_name": "infiniteyou_image_projector",
"model_class": "diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_infiniteyou.FluxInfiniteYouImageProjectorStateDictConverter",
},
{
# Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors")
"model_hash": "7f9583eb8ba86642abb9a21a4b2c9e16",
"model_name": "flux_controlnet",
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
"extra_kwargs": {"num_joint_blocks": 4, "num_single_blocks": 10},
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors")
"model_hash": "77c2e4dd2440269eb33bfaa0d004f6ab",
"model_name": "flux_lora_encoder",
"model_class": "diffsynth.models.flux_lora_encoder.FluxLoRAEncoder",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev", origin_file_pattern="model.safetensors")
"model_hash": "30143afb2dea73d1ac580e0787628f8c",
"model_name": "flux_lora_patcher",
"model_class": "diffsynth.models.flux_lora_patcher.FluxLoraPatcher",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors")
"model_hash": "2bd19e845116e4f875a0a048e27fc219",
"model_name": "nexus_gen_llm",
"model_class": "diffsynth.models.nexus_gen.NexusGenAutoregressiveModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen.NexusGenAutoregressiveModelStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin")
"model_hash": "63c969fd37cce769a90aa781fbff5f81",
"model_name": "nexus_gen_editing_adapter",
"model_class": "diffsynth.models.nexus_gen_projector.NexusGenImageEmbeddingMerger",
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenMergerStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin")
"model_hash": "63c969fd37cce769a90aa781fbff5f81",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin")
"model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d",
"model_name": "nexus_gen_generation_adapter",
"model_class": "diffsynth.models.nexus_gen_projector.NexusGenAdapter",
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenAdapterStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin")
"model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
},
{
# Example: ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin")
"model_hash": "4daaa66cc656a8fe369908693dad0a35",
"model_name": "flux_ipadapter",
"model_class": "diffsynth.models.flux_ipadapter.FluxIpAdapter",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.FluxIpAdapterStateDictConverter",
},
{
# Example: ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors")
"model_hash": "04d8c1e20a1f1b25f7434f111992a33f",
"model_name": "siglip_vision_model",
"model_class": "diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.SiglipStateDictConverter",
},
{
# Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
"model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50",
"model_name": "step1x_connector",
"model_class": "diffsynth.models.step1x_connector.Qwen2Connector",
"state_dict_converter": "diffsynth.utils.state_dict_converters.step1x_connector.Qwen2ConnectorStateDictConverter",
},
{
# Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
"model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
"extra_kwargs": {"disable_guidance_embedder": True},
},
{
# Example: ModelConfig(model_id="MAILAND/majicflus_v1", origin_file_pattern="majicflus_v134.safetensors")
"model_hash": "3394f306c4cbf04334b712bf5aaed95f",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
},
]
flux2_series = [
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors")
"model_hash": "28fca3d8e5bf2a2d1271748a773f6757",
"model_name": "flux2_text_encoder",
"model_class": "diffsynth.models.flux2_text_encoder.Flux2TextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux2_text_encoder.Flux2TextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors")
"model_hash": "d38e1d5c5aec3b0a11e79327ac6e3b0f",
"model_name": "flux2_dit",
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "c54288e3ee12ca215898840682337b95",
"model_name": "flux2_vae",
"model_class": "diffsynth.models.flux2_vae.Flux2VAE",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors")
"model_hash": "3bde7b817fec8143028b6825a63180df",
"model_name": "flux2_dit",
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 7680, "num_attention_heads": 24, "num_layers": 5, "num_single_layers": 20}
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors")
"model_hash": "9195f3ea256fcd0ae6d929c203470754",
"model_name": "z_image_text_encoder",
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
"extra_kwargs": {"model_size": "8B"},
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors")
"model_hash": "39c6fc48f07bebecedbbaa971ff466c8",
"model_name": "flux2_dit",
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 12288, "num_attention_heads": 32, "num_layers": 8, "num_single_layers": 24}
},
]
z_image_series = [
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors")
"model_hash": "fc3a8a1247fe185ce116ccbe0e426c28",
"model_name": "z_image_dit",
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
},
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors")
"model_hash": "0f050f62a88876fea6eae0a18dac5a2e",
"model_name": "z_image_text_encoder",
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
},
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
"model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
"model_name": "flux_vae_encoder",
"model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverterDiffusers",
"extra_kwargs": {"use_conv_attention": False},
},
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
"model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
"model_name": "flux_vae_decoder",
"model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers",
"extra_kwargs": {"use_conv_attention": False},
},
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors")
"model_hash": "aa3563718e5c3ecde3dfbb020ca61180",
"model_name": "z_image_dit",
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
"extra_kwargs": {"siglip_feat_dim": 1152},
},
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors")
"model_hash": "89d48e420f45cff95115a9f3e698d44a",
"model_name": "siglip_vision_model_428m",
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M",
},
{
# Example: ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors")
"model_hash": "1677708d40029ab380a95f6c731a57d7",
"model_name": "z_image_controlnet",
"model_class": "diffsynth.models.z_image_controlnet.ZImageControlNet",
},
{
# Example: ???
"model_hash": "9510cb8cd1dd34ee0e4f111c24905510",
"model_name": "z_image_image2lora_style",
"model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
"extra_kwargs": {"compress_dim": 128},
},
]
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series

View File

@@ -0,0 +1,213 @@
flux_general_vram_config = {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.general_modules.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.flux_lora_encoder.LoRALayerBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.flux_lora_patcher.LoraMerger": "diffsynth.core.vram.layers.AutoWrappedModule",
}
VRAM_MANAGEMENT_MODULE_MAPS = {
"diffsynth.models.qwen_image_dit.QwenImageDiT": {
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionPatchEmbed": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.qwen_image_vae.QwenImageVAE": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.qwen_image_vae.QwenImageRMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.qwen_image_controlnet.BlockWiseControlBlock": {
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder": {
"transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder": {
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTLayerScale": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTRopePositionEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter": {
"diffsynth.models.wan_video_animate_adapter.FaceEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_animate_adapter.EqualLinear": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_animate_adapter.ConvLayer": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_animate_adapter.FusedLeakyReLU": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_animate_adapter.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wan_video_dit_s2v.WanS2VModel": {
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_dit_s2v.WanS2VDiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_dit_s2v.CausalAudioEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wan_video_dit.WanModel": {
"diffsynth.models.wan_video_dit.MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wan_video_image_encoder.WanImageEncoder": {
"diffsynth.models.wan_video_image_encoder.VisionTransformer": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wan_video_mot.MotWanModel": {
"diffsynth.models.wan_video_mot.MotWanAttentionBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wan_video_motion_controller.WanMotionControllerModel": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.wan_video_text_encoder.WanTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_text_encoder.T5RelativeEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_text_encoder.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wan_video_vace.VaceWanModel": {
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wan_video_vae.WanVideoVAE": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wan_video_vae.WanVideoVAE38": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wav2vec.WanS2VAudioEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.longcat_video_dit.RMSNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.longcat_video_dit.LayerNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.flux_dit.FluxDiT": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"diffsynth.models.flux_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip": flux_general_vram_config,
"diffsynth.models.flux_vae.FluxVAEEncoder": flux_general_vram_config,
"diffsynth.models.flux_vae.FluxVAEDecoder": flux_general_vram_config,
"diffsynth.models.flux_controlnet.FluxControlNet": flux_general_vram_config,
"diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector": flux_general_vram_config,
"diffsynth.models.flux_ipadapter.FluxIpAdapter": flux_general_vram_config,
"diffsynth.models.flux_lora_patcher.FluxLoraPatcher": flux_general_vram_config,
"diffsynth.models.step1x_connector.Qwen2Connector": flux_general_vram_config,
"diffsynth.models.flux_lora_encoder.FluxLoRAEncoder": flux_general_vram_config,
"diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.t5.modeling_t5.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.t5.modeling_t5.T5DenseActDense": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.t5.modeling_t5.T5DenseGatedActDense": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M": {
"transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.siglip.modeling_siglip.SiglipEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.MultiheadAttention": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.flux2_dit.Flux2DiT": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.flux2_text_encoder.Flux2TextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.mistral.modeling_mistral.MistralRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.flux2_vae.Flux2VAE": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.z_image_text_encoder.ZImageTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.z_image_dit.ZImageDiT": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.z_image_controlnet.ZImageControlNet": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M": {
"transformers.models.siglip2.modeling_siglip2.Siglip2VisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.siglip2.modeling_siglip2.Siglip2MultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
}

View File

@@ -1,2 +0,0 @@
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
from .processors import Annotator

View File

@@ -1,53 +0,0 @@
import torch
import numpy as np
from .processors import Processor_id
class ControlNetConfigUnit:
def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
self.processor_id = processor_id
self.model_path = model_path
self.scale = scale
class ControlNetUnit:
def __init__(self, processor, model, scale=1.0):
self.processor = processor
self.model = model
self.scale = scale
class MultiControlNetManager:
def __init__(self, controlnet_units=[]):
self.processors = [unit.processor for unit in controlnet_units]
self.models = [unit.model for unit in controlnet_units]
self.scales = [unit.scale for unit in controlnet_units]
def process_image(self, image, processor_id=None):
if processor_id is None:
processed_image = [processor(image) for processor in self.processors]
else:
processed_image = [self.processors[processor_id](image)]
processed_image = torch.concat([
torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
for image_ in processed_image
], dim=0)
return processed_image
def __call__(
self,
sample, timestep, encoder_hidden_states, conditionings,
tiled=False, tile_size=64, tile_stride=32
):
res_stack = None
for conditioning, model, scale in zip(conditionings, self.models, self.scales):
res_stack_ = model(
sample, timestep, encoder_hidden_states, conditioning,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
res_stack_ = [res * scale for res in res_stack_]
if res_stack is None:
res_stack = res_stack_
else:
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
return res_stack

View File

@@ -1,51 +0,0 @@
from typing_extensions import Literal, TypeAlias
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from controlnet_aux.processor import (
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
)
Processor_id: TypeAlias = Literal[
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
]
class Annotator:
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None):
if processor_id == "canny":
self.processor = CannyDetector()
elif processor_id == "depth":
self.processor = MidasDetector.from_pretrained(model_path).to("cuda")
elif processor_id == "softedge":
self.processor = HEDdetector.from_pretrained(model_path).to("cuda")
elif processor_id == "lineart":
self.processor = LineartDetector.from_pretrained(model_path).to("cuda")
elif processor_id == "lineart_anime":
self.processor = LineartAnimeDetector.from_pretrained(model_path).to("cuda")
elif processor_id == "openpose":
self.processor = OpenposeDetector.from_pretrained(model_path).to("cuda")
elif processor_id == "tile":
self.processor = None
else:
raise ValueError(f"Unsupported processor_id: {processor_id}")
self.processor_id = processor_id
self.detect_resolution = detect_resolution
def __call__(self, image):
width, height = image.size
if self.processor_id == "openpose":
kwargs = {
"include_body": True,
"include_hand": True,
"include_face": True
}
else:
kwargs = {}
if self.processor is not None:
detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
image = image.resize((width, height))
return image

View File

@@ -0,0 +1,6 @@
from .attention import *
from .data import *
from .gradient import *
from .loader import *
from .vram import *
from .device import *

View File

@@ -0,0 +1 @@
from .attention import attention_forward

View File

@@ -0,0 +1,121 @@
import torch, os
from einops import rearrange
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
try:
from sageattention import sageattn
SAGE_ATTN_AVAILABLE = True
except ModuleNotFoundError:
SAGE_ATTN_AVAILABLE = False
try:
import xformers.ops as xops
XFORMERS_AVAILABLE = True
except ModuleNotFoundError:
XFORMERS_AVAILABLE = False
def initialize_attention_priority():
if os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION') is not None:
return os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION').lower()
elif FLASH_ATTN_3_AVAILABLE:
return "flash_attention_3"
elif FLASH_ATTN_2_AVAILABLE:
return "flash_attention_2"
elif SAGE_ATTN_AVAILABLE:
return "sage_attention"
elif XFORMERS_AVAILABLE:
return "xformers"
else:
return "torch"
ATTENTION_IMPLEMENTATION = initialize_attention_priority()
def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", required_in_pattern="b n s d", dims=None):
dims = {} if dims is None else dims
if q_pattern != required_in_pattern:
q = rearrange(q, f"{q_pattern} -> {required_in_pattern}", **dims)
if k_pattern != required_in_pattern:
k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims)
if v_pattern != required_in_pattern:
v = rearrange(v, f"{q_pattern} -> {required_in_pattern}", **dims)
return q, k, v
def rearrange_out(out: torch.Tensor, out_pattern="b n s d", required_out_pattern="b n s d", dims=None):
dims = {} if dims is None else dims
if out_pattern != required_out_pattern:
out = rearrange(out, f"{required_out_pattern} -> {out_pattern}", **dims)
return out
def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None):
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale)
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def flash_attention_3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=scale)
if isinstance(out, tuple):
out = out[0]
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale)
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def sage_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = sageattn(q, k, v, sm_scale=scale)
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def xformers_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = xops.memory_efficient_attention(q, k, v, scale=scale)
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, compatibility_mode=False):
if compatibility_mode or (attn_mask is not None):
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale)
else:
if ATTENTION_IMPLEMENTATION == "flash_attention_3":
return flash_attention_3(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
elif ATTENTION_IMPLEMENTATION == "flash_attention_2":
return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
elif ATTENTION_IMPLEMENTATION == "sage_attention":
return sage_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
elif ATTENTION_IMPLEMENTATION == "xformers":
return xformers_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
else:
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)

View File

@@ -0,0 +1 @@
from .unified_dataset import UnifiedDataset

View File

@@ -0,0 +1,220 @@
import torch, torchvision, imageio, os
import imageio.v3 as iio
from PIL import Image
class DataProcessingPipeline:
def __init__(self, operators=None):
self.operators: list[DataProcessingOperator] = [] if operators is None else operators
def __call__(self, data):
for operator in self.operators:
data = operator(data)
return data
def __rshift__(self, pipe):
if isinstance(pipe, DataProcessingOperator):
pipe = DataProcessingPipeline([pipe])
return DataProcessingPipeline(self.operators + pipe.operators)
class DataProcessingOperator:
def __call__(self, data):
raise NotImplementedError("DataProcessingOperator cannot be called directly.")
def __rshift__(self, pipe):
if isinstance(pipe, DataProcessingOperator):
pipe = DataProcessingPipeline([pipe])
return DataProcessingPipeline([self]).__rshift__(pipe)
class DataProcessingOperatorRaw(DataProcessingOperator):
def __call__(self, data):
return data
class ToInt(DataProcessingOperator):
def __call__(self, data):
return int(data)
class ToFloat(DataProcessingOperator):
def __call__(self, data):
return float(data)
class ToStr(DataProcessingOperator):
def __init__(self, none_value=""):
self.none_value = none_value
def __call__(self, data):
if data is None: data = self.none_value
return str(data)
class LoadImage(DataProcessingOperator):
def __init__(self, convert_RGB=True, convert_RGBA=False):
self.convert_RGB = convert_RGB
self.convert_RGBA = convert_RGBA
def __call__(self, data: str):
image = Image.open(data)
if self.convert_RGB: image = image.convert("RGB")
if self.convert_RGBA: image = image.convert("RGBA")
return image
class ImageCropAndResize(DataProcessingOperator):
def __init__(self, height=None, width=None, max_pixels=None, height_division_factor=1, width_division_factor=1):
self.height = height
self.width = width
self.max_pixels = max_pixels
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
def crop_and_resize(self, image, target_height, target_width):
width, height = image.size
scale = max(target_width / width, target_height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
return image
def get_height_width(self, image):
if self.height is None or self.width is None:
width, height = image.size
if width * height > self.max_pixels:
scale = (width * height / self.max_pixels) ** 0.5
height, width = int(height / scale), int(width / scale)
height = height // self.height_division_factor * self.height_division_factor
width = width // self.width_division_factor * self.width_division_factor
else:
height, width = self.height, self.width
return height, width
def __call__(self, data: Image.Image):
image = self.crop_and_resize(data, *self.get_height_width(data))
return image
class ToList(DataProcessingOperator):
def __call__(self, data):
return [data]
class LoadVideo(DataProcessingOperator):
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
self.num_frames = num_frames
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
# frame_processor is build in the video loader for high efficiency.
self.frame_processor = frame_processor
def get_num_frames(self, reader):
num_frames = self.num_frames
if int(reader.count_frames()) < num_frames:
num_frames = int(reader.count_frames())
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
num_frames -= 1
return num_frames
def __call__(self, data: str):
reader = imageio.get_reader(data)
num_frames = self.get_num_frames(reader)
frames = []
for frame_id in range(num_frames):
frame = reader.get_data(frame_id)
frame = Image.fromarray(frame)
frame = self.frame_processor(frame)
frames.append(frame)
reader.close()
return frames
class SequencialProcess(DataProcessingOperator):
def __init__(self, operator=lambda x: x):
self.operator = operator
def __call__(self, data):
return [self.operator(i) for i in data]
class LoadGIF(DataProcessingOperator):
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
self.num_frames = num_frames
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
# frame_processor is build in the video loader for high efficiency.
self.frame_processor = frame_processor
def get_num_frames(self, path):
num_frames = self.num_frames
images = iio.imread(path, mode="RGB")
if len(images) < num_frames:
num_frames = len(images)
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
num_frames -= 1
return num_frames
def __call__(self, data: str):
num_frames = self.get_num_frames(data)
frames = []
images = iio.imread(data, mode="RGB")
for img in images:
frame = Image.fromarray(img)
frame = self.frame_processor(frame)
frames.append(frame)
if len(frames) >= num_frames:
break
return frames
class RouteByExtensionName(DataProcessingOperator):
def __init__(self, operator_map):
self.operator_map = operator_map
def __call__(self, data: str):
file_ext_name = data.split(".")[-1].lower()
for ext_names, operator in self.operator_map:
if ext_names is None or file_ext_name in ext_names:
return operator(data)
raise ValueError(f"Unsupported file: {data}")
class RouteByType(DataProcessingOperator):
def __init__(self, operator_map):
self.operator_map = operator_map
def __call__(self, data):
for dtype, operator in self.operator_map:
if dtype is None or isinstance(data, dtype):
return operator(data)
raise ValueError(f"Unsupported data: {data}")
class LoadTorchPickle(DataProcessingOperator):
def __init__(self, map_location="cpu"):
self.map_location = map_location
def __call__(self, data):
return torch.load(data, map_location=self.map_location, weights_only=False)
class ToAbsolutePath(DataProcessingOperator):
def __init__(self, base_path=""):
self.base_path = base_path
def __call__(self, data):
return os.path.join(self.base_path, data)
class LoadAudio(DataProcessingOperator):
def __init__(self, sr=16000):
self.sr = sr
def __call__(self, data: str):
import librosa
input_audio, sample_rate = librosa.load(data, sr=self.sr)
return input_audio

View File

@@ -0,0 +1,116 @@
from .operators import *
import torch, json, pandas
class UnifiedDataset(torch.utils.data.Dataset):
def __init__(
self,
base_path=None, metadata_path=None,
repeat=1,
data_file_keys=tuple(),
main_data_operator=lambda x: x,
special_operator_map=None,
max_data_items=None,
):
self.base_path = base_path
self.metadata_path = metadata_path
self.repeat = repeat
self.data_file_keys = data_file_keys
self.main_data_operator = main_data_operator
self.cached_data_operator = LoadTorchPickle()
self.special_operator_map = {} if special_operator_map is None else special_operator_map
self.max_data_items = max_data_items
self.data = []
self.cached_data = []
self.load_from_cache = metadata_path is None
self.load_metadata(metadata_path)
@staticmethod
def default_image_operator(
base_path="",
max_pixels=1920*1080, height=None, width=None,
height_division_factor=16, width_division_factor=16,
):
return RouteByType(operator_map=[
(str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)),
(list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))),
])
@staticmethod
def default_video_operator(
base_path="",
max_pixels=1920*1080, height=None, width=None,
height_division_factor=16, width_division_factor=16,
num_frames=81, time_division_factor=4, time_division_remainder=1,
):
return RouteByType(operator_map=[
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
(("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),
(("gif",), LoadGIF(
num_frames, time_division_factor, time_division_remainder,
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
)),
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
num_frames, time_division_factor, time_division_remainder,
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
)),
])),
])
def search_for_cached_data_files(self, path):
for file_name in os.listdir(path):
subpath = os.path.join(path, file_name)
if os.path.isdir(subpath):
self.search_for_cached_data_files(subpath)
elif subpath.endswith(".pth"):
self.cached_data.append(subpath)
def load_metadata(self, metadata_path):
if metadata_path is None:
print("No metadata_path. Searching for cached data files.")
self.search_for_cached_data_files(self.base_path)
print(f"{len(self.cached_data)} cached data files found.")
elif metadata_path.endswith(".json"):
with open(metadata_path, "r") as f:
metadata = json.load(f)
self.data = metadata
elif metadata_path.endswith(".jsonl"):
metadata = []
with open(metadata_path, 'r') as f:
for line in f:
metadata.append(json.loads(line.strip()))
self.data = metadata
else:
metadata = pandas.read_csv(metadata_path)
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
def __getitem__(self, data_id):
if self.load_from_cache:
data = self.cached_data[data_id % len(self.cached_data)]
data = self.cached_data_operator(data)
else:
data = self.data[data_id % len(self.data)].copy()
for key in self.data_file_keys:
if key in data:
if key in self.special_operator_map:
data[key] = self.special_operator_map[key](data[key])
elif key in self.data_file_keys:
data[key] = self.main_data_operator(data[key])
return data
def __len__(self):
if self.max_data_items is not None:
return self.max_data_items
elif self.load_from_cache:
return len(self.cached_data) * self.repeat
else:
return len(self.data) * self.repeat
def check_data_equal(self, data1, data2):
# Debug only
if len(data1) != len(data2):
return False
for k in data1:
if data1[k] != data2[k]:
return False
return True

View File

@@ -0,0 +1,2 @@
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name
from .npu_compatible_device import IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE

View File

@@ -0,0 +1,107 @@
import importlib
import torch
from typing import Any
def is_torch_npu_available():
return importlib.util.find_spec("torch_npu") is not None
IS_CUDA_AVAILABLE = torch.cuda.is_available()
IS_NPU_AVAILABLE = is_torch_npu_available() and torch.npu.is_available()
if IS_NPU_AVAILABLE:
import torch_npu
torch.npu.config.allow_internal_format = False
def get_device_type() -> str:
"""Get device type based on current machine, currently only support CPU, CUDA, NPU."""
if IS_CUDA_AVAILABLE:
device = "cuda"
elif IS_NPU_AVAILABLE:
device = "npu"
else:
device = "cpu"
return device
def get_torch_device() -> Any:
"""Get torch attribute based on device type, e.g. torch.cuda or torch.npu"""
device_name = get_device_type()
try:
return getattr(torch, device_name)
except AttributeError:
print(f"Device namespace '{device_name}' not found in torch, try to load 'torch.cuda'.")
return torch.cuda
def get_device_id() -> int:
"""Get current device id based on device type."""
return get_torch_device().current_device()
def get_device_name() -> str:
"""Get current device name based on device type."""
return f"{get_device_type()}:{get_device_id()}"
def synchronize() -> None:
"""Execute torch synchronize operation."""
get_torch_device().synchronize()
def empty_cache() -> None:
"""Execute torch empty cache operation."""
get_torch_device().empty_cache()
def get_nccl_backend() -> str:
"""Return distributed communication backend type based on device type."""
if IS_CUDA_AVAILABLE:
return "nccl"
elif IS_NPU_AVAILABLE:
return "hccl"
else:
raise RuntimeError(f"No available distributed communication backend found on device type {get_device_type()}.")
def enable_high_precision_for_bf16():
"""
Set high accumulation dtype for matmul and reduction.
"""
if IS_CUDA_AVAILABLE:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
if IS_NPU_AVAILABLE:
torch.npu.matmul.allow_tf32 = False
torch.npu.matmul.allow_bf16_reduced_precision_reduction = False
def parse_device_type(device):
if isinstance(device, str):
if device.startswith("cuda"):
return "cuda"
elif device.startswith("npu"):
return "npu"
else:
return "cpu"
elif isinstance(device, torch.device):
return device.type
def parse_nccl_backend(device_type):
if device_type == "cuda":
return "nccl"
elif device_type == "npu":
return "hccl"
else:
raise RuntimeError(f"No available distributed communication backend found on device type {device_type}.")
def get_available_device_type():
return get_device_type()

View File

@@ -0,0 +1 @@
from .gradient_checkpoint import gradient_checkpoint_forward

View File

@@ -0,0 +1,34 @@
import torch
def create_custom_forward(module):
def custom_forward(*inputs, **kwargs):
return module(*inputs, **kwargs)
return custom_forward
def gradient_checkpoint_forward(
model,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
*args,
**kwargs,
):
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
model_output = torch.utils.checkpoint.checkpoint(
create_custom_forward(model),
*args,
**kwargs,
use_reentrant=False,
)
elif use_gradient_checkpointing:
model_output = torch.utils.checkpoint.checkpoint(
create_custom_forward(model),
*args,
**kwargs,
use_reentrant=False,
)
else:
model_output = model(*args, **kwargs)
return model_output

View File

@@ -0,0 +1,3 @@
from .file import load_state_dict, hash_state_dict_keys, hash_model_file
from .model import load_model, load_model_with_disk_offload
from .config import ModelConfig

View File

@@ -0,0 +1,118 @@
import torch, glob, os
from typing import Optional, Union
from dataclasses import dataclass
from modelscope import snapshot_download
from huggingface_hub import snapshot_download as hf_snapshot_download
from typing import Optional
@dataclass
class ModelConfig:
path: Union[str, list[str]] = None
model_id: str = None
origin_file_pattern: Union[str, list[str]] = None
download_source: str = None
local_model_path: str = None
skip_download: bool = None
offload_device: Optional[Union[str, torch.device]] = None
offload_dtype: Optional[torch.dtype] = None
onload_device: Optional[Union[str, torch.device]] = None
onload_dtype: Optional[torch.dtype] = None
preparing_device: Optional[Union[str, torch.device]] = None
preparing_dtype: Optional[torch.dtype] = None
computation_device: Optional[Union[str, torch.device]] = None
computation_dtype: Optional[torch.dtype] = None
clear_parameters: bool = False
def check_input(self):
if self.path is None and self.model_id is None:
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
def parse_original_file_pattern(self):
if self.origin_file_pattern is None or self.origin_file_pattern == "":
return "*"
elif self.origin_file_pattern.endswith("/"):
return self.origin_file_pattern + "*"
else:
return self.origin_file_pattern
def parse_download_source(self):
if self.download_source is None:
if os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE') is not None:
return os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE')
else:
return "modelscope"
else:
return self.download_source
def parse_skip_download(self):
if self.skip_download is None:
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None:
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "true":
return True
elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "false":
return False
else:
return False
else:
return self.skip_download
def download(self):
origin_file_pattern = self.parse_original_file_pattern()
downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
download_source = self.parse_download_source()
if download_source.lower() == "modelscope":
snapshot_download(
self.model_id,
local_dir=os.path.join(self.local_model_path, self.model_id),
allow_file_pattern=origin_file_pattern,
ignore_file_pattern=downloaded_files,
local_files_only=False
)
elif download_source.lower() == "huggingface":
hf_snapshot_download(
self.model_id,
local_dir=os.path.join(self.local_model_path, self.model_id),
allow_patterns=origin_file_pattern,
ignore_patterns=downloaded_files,
local_files_only=False
)
else:
raise ValueError("`download_source` should be `modelscope` or `huggingface`.")
def require_downloading(self):
if self.path is not None:
return False
skip_download = self.parse_skip_download()
return not skip_download
def reset_local_model_path(self):
if os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') is not None:
self.local_model_path = os.environ.get('DIFFSYNTH_MODEL_BASE_PATH')
elif self.local_model_path is None:
self.local_model_path = "./models"
def download_if_necessary(self):
self.check_input()
self.reset_local_model_path()
if self.require_downloading():
self.download()
if self.path is None:
if self.origin_file_pattern is None or self.origin_file_pattern == "":
self.path = os.path.join(self.local_model_path, self.model_id)
else:
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
if isinstance(self.path, list) and len(self.path) == 1:
self.path = self.path[0]
def vram_config(self):
return {
"offload_device": self.offload_device,
"offload_dtype": self.offload_dtype,
"onload_device": self.onload_device,
"onload_dtype": self.onload_dtype,
"preparing_device": self.preparing_device,
"preparing_dtype": self.preparing_dtype,
"computation_device": self.computation_device,
"computation_dtype": self.computation_dtype,
}

View File

@@ -0,0 +1,121 @@
from safetensors import safe_open
import torch, hashlib
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
if isinstance(file_path, list):
state_dict = {}
for file_path_ in file_path:
state_dict.update(load_state_dict(file_path_, torch_dtype, device))
return state_dict
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
state_dict = {}
with safe_open(file_path, framework="pt", device=str(device)) as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
state_dict = torch.load(file_path, map_location=device, weights_only=True)
if len(state_dict) == 1:
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
elif "module" in state_dict:
state_dict = state_dict["module"]
elif "model_state" in state_dict:
state_dict = state_dict["model_state"]
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
state_dict[i] = state_dict[i].to(torch_dtype)
return state_dict
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, torch.Tensor):
if with_shape:
shape = "_".join(map(str, list(value.shape)))
keys.append(key + ":" + shape)
keys.append(key)
elif isinstance(value, dict):
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
keys.sort()
keys_str = ",".join(keys)
return keys_str
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()
def load_keys_dict(file_path):
if isinstance(file_path, list):
state_dict = {}
for file_path_ in file_path:
state_dict.update(load_keys_dict(file_path_))
return state_dict
if file_path.endswith(".safetensors"):
return load_keys_dict_from_safetensors(file_path)
else:
return load_keys_dict_from_bin(file_path)
def load_keys_dict_from_safetensors(file_path):
keys_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
keys_dict[k] = f.get_slice(k).get_shape()
return keys_dict
def convert_state_dict_to_keys_dict(state_dict):
keys_dict = {}
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
keys_dict[k] = list(v.shape)
else:
keys_dict[k] = convert_state_dict_to_keys_dict(v)
return keys_dict
def load_keys_dict_from_bin(file_path):
state_dict = load_state_dict_from_bin(file_path)
keys_dict = convert_state_dict_to_keys_dict(state_dict)
return keys_dict
def convert_keys_dict_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, dict):
keys.append(key + "|" + convert_keys_dict_to_single_str(value, with_shape=with_shape))
else:
if with_shape:
shape = "_".join(map(str, list(value)))
keys.append(key + ":" + shape)
keys.append(key)
keys.sort()
keys_str = ",".join(keys)
return keys_str
def hash_model_file(path, with_shape=True):
keys_dict = load_keys_dict(path)
keys_str = convert_keys_dict_to_single_str(keys_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()

View File

@@ -0,0 +1,79 @@
from ..vram.initialization import skip_model_initialization
from ..vram.disk_map import DiskMap
from ..vram.layers import enable_vram_management
from .file import load_state_dict
import torch
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None):
config = {} if config is None else config
# Why do we use `skip_model_initialization`?
# It skips the random initialization of model parameters,
# thereby speeding up model loading and avoiding excessive memory usage.
with skip_model_initialization():
model = model_class(**config)
# What is `module_map`?
# This is a module mapping table for VRAM management.
if module_map is not None:
devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]]
device = [d for d in devices if d != "disk"][0]
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
dtype = [d for d in dtypes if d != "disk"][0]
if vram_config["offload_device"] != "disk":
state_dict = DiskMap(path, device, torch_dtype=dtype)
if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict)
else:
state_dict = {i: state_dict[i] for i in state_dict}
model.load_state_dict(state_dict, assign=True)
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit)
else:
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit)
else:
# Why do we use `DiskMap`?
# Sometimes a model file contains multiple models,
# and DiskMap can load only the parameters of a single model,
# avoiding the need to load all parameters in the file.
if use_disk_map:
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
else:
state_dict = load_state_dict(path, torch_dtype, device)
# Why do we use `state_dict_converter`?
# Some models are saved in complex formats,
# and we need to convert the state dict into the appropriate format.
if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict)
else:
state_dict = {i: state_dict[i] for i in state_dict}
model.load_state_dict(state_dict, assign=True)
# Why do we call `to()`?
# Because some models override the behavior of `to()`,
# especially those from libraries like Transformers.
model = model.to(dtype=torch_dtype, device=device)
if hasattr(model, "eval"):
model = model.eval()
return model
def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None):
if isinstance(path, str):
path = [path]
config = {} if config is None else config
with skip_model_initialization():
model = model_class(**config)
if hasattr(model, "eval"):
model = model.eval()
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": "disk",
"onload_device": "disk",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": device,
"computation_dtype": torch_dtype,
"computation_device": device,
}
enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
return model

View File

@@ -0,0 +1,2 @@
from .initialization import skip_model_initialization
from .layers import *

View File

@@ -0,0 +1,93 @@
from safetensors import safe_open
import torch, os
class SafetensorsCompatibleTensor:
def __init__(self, tensor):
self.tensor = tensor
def get_shape(self):
return list(self.tensor.shape)
class SafetensorsCompatibleBinaryLoader:
def __init__(self, path, device):
print("Detected non-safetensors files, which may cause slower loading. It's recommended to convert it to a safetensors file.")
self.state_dict = torch.load(path, weights_only=True, map_location=device)
def keys(self):
return self.state_dict.keys()
def get_tensor(self, name):
return self.state_dict[name]
def get_slice(self, name):
return SafetensorsCompatibleTensor(self.state_dict[name])
class DiskMap:
def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, buffer_size=10**9):
self.path = path if isinstance(path, list) else [path]
self.device = device
self.torch_dtype = torch_dtype
if os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE') is not None:
self.buffer_size = int(os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE'))
else:
self.buffer_size = buffer_size
self.files = []
self.flush_files()
self.name_map = {}
for file_id, file in enumerate(self.files):
for name in file.keys():
self.name_map[name] = file_id
self.rename_dict = self.fetch_rename_dict(state_dict_converter)
def flush_files(self):
if len(self.files) == 0:
for path in self.path:
if path.endswith(".safetensors"):
self.files.append(safe_open(path, framework="pt", device=str(self.device)))
else:
self.files.append(SafetensorsCompatibleBinaryLoader(path, device=self.device))
else:
for i, path in enumerate(self.path):
if path.endswith(".safetensors"):
self.files[i] = safe_open(path, framework="pt", device=str(self.device))
self.num_params = 0
def __getitem__(self, name):
if self.rename_dict is not None: name = self.rename_dict[name]
file_id = self.name_map[name]
param = self.files[file_id].get_tensor(name)
if self.torch_dtype is not None and isinstance(param, torch.Tensor):
param = param.to(self.torch_dtype)
if isinstance(param, torch.Tensor) and param.device == "cpu":
param = param.clone()
if isinstance(param, torch.Tensor):
self.num_params += param.numel()
if self.num_params > self.buffer_size:
self.flush_files()
return param
def fetch_rename_dict(self, state_dict_converter):
if state_dict_converter is None:
return None
state_dict = {}
for file in self.files:
for name in file.keys():
state_dict[name] = name
state_dict = state_dict_converter(state_dict)
return state_dict
def __iter__(self):
if self.rename_dict is not None:
return self.rename_dict.__iter__()
else:
return self.name_map.__iter__()
def __contains__(self, x):
if self.rename_dict is not None:
return x in self.rename_dict
else:
return x in self.name_map

View File

@@ -0,0 +1,21 @@
import torch
from contextlib import contextmanager
@contextmanager
def skip_model_initialization(device=torch.device("meta")):
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
old_register_parameter = torch.nn.Module.register_parameter
torch.nn.Module.register_parameter = register_empty_parameter
try:
yield
finally:
torch.nn.Module.register_parameter = old_register_parameter

View File

@@ -0,0 +1,479 @@
import torch, copy
from typing import Union
from .initialization import skip_model_initialization
from .disk_map import DiskMap
from ..device import parse_device_type, get_device_name, IS_NPU_AVAILABLE
class AutoTorchModule(torch.nn.Module):
def __init__(
self,
offload_dtype: torch.dtype = None,
offload_device: Union[str, torch.device] = None,
onload_dtype: torch.dtype = None,
onload_device: Union[str, torch.device] = None,
preparing_dtype: torch.dtype = None,
preparing_device: Union[str, torch.device] = None,
computation_dtype: torch.dtype = None,
computation_device: Union[str, torch.device] = None,
vram_limit: float = None,
):
super().__init__()
self.set_dtype_and_device(
offload_dtype,
offload_device,
onload_dtype,
onload_device,
preparing_dtype,
preparing_device,
computation_dtype,
computation_device,
vram_limit,
)
self.state = 0
self.name = ""
self.computation_device_type = parse_device_type(self.computation_device)
def set_dtype_and_device(
self,
offload_dtype: torch.dtype = None,
offload_device: Union[str, torch.device] = None,
onload_dtype: torch.dtype = None,
onload_device: Union[str, torch.device] = None,
preparing_dtype: torch.dtype = None,
preparing_device: Union[str, torch.device] = None,
computation_dtype: torch.dtype = None,
computation_device: Union[str, torch.device] = None,
vram_limit: float = None,
):
self.offload_dtype = offload_dtype or computation_dtype
self.offload_device = offload_device or computation_dtype
self.onload_dtype = onload_dtype or computation_dtype
self.onload_device = onload_device or computation_dtype
self.preparing_dtype = preparing_dtype or computation_dtype
self.preparing_device = preparing_device or computation_dtype
self.computation_dtype = computation_dtype
self.computation_device = computation_device
self.vram_limit = vram_limit
def cast_to(self, weight, dtype, device):
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight)
return r
def check_free_vram(self):
device = self.computation_device if not IS_NPU_AVAILABLE else get_device_name()
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
return used_memory < self.vram_limit
def offload(self):
if self.state != 0:
self.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
if self.state != 1:
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def param_name(self, name):
if self.name == "":
return name
else:
return self.name + "." + name
class AutoWrappedModule(AutoTorchModule):
def __init__(
self,
module: torch.nn.Module,
offload_dtype: torch.dtype = None,
offload_device: Union[str, torch.device] = None,
onload_dtype: torch.dtype = None,
onload_device: Union[str, torch.device] = None,
preparing_dtype: torch.dtype = None,
preparing_device: Union[str, torch.device] = None,
computation_dtype: torch.dtype = None,
computation_device: Union[str, torch.device] = None,
vram_limit: float = None,
name: str = "",
disk_map: DiskMap = None,
**kwargs
):
super().__init__(
offload_dtype,
offload_device,
onload_dtype,
onload_device,
preparing_dtype,
preparing_device,
computation_dtype,
computation_device,
vram_limit,
)
self.module = module
if offload_dtype == "disk":
self.name = name
self.disk_map = disk_map
self.required_params = [name for name, _ in self.module.named_parameters()]
self.disk_offload = True
else:
self.disk_offload = False
def load_from_disk(self, torch_dtype, device, copy_module=False):
if copy_module:
module = copy.deepcopy(self.module)
else:
module = self.module
state_dict = {}
for name in self.required_params:
param = self.disk_map[self.param_name(name)]
param = param.to(dtype=torch_dtype, device=device)
state_dict[name] = param
module.load_state_dict(state_dict, assign=True)
module.to(dtype=torch_dtype, device=device)
return module
def offload_to_disk(self, model: torch.nn.Module):
for buf in model.buffers():
# If there are some parameters are registed in buffers (not in state dict),
# We cannot offload the model.
for children in model.children():
self.offload_to_disk(children)
break
else:
model.to("meta")
def offload(self):
# offload / onload / preparing -> offload
if self.state != 0:
if self.disk_offload:
self.offload_to_disk(self.module)
else:
self.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
# offload / onload / preparing -> onload
if self.state < 1:
if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
self.load_from_disk(self.onload_dtype, self.onload_device)
elif self.onload_device != "disk":
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def preparing(self):
# onload / preparing -> preparing
if self.state != 2:
if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
self.load_from_disk(self.preparing_dtype, self.preparing_device)
elif self.preparing_device != "disk":
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
self.state = 2
def cast_to(self, module, dtype, device):
return copy.deepcopy(module).to(dtype=dtype, device=device)
def computation(self):
# onload / preparing -> computation (temporary)
if self.state == 2:
torch_dtype, device = self.preparing_dtype, self.preparing_device
else:
torch_dtype, device = self.onload_dtype, self.onload_device
if torch_dtype == self.computation_dtype and device == self.computation_device:
module = self.module
elif self.disk_offload and device == "disk":
module = self.load_from_disk(self.computation_dtype, self.computation_device, copy_module=True)
else:
module = self.cast_to(self.module, dtype=self.computation_dtype, device=self.computation_device)
return module
def forward(self, *args, **kwargs):
if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
self.preparing()
module = self.computation()
return module(*args, **kwargs)
def __getattr__(self, name):
if name in self.__dict__ or name == "module":
return super().__getattr__(name)
else:
return getattr(self.module, name)
class AutoWrappedNonRecurseModule(AutoWrappedModule):
def __init__(
self,
module: torch.nn.Module,
offload_dtype: torch.dtype = None,
offload_device: Union[str, torch.device] = None,
onload_dtype: torch.dtype = None,
onload_device: Union[str, torch.device] = None,
preparing_dtype: torch.dtype = None,
preparing_device: Union[str, torch.device] = None,
computation_dtype: torch.dtype = None,
computation_device: Union[str, torch.device] = None,
vram_limit: float = None,
name: str = "",
disk_map: DiskMap = None,
**kwargs
):
super().__init__(
module,
offload_dtype,
offload_device,
onload_dtype,
onload_device,
preparing_dtype,
preparing_device,
computation_dtype,
computation_device,
vram_limit,
name,
disk_map,
**kwargs
)
if self.disk_offload:
self.required_params = [name for name, _ in self.module.named_parameters(recurse=False)]
def load_from_disk(self, torch_dtype, device, copy_module=False):
if copy_module:
module = copy.deepcopy(self.module)
else:
module = self.module
state_dict = {}
for name in self.required_params:
param = self.disk_map[self.param_name(name)]
param = param.to(dtype=torch_dtype, device=device)
state_dict[name] = param
module.load_state_dict(state_dict, assign=True, strict=False)
return module
def offload_to_disk(self, model: torch.nn.Module):
for name in self.required_params:
getattr(self, name).to("meta")
def cast_to(self, module, dtype, device):
# Parameter casting is implemented in the model architecture.
return module
def __getattr__(self, name):
if name in self.__dict__ or name == "module":
return super().__getattr__(name)
else:
return getattr(self.module, name)
class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
def __init__(
self,
module: torch.nn.Linear,
offload_dtype: torch.dtype = None,
offload_device: Union[str, torch.device] = None,
onload_dtype: torch.dtype = None,
onload_device: Union[str, torch.device] = None,
preparing_dtype: torch.dtype = None,
preparing_device: Union[str, torch.device] = None,
computation_dtype: torch.dtype = None,
computation_device: Union[str, torch.device] = None,
vram_limit: float = None,
name: str = "",
disk_map: DiskMap = None,
**kwargs
):
with skip_model_initialization():
super().__init__(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
)
self.set_dtype_and_device(
offload_dtype,
offload_device,
onload_dtype,
onload_device,
preparing_dtype,
preparing_device,
computation_dtype,
computation_device,
vram_limit,
)
self.weight = module.weight
self.bias = module.bias
self.state = 0
self.name = name
self.lora_A_weights = []
self.lora_B_weights = []
self.lora_merger = None
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
self.computation_device_type = parse_device_type(self.computation_device)
if offload_dtype == "disk":
self.disk_map = disk_map
self.disk_offload = True
else:
self.disk_offload = False
def fp8_linear(
self,
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor = None,
) -> torch.Tensor:
device = input.device
origin_dtype = input.dtype
origin_shape = input.shape
input = input.reshape(-1, origin_shape[-1])
x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
fp8_max = 448.0
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
# To avoid overflow and ensure numerical compatibility during FP8 computation,
# we scale down the input by 2.0 in advance.
# This scaling will be compensated later during the final result scaling.
if self.computation_dtype == torch.float8_e4m3fnuz:
fp8_max = fp8_max / 2.0
scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
input = input / (scale_a + 1e-8)
input = input.to(self.computation_dtype)
weight = weight.to(self.computation_dtype)
bias = bias.to(torch.bfloat16)
result = torch._scaled_mm(
input,
weight.T,
scale_a=scale_a,
scale_b=scale_b.T,
bias=bias,
out_dtype=origin_dtype,
)
new_shape = origin_shape[:-1] + result.shape[-1:]
result = result.reshape(new_shape)
return result
def load_from_disk(self, torch_dtype, device, assign=True):
weight = self.disk_map[self.name + ".weight"].to(dtype=torch_dtype, device=device)
bias = None if self.bias is None else self.disk_map[self.name + ".bias"].to(dtype=torch_dtype, device=device)
if assign:
state_dict = {"weight": weight}
if bias is not None: state_dict["bias"] = bias
self.load_state_dict(state_dict, assign=True)
return weight, bias
def offload(self):
# offload / onload / preparing -> offload
if self.state != 0:
if self.disk_offload:
self.to("meta")
else:
self.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
# offload / onload / preparing -> onload
if self.state < 1:
if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
self.load_from_disk(self.onload_dtype, self.onload_device)
elif self.onload_device != "disk":
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def preparing(self):
# onload / preparing -> preparing
if self.state != 2:
if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
self.load_from_disk(self.preparing_dtype, self.preparing_device)
elif self.preparing_device != "disk":
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
self.state = 2
def computation(self):
# onload / preparing -> computation (temporary)
if self.state == 2:
torch_dtype, device = self.preparing_dtype, self.preparing_device
else:
torch_dtype, device = self.onload_dtype, self.onload_device
if torch_dtype == self.computation_dtype and device == self.computation_device:
weight, bias = self.weight, self.bias
elif self.disk_offload and device == "disk":
weight, bias = self.load_from_disk(self.computation_dtype, self.computation_device, assign=False)
else:
weight = self.cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else self.cast_to(self.bias, self.computation_dtype, self.computation_device)
return weight, bias
def linear_forward(self, x, weight, bias):
if self.enable_fp8:
out = self.fp8_linear(x, weight, bias)
else:
out = torch.nn.functional.linear(x, weight, bias)
return out
def lora_forward(self, x, out):
if self.lora_merger is None:
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
out = out + x @ lora_A.T @ lora_B.T
else:
lora_output = []
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
lora_output.append(x @ lora_A.T @ lora_B.T)
lora_output = torch.stack(lora_output)
out = self.lora_merger(out, lora_output)
return out
def forward(self, x, *args, **kwargs):
if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
self.preparing()
weight, bias = self.computation()
out = self.linear_forward(x, weight, bias)
if len(self.lora_A_weights) > 0:
out = self.lora_forward(x, out)
return out
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, name_prefix="", disk_map=None, **kwargs):
if isinstance(model, AutoWrappedNonRecurseModule):
model = model.module
for name, module in model.named_children():
layer_name = name if name_prefix == "" else name_prefix + "." + name
for source_module, target_module in module_map.items():
if isinstance(module, source_module):
module_ = target_module(module, **vram_config, vram_limit=vram_limit, name=layer_name, disk_map=disk_map, **kwargs)
if isinstance(module_, AutoWrappedNonRecurseModule):
enable_vram_management_recursively(module_, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
setattr(model, name, module_)
break
else:
enable_vram_management_recursively(module, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
def fill_vram_config(model, vram_config):
vram_config_ = vram_config.copy()
vram_config_["onload_dtype"] = vram_config["computation_dtype"]
vram_config_["onload_device"] = vram_config["computation_device"]
vram_config_["preparing_dtype"] = vram_config["computation_dtype"]
vram_config_["preparing_device"] = vram_config["computation_device"]
for k in vram_config:
if vram_config[k] != vram_config_[k]:
print(f"No fine-grained VRAM configuration is provided for {model.__class__.__name__}. [`onload`, `preparing`, `computation`] will be the same state. `vram_config` is set to {vram_config_}")
break
return vram_config_
def enable_vram_management(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, disk_map=None, **kwargs):
for source_module, target_module in module_map.items():
# If no fine-grained VRAM configuration is provided, the entire model will be managed uniformly.
if isinstance(model, source_module):
vram_config = fill_vram_config(model, vram_config)
model = target_module(model, **vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
break
else:
enable_vram_management_recursively(model, module_map, vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
# `vram_management_enabled` is a flag that allows the pipeline to determine whether VRAM management is enabled.
model.vram_management_enabled = True
return model

View File

@@ -1 +0,0 @@
from .video import VideoData, save_video, save_frames

View File

@@ -0,0 +1,6 @@
from .flow_match import FlowMatchScheduler
from .training_module import DiffusionTrainingModule
from .logger import ModelLogger
from .runner import launch_training_task, launch_data_process_task
from .parsers import *
from .loss import *

View File

@@ -0,0 +1,451 @@
from PIL import Image
import torch
import numpy as np
from einops import repeat, reduce
from typing import Union
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
from ..core.device.npu_compatible_device import get_device_type
from ..utils.lora import GeneralLoRALoader
from ..models.model_loader import ModelPool
from ..utils.controlnet import ControlNetInput
from ..core.device import get_device_name, IS_NPU_AVAILABLE
class PipelineUnit:
def __init__(
self,
seperate_cfg: bool = False,
take_over: bool = False,
input_params: tuple[str] = None,
output_params: tuple[str] = None,
input_params_posi: dict[str, str] = None,
input_params_nega: dict[str, str] = None,
onload_model_names: tuple[str] = None
):
self.seperate_cfg = seperate_cfg
self.take_over = take_over
self.input_params = input_params
self.output_params = output_params
self.input_params_posi = input_params_posi
self.input_params_nega = input_params_nega
self.onload_model_names = onload_model_names
def fetch_input_params(self):
params = []
if self.input_params is not None:
for param in self.input_params:
params.append(param)
if self.input_params_posi is not None:
for _, param in self.input_params_posi.items():
params.append(param)
if self.input_params_nega is not None:
for _, param in self.input_params_nega.items():
params.append(param)
params = sorted(list(set(params)))
return params
def fetch_output_params(self):
params = []
if self.output_params is not None:
for param in self.output_params:
params.append(param)
return params
def process(self, pipe, **kwargs) -> dict:
return {}
def post_process(self, pipe, **kwargs) -> dict:
return {}
class BasePipeline(torch.nn.Module):
def __init__(
self,
device=get_device_type(), torch_dtype=torch.float16,
height_division_factor=64, width_division_factor=64,
time_division_factor=None, time_division_remainder=None,
):
super().__init__()
# The device and torch_dtype is used for the storage of intermediate variables, not models.
self.device = device
self.torch_dtype = torch_dtype
self.device_type = parse_device_type(device)
# The following parameters are used for shape check.
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
# VRAM management
self.vram_management_enabled = False
# Pipeline Unit Runner
self.unit_runner = PipelineUnitRunner()
# LoRA Loader
self.lora_loader = GeneralLoRALoader
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None:
self.device = device
if dtype is not None:
self.torch_dtype = dtype
super().to(*args, **kwargs)
return self
def check_resize_height_width(self, height, width, num_frames=None):
# Shape check
if height % self.height_division_factor != 0:
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
if width % self.width_division_factor != 0:
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
if num_frames is None:
return height, width
else:
if num_frames % self.time_division_factor != self.time_division_remainder:
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
return height, width, num_frames
def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
# Transform a PIL.Image to torch.Tensor
image = torch.Tensor(np.array(image, dtype=np.float32))
image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
image = image * ((max_value - min_value) / 255) + min_value
image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
return image
def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
# Transform a list of PIL.Image to torch.Tensor
video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
video = torch.stack(video, dim=pattern.index("T") // 2)
return video
def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
# Transform a torch.Tensor to PIL.Image
if pattern != "H W C":
vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
image = image.to(device="cpu", dtype=torch.uint8)
image = Image.fromarray(image.numpy())
return image
def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
# Transform a torch.Tensor to list of PIL.Image
if pattern != "T H W C":
vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
return video
def load_models_to_device(self, model_names):
if self.vram_management_enabled:
# offload models
for name, model in self.named_children():
if name not in model_names:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
if hasattr(model, "offload"):
model.offload()
else:
for module in model.modules():
if hasattr(module, "offload"):
module.offload()
getattr(torch, self.device_type).empty_cache()
# onload models
for name, model in self.named_children():
if name in model_names:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
if hasattr(model, "onload"):
model.onload()
else:
for module in model.modules():
if hasattr(module, "onload"):
module.onload()
def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
# Initialize Gaussian noise
generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
return noise
def get_vram(self):
device = self.device if not IS_NPU_AVAILABLE else get_device_name()
return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)
def get_module(self, model, name):
if "." in name:
name, suffix = name[:name.index(".")], name[name.index(".") + 1:]
if name.isdigit():
return self.get_module(model[int(name)], suffix)
else:
return self.get_module(getattr(model, name), suffix)
else:
return getattr(model, name)
def freeze_except(self, model_names):
self.eval()
self.requires_grad_(False)
for name in model_names:
module = self.get_module(self, name)
if module is None:
print(f"No {name} models in the pipeline. We cannot enable training on the model. If this occurs during the data processing stage, it is normal.")
continue
module.train()
module.requires_grad_(True)
def blend_with_mask(self, base, addition, mask):
return base * (1 - mask) + addition * mask
def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs):
timestep = scheduler.timesteps[progress_id]
if inpaint_mask is not None:
noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents)
noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask)
latents_next = scheduler.step(noise_pred, timestep, latents)
return latents_next
def split_pipeline_units(self, model_names: list[str]):
return PipelineUnitGraph().split_pipeline_units(self.units, model_names)
def flush_vram_management_device(self, device):
for module in self.modules():
if isinstance(module, AutoTorchModule):
module.offload_device = device
module.onload_device = device
module.preparing_device = device
module.computation_device = device
def load_lora(
self,
module: torch.nn.Module,
lora_config: Union[ModelConfig, str] = None,
alpha=1,
hotload=None,
state_dict=None,
verbose=1,
):
if state_dict is None:
if isinstance(lora_config, str):
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
else:
lora_config.download_if_necessary()
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
else:
lora = state_dict
lora_loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device)
lora = lora_loader.convert_state_dict(lora)
if hotload is None:
hotload = hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")
if hotload:
if not (hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")):
raise ValueError("VRAM Management is not enabled. LoRA hotloading is not supported.")
updated_num = 0
for _, module in module.named_modules():
if isinstance(module, AutoWrappedLinear):
name = module.name
lora_a_name = f'{name}.lora_A.weight'
lora_b_name = f'{name}.lora_B.weight'
if lora_a_name in lora and lora_b_name in lora:
updated_num += 1
module.lora_A_weights.append(lora[lora_a_name] * alpha)
module.lora_B_weights.append(lora[lora_b_name])
if verbose >= 1:
print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.")
else:
lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha)
def clear_lora(self, verbose=1):
cleared_num = 0
for name, module in self.named_modules():
if isinstance(module, AutoWrappedLinear):
if hasattr(module, "lora_A_weights"):
if len(module.lora_A_weights) > 0:
cleared_num += 1
module.lora_A_weights.clear()
if hasattr(module, "lora_B_weights"):
module.lora_B_weights.clear()
if verbose >= 1:
print(f"{cleared_num} LoRA layers are cleared.")
def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None):
model_pool = ModelPool()
for model_config in model_configs:
model_config.download_if_necessary()
vram_config = model_config.vram_config()
vram_config["computation_dtype"] = vram_config["computation_dtype"] or self.torch_dtype
vram_config["computation_device"] = vram_config["computation_device"] or self.device
model_pool.auto_load_model(
model_config.path,
vram_config=vram_config,
vram_limit=vram_limit,
clear_parameters=model_config.clear_parameters,
)
return model_pool
def check_vram_management_state(self):
vram_management_enabled = False
for module in self.children():
if hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled"):
vram_management_enabled = True
return vram_management_enabled
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
if cfg_scale != 1.0:
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
return noise_pred
class PipelineUnitGraph:
def __init__(self):
pass
def build_edges(self, units: list[PipelineUnit]):
# Establish dependencies between units
# to search for subsequent related computation units.
last_compute_unit_id = {}
edges = []
for unit_id, unit in enumerate(units):
for input_param in unit.fetch_input_params():
if input_param in last_compute_unit_id:
edges.append((last_compute_unit_id[input_param], unit_id))
for output_param in unit.fetch_output_params():
last_compute_unit_id[output_param] = unit_id
return edges
def build_chains(self, units: list[PipelineUnit]):
# Establish updating chains for each variable
# to track their computation process.
params = sum([unit.fetch_input_params() + unit.fetch_output_params() for unit in units], [])
params = sorted(list(set(params)))
chains = {param: [] for param in params}
for unit_id, unit in enumerate(units):
for param in unit.fetch_output_params():
chains[param].append(unit_id)
return chains
def search_direct_unit_ids(self, units: list[PipelineUnit], model_names: list[str]):
# Search for units that directly participate in the model's computation.
related_unit_ids = []
for unit_id, unit in enumerate(units):
for model_name in model_names:
if unit.onload_model_names is not None and model_name in unit.onload_model_names:
related_unit_ids.append(unit_id)
break
return related_unit_ids
def search_related_unit_ids(self, edges, start_unit_ids, direction="target"):
# Search for subsequent related computation units.
related_unit_ids = [unit_id for unit_id in start_unit_ids]
while True:
neighbors = []
for source, target in edges:
if direction == "target" and source in related_unit_ids and target not in related_unit_ids:
neighbors.append(target)
elif direction == "source" and source not in related_unit_ids and target in related_unit_ids:
neighbors.append(source)
neighbors = sorted(list(set(neighbors)))
if len(neighbors) == 0:
break
else:
related_unit_ids.extend(neighbors)
related_unit_ids = sorted(list(set(related_unit_ids)))
return related_unit_ids
def search_updating_unit_ids(self, units: list[PipelineUnit], chains, related_unit_ids):
# If the input parameters of this subgraph are updated outside the subgraph,
# search for the units where these updates occur.
first_compute_unit_id = {}
for unit_id in related_unit_ids:
for param in units[unit_id].fetch_input_params():
if param not in first_compute_unit_id:
first_compute_unit_id[param] = unit_id
updating_unit_ids = []
for param in first_compute_unit_id:
unit_id = first_compute_unit_id[param]
chain = chains[param]
if unit_id in chain and chain.index(unit_id) != len(chain) - 1:
for unit_id_ in chain[chain.index(unit_id) + 1:]:
if unit_id_ not in related_unit_ids:
updating_unit_ids.append(unit_id_)
related_unit_ids.extend(updating_unit_ids)
related_unit_ids = sorted(list(set(related_unit_ids)))
return related_unit_ids
def split_pipeline_units(self, units: list[PipelineUnit], model_names: list[str]):
# Split the computation graph,
# separating all model-related computations.
related_unit_ids = self.search_direct_unit_ids(units, model_names)
edges = self.build_edges(units)
chains = self.build_chains(units)
while True:
num_related_unit_ids = len(related_unit_ids)
related_unit_ids = self.search_related_unit_ids(edges, related_unit_ids, "target")
related_unit_ids = self.search_updating_unit_ids(units, chains, related_unit_ids)
if len(related_unit_ids) == num_related_unit_ids:
break
else:
num_related_unit_ids = len(related_unit_ids)
related_units = [units[i] for i in related_unit_ids]
unrelated_units = [units[i] for i in range(len(units)) if i not in related_unit_ids]
return related_units, unrelated_units
class PipelineUnitRunner:
def __init__(self):
pass
def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]:
if unit.take_over:
# Let the pipeline unit take over this function.
inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega)
elif unit.seperate_cfg:
# Positive side
processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
if unit.input_params is not None:
for name in unit.input_params:
processor_inputs[name] = inputs_shared.get(name)
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_posi.update(processor_outputs)
# Negative side
if inputs_shared["cfg_scale"] != 1:
processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
if unit.input_params is not None:
for name in unit.input_params:
processor_inputs[name] = inputs_shared.get(name)
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_nega.update(processor_outputs)
else:
inputs_nega.update(processor_outputs)
else:
processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params}
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_shared.update(processor_outputs)
return inputs_shared, inputs_posi, inputs_nega

View File

@@ -0,0 +1,184 @@
import torch, math
from typing_extensions import Literal
class FlowMatchScheduler():
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image"] = "FLUX.1"):
self.set_timesteps_fn = {
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
"Wan": FlowMatchScheduler.set_timesteps_wan,
"Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image,
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
}.get(template, FlowMatchScheduler.set_timesteps_flux)
self.num_train_timesteps = 1000
@staticmethod
def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None):
sigma_min = 0.003/1.002
sigma_max = 1.0
shift = 3 if shift is None else shift
num_train_timesteps = 1000
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def set_timesteps_wan(num_inference_steps=100, denoising_strength=1.0, shift=None):
sigma_min = 0.0
sigma_max = 1.0
shift = 5 if shift is None else shift
num_train_timesteps = 1000
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def _calculate_shift_qwen_image(image_seq_len, base_seq_len=256, max_seq_len=8192, base_shift=0.5, max_shift=0.9):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
@staticmethod
def set_timesteps_qwen_image(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
sigma_min = 0.0
sigma_max = 1.0
num_train_timesteps = 1000
shift_terminal = 0.02
# Sigmas
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
# Mu
if exponential_shift_mu is not None:
mu = exponential_shift_mu
elif dynamic_shift_len is not None:
mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len)
else:
mu = 0.8
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
# Shift terminal
one_minus_z = 1 - sigmas
scale_factor = one_minus_z[-1] / (1 - shift_terminal)
sigmas = 1 - (one_minus_z / scale_factor)
# Timesteps
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def compute_empirical_mu(image_seq_len, num_steps):
a1, b1 = 8.73809524e-05, 1.89833333
a2, b2 = 0.00016927, 0.45666666
if image_seq_len > 4300:
mu = a2 * image_seq_len + b2
return float(mu)
m_200 = a2 * image_seq_len + b2
m_10 = a1 * image_seq_len + b1
a = (m_200 - m_10) / 190.0
b = m_200 - 200.0 * a
mu = a * num_steps + b
return float(mu)
@staticmethod
def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None):
sigma_min = 1 / num_inference_steps
sigma_max = 1.0
num_train_timesteps = 1000
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
if dynamic_shift_len is None:
# If you ask me why I set mu=0.8,
# I can only say that it yields better training results.
mu = 0.8
else:
mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps)
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
sigma_min = 0.0
sigma_max = 1.0
shift = 3 if shift is None else shift
num_train_timesteps = 1000
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_train_timesteps
if target_timesteps is not None:
target_timesteps = target_timesteps.to(dtype=timesteps.dtype, device=timesteps.device)
for timestep in target_timesteps:
timestep_id = torch.argmin((timesteps - timestep).abs())
timesteps[timestep_id] = timestep
return sigmas, timesteps
def set_training_weight(self):
steps = 1000
x = self.timesteps
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
y_shifted = y - y.min()
bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
if len(self.timesteps) != 1000:
# This is an empirical formula.
bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
self.linear_timesteps_weights = bsmntw_weighing
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
self.sigmas, self.timesteps = self.set_timesteps_fn(
num_inference_steps=num_inference_steps,
denoising_strength=denoising_strength,
**kwargs,
)
if training:
self.set_training_weight()
self.training = True
else:
self.training = False
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
if to_final or timestep_id + 1 >= len(self.timesteps):
sigma_ = 0
else:
sigma_ = self.sigmas[timestep_id + 1]
prev_sample = sample + model_output * (sigma_ - sigma)
return prev_sample
def return_to_timestep(self, timestep, sample, sample_stablized):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
model_output = (sample - sample_stablized) / sigma
return model_output
def add_noise(self, original_samples, noise, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
sample = (1 - sigma) * original_samples + sigma * noise
return sample
def training_target(self, sample, noise, timestep):
target = noise - sample
return target
def training_weight(self, timestep):
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
weights = self.linear_timesteps_weights[timestep_id]
return weights

View File

@@ -0,0 +1,43 @@
import os, torch
from accelerate import Accelerator
class ModelLogger:
def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x):
self.output_path = output_path
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
self.state_dict_converter = state_dict_converter
self.num_steps = 0
def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None, **kwargs):
self.num_steps += 1
if save_steps is not None and self.num_steps % save_steps == 0:
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True)
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
accelerator.save(state_dict, path, safe_serialization=True)
def on_training_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None):
if save_steps is not None and self.num_steps % save_steps != 0:
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True)
path = os.path.join(self.output_path, file_name)
accelerator.save(state_dict, path, safe_serialization=True)

119
diffsynth/diffusion/loss.py Normal file
View File

@@ -0,0 +1,119 @@
from .base_pipeline import BasePipeline
import torch
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
noise = torch.randn_like(inputs["input_latents"])
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * pipe.scheduler.training_weight(timestep)
return loss
def DirectDistillLoss(pipe: BasePipeline, **inputs):
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
pipe.scheduler.training = True
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)
inputs["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs)
loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float())
return loss
class TrajectoryImitationLoss(torch.nn.Module):
def __init__(self):
super().__init__()
self.initialized = False
def initialize(self, device):
import lpips # TODO: remove it
self.loss_fn = lpips.LPIPS(net='alex').to(device)
self.initialized = True
def fetch_trajectory(self, pipe: BasePipeline, timesteps_student, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
trajectory = [inputs_shared["latents"].clone()]
pipe.scheduler.set_timesteps(num_inference_steps, target_timesteps=timesteps_student)
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
noise_pred = pipe.cfg_guided_model_fn(
pipe.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
trajectory.append(inputs_shared["latents"].clone())
return pipe.scheduler.timesteps, trajectory
def align_trajectory(self, pipe: BasePipeline, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
loss = 0
pipe.scheduler.set_timesteps(num_inference_steps, training=True)
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
progress_id_teacher = torch.argmin((timesteps_teacher - timestep).abs())
inputs_shared["latents"] = trajectory_teacher[progress_id_teacher]
noise_pred = pipe.cfg_guided_model_fn(
pipe.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
sigma = pipe.scheduler.sigmas[progress_id]
sigma_ = 0 if progress_id + 1 >= len(pipe.scheduler.timesteps) else pipe.scheduler.sigmas[progress_id + 1]
if progress_id + 1 >= len(pipe.scheduler.timesteps):
latents_ = trajectory_teacher[-1]
else:
progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs())
latents_ = trajectory_teacher[progress_id_teacher]
target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma)
loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep)
return loss
def compute_regularization(self, pipe: BasePipeline, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
inputs_shared["latents"] = trajectory_teacher[0]
pipe.scheduler.set_timesteps(num_inference_steps)
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
noise_pred = pipe.cfg_guided_model_fn(
pipe.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
image_pred = pipe.vae_decoder(inputs_shared["latents"])
image_real = pipe.vae_decoder(trajectory_teacher[-1])
loss = self.loss_fn(image_pred.float(), image_real.float())
return loss
def forward(self, pipe: BasePipeline, inputs_shared, inputs_posi, inputs_nega):
if not self.initialized:
self.initialize(pipe.device)
with torch.no_grad():
pipe.scheduler.set_timesteps(8)
timesteps_teacher, trajectory_teacher = self.fetch_trajectory(inputs_shared["teacher"], pipe.scheduler.timesteps, inputs_shared, inputs_posi, inputs_nega, 50, 2)
timesteps_teacher = timesteps_teacher.to(dtype=pipe.torch_dtype, device=pipe.device)
loss_1 = self.align_trajectory(pipe, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
loss_2 = self.compute_regularization(pipe, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
loss = loss_1 + loss_2
return loss

View File

@@ -0,0 +1,70 @@
import argparse
def add_dataset_base_config(parser: argparse.ArgumentParser):
parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.")
return parser
def add_image_size_config(parser: argparse.ArgumentParser):
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
return parser
def add_video_size_config(parser: argparse.ArgumentParser):
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.")
return parser
def add_model_config(parser: argparse.ArgumentParser):
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
parser.add_argument("--fp8_models", default=None, help="Models with FP8 precision, comma-separated.")
parser.add_argument("--offload_models", default=None, help="Models with offload, comma-separated. Only used in splited training.")
return parser
def add_training_config(parser: argparse.ArgumentParser):
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
return parser
def add_output_config(parser: argparse.ArgumentParser):
parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
return parser
def add_lora_config(parser: argparse.ArgumentParser):
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.")
parser.add_argument("--preset_lora_path", type=str, default=None, help="Path to the preset LoRA checkpoint. If provided, this LoRA will be fused to the base model.")
parser.add_argument("--preset_lora_model", type=str, default=None, help="Which model the preset LoRA is fused to.")
return parser
def add_gradient_config(parser: argparse.ArgumentParser):
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
return parser
def add_general_config(parser: argparse.ArgumentParser):
parser = add_dataset_base_config(parser)
parser = add_model_config(parser)
parser = add_training_config(parser)
parser = add_output_config(parser)
parser = add_lora_config(parser)
parser = add_gradient_config(parser)
return parser

View File

@@ -0,0 +1,71 @@
import os, torch
from tqdm import tqdm
from accelerate import Accelerator
from .training_module import DiffusionTrainingModule
from .logger import ModelLogger
def launch_training_task(
accelerator: Accelerator,
dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule,
model_logger: ModelLogger,
learning_rate: float = 1e-5,
weight_decay: float = 1e-2,
num_workers: int = 1,
save_steps: int = None,
num_epochs: int = 1,
args = None,
):
if args is not None:
learning_rate = args.learning_rate
weight_decay = args.weight_decay
num_workers = args.dataset_num_workers
save_steps = args.save_steps
num_epochs = args.num_epochs
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
for epoch_id in range(num_epochs):
for data in tqdm(dataloader):
with accelerator.accumulate(model):
optimizer.zero_grad()
if dataset.load_from_cache:
loss = model({}, inputs=data)
else:
loss = model(data)
accelerator.backward(loss)
optimizer.step()
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
scheduler.step()
if save_steps is None:
model_logger.on_epoch_end(accelerator, model, epoch_id)
model_logger.on_training_end(accelerator, model, save_steps)
def launch_data_process_task(
accelerator: Accelerator,
dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule,
model_logger: ModelLogger,
num_workers: int = 8,
args = None,
):
if args is not None:
num_workers = args.dataset_num_workers
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
model, dataloader = accelerator.prepare(model, dataloader)
for data_id, data in enumerate(tqdm(dataloader)):
with accelerator.accumulate(model):
with torch.no_grad():
folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
os.makedirs(folder, exist_ok=True)
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
data = model(data)
torch.save(data, save_path)

View File

@@ -0,0 +1,263 @@
import torch, json, os
from ..core import ModelConfig, load_state_dict
from ..utils.controlnet import ControlNetInput
from peft import LoraConfig, inject_adapter_in_model
class DiffusionTrainingModule(torch.nn.Module):
def __init__(self):
super().__init__()
def to(self, *args, **kwargs):
for name, model in self.named_children():
model.to(*args, **kwargs)
return self
def trainable_modules(self):
trainable_modules = filter(lambda p: p.requires_grad, self.parameters())
return trainable_modules
def trainable_param_names(self):
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters()))
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
return trainable_param_names
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None):
if lora_alpha is None:
lora_alpha = lora_rank
if isinstance(target_modules, list) and len(target_modules) == 1:
target_modules = target_modules[0]
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
model = inject_adapter_in_model(lora_config, model)
if upcast_dtype is not None:
for param in model.parameters():
if param.requires_grad:
param.data = param.to(upcast_dtype)
return model
def mapping_lora_state_dict(self, state_dict):
new_state_dict = {}
for key, value in state_dict.items():
if "lora_A.weight" in key or "lora_B.weight" in key:
new_key = key.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight")
new_state_dict[new_key] = value
elif "lora_A.default.weight" in key or "lora_B.default.weight" in key:
new_state_dict[key] = value
return new_state_dict
def export_trainable_state_dict(self, state_dict, remove_prefix=None):
trainable_param_names = self.trainable_param_names()
state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names}
if remove_prefix is not None:
state_dict_ = {}
for name, param in state_dict.items():
if name.startswith(remove_prefix):
name = name[len(remove_prefix):]
state_dict_[name] = param
state_dict = state_dict_
return state_dict
def transfer_data_to_device(self, data, device, torch_float_dtype=None):
if data is None:
return data
elif isinstance(data, torch.Tensor):
data = data.to(device)
if torch_float_dtype is not None and data.dtype in [torch.float, torch.float16, torch.bfloat16]:
data = data.to(torch_float_dtype)
return data
elif isinstance(data, tuple):
data = tuple(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
return data
elif isinstance(data, list):
data = list(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
return data
elif isinstance(data, dict):
data = {i: self.transfer_data_to_device(data[i], device, torch_float_dtype) for i in data}
return data
else:
return data
def parse_vram_config(self, fp8=False, offload=False, device="cpu"):
if fp8:
return {
"offload_dtype": torch.float8_e4m3fn,
"offload_device": device,
"onload_dtype": torch.float8_e4m3fn,
"onload_device": device,
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": device,
"computation_dtype": torch.bfloat16,
"computation_device": device,
}
elif offload:
return {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": "disk",
"onload_device": "disk",
"preparing_dtype": torch.bfloat16,
"preparing_device": device,
"computation_dtype": torch.bfloat16,
"computation_device": device,
"clear_parameters": True,
}
else:
return {}
def parse_model_configs(self, model_paths, model_id_with_origin_paths, fp8_models=None, offload_models=None, device="cpu"):
fp8_models = [] if fp8_models is None else fp8_models.split(",")
offload_models = [] if offload_models is None else offload_models.split(",")
model_configs = []
if model_paths is not None:
model_paths = json.loads(model_paths)
for path in model_paths:
vram_config = self.parse_vram_config(
fp8=path in fp8_models,
offload=path in offload_models,
device=device
)
model_configs.append(ModelConfig(path=path, **vram_config))
if model_id_with_origin_paths is not None:
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
for model_id_with_origin_path in model_id_with_origin_paths:
vram_config = self.parse_vram_config(
fp8=model_id_with_origin_path in fp8_models,
offload=model_id_with_origin_path in offload_models,
device=device
)
config = self.parse_path_or_model_id(model_id_with_origin_path)
model_configs.append(ModelConfig(model_id=config.model_id, origin_file_pattern=config.origin_file_pattern, **vram_config))
return model_configs
def parse_path_or_model_id(self, model_id_with_origin_path, default_value=None):
if model_id_with_origin_path is None:
return default_value
elif os.path.exists(model_id_with_origin_path):
return ModelConfig(path=model_id_with_origin_path)
else:
if ":" not in model_id_with_origin_path:
raise ValueError(f"Failed to parse model config: {model_id_with_origin_path}. This is neither a valid path nor in the format of `model_id/origin_file_pattern`.")
split_id = model_id_with_origin_path.rfind(":")
model_id = model_id_with_origin_path[:split_id]
origin_file_pattern = model_id_with_origin_path[split_id + 1:]
return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern)
def auto_detect_lora_target_modules(
self,
model: torch.nn.Module,
search_for_linear=False,
linear_detector=lambda x: min(x.weight.shape) >= 512,
block_list_detector=lambda x: isinstance(x, torch.nn.ModuleList) and len(x) > 1,
name_prefix="",
):
lora_target_modules = []
if search_for_linear:
for name, module in model.named_modules():
module_name = name_prefix + ["", "."][name_prefix != ""] + name
if isinstance(module, torch.nn.Linear) and linear_detector(module):
lora_target_modules.append(module_name)
else:
for name, module in model.named_children():
module_name = name_prefix + ["", "."][name_prefix != ""] + name
lora_target_modules += self.auto_detect_lora_target_modules(
module,
search_for_linear=block_list_detector(module),
linear_detector=linear_detector,
block_list_detector=block_list_detector,
name_prefix=module_name,
)
return lora_target_modules
def parse_lora_target_modules(self, model, lora_target_modules):
if lora_target_modules == "":
print("No LoRA target modules specified. The framework will automatically search for them.")
lora_target_modules = self.auto_detect_lora_target_modules(model)
print(f"LoRA will be patched at {lora_target_modules}.")
else:
lora_target_modules = lora_target_modules.split(",")
return lora_target_modules
def switch_pipe_to_training_mode(
self,
pipe,
trainable_models=None,
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
preset_lora_path=None, preset_lora_model=None,
task="sft",
):
# Scheduler
pipe.scheduler.set_timesteps(1000, training=True)
# Freeze untrainable models
pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
# Preset LoRA
if preset_lora_path is not None:
pipe.load_lora(getattr(pipe, preset_lora_model), preset_lora_path)
# FP8
# FP8 relies on a model-specific memory management scheme.
# It is delegated to the subclass.
# Add LoRA to the base models
if lora_base_model is not None and not task.endswith(":data_process"):
if (not hasattr(pipe, lora_base_model)) or getattr(pipe, lora_base_model) is None:
print(f"No {lora_base_model} models in the pipeline. We cannot patch LoRA on the model. If this occurs during the data processing stage, it is normal.")
return
model = self.add_lora_to_model(
getattr(pipe, lora_base_model),
target_modules=self.parse_lora_target_modules(getattr(pipe, lora_base_model), lora_target_modules),
lora_rank=lora_rank,
upcast_dtype=pipe.torch_dtype,
)
if lora_checkpoint is not None:
state_dict = load_state_dict(lora_checkpoint)
state_dict = self.mapping_lora_state_dict(state_dict)
load_result = model.load_state_dict(state_dict, strict=False)
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
if len(load_result[1]) > 0:
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
setattr(pipe, lora_base_model, model)
def split_pipeline_units(self, task, pipe, trainable_models=None, lora_base_model=None):
models_require_backward = []
if trainable_models is not None:
models_require_backward += trainable_models.split(",")
if lora_base_model is not None:
models_require_backward += [lora_base_model]
if task.endswith(":data_process"):
_, pipe.units = pipe.split_pipeline_units(models_require_backward)
elif task.endswith(":train"):
pipe.units, _ = pipe.split_pipeline_units(models_require_backward)
return pipe
def parse_extra_inputs(self, data, extra_inputs, inputs_shared):
controlnet_keys_map = (
("blockwise_controlnet_", "blockwise_controlnet_inputs",),
("controlnet_", "controlnet_inputs"),
)
controlnet_inputs = {}
for extra_input in extra_inputs:
for prefix, name in controlnet_keys_map:
if extra_input.startswith(prefix):
if name not in controlnet_inputs:
controlnet_inputs[name] = {}
controlnet_inputs[name][extra_input.replace(prefix, "")] = data[extra_input]
break
else:
inputs_shared[extra_input] = data[extra_input]
for name, params in controlnet_inputs.items():
inputs_shared[name] = [ControlNetInput(**params)]
return inputs_shared

View File

@@ -1,118 +0,0 @@
import torch
from einops import repeat
from PIL import Image
import numpy as np
class ResidualDenseBlock(torch.nn.Module):
def __init__(self, num_feat=64, num_grow_ch=32):
super(ResidualDenseBlock, self).__init__()
self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(torch.nn.Module):
def __init__(self, num_feat, num_grow_ch=32):
super(RRDB, self).__init__()
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
def forward(self, x):
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
return out * 0.2 + x
class RRDBNet(torch.nn.Module):
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32):
super(RRDBNet, self).__init__()
self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
# upsample
self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
feat = x
feat = self.conv_first(feat)
body_feat = self.conv_body(self.body(feat))
feat = feat + body_feat
# upsample
feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
feat = self.lrelu(self.conv_up1(feat))
feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
feat = self.lrelu(self.conv_up2(feat))
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
return out
class ESRGAN(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
@staticmethod
def from_pretrained(model_path):
model = RRDBNet()
state_dict = torch.load(model_path, map_location="cpu")["params_ema"]
model.load_state_dict(state_dict)
model.eval()
return ESRGAN(model)
def process_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
return image
def process_images(self, images):
images = [self.process_image(image) for image in images]
images = torch.stack(images)
return images
def decode_images(self, images):
images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
images = [Image.fromarray(image) for image in images]
return images
@torch.no_grad()
def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
# Preprocess
input_tensor = self.process_images(images)
# Interpolate
output_tensor = []
for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
batch_input_tensor = input_tensor[batch_id: batch_id_]
batch_input_tensor = batch_input_tensor.to(
device=self.model.conv_first.weight.device,
dtype=self.model.conv_first.weight.dtype)
batch_output_tensor = self.model(batch_input_tensor)
output_tensor.append(batch_output_tensor.cpu())
# Output
output_tensor = torch.concat(output_tensor, dim=0)
# To images
output_images = self.decode_images(output_tensor)
return output_images

View File

@@ -1,63 +0,0 @@
from .runners.fast import TableManager, PyramidPatchMatcher
from PIL import Image
import numpy as np
import cupy as cp
class FastBlendSmoother:
def __init__(self):
self.batch_size = 8
self.window_size = 64
self.ebsynth_config = {
"minimum_patch_size": 5,
"threads_per_block": 8,
"num_iter": 5,
"gpu_id": 0,
"guide_weight": 10.0,
"initialize": "identity",
"tracking_window_size": 0,
}
@staticmethod
def from_model_manager(model_manager):
# TODO: fetch GPU ID from model_manager
return FastBlendSmoother()
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config):
frames_guide = [np.array(frame) for frame in frames_guide]
frames_style = [np.array(frame) for frame in frames_style]
table_manager = TableManager()
patch_match_engine = PyramidPatchMatcher(
image_height=frames_style[0].shape[0],
image_width=frames_style[0].shape[1],
channel=3,
**ebsynth_config
)
# left part
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4")
table_l = table_manager.remapping_table_to_blending_table(table_l)
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4")
# right part
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4")
table_r = table_manager.remapping_table_to_blending_table(table_r)
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1]
# merge
frames = []
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
weight_m = -1
weight = weight_l + weight_m + weight_r
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
frames.append(frame)
frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames]
return frames
def __call__(self, rendered_frames, original_frames=None, **kwargs):
frames = self.run(
original_frames, rendered_frames,
self.batch_size, self.window_size, self.ebsynth_config
)
mempool = cp.get_default_memory_pool()
pinned_mempool = cp.get_default_pinned_memory_pool()
mempool.free_all_blocks()
pinned_mempool.free_all_blocks()
return frames

View File

@@ -1,397 +0,0 @@
from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeSingleFrameRunner
from .data import VideoData, get_video_fps, save_video, search_for_images
import os
import gradio as gr
def check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder):
frames_guide = VideoData(video_guide, video_guide_folder)
frames_style = VideoData(video_style, video_style_folder)
message = ""
if len(frames_guide) < len(frames_style):
message += f"The number of frames mismatches. Only the first {len(frames_guide)} frames of style video will be used.\n"
frames_style.set_length(len(frames_guide))
elif len(frames_guide) > len(frames_style):
message += f"The number of frames mismatches. Only the first {len(frames_style)} frames of guide video will be used.\n"
frames_guide.set_length(len(frames_style))
height_guide, width_guide = frames_guide.shape()
height_style, width_style = frames_style.shape()
if height_guide != height_style or width_guide != width_style:
message += f"The shape of frames mismatches. The frames in style video will be resized to (height: {height_guide}, width: {width_guide})\n"
frames_style.set_shape(height_guide, width_guide)
return frames_guide, frames_style, message
def smooth_video(
video_guide,
video_guide_folder,
video_style,
video_style_folder,
mode,
window_size,
batch_size,
tracking_window_size,
output_path,
fps,
minimum_patch_size,
num_iter,
guide_weight,
initialize,
progress = None,
):
# input
frames_guide, frames_style, message = check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder)
if len(message) > 0:
print(message)
# output
if output_path == "":
if video_style is None:
output_path = os.path.join(video_style_folder, "output")
else:
output_path = os.path.join(os.path.split(video_style)[0], "output")
os.makedirs(output_path, exist_ok=True)
print("No valid output_path. Your video will be saved here:", output_path)
elif not os.path.exists(output_path):
os.makedirs(output_path, exist_ok=True)
print("Your video will be saved here:", output_path)
frames_path = os.path.join(output_path, "frames")
video_path = os.path.join(output_path, "video.mp4")
os.makedirs(frames_path, exist_ok=True)
# process
if mode == "Fast" or mode == "Balanced":
tracking_window_size = 0
ebsynth_config = {
"minimum_patch_size": minimum_patch_size,
"threads_per_block": 8,
"num_iter": num_iter,
"gpu_id": 0,
"guide_weight": guide_weight,
"initialize": initialize,
"tracking_window_size": tracking_window_size,
}
if mode == "Fast":
FastModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
elif mode == "Balanced":
BalancedModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
elif mode == "Accurate":
AccurateModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
# output
try:
fps = int(fps)
except:
fps = get_video_fps(video_style) if video_style is not None else 30
print("Fps:", fps)
print("Saving video...")
video_path = save_video(frames_path, video_path, num_frames=len(frames_style), fps=fps)
print("Success!")
print("Your frames are here:", frames_path)
print("Your video is here:", video_path)
return output_path, fps, video_path
class KeyFrameMatcher:
def __init__(self):
pass
def extract_number_from_filename(self, file_name):
result = []
number = -1
for i in file_name:
if ord(i)>=ord("0") and ord(i)<=ord("9"):
if number == -1:
number = 0
number = number*10 + ord(i) - ord("0")
else:
if number != -1:
result.append(number)
number = -1
if number != -1:
result.append(number)
result = tuple(result)
return result
def extract_number_from_filenames(self, file_names):
numbers = [self.extract_number_from_filename(file_name) for file_name in file_names]
min_length = min(len(i) for i in numbers)
for i in range(min_length-1, -1, -1):
if len(set(number[i] for number in numbers))==len(file_names):
return [number[i] for number in numbers]
return list(range(len(file_names)))
def match_using_filename(self, file_names_a, file_names_b):
file_names_b_set = set(file_names_b)
matched_file_name = []
for file_name in file_names_a:
if file_name not in file_names_b_set:
matched_file_name.append(None)
else:
matched_file_name.append(file_name)
return matched_file_name
def match_using_numbers(self, file_names_a, file_names_b):
numbers_a = self.extract_number_from_filenames(file_names_a)
numbers_b = self.extract_number_from_filenames(file_names_b)
numbers_b_dict = {number: file_name for number, file_name in zip(numbers_b, file_names_b)}
matched_file_name = []
for number in numbers_a:
if number in numbers_b_dict:
matched_file_name.append(numbers_b_dict[number])
else:
matched_file_name.append(None)
return matched_file_name
def match_filenames(self, file_names_a, file_names_b):
matched_file_name = self.match_using_filename(file_names_a, file_names_b)
if sum([i is not None for i in matched_file_name]) > 0:
return matched_file_name
matched_file_name = self.match_using_numbers(file_names_a, file_names_b)
return matched_file_name
def detect_frames(frames_path, keyframes_path):
if not os.path.exists(frames_path) and not os.path.exists(keyframes_path):
return "Please input the directory of guide video and rendered frames"
elif not os.path.exists(frames_path):
return "Please input the directory of guide video"
elif not os.path.exists(keyframes_path):
return "Please input the directory of rendered frames"
frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
if len(frames)==0:
return f"No images detected in {frames_path}"
if len(keyframes)==0:
return f"No images detected in {keyframes_path}"
matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
max_filename_length = max([len(i) for i in frames])
if sum([i is not None for i in matched_keyframes])==0:
message = ""
for frame, matched_keyframe in zip(frames, matched_keyframes):
message += frame + " " * (max_filename_length - len(frame) + 1)
message += "--> No matched keyframes\n"
else:
message = ""
for frame, matched_keyframe in zip(frames, matched_keyframes):
message += frame + " " * (max_filename_length - len(frame) + 1)
if matched_keyframe is None:
message += "--> [to be rendered]\n"
else:
message += f"--> {matched_keyframe}\n"
return message
def check_input_for_interpolating(frames_path, keyframes_path):
# search for images
frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
# match frames
matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
file_list = [file_name for file_name in matched_keyframes if file_name is not None]
index_style = [i for i, file_name in enumerate(matched_keyframes) if file_name is not None]
frames_guide = VideoData(None, frames_path)
frames_style = VideoData(None, keyframes_path, file_list=file_list)
# match shape
message = ""
height_guide, width_guide = frames_guide.shape()
height_style, width_style = frames_style.shape()
if height_guide != height_style or width_guide != width_style:
message += f"The shape of frames mismatches. The rendered keyframes will be resized to (height: {height_guide}, width: {width_guide})\n"
frames_style.set_shape(height_guide, width_guide)
return frames_guide, frames_style, index_style, message
def interpolate_video(
frames_path,
keyframes_path,
output_path,
fps,
batch_size,
tracking_window_size,
minimum_patch_size,
num_iter,
guide_weight,
initialize,
progress = None,
):
# input
frames_guide, frames_style, index_style, message = check_input_for_interpolating(frames_path, keyframes_path)
if len(message) > 0:
print(message)
# output
if output_path == "":
output_path = os.path.join(keyframes_path, "output")
os.makedirs(output_path, exist_ok=True)
print("No valid output_path. Your video will be saved here:", output_path)
elif not os.path.exists(output_path):
os.makedirs(output_path, exist_ok=True)
print("Your video will be saved here:", output_path)
output_frames_path = os.path.join(output_path, "frames")
output_video_path = os.path.join(output_path, "video.mp4")
os.makedirs(output_frames_path, exist_ok=True)
# process
ebsynth_config = {
"minimum_patch_size": minimum_patch_size,
"threads_per_block": 8,
"num_iter": num_iter,
"gpu_id": 0,
"guide_weight": guide_weight,
"initialize": initialize,
"tracking_window_size": tracking_window_size
}
if len(index_style)==1:
InterpolationModeSingleFrameRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
else:
InterpolationModeRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
try:
fps = int(fps)
except:
fps = 30
print("Fps:", fps)
print("Saving video...")
video_path = save_video(output_frames_path, output_video_path, num_frames=len(frames_guide), fps=fps)
print("Success!")
print("Your frames are here:", output_frames_path)
print("Your video is here:", video_path)
return output_path, fps, video_path
def on_ui_tabs():
with gr.Blocks(analytics_enabled=False) as ui_component:
with gr.Tab("Blend"):
gr.Markdown("""
# Blend
Given a guide video and a style video, this algorithm will make the style video fluent according to the motion features of the guide video. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/208d902d-6aba-48d7-b7d5-cd120ebd306d) to see the example. Note that this extension doesn't support long videos. Please use short videos (e.g., several seconds). The algorithm is mainly designed for 512*512 resolution. Please use a larger `Minimum patch size` for higher resolution.
""")
with gr.Row():
with gr.Column():
with gr.Tab("Guide video"):
video_guide = gr.Video(label="Guide video")
with gr.Tab("Guide video (images format)"):
video_guide_folder = gr.Textbox(label="Guide video (images format)", value="")
with gr.Column():
with gr.Tab("Style video"):
video_style = gr.Video(label="Style video")
with gr.Tab("Style video (images format)"):
video_style_folder = gr.Textbox(label="Style video (images format)", value="")
with gr.Column():
output_path = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of style video")
fps = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
video_output = gr.Video(label="Output video", interactive=False, show_share_button=True)
btn = gr.Button(value="Blend")
with gr.Row():
with gr.Column():
gr.Markdown("# Settings")
mode = gr.Radio(["Fast", "Balanced", "Accurate"], label="Inference mode", value="Fast", interactive=True)
window_size = gr.Slider(label="Sliding window size", value=15, minimum=1, maximum=1000, step=1, interactive=True)
batch_size = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
tracking_window_size = gr.Slider(label="Tracking window size (only for accurate mode)", value=0, minimum=0, maximum=10, step=1, interactive=True)
gr.Markdown("## Advanced Settings")
minimum_patch_size = gr.Slider(label="Minimum patch size (odd number)", value=5, minimum=5, maximum=99, step=2, interactive=True)
num_iter = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
guide_weight = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
initialize = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
with gr.Column():
gr.Markdown("""
# Reference
* Output directory: the directory to save the video.
* Inference mode
|Mode|Time|Memory|Quality|Frame by frame output|Description|
|-|-|-|-|-|-|
|Fast|■|■■■|■■|No|Blend the frames using a tree-like data structure, which requires much RAM but is fast.|
|Balanced|■■|■|■■|Yes|Blend the frames naively.|
|Accurate|■■■|■|■■■|Yes|Blend the frames and align them together for higher video quality. When [batch size] >= [sliding window size] * 2 + 1, the performance is the best.|
* Sliding window size: our algorithm will blend the frames in a sliding windows. If the size is n, each frame will be blended with the last n frames and the next n frames. A large sliding window can make the video fluent but sometimes smoggy.
* Batch size: a larger batch size makes the program faster but requires more VRAM.
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
* Advanced settings
* Minimum patch size (odd number): the minimum patch size used for patch matching. (Default: 5)
* Number of iterations: the number of iterations of patch matching. (Default: 5)
* Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
* NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
""")
btn.click(
smooth_video,
inputs=[
video_guide,
video_guide_folder,
video_style,
video_style_folder,
mode,
window_size,
batch_size,
tracking_window_size,
output_path,
fps,
minimum_patch_size,
num_iter,
guide_weight,
initialize
],
outputs=[output_path, fps, video_output]
)
with gr.Tab("Interpolate"):
gr.Markdown("""
# Interpolate
Given a guide video and some rendered keyframes, this algorithm will render the remaining frames. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/3490c5b4-8f67-478f-86de-f9adc2ace16a) to see the example. The algorithm is experimental and is only tested for 512*512 resolution.
""")
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
video_guide_folder_ = gr.Textbox(label="Guide video (images format)", value="")
with gr.Column():
rendered_keyframes_ = gr.Textbox(label="Rendered keyframes (images format)", value="")
with gr.Row():
detected_frames = gr.Textbox(label="Detected frames", value="Please input the directory of guide video and rendered frames", lines=9, max_lines=9, interactive=False)
video_guide_folder_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
rendered_keyframes_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
with gr.Column():
output_path_ = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of rendered keyframes")
fps_ = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
video_output_ = gr.Video(label="Output video", interactive=False, show_share_button=True)
btn_ = gr.Button(value="Interpolate")
with gr.Row():
with gr.Column():
gr.Markdown("# Settings")
batch_size_ = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
tracking_window_size_ = gr.Slider(label="Tracking window size", value=0, minimum=0, maximum=10, step=1, interactive=True)
gr.Markdown("## Advanced Settings")
minimum_patch_size_ = gr.Slider(label="Minimum patch size (odd number, larger is better)", value=15, minimum=5, maximum=99, step=2, interactive=True)
num_iter_ = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
guide_weight_ = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
initialize_ = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
with gr.Column():
gr.Markdown("""
# Reference
* Output directory: the directory to save the video.
* Batch size: a larger batch size makes the program faster but requires more VRAM.
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
* Advanced settings
* Minimum patch size (odd number): the minimum patch size used for patch matching. **This parameter should be larger than that in blending. (Default: 15)**
* Number of iterations: the number of iterations of patch matching. (Default: 5)
* Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
* NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
""")
btn_.click(
interpolate_video,
inputs=[
video_guide_folder_,
rendered_keyframes_,
output_path_,
fps_,
batch_size_,
tracking_window_size_,
minimum_patch_size_,
num_iter_,
guide_weight_,
initialize_,
],
outputs=[output_path_, fps_, video_output_]
)
return [(ui_component, "FastBlend", "FastBlend_ui")]

View File

@@ -1,119 +0,0 @@
import cupy as cp
remapping_kernel = cp.RawKernel(r'''
extern "C" __global__
void remap(
const int height,
const int width,
const int channel,
const int patch_size,
const int pad_size,
const float* source_style,
const int* nnf,
float* target_style
) {
const int r = (patch_size - 1) / 2;
const int x = blockDim.x * blockIdx.x + threadIdx.x;
const int y = blockDim.y * blockIdx.y + threadIdx.y;
if (x >= height or y >= width) return;
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size);
const int min_px = x < r ? -x : -r;
const int max_px = x + r > height - 1 ? height - 1 - x : r;
const int min_py = y < r ? -y : -r;
const int max_py = y + r > width - 1 ? width - 1 - y : r;
int num = 0;
for (int px = min_px; px <= max_px; px++){
for (int py = min_py; py <= max_py; py++){
const int nid = (x + px) * width + y + py;
const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px;
const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py;
if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue;
const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size);
num++;
for (int c = 0; c < channel; c++){
target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c];
}
}
}
for (int c = 0; c < channel; c++){
target_style[z + pid * channel + c] /= num;
}
}
''', 'remap')
patch_error_kernel = cp.RawKernel(r'''
extern "C" __global__
void patch_error(
const int height,
const int width,
const int channel,
const int patch_size,
const int pad_size,
const float* source,
const int* nnf,
const float* target,
float* error
) {
const int r = (patch_size - 1) / 2;
const int x = blockDim.x * blockIdx.x + threadIdx.x;
const int y = blockDim.y * blockIdx.y + threadIdx.y;
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
if (x >= height or y >= width) return;
const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0];
const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1];
float e = 0;
for (int px = -r; px <= r; px++){
for (int py = -r; py <= r; py++){
const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py;
const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py;
for (int c = 0; c < channel; c++){
const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c];
e += diff * diff;
}
}
}
error[blockIdx.z * height * width + x * width + y] = e;
}
''', 'patch_error')
pairwise_patch_error_kernel = cp.RawKernel(r'''
extern "C" __global__
void pairwise_patch_error(
const int height,
const int width,
const int channel,
const int patch_size,
const int pad_size,
const float* source_a,
const int* nnf_a,
const float* source_b,
const int* nnf_b,
float* error
) {
const int r = (patch_size - 1) / 2;
const int x = blockDim.x * blockIdx.x + threadIdx.x;
const int y = blockDim.y * blockIdx.y + threadIdx.y;
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
if (x >= height or y >= width) return;
const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2;
const int x_a = nnf_a[z_nnf + 0];
const int y_a = nnf_a[z_nnf + 1];
const int x_b = nnf_b[z_nnf + 0];
const int y_b = nnf_b[z_nnf + 1];
float e = 0;
for (int px = -r; px <= r; px++){
for (int py = -r; py <= r; py++){
const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py;
const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py;
for (int c = 0; c < channel; c++){
const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c];
e += diff * diff;
}
}
}
error[blockIdx.z * height * width + x * width + y] = e;
}
''', 'pairwise_patch_error')

View File

@@ -1,146 +0,0 @@
import imageio, os
import numpy as np
from PIL import Image
def read_video(file_name):
reader = imageio.get_reader(file_name)
video = []
for frame in reader:
frame = np.array(frame)
video.append(frame)
reader.close()
return video
def get_video_fps(file_name):
reader = imageio.get_reader(file_name)
fps = reader.get_meta_data()["fps"]
reader.close()
return fps
def save_video(frames_path, video_path, num_frames, fps):
writer = imageio.get_writer(video_path, fps=fps, quality=9)
for i in range(num_frames):
frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
writer.append_data(frame)
writer.close()
return video_path
class LowMemoryVideo:
def __init__(self, file_name):
self.reader = imageio.get_reader(file_name)
def __len__(self):
return self.reader.count_frames()
def __getitem__(self, item):
return np.array(self.reader.get_data(item))
def __del__(self):
self.reader.close()
def split_file_name(file_name):
result = []
number = -1
for i in file_name:
if ord(i)>=ord("0") and ord(i)<=ord("9"):
if number == -1:
number = 0
number = number*10 + ord(i) - ord("0")
else:
if number != -1:
result.append(number)
number = -1
result.append(i)
if number != -1:
result.append(number)
result = tuple(result)
return result
def search_for_images(folder):
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
file_list = [i[1] for i in sorted(file_list)]
file_list = [os.path.join(folder, i) for i in file_list]
return file_list
def read_images(folder):
file_list = search_for_images(folder)
frames = [np.array(Image.open(i)) for i in file_list]
return frames
class LowMemoryImageFolder:
def __init__(self, folder, file_list=None):
if file_list is None:
self.file_list = search_for_images(folder)
else:
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
def __len__(self):
return len(self.file_list)
def __getitem__(self, item):
return np.array(Image.open(self.file_list[item]))
def __del__(self):
pass
class VideoData:
def __init__(self, video_file, image_folder, **kwargs):
if video_file is not None:
self.data_type = "video"
self.data = LowMemoryVideo(video_file, **kwargs)
elif image_folder is not None:
self.data_type = "images"
self.data = LowMemoryImageFolder(image_folder, **kwargs)
else:
raise ValueError("Cannot open video or image folder")
self.length = None
self.height = None
self.width = None
def raw_data(self):
frames = []
for i in range(self.__len__()):
frames.append(self.__getitem__(i))
return frames
def set_length(self, length):
self.length = length
def set_shape(self, height, width):
self.height = height
self.width = width
def __len__(self):
if self.length is None:
return len(self.data)
else:
return self.length
def shape(self):
if self.height is not None and self.width is not None:
return self.height, self.width
else:
height, width, _ = self.__getitem__(0).shape
return height, width
def __getitem__(self, item):
frame = self.data.__getitem__(item)
height, width, _ = frame.shape
if self.height is not None and self.width is not None:
if self.height != height or self.width != width:
frame = Image.fromarray(frame).resize((self.width, self.height))
frame = np.array(frame)
return frame
def __del__(self):
pass

View File

@@ -1,298 +0,0 @@
from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel
import numpy as np
import cupy as cp
import cv2
class PatchMatcher:
def __init__(
self, height, width, channel, minimum_patch_size,
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
random_search_steps=3, random_search_range=4,
use_mean_target_style=False, use_pairwise_patch_error=False,
tracking_window_size=0
):
self.height = height
self.width = width
self.channel = channel
self.minimum_patch_size = minimum_patch_size
self.threads_per_block = threads_per_block
self.num_iter = num_iter
self.gpu_id = gpu_id
self.guide_weight = guide_weight
self.random_search_steps = random_search_steps
self.random_search_range = random_search_range
self.use_mean_target_style = use_mean_target_style
self.use_pairwise_patch_error = use_pairwise_patch_error
self.tracking_window_size = tracking_window_size
self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1]
self.pad_size = self.patch_size_list[0] // 2
self.grid = (
(height + threads_per_block - 1) // threads_per_block,
(width + threads_per_block - 1) // threads_per_block
)
self.block = (threads_per_block, threads_per_block)
def pad_image(self, image):
return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0)))
def unpad_image(self, image):
return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :]
def apply_nnf_to_image(self, nnf, source):
batch_size = source.shape[0]
target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32)
remapping_kernel(
self.grid + (batch_size,),
self.block,
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target)
)
return target
def get_patch_error(self, source, nnf, target):
batch_size = source.shape[0]
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
patch_error_kernel(
self.grid + (batch_size,),
self.block,
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error)
)
return error
def get_pairwise_patch_error(self, source, nnf):
batch_size = source.shape[0]//2
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy()
source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy()
pairwise_patch_error_kernel(
self.grid + (batch_size,),
self.block,
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error)
)
error = error.repeat(2, axis=0)
return error
def get_error(self, source_guide, target_guide, source_style, target_style, nnf):
error_guide = self.get_patch_error(source_guide, nnf, target_guide)
if self.use_mean_target_style:
target_style = self.apply_nnf_to_image(nnf, source_style)
target_style = target_style.mean(axis=0, keepdims=True)
target_style = target_style.repeat(source_guide.shape[0], axis=0)
if self.use_pairwise_patch_error:
error_style = self.get_pairwise_patch_error(source_style, nnf)
else:
error_style = self.get_patch_error(source_style, nnf, target_style)
error = error_guide * self.guide_weight + error_style
return error
def clamp_bound(self, nnf):
nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1)
nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1)
return nnf
def random_step(self, nnf, r):
batch_size = nnf.shape[0]
step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32)
upd_nnf = self.clamp_bound(nnf + step)
return upd_nnf
def neighboor_step(self, nnf, d):
if d==0:
upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1)
upd_nnf[:, :, :, 0] += 1
elif d==1:
upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2)
upd_nnf[:, :, :, 1] += 1
elif d==2:
upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1)
upd_nnf[:, :, :, 0] -= 1
elif d==3:
upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2)
upd_nnf[:, :, :, 1] -= 1
upd_nnf = self.clamp_bound(upd_nnf)
return upd_nnf
def shift_nnf(self, nnf, d):
if d>0:
d = min(nnf.shape[0], d)
upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
else:
d = max(-nnf.shape[0], d)
upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
return upd_nnf
def track_step(self, nnf, d):
if self.use_pairwise_patch_error:
upd_nnf = cp.zeros_like(nnf)
upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d)
upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d)
else:
upd_nnf = self.shift_nnf(nnf, d)
return upd_nnf
def C(self, n, m):
# not used
c = 1
for i in range(1, n+1):
c *= i
for i in range(1, m+1):
c //= i
for i in range(1, n-m+1):
c //= i
return c
def bezier_step(self, nnf, r):
# not used
n = r * 2 - 1
upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32)
for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))):
if d>0:
ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
elif d<0:
ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
upd_nnf += ctl_nnf * (self.C(n, i) / 2**n)
upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype)
return upd_nnf
def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf):
upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf)
upd_idx = (upd_err < err)
nnf[upd_idx] = upd_nnf[upd_idx]
err[upd_idx] = upd_err[upd_idx]
return nnf, err
def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err):
for d in cp.random.permutation(4):
upd_nnf = self.neighboor_step(nnf, d)
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
return nnf, err
def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err):
for i in range(self.random_search_steps):
upd_nnf = self.random_step(nnf, self.random_search_range)
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
return nnf, err
def track(self, source_guide, target_guide, source_style, target_style, nnf, err):
for d in range(1, self.tracking_window_size + 1):
upd_nnf = self.track_step(nnf, d)
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
upd_nnf = self.track_step(nnf, -d)
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
return nnf, err
def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err):
nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err)
nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err)
nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err)
return nnf, err
def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
with cp.cuda.Device(self.gpu_id):
source_guide = self.pad_image(source_guide)
target_guide = self.pad_image(target_guide)
source_style = self.pad_image(source_style)
for it in range(self.num_iter):
self.patch_size = self.patch_size_list[it]
target_style = self.apply_nnf_to_image(nnf, source_style)
err = self.get_error(source_guide, target_guide, source_style, target_style, nnf)
nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err)
target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style))
return nnf, target_style
class PyramidPatchMatcher:
def __init__(
self, image_height, image_width, channel, minimum_patch_size,
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
use_mean_target_style=False, use_pairwise_patch_error=False,
tracking_window_size=0,
initialize="identity"
):
maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2
self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size))
self.pyramid_heights = []
self.pyramid_widths = []
self.patch_matchers = []
self.minimum_patch_size = minimum_patch_size
self.num_iter = num_iter
self.gpu_id = gpu_id
self.initialize = initialize
for level in range(self.pyramid_level):
height = image_height//(2**(self.pyramid_level - 1 - level))
width = image_width//(2**(self.pyramid_level - 1 - level))
self.pyramid_heights.append(height)
self.pyramid_widths.append(width)
self.patch_matchers.append(PatchMatcher(
height, width, channel, minimum_patch_size=minimum_patch_size,
threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight,
use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error,
tracking_window_size=tracking_window_size
))
def resample_image(self, images, level):
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
images = images.get()
images_resample = []
for image in images:
image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
images_resample.append(image_resample)
images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
return images_resample
def initialize_nnf(self, batch_size):
if self.initialize == "random":
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
nnf = cp.stack([
cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32),
cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32)
], axis=3)
elif self.initialize == "identity":
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
nnf = cp.stack([
cp.repeat(cp.arange(height), width).reshape(height, width),
cp.tile(cp.arange(width), height).reshape(height, width)
], axis=2)
nnf = cp.stack([nnf] * batch_size)
else:
raise NotImplementedError()
return nnf
def update_nnf(self, nnf, level):
# upscale
nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
# check if scale is 2
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
nnf = nnf.get().astype(np.float32)
nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
nnf = cp.array(np.stack(nnf), dtype=cp.int32)
nnf = self.patch_matchers[level].clamp_bound(nnf)
return nnf
def apply_nnf_to_image(self, nnf, image):
with cp.cuda.Device(self.gpu_id):
image = self.patch_matchers[-1].pad_image(image)
image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image)
return image
def estimate_nnf(self, source_guide, target_guide, source_style):
with cp.cuda.Device(self.gpu_id):
if not isinstance(source_guide, cp.ndarray):
source_guide = cp.array(source_guide, dtype=cp.float32)
if not isinstance(target_guide, cp.ndarray):
target_guide = cp.array(target_guide, dtype=cp.float32)
if not isinstance(source_style, cp.ndarray):
source_style = cp.array(source_style, dtype=cp.float32)
for level in range(self.pyramid_level):
nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level)
source_guide_ = self.resample_image(source_guide, level)
target_guide_ = self.resample_image(target_guide, level)
source_style_ = self.resample_image(source_style, level)
nnf, target_style = self.patch_matchers[level].estimate_nnf(
source_guide_, target_guide_, source_style_, nnf
)
return nnf.get(), target_style.get()

View File

@@ -1,4 +0,0 @@
from .accurate import AccurateModeRunner
from .fast import FastModeRunner
from .balanced import BalancedModeRunner
from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner

View File

@@ -1,35 +0,0 @@
from ..patch_match import PyramidPatchMatcher
import os
import numpy as np
from PIL import Image
from tqdm import tqdm
class AccurateModeRunner:
def __init__(self):
pass
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
patch_match_engine = PyramidPatchMatcher(
image_height=frames_style[0].shape[0],
image_width=frames_style[0].shape[1],
channel=3,
use_mean_target_style=True,
**ebsynth_config
)
# run
n = len(frames_style)
for target in tqdm(range(n), desc=desc):
l, r = max(target - window_size, 0), min(target + window_size + 1, n)
remapped_frames = []
for i in range(l, r, batch_size):
j = min(i + batch_size, r)
source_guide = np.stack([frames_guide[source] for source in range(i, j)])
target_guide = np.stack([frames_guide[target]] * (j - i))
source_style = np.stack([frames_style[source] for source in range(i, j)])
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
remapped_frames.append(target_style)
frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
frame = frame.clip(0, 255).astype("uint8")
if save_path is not None:
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))

View File

@@ -1,46 +0,0 @@
from ..patch_match import PyramidPatchMatcher
import os
import numpy as np
from PIL import Image
from tqdm import tqdm
class BalancedModeRunner:
def __init__(self):
pass
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
patch_match_engine = PyramidPatchMatcher(
image_height=frames_style[0].shape[0],
image_width=frames_style[0].shape[1],
channel=3,
**ebsynth_config
)
# tasks
n = len(frames_style)
tasks = []
for target in range(n):
for source in range(target - window_size, target + window_size + 1):
if source >= 0 and source < n and source != target:
tasks.append((source, target))
# run
frames = [(None, 1) for i in range(n)]
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
source_style = np.stack([frames_style[source] for source, target in tasks_batch])
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
for (source, target), result in zip(tasks_batch, target_style):
frame, weight = frames[target]
if frame is None:
frame = frames_style[target]
frames[target] = (
frame * (weight / (weight + 1)) + result / (weight + 1),
weight + 1
)
if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
frame = frame.clip(0, 255).astype("uint8")
if save_path is not None:
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
frames[target] = (None, 1)

View File

@@ -1,141 +0,0 @@
from ..patch_match import PyramidPatchMatcher
import functools, os
import numpy as np
from PIL import Image
from tqdm import tqdm
class TableManager:
def __init__(self):
pass
def task_list(self, n):
tasks = []
max_level = 1
while (1<<max_level)<=n:
max_level += 1
for i in range(n):
j = i
for level in range(max_level):
if i&(1<<level):
continue
j |= 1<<level
if j>=n:
break
meta_data = {
"source": i,
"target": j,
"level": level + 1
}
tasks.append(meta_data)
tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
return tasks
def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
n = len(frames_guide)
tasks = self.task_list(n)
remapping_table = [[(frames_style[i], 1)] for i in range(n)]
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
for task, result in zip(tasks_batch, target_style):
target, level = task["target"], task["level"]
if len(remapping_table[target])==level:
remapping_table[target].append((result, 1))
else:
frame, weight = remapping_table[target][level]
remapping_table[target][level] = (
frame * (weight / (weight + 1)) + result / (weight + 1),
weight + 1
)
return remapping_table
def remapping_table_to_blending_table(self, table):
for i in range(len(table)):
for j in range(1, len(table[i])):
frame_1, weight_1 = table[i][j-1]
frame_2, weight_2 = table[i][j]
frame = (frame_1 + frame_2) / 2
weight = weight_1 + weight_2
table[i][j] = (frame, weight)
return table
def tree_query(self, leftbound, rightbound):
node_list = []
node_index = rightbound
while node_index>=leftbound:
node_level = 0
while (1<<node_level)&node_index and node_index-(1<<node_level+1)+1>=leftbound:
node_level += 1
node_list.append((node_index, node_level))
node_index -= 1<<node_level
return node_list
def process_window_sum(self, frames_guide, blending_table, patch_match_engine, window_size, batch_size, desc=""):
n = len(blending_table)
tasks = []
frames_result = []
for target in range(n):
node_list = self.tree_query(max(target-window_size, 0), target)
for source, level in node_list:
if source!=target:
meta_data = {
"source": source,
"target": target,
"level": level
}
tasks.append(meta_data)
else:
frames_result.append(blending_table[target][level])
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
source_style = np.stack([blending_table[task["source"]][task["level"]][0] for task in tasks_batch])
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
for task, frame_2 in zip(tasks_batch, target_style):
source, target, level = task["source"], task["target"], task["level"]
frame_1, weight_1 = frames_result[target]
weight_2 = blending_table[source][level][1]
weight = weight_1 + weight_2
frame = frame_1 * (weight_1 / weight) + frame_2 * (weight_2 / weight)
frames_result[target] = (frame, weight)
return frames_result
class FastModeRunner:
def __init__(self):
pass
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, save_path=None):
frames_guide = frames_guide.raw_data()
frames_style = frames_style.raw_data()
table_manager = TableManager()
patch_match_engine = PyramidPatchMatcher(
image_height=frames_style[0].shape[0],
image_width=frames_style[0].shape[1],
channel=3,
**ebsynth_config
)
# left part
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="Fast Mode Step 1/4")
table_l = table_manager.remapping_table_to_blending_table(table_l)
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 2/4")
# right part
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="Fast Mode Step 3/4")
table_r = table_manager.remapping_table_to_blending_table(table_r)
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 4/4")[::-1]
# merge
frames = []
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
weight_m = -1
weight = weight_l + weight_m + weight_r
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
frames.append(frame)
frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
if save_path is not None:
for target, frame in enumerate(frames):
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))

View File

@@ -1,121 +0,0 @@
from ..patch_match import PyramidPatchMatcher
import os
import numpy as np
from PIL import Image
from tqdm import tqdm
class InterpolationModeRunner:
def __init__(self):
pass
def get_index_dict(self, index_style):
index_dict = {}
for i, index in enumerate(index_style):
index_dict[index] = i
return index_dict
def get_weight(self, l, m, r):
weight_l, weight_r = abs(m - r), abs(m - l)
if weight_l + weight_r == 0:
weight_l, weight_r = 0.5, 0.5
else:
weight_l, weight_r = weight_l / (weight_l + weight_r), weight_r / (weight_l + weight_r)
return weight_l, weight_r
def get_task_group(self, index_style, n):
task_group = []
index_style = sorted(index_style)
# first frame
if index_style[0]>0:
tasks = []
for m in range(index_style[0]):
tasks.append((index_style[0], m, index_style[0]))
task_group.append(tasks)
# middle frames
for l, r in zip(index_style[:-1], index_style[1:]):
tasks = []
for m in range(l, r):
tasks.append((l, m, r))
task_group.append(tasks)
# last frame
tasks = []
for m in range(index_style[-1], n):
tasks.append((index_style[-1], m, index_style[-1]))
task_group.append(tasks)
return task_group
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
patch_match_engine = PyramidPatchMatcher(
image_height=frames_style[0].shape[0],
image_width=frames_style[0].shape[1],
channel=3,
use_mean_target_style=False,
use_pairwise_patch_error=True,
**ebsynth_config
)
# task
index_dict = self.get_index_dict(index_style)
task_group = self.get_task_group(index_style, len(frames_guide))
# run
for tasks in task_group:
index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
source_guide, target_guide, source_style = [], [], []
for l, m, r in tasks_batch:
# l -> m
source_guide.append(frames_guide[l])
target_guide.append(frames_guide[m])
source_style.append(frames_style[index_dict[l]])
# r -> m
source_guide.append(frames_guide[r])
target_guide.append(frames_guide[m])
source_style.append(frames_style[index_dict[r]])
source_guide = np.stack(source_guide)
target_guide = np.stack(target_guide)
source_style = np.stack(source_style)
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
if save_path is not None:
for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
weight_l, weight_r = self.get_weight(l, m, r)
frame = frame_l * weight_l + frame_r * weight_r
frame = frame.clip(0, 255).astype("uint8")
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
class InterpolationModeSingleFrameRunner:
def __init__(self):
pass
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
# check input
tracking_window_size = ebsynth_config["tracking_window_size"]
if tracking_window_size * 2 >= batch_size:
raise ValueError("batch_size should be larger than track_window_size * 2")
frame_style = frames_style[0]
frame_guide = frames_guide[index_style[0]]
patch_match_engine = PyramidPatchMatcher(
image_height=frame_style.shape[0],
image_width=frame_style.shape[1],
channel=3,
**ebsynth_config
)
# run
frame_id, n = 0, len(frames_guide)
for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
if i + batch_size > n:
l, r = max(n - batch_size, 0), n
else:
l, r = i, i + batch_size
source_guide = np.stack([frame_guide] * (r-l))
target_guide = np.stack([frames_guide[i] for i in range(l, r)])
source_style = np.stack([frame_style] * (r-l))
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
for i, frame in zip(range(l, r), target_style):
if i==frame_id:
frame = frame.clip(0, 255).astype("uint8")
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
frame_id += 1
if r < n and r-frame_id <= tracking_window_size:
break

View File

@@ -1,241 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
def warp(tenInput, tenFlow, device):
backwarp_tenGrid = {}
k = (str(tenFlow.device), str(tenFlow.size()))
if k not in backwarp_tenGrid:
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
backwarp_tenGrid[k] = torch.cat(
[tenHorizontal, tenVertical], 1).to(device)
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
class IFBlock(nn.Module):
def __init__(self, in_planes, c=64):
super(IFBlock, self).__init__()
self.conv0 = nn.Sequential(conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1),)
self.convblock0 = nn.Sequential(conv(c, c), conv(c, c))
self.convblock1 = nn.Sequential(conv(c, c), conv(c, c))
self.convblock2 = nn.Sequential(conv(c, c), conv(c, c))
self.convblock3 = nn.Sequential(conv(c, c), conv(c, c))
self.conv1 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 4, 4, 2, 1))
self.conv2 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 1, 4, 2, 1))
def forward(self, x, flow, scale=1):
x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
feat = self.conv0(torch.cat((x, flow), 1))
feat = self.convblock0(feat) + feat
feat = self.convblock1(feat) + feat
feat = self.convblock2(feat) + feat
feat = self.convblock3(feat) + feat
flow = self.conv1(feat)
mask = self.conv2(feat)
flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale
mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
return flow, mask
class IFNet(nn.Module):
def __init__(self):
super(IFNet, self).__init__()
self.block0 = IFBlock(7+4, c=90)
self.block1 = IFBlock(7+4, c=90)
self.block2 = IFBlock(7+4, c=90)
self.block_tea = IFBlock(10+4, c=90)
def forward(self, x, scale_list=[4, 2, 1], training=False):
if training == False:
channel = x.shape[1] // 2
img0 = x[:, :channel]
img1 = x[:, channel:]
flow_list = []
merged = []
mask_list = []
warped_img0 = img0
warped_img1 = img1
flow = (x[:, :4]).detach() * 0
mask = (x[:, :1]).detach() * 0
block = [self.block0, self.block1, self.block2]
for i in range(3):
f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
mask = mask + (m0 + (-m1)) / 2
mask_list.append(mask)
flow_list.append(flow)
warped_img0 = warp(img0, flow[:, :2], device=x.device)
warped_img1 = warp(img1, flow[:, 2:4], device=x.device)
merged.append((warped_img0, warped_img1))
'''
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
res = tmp[:, 1:4] * 2 - 1
'''
for i in range(3):
mask_list[i] = torch.sigmoid(mask_list[i])
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
return flow_list, mask_list[2], merged
def state_dict_converter(self):
return IFNetStateDictConverter()
class IFNetStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {k.replace("module.", ""): v for k, v in state_dict.items()}
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)
class RIFEInterpolater:
def __init__(self, model, device="cuda"):
self.model = model
self.device = device
# IFNet only does not support float16
self.torch_dtype = torch.float32
@staticmethod
def from_model_manager(model_manager):
return RIFEInterpolater(model_manager.RIFE, device=model_manager.device)
def process_image(self, image):
width, height = image.size
if width % 32 != 0 or height % 32 != 0:
width = (width + 31) // 32
height = (height + 31) // 32
image = image.resize((width, height))
image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
return image
def process_images(self, images):
images = [self.process_image(image) for image in images]
images = torch.stack(images)
return images
def decode_images(self, images):
images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
images = [Image.fromarray(image) for image in images]
return images
def add_interpolated_images(self, images, interpolated_images):
output_images = []
for image, interpolated_image in zip(images, interpolated_images):
output_images.append(image)
output_images.append(interpolated_image)
output_images.append(images[-1])
return output_images
@torch.no_grad()
def interpolate_(self, images, scale=1.0):
input_tensor = self.process_images(images)
input_tensor = torch.cat((input_tensor[:-1], input_tensor[1:]), dim=1)
input_tensor = input_tensor.to(device=self.device, dtype=self.torch_dtype)
flow, mask, merged = self.model(input_tensor, [4/scale, 2/scale, 1/scale])
output_images = self.decode_images(merged[2].cpu())
if output_images[0].size != images[0].size:
output_images = [image.resize(images[0].size) for image in output_images]
return output_images
@torch.no_grad()
def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1, progress_bar=lambda x:x):
# Preprocess
processed_images = self.process_images(images)
for iter in range(num_iter):
# Input
input_tensor = torch.cat((processed_images[:-1], processed_images[1:]), dim=1)
# Interpolate
output_tensor = []
for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
batch_input_tensor = input_tensor[batch_id: batch_id_]
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
output_tensor.append(merged[2].cpu())
# Output
output_tensor = torch.concat(output_tensor, dim=0).clip(0, 1)
processed_images = self.add_interpolated_images(processed_images, output_tensor)
processed_images = torch.stack(processed_images)
# To images
output_images = self.decode_images(processed_images)
if output_images[0].size != images[0].size:
output_images = [image.resize(images[0].size) for image in output_images]
return output_images
class RIFESmoother(RIFEInterpolater):
def __init__(self, model, device="cuda"):
super(RIFESmoother, self).__init__(model, device=device)
@staticmethod
def from_model_manager(model_manager):
return RIFESmoother(model_manager.RIFE, device=model_manager.device)
def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
output_tensor = []
for batch_id in range(0, input_tensor.shape[0], batch_size):
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
batch_input_tensor = input_tensor[batch_id: batch_id_]
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
output_tensor.append(merged[2].cpu())
output_tensor = torch.concat(output_tensor, dim=0)
return output_tensor
@torch.no_grad()
def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=1, **kwargs):
# Preprocess
processed_images = self.process_images(rendered_frames)
for iter in range(num_iter):
# Input
input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
# Interpolate
output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
# Blend
input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
# Add to frames
processed_images[1:-1] = output_tensor
# To images
output_images = self.decode_images(processed_images)
if output_images[0].size != rendered_frames[0].size:
output_images = [image.resize(rendered_frames[0].size) for image in output_images]
return output_images

View File

@@ -1,482 +0,0 @@
import torch, os
from safetensors import safe_open
from .sd_text_encoder import SDTextEncoder
from .sd_unet import SDUNet
from .sd_vae_encoder import SDVAEEncoder
from .sd_vae_decoder import SDVAEDecoder
from .sd_lora import SDLoRA
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from .sdxl_unet import SDXLUNet
from .sdxl_vae_decoder import SDXLVAEDecoder
from .sdxl_vae_encoder import SDXLVAEEncoder
from .sd_controlnet import SDControlNet
from .sd_motion import SDMotionModel
from .sdxl_motion import SDXLMotionModel
from .svd_image_encoder import SVDImageEncoder
from .svd_unet import SVDUNet
from .svd_vae_decoder import SVDVAEDecoder
from .svd_vae_encoder import SVDVAEEncoder
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from .hunyuan_dit import HunyuanDiT
class ModelManager:
def __init__(self, torch_dtype=torch.float16, device="cuda"):
self.torch_dtype = torch_dtype
self.device = device
self.model = {}
self.model_path = {}
self.textual_inversion_dict = {}
def is_stable_video_diffusion(self, state_dict):
param_name = "model.diffusion_model.output_blocks.9.1.time_stack.0.norm_in.weight"
return param_name in state_dict
def is_RIFE(self, state_dict):
param_name = "block_tea.convblock3.0.1.weight"
return param_name in state_dict or ("module." + param_name) in state_dict
def is_beautiful_prompt(self, state_dict):
param_name = "transformer.h.9.self_attention.query_key_value.weight"
return param_name in state_dict
def is_stabe_diffusion_xl(self, state_dict):
param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight"
return param_name in state_dict
def is_stable_diffusion(self, state_dict):
if self.is_stabe_diffusion_xl(state_dict):
return False
param_name = "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight"
return param_name in state_dict
def is_controlnet(self, state_dict):
param_name = "control_model.time_embed.0.weight"
param_name_2 = "mid_block.resnets.1.time_emb_proj.weight" # For controlnets in diffusers format
return param_name in state_dict or param_name_2 in state_dict
def is_animatediff(self, state_dict):
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
return param_name in state_dict
def is_animatediff_xl(self, state_dict):
param_name = "up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.weight"
return param_name in state_dict
def is_sd_lora(self, state_dict):
param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight"
return param_name in state_dict
def is_translator(self, state_dict):
param_name = "model.encoder.layers.5.self_attn_layer_norm.weight"
return param_name in state_dict and len(state_dict) == 254
def is_ipadapter(self, state_dict):
return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([3072, 1024])
def is_ipadapter_image_encoder(self, state_dict):
param_name = "vision_model.encoder.layers.31.self_attn.v_proj.weight"
return param_name in state_dict and len(state_dict) == 521
def is_ipadapter_xl(self, state_dict):
return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([8192, 1280])
def is_ipadapter_xl_image_encoder(self, state_dict):
param_name = "vision_model.encoder.layers.47.self_attn.v_proj.weight"
return param_name in state_dict and len(state_dict) == 777
def is_hunyuan_dit_clip_text_encoder(self, state_dict):
param_name = "bert.encoder.layer.23.attention.output.dense.weight"
return param_name in state_dict
def is_hunyuan_dit_t5_text_encoder(self, state_dict):
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
return param_name in state_dict
def is_hunyuan_dit(self, state_dict):
param_name = "final_layer.adaLN_modulation.1.weight"
return param_name in state_dict
def is_diffusers_vae(self, state_dict):
param_name = "quant_conv.weight"
return param_name in state_dict
def is_ExVideo_StableVideoDiffusion(self, state_dict):
param_name = "blocks.185.positional_embedding.embeddings"
return param_name in state_dict
def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
component_dict = {
"image_encoder": SVDImageEncoder,
"unet": SVDUNet,
"vae_decoder": SVDVAEDecoder,
"vae_encoder": SVDVAEEncoder,
}
if components is None:
components = ["image_encoder", "unet", "vae_decoder", "vae_encoder"]
for component in components:
if component == "unet":
self.model[component] = component_dict[component](add_positional_conv=add_positional_conv)
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv), strict=False)
else:
self.model[component] = component_dict[component]()
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
self.model[component].to(self.torch_dtype).to(self.device)
self.model_path[component] = file_path
def load_stable_diffusion(self, state_dict, components=None, file_path=""):
component_dict = {
"text_encoder": SDTextEncoder,
"unet": SDUNet,
"vae_decoder": SDVAEDecoder,
"vae_encoder": SDVAEEncoder,
"refiner": SDXLUNet,
}
if components is None:
components = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
for component in components:
if component == "text_encoder":
# Add additional token embeddings to text encoder
token_embeddings = [state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"]]
for keyword in self.textual_inversion_dict:
_, embeddings = self.textual_inversion_dict[keyword]
token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
token_embeddings = torch.concat(token_embeddings, dim=0)
state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
self.model[component].to(self.torch_dtype).to(self.device)
else:
self.model[component] = component_dict[component]()
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
self.model[component].to(self.torch_dtype).to(self.device)
self.model_path[component] = file_path
def load_stable_diffusion_xl(self, state_dict, components=None, file_path=""):
component_dict = {
"text_encoder": SDXLTextEncoder,
"text_encoder_2": SDXLTextEncoder2,
"unet": SDXLUNet,
"vae_decoder": SDXLVAEDecoder,
"vae_encoder": SDXLVAEEncoder,
}
if components is None:
components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
for component in components:
self.model[component] = component_dict[component]()
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
if component in ["vae_decoder", "vae_encoder"]:
# These two model will output nan when float16 is enabled.
# The precision problem happens in the last three resnet blocks.
# I do not know how to solve this problem.
self.model[component].to(torch.float32).to(self.device)
else:
self.model[component].to(self.torch_dtype).to(self.device)
self.model_path[component] = file_path
def load_controlnet(self, state_dict, file_path=""):
component = "controlnet"
if component not in self.model:
self.model[component] = []
self.model_path[component] = []
model = SDControlNet()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component].append(model)
self.model_path[component].append(file_path)
def load_animatediff(self, state_dict, file_path=""):
component = "motion_modules"
model = SDMotionModel()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_animatediff_xl(self, state_dict, file_path=""):
component = "motion_modules_xl"
model = SDXLMotionModel()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_beautiful_prompt(self, state_dict, file_path=""):
component = "beautiful_prompt"
from transformers import AutoModelForCausalLM
model_folder = os.path.dirname(file_path)
model = AutoModelForCausalLM.from_pretrained(
model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype
).to(self.device).eval()
self.model[component] = model
self.model_path[component] = file_path
def load_RIFE(self, state_dict, file_path=""):
component = "RIFE"
from ..extensions.RIFE import IFNet
model = IFNet().eval()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(torch.float32).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_sd_lora(self, state_dict, alpha):
SDLoRA().add_lora_to_text_encoder(self.model["text_encoder"], state_dict, alpha=alpha, device=self.device)
SDLoRA().add_lora_to_unet(self.model["unet"], state_dict, alpha=alpha, device=self.device)
def load_translator(self, state_dict, file_path=""):
# This model is lightweight, we do not place it on GPU.
component = "translator"
from transformers import AutoModelForSeq2SeqLM
model_folder = os.path.dirname(file_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_folder).eval()
self.model[component] = model
self.model_path[component] = file_path
def load_ipadapter(self, state_dict, file_path=""):
component = "ipadapter"
model = SDIpAdapter()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_ipadapter_image_encoder(self, state_dict, file_path=""):
component = "ipadapter_image_encoder"
model = IpAdapterCLIPImageEmbedder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_ipadapter_xl(self, state_dict, file_path=""):
component = "ipadapter_xl"
model = SDXLIpAdapter()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_ipadapter_xl_image_encoder(self, state_dict, file_path=""):
component = "ipadapter_xl_image_encoder"
model = IpAdapterXLCLIPImageEmbedder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_hunyuan_dit_clip_text_encoder(self, state_dict, file_path=""):
component = "hunyuan_dit_clip_text_encoder"
model = HunyuanDiTCLIPTextEncoder()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_hunyuan_dit_t5_text_encoder(self, state_dict, file_path=""):
component = "hunyuan_dit_t5_text_encoder"
model = HunyuanDiTT5TextEncoder()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_hunyuan_dit(self, state_dict, file_path=""):
component = "hunyuan_dit"
model = HunyuanDiT()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_diffusers_vae(self, state_dict, file_path=""):
# TODO: detect SD and SDXL
component = "vae_encoder"
model = SDXLVAEEncoder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
component = "vae_decoder"
model = SDXLVAEDecoder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_ExVideo_StableVideoDiffusion(self, state_dict, file_path=""):
unet_state_dict = self.model["unet"].state_dict()
self.model["unet"].to("cpu")
del self.model["unet"]
add_positional_conv = state_dict["blocks.185.positional_embedding.embeddings"].shape[0]
self.model["unet"] = SVDUNet(add_positional_conv=add_positional_conv)
self.model["unet"].load_state_dict(unet_state_dict, strict=False)
self.model["unet"].load_state_dict(state_dict, strict=False)
self.model["unet"].to(self.torch_dtype).to(self.device)
def search_for_embeddings(self, state_dict):
embeddings = []
for k in state_dict:
if isinstance(state_dict[k], torch.Tensor):
embeddings.append(state_dict[k])
elif isinstance(state_dict[k], dict):
embeddings += self.search_for_embeddings(state_dict[k])
return embeddings
def load_textual_inversions(self, folder):
# Store additional tokens here
self.textual_inversion_dict = {}
# Load every textual inversion file
for file_name in os.listdir(folder):
if file_name.endswith(".txt"):
continue
keyword = os.path.splitext(file_name)[0]
state_dict = load_state_dict(os.path.join(folder, file_name))
# Search for embeddings
for embeddings in self.search_for_embeddings(state_dict):
if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
self.textual_inversion_dict[keyword] = (tokens, embeddings)
break
def load_model(self, file_path, components=None, lora_alphas=[]):
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
if self.is_stable_video_diffusion(state_dict):
self.load_stable_video_diffusion(state_dict, file_path=file_path)
elif self.is_animatediff(state_dict):
self.load_animatediff(state_dict, file_path=file_path)
elif self.is_animatediff_xl(state_dict):
self.load_animatediff_xl(state_dict, file_path=file_path)
elif self.is_controlnet(state_dict):
self.load_controlnet(state_dict, file_path=file_path)
elif self.is_stabe_diffusion_xl(state_dict):
self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
elif self.is_stable_diffusion(state_dict):
self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
elif self.is_sd_lora(state_dict):
self.load_sd_lora(state_dict, alpha=lora_alphas.pop(0))
elif self.is_beautiful_prompt(state_dict):
self.load_beautiful_prompt(state_dict, file_path=file_path)
elif self.is_RIFE(state_dict):
self.load_RIFE(state_dict, file_path=file_path)
elif self.is_translator(state_dict):
self.load_translator(state_dict, file_path=file_path)
elif self.is_ipadapter(state_dict):
self.load_ipadapter(state_dict, file_path=file_path)
elif self.is_ipadapter_image_encoder(state_dict):
self.load_ipadapter_image_encoder(state_dict, file_path=file_path)
elif self.is_ipadapter_xl(state_dict):
self.load_ipadapter_xl(state_dict, file_path=file_path)
elif self.is_ipadapter_xl_image_encoder(state_dict):
self.load_ipadapter_xl_image_encoder(state_dict, file_path=file_path)
elif self.is_hunyuan_dit_clip_text_encoder(state_dict):
self.load_hunyuan_dit_clip_text_encoder(state_dict, file_path=file_path)
elif self.is_hunyuan_dit_t5_text_encoder(state_dict):
self.load_hunyuan_dit_t5_text_encoder(state_dict, file_path=file_path)
elif self.is_hunyuan_dit(state_dict):
self.load_hunyuan_dit(state_dict, file_path=file_path)
elif self.is_diffusers_vae(state_dict):
self.load_diffusers_vae(state_dict, file_path=file_path)
elif self.is_ExVideo_StableVideoDiffusion(state_dict):
self.load_ExVideo_StableVideoDiffusion(state_dict, file_path=file_path)
def load_models(self, file_path_list, lora_alphas=[]):
for file_path in file_path_list:
self.load_model(file_path, lora_alphas=lora_alphas)
def to(self, device):
for component in self.model:
if isinstance(self.model[component], list):
for model in self.model[component]:
model.to(device)
else:
self.model[component].to(device)
torch.cuda.empty_cache()
def get_model_with_model_path(self, model_path):
for component in self.model_path:
if isinstance(self.model_path[component], str):
if os.path.samefile(self.model_path[component], model_path):
return self.model[component]
elif isinstance(self.model_path[component], list):
for i, model_path_ in enumerate(self.model_path[component]):
if os.path.samefile(model_path_, model_path):
return self.model[component][i]
raise ValueError(f"Please load model {model_path} before you use it.")
def __getattr__(self, __name):
if __name in self.model:
return self.model[__name]
else:
return super.__getattribute__(__name)
def load_state_dict(file_path, torch_dtype=None):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu")
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
state_dict[i] = state_dict[i].to(torch_dtype)
return state_dict
def search_parameter(param, state_dict):
for name, param_ in state_dict.items():
if param.numel() == param_.numel():
if param.shape == param_.shape:
if torch.dist(param, param_) < 1e-6:
return name
else:
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
return name
return None
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
matched_keys = set()
with torch.no_grad():
for name in source_state_dict:
rename = search_parameter(source_state_dict[name], target_state_dict)
if rename is not None:
print(f'"{name}": "{rename}",')
matched_keys.add(rename)
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
length = source_state_dict[name].shape[0] // 3
rename = []
for i in range(3):
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
if None not in rename:
print(f'"{name}": {rename},')
for rename_ in rename:
matched_keys.add(rename_)
for name in target_state_dict:
if name not in matched_keys:
print("Cannot find", name, target_state_dict[name].shape)

View File

@@ -1,89 +0,0 @@
import torch
from einops import rearrange
def low_version_attention(query, key, value, attn_bias=None):
scale = 1 / query.shape[-1] ** 0.5
query = query * scale
attn = torch.matmul(query, key.transpose(-2, -1))
if attn_bias is not None:
attn = attn + attn_bias
attn = attn.softmax(-1)
return attn @ value
class Attention(torch.nn.Module):
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
super().__init__()
dim_inner = head_dim * num_heads
kv_dim = kv_dim if kv_dim is not None else q_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
batch_size = q.shape[0]
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
hidden_states = hidden_states + scale * ip_hidden_states
return hidden_states
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
batch_size = encoder_hidden_states.shape[0]
q = self.to_q(hidden_states)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if qkv_preprocessor is not None:
q, k, v = qkv_preprocessor(q, k, v)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
if ipadapter_kwargs is not None:
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.to_out(hidden_states)
return hidden_states
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
q = self.to_q(hidden_states)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
if attn_mask is not None:
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
else:
import xformers.ops as xops
hidden_states = xops.memory_efficient_attention(q, k, v)
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.to_out(hidden_states)
return hidden_states
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)

View File

@@ -0,0 +1,96 @@
from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
import torch
from ..core.device.npu_compatible_device import get_device_type
class DINOv3ImageEncoder(DINOv3ViTModel):
def __init__(self):
config = DINOv3ViTConfig(
architectures = [
"DINOv3ViTModel"
],
attention_dropout = 0.0,
drop_path_rate = 0.0,
dtype = "float32",
hidden_act = "silu",
hidden_size = 4096,
image_size = 224,
initializer_range = 0.02,
intermediate_size = 8192,
key_bias = False,
layer_norm_eps = 1e-05,
layerscale_value = 1.0,
mlp_bias = True,
model_type = "dinov3_vit",
num_attention_heads = 32,
num_channels = 3,
num_hidden_layers = 40,
num_register_tokens = 4,
patch_size = 16,
pos_embed_jitter = None,
pos_embed_rescale = 2.0,
pos_embed_shift = None,
proj_bias = True,
query_bias = False,
rope_theta = 100.0,
transformers_version = "4.56.1",
use_gated_mlp = True,
value_bias = False
)
super().__init__(config)
self.processor = DINOv3ViTImageProcessorFast(
crop_size = None,
data_format = "channels_first",
default_to_square = True,
device = None,
disable_grouping = None,
do_center_crop = None,
do_convert_rgb = None,
do_normalize = True,
do_rescale = True,
do_resize = True,
image_mean = [
0.485,
0.456,
0.406
],
image_processor_type = "DINOv3ViTImageProcessorFast",
image_std = [
0.229,
0.224,
0.225
],
input_data_format = None,
resample = 2,
rescale_factor = 0.00392156862745098,
return_tensors = None,
size = {
"height": 224,
"width": 224
}
)
def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
inputs = self.processor(images=image, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device)
bool_masked_pos = None
head_mask = None
pixel_values = pixel_values.to(torch_dtype)
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
position_embeddings = self.rope_embeddings(pixel_values)
for i, layer_module in enumerate(self.layer):
layer_head_mask = head_mask[i] if head_mask is not None else None
hidden_states = layer_module(
hidden_states,
attention_mask=layer_head_mask,
position_embeddings=position_embeddings,
)
sequence_output = self.norm(hidden_states)
pooled_output = sequence_output[:, 0, :]
return pooled_output

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,58 @@
from transformers import Mistral3ForConditionalGeneration, Mistral3Config
class Flux2TextEncoder(Mistral3ForConditionalGeneration):
def __init__(self):
config = Mistral3Config(**{
"architectures": [
"Mistral3ForConditionalGeneration"
],
"dtype": "bfloat16",
"image_token_index": 10,
"model_type": "mistral3",
"multimodal_projector_bias": False,
"projector_hidden_act": "gelu",
"spatial_merge_size": 2,
"text_config": {
"attention_dropout": 0.0,
"dtype": "bfloat16",
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 5120,
"initializer_range": 0.02,
"intermediate_size": 32768,
"max_position_embeddings": 131072,
"model_type": "mistral",
"num_attention_heads": 32,
"num_hidden_layers": 40,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-05,
"rope_theta": 1000000000.0,
"sliding_window": None,
"use_cache": True,
"vocab_size": 131072
},
"transformers_version": "4.57.1",
"vision_config": {
"attention_dropout": 0.0,
"dtype": "bfloat16",
"head_dim": 64,
"hidden_act": "silu",
"hidden_size": 1024,
"image_size": 1540,
"initializer_range": 0.02,
"intermediate_size": 4096,
"model_type": "pixtral",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 24,
"patch_size": 14,
"rope_theta": 10000.0
},
"vision_feature_layer": -1
})
super().__init__(config)
def forward(self, input_ids = None, pixel_values = None, attention_mask = None, position_ids = None, past_key_values = None, inputs_embeds = None, labels = None, use_cache = None, output_attentions = None, output_hidden_states = None, return_dict = None, cache_position = None, logits_to_keep = 0, image_sizes = None, **kwargs):
return super().forward(input_ids, pixel_values, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, image_sizes, **kwargs)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,384 @@
import torch
from einops import rearrange, repeat
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm
# from .utils import hash_state_dict_keys, init_weights_on_device
from contextlib import contextmanager
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()
@contextmanager
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
old_register_parameter = torch.nn.Module.register_parameter
if include_buffers:
old_register_buffer = torch.nn.Module.register_buffer
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
def register_empty_buffer(module, name, buffer, persistent=True):
old_register_buffer(module, name, buffer, persistent=persistent)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(device)
def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs["device"] = device
return fn(*args, **kwargs)
return wrapper
if include_buffers:
tensor_constructors_to_patch = {
torch_function_name: getattr(torch, torch_function_name)
for torch_function_name in ["empty", "zeros", "ones", "full"]
}
else:
tensor_constructors_to_patch = {}
try:
torch.nn.Module.register_parameter = register_empty_parameter
if include_buffers:
torch.nn.Module.register_buffer = register_empty_buffer
for torch_function_name in tensor_constructors_to_patch.keys():
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
yield
finally:
torch.nn.Module.register_parameter = old_register_parameter
if include_buffers:
torch.nn.Module.register_buffer = old_register_buffer
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
setattr(torch, torch_function_name, old_torch_function)
class FluxControlNet(torch.nn.Module):
def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
self.time_embedder = TimestepEmbeddings(256, 3072)
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
self.context_embedder = torch.nn.Linear(4096, 3072)
self.x_embedder = torch.nn.Linear(64, 3072)
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)])
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)])
self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)])
self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)])
self.mode_dict = mode_dict
self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None
self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072)
def prepare_image_ids(self, latents):
batch_size, _, height, width = latents.shape
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
return latent_image_ids
def patchify(self, hidden_states):
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
return hidden_states
def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states):
if len(res_stack) == 0:
return [torch.zeros_like(hidden_states)] * num_blocks
interval = (num_blocks + len(res_stack) - 1) // len(res_stack)
aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)]
return aligned_res_stack
def forward(
self,
hidden_states,
controlnet_conditioning,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
processor_id=None,
tiled=False, tile_size=128, tile_stride=64,
**kwargs
):
if image_ids is None:
image_ids = self.prepare_image_ids(hidden_states)
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
if self.guidance_embedder is not None:
guidance = guidance * 1000
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
prompt_emb = self.context_embedder(prompt_emb)
if self.controlnet_mode_embedder is not None: # Different from FluxDiT
processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int)
processor_id = repeat(processor_id, "D -> B D", B=1).to(text_ids.device)
prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1)
text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
hidden_states = self.patchify(hidden_states)
hidden_states = self.x_embedder(hidden_states)
controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT
controlnet_res_stack = []
for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
controlnet_res_stack.append(controlnet_block(hidden_states))
controlnet_single_res_stack = []
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:]))
controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:])
controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:])
return controlnet_res_stack, controlnet_single_res_stack
# @staticmethod
# def state_dict_converter():
# return FluxControlNetStateDictConverter()
def quantize(self):
def cast_to(weight, dtype=None, device=None, copy=False):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
return weight.to(dtype=dtype, copy=copy)
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight)
return r
def cast_weight(s, input=None, dtype=None, device=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if device is None:
device = input.device
weight = cast_to(s.weight, dtype, device)
return weight
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
bias = None
weight = cast_to(s.weight, dtype, device)
bias = cast_to(s.bias, bias_dtype, device)
return weight, bias
class quantized_layer:
class QLinear(torch.nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self,input,**kwargs):
weight,bias= cast_bias_weight(self,input)
return torch.nn.functional.linear(input,weight,bias)
class QRMSNorm(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self,hidden_states,**kwargs):
weight= cast_weight(self.module,hidden_states)
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
hidden_states = hidden_states.to(input_dtype) * weight
return hidden_states
class QEmbedding(torch.nn.Embedding):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self,input,**kwargs):
weight= cast_weight(self,input)
return torch.nn.functional.embedding(
input, weight, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse)
def replace_layer(model):
for name, module in model.named_children():
if isinstance(module,quantized_layer.QRMSNorm):
continue
if isinstance(module, torch.nn.Linear):
with init_weights_on_device():
new_layer = quantized_layer.QLinear(module.in_features,module.out_features)
new_layer.weight = module.weight
if module.bias is not None:
new_layer.bias = module.bias
setattr(model, name, new_layer)
elif isinstance(module, RMSNorm):
if hasattr(module,"quantized"):
continue
module.quantized= True
new_layer = quantized_layer.QRMSNorm(module)
setattr(model, name, new_layer)
elif isinstance(module,torch.nn.Embedding):
rows, cols = module.weight.shape
new_layer = quantized_layer.QEmbedding(
num_embeddings=rows,
embedding_dim=cols,
_weight=module.weight,
# _freeze=module.freeze,
padding_idx=module.padding_idx,
max_norm=module.max_norm,
norm_type=module.norm_type,
scale_grad_by_freq=module.scale_grad_by_freq,
sparse=module.sparse)
setattr(model, name, new_layer)
else:
replace_layer(module)
replace_layer(self)
class FluxControlNetStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
hash_value = hash_state_dict_keys(state_dict)
global_rename_dict = {
"context_embedder": "context_embedder",
"x_embedder": "x_embedder",
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
"norm_out.linear": "final_norm_out.linear",
"proj_out": "final_proj_out",
}
rename_dict = {
"proj_out": "proj_out",
"norm1.linear": "norm1_a.linear",
"norm1_context.linear": "norm1_b.linear",
"attn.to_q": "attn.a_to_q",
"attn.to_k": "attn.a_to_k",
"attn.to_v": "attn.a_to_v",
"attn.to_out.0": "attn.a_to_out",
"attn.add_q_proj": "attn.b_to_q",
"attn.add_k_proj": "attn.b_to_k",
"attn.add_v_proj": "attn.b_to_v",
"attn.to_add_out": "attn.b_to_out",
"ff.net.0.proj": "ff_a.0",
"ff.net.2": "ff_a.2",
"ff_context.net.0.proj": "ff_b.0",
"ff_context.net.2": "ff_b.2",
"attn.norm_q": "attn.norm_q_a",
"attn.norm_k": "attn.norm_k_a",
"attn.norm_added_q": "attn.norm_q_b",
"attn.norm_added_k": "attn.norm_k_b",
}
rename_dict_single = {
"attn.to_q": "a_to_q",
"attn.to_k": "a_to_k",
"attn.to_v": "a_to_v",
"attn.norm_q": "norm_q_a",
"attn.norm_k": "norm_k_a",
"norm.linear": "norm.linear",
"proj_mlp": "proj_in_besides_attn",
"proj_out": "proj_out",
}
state_dict_ = {}
for name, param in state_dict.items():
if name.endswith(".weight") or name.endswith(".bias"):
suffix = ".weight" if name.endswith(".weight") else ".bias"
prefix = name[:-len(suffix)]
if prefix in global_rename_dict:
state_dict_[global_rename_dict[prefix] + suffix] = param
elif prefix.startswith("transformer_blocks."):
names = prefix.split(".")
names[0] = "blocks"
middle = ".".join(names[2:])
if middle in rename_dict:
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
state_dict_[name_] = param
elif prefix.startswith("single_transformer_blocks."):
names = prefix.split(".")
names[0] = "single_blocks"
middle = ".".join(names[2:])
if middle in rename_dict_single:
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
state_dict_[name_] = param
else:
state_dict_[name] = param
else:
state_dict_[name] = param
for name in list(state_dict_.keys()):
if ".proj_in_besides_attn." in name:
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
param = torch.concat([
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
state_dict_[name],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
state_dict_.pop(name)
for name in list(state_dict_.keys()):
for component in ["a", "b"]:
if f".{component}_to_q." in name:
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
if hash_value == "78d18b9101345ff695f312e7e62538c0":
extra_kwargs = {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}}
elif hash_value == "b001c89139b5f053c715fe772362dd2a":
extra_kwargs = {"num_single_blocks": 0}
elif hash_value == "52357cb26250681367488a8954c271e8":
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16":
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10}
elif hash_value == "43ad5aaa27dd4ee01b832ed16773fa52":
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0}
else:
extra_kwargs = {}
return state_dict_, extra_kwargs
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -0,0 +1,395 @@
import torch
from .general_modules import TimestepEmbeddings, AdaLayerNorm, RMSNorm
from einops import rearrange
def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
batch_size, num_tokens = hidden_states.shape[0:2]
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, num_tokens, -1)
hidden_states = hidden_states + scale * ip_hidden_states
return hidden_states
class RoPEEmbedding(torch.nn.Module):
def __init__(self, dim, theta, axes_dim):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
batch_size, seq_length = pos.shape
out = torch.einsum("...n,d->...nd", pos, omega)
cos_out = torch.cos(out)
sin_out = torch.sin(out)
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
return out.float()
def forward(self, ids):
n_axes = ids.shape[-1]
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
return emb.unsqueeze(1)
class FluxJointAttention(torch.nn.Module):
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.only_out_a = only_out_a
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
self.norm_q_b = RMSNorm(head_dim, eps=1e-6)
self.norm_k_b = RMSNorm(head_dim, eps=1e-6)
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
if not only_out_a:
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
def apply_rope(self, xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
batch_size = hidden_states_a.shape[0]
# Part A
qkv_a = self.a_to_qkv(hidden_states_a)
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q_a, k_a, v_a = qkv_a.chunk(3, dim=1)
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
# Part B
qkv_b = self.b_to_qkv(hidden_states_b)
qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q_b, k_b, v_b = qkv_b.chunk(3, dim=1)
q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b)
q = torch.concat([q_b, q_a], dim=2)
k = torch.concat([k_b, k_a], dim=2)
v = torch.concat([v_b, v_a], dim=2)
q, k = self.apply_rope(q, k, image_rotary_emb)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
if ipadapter_kwargs_list is not None:
hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list)
hidden_states_a = self.a_to_out(hidden_states_a)
if self.only_out_a:
return hidden_states_a
else:
hidden_states_b = self.b_to_out(hidden_states_b)
return hidden_states_a, hidden_states_b
class FluxJointTransformerBlock(torch.nn.Module):
def __init__(self, dim, num_attention_heads):
super().__init__()
self.norm1_a = AdaLayerNorm(dim)
self.norm1_b = AdaLayerNorm(dim)
self.attn = FluxJointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_a = torch.nn.Sequential(
torch.nn.Linear(dim, dim*4),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(dim*4, dim)
)
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_b = torch.nn.Sequential(
torch.nn.Linear(dim, dim*4),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(dim*4, dim)
)
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
# Attention
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
# Part A
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
# Part B
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
return hidden_states_a, hidden_states_b
class FluxSingleAttention(torch.nn.Module):
def __init__(self, dim_a, dim_b, num_heads, head_dim):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
def apply_rope(self, xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def forward(self, hidden_states, image_rotary_emb):
batch_size = hidden_states.shape[0]
qkv_a = self.a_to_qkv(hidden_states)
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q_a, k_a, v = qkv_a.chunk(3, dim=1)
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
q, k = self.apply_rope(q_a, k_a, image_rotary_emb)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
return hidden_states
class AdaLayerNormSingle(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = torch.nn.SiLU()
self.linear = torch.nn.Linear(dim, 3 * dim, bias=True)
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa
class FluxSingleTransformerBlock(torch.nn.Module):
def __init__(self, dim, num_attention_heads):
super().__init__()
self.num_heads = num_attention_heads
self.head_dim = dim // num_attention_heads
self.dim = dim
self.norm = AdaLayerNormSingle(dim)
self.to_qkv_mlp = torch.nn.Linear(dim, dim * (3 + 4))
self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6)
self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6)
self.proj_out = torch.nn.Linear(dim * 5, dim)
def apply_rope(self, xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
batch_size = hidden_states.shape[0]
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q, k, v = qkv.chunk(3, dim=1)
q, k = self.norm_q_a(q), self.norm_k_a(k)
q, k = self.apply_rope(q, k, image_rotary_emb)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
if ipadapter_kwargs_list is not None:
hidden_states = interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs_list)
return hidden_states
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
residual = hidden_states_a
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
hidden_states_a = residual + hidden_states_a
return hidden_states_a, hidden_states_b
class AdaLayerNormContinuous(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = torch.nn.SiLU()
self.linear = torch.nn.Linear(dim, dim * 2, bias=True)
self.norm = torch.nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
def forward(self, x, conditioning):
emb = self.linear(self.silu(conditioning))
shift, scale = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]
return x
class FluxDiT(torch.nn.Module):
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
self.time_embedder = TimestepEmbeddings(256, 3072)
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
self.context_embedder = torch.nn.Linear(4096, 3072)
self.x_embedder = torch.nn.Linear(input_dim, 3072)
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)])
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
self.final_norm_out = AdaLayerNormContinuous(3072)
self.final_proj_out = torch.nn.Linear(3072, 64)
self.input_dim = input_dim
def patchify(self, hidden_states):
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
return hidden_states
def unpatchify(self, hidden_states, height, width):
hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
return hidden_states
def prepare_image_ids(self, latents):
batch_size, _, height, width = latents.shape
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
return latent_image_ids
def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
N = len(entity_masks)
batch_size = entity_masks[0].shape[0]
total_seq_len = N * prompt_seq_len + image_seq_len
patched_masks = [self.patchify(entity_masks[i]) for i in range(N)]
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
image_start = N * prompt_seq_len
image_end = N * prompt_seq_len + image_seq_len
# prompt-image mask
for i in range(N):
prompt_start = i * prompt_seq_len
prompt_end = (i + 1) * prompt_seq_len
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1)
# prompt update with image
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
# image update with prompt
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
# prompt-prompt mask
for i in range(N):
for j in range(N):
if i != j:
prompt_start_i = i * prompt_seq_len
prompt_end_i = (i + 1) * prompt_seq_len
prompt_start_j = j * prompt_seq_len
prompt_end_j = (j + 1) * prompt_seq_len
attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False
attention_mask = attention_mask.float()
attention_mask[attention_mask == 0] = float('-inf')
attention_mask[attention_mask == 1] = 0
return attention_mask
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, repeat_dim):
max_masks = 0
attention_mask = None
prompt_embs = [prompt_emb]
if entity_masks is not None:
# entity_masks
batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
# global mask
global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
entity_masks = entity_masks + [global_mask] # append global to last
# attention mask
attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
attention_mask = attention_mask.unsqueeze(1)
# embds: n_masks * b * seq * d
local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)]
prompt_embs = local_embs + prompt_embs # append global to last
prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
prompt_emb = torch.cat(prompt_embs, dim=1)
# positional embedding
text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
return prompt_emb, image_rotary_emb, attention_mask
def forward(
self,
hidden_states,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None,
use_gradient_checkpointing=False,
**kwargs
):
# (Deprecated) The real forward is in `pipelines.flux_image`.
return None

View File

@@ -0,0 +1,129 @@
import math
import torch
import torch.nn as nn
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
class InfiniteYouImageProjector(nn.Module):
def __init__(
self,
dim=1280,
depth=4,
dim_head=64,
heads=20,
num_queries=8,
embedding_dim=512,
output_dim=4096,
ff_mult=4,
):
super().__init__()
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList([
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]))
def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1)
latents = latents.to(dtype=x.dtype, device=x.device)
x = self.proj_in(x)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)
@staticmethod
def state_dict_converter():
return FluxInfiniteYouImageProjectorStateDictConverter()
class FluxInfiniteYouImageProjectorStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict['image_proj']

View File

@@ -0,0 +1,110 @@
from .general_modules import RMSNorm
from transformers import SiglipVisionModel, SiglipVisionConfig
import torch
class SiglipVisionModelSO400M(SiglipVisionModel):
def __init__(self):
config = SiglipVisionConfig(
hidden_size=1152,
image_size=384,
intermediate_size=4304,
model_type="siglip_vision_model",
num_attention_heads=16,
num_hidden_layers=27,
patch_size=14,
architectures=["SiglipModel"],
initializer_factor=1.0,
torch_dtype="float32",
transformers_version="4.37.0.dev0"
)
super().__init__(config)
class MLPProjModel(torch.nn.Module):
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.num_tokens = num_tokens
self.proj = torch.nn.Sequential(
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
torch.nn.GELU(),
torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, id_embeds):
x = self.proj(id_embeds)
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
x = self.norm(x)
return x
class IpAdapterModule(torch.nn.Module):
def __init__(self, num_attention_heads, attention_head_dim, input_dim):
super().__init__()
self.num_heads = num_attention_heads
self.head_dim = attention_head_dim
output_dim = num_attention_heads * attention_head_dim
self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
self.norm_added_k = RMSNorm(attention_head_dim, eps=1e-5, elementwise_affine=False)
def forward(self, hidden_states):
batch_size = hidden_states.shape[0]
# ip_k
ip_k = self.to_k_ip(hidden_states)
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
ip_k = self.norm_added_k(ip_k)
# ip_v
ip_v = self.to_v_ip(hidden_states)
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
return ip_k, ip_v
class FluxIpAdapter(torch.nn.Module):
def __init__(self, num_attention_heads=24, attention_head_dim=128, cross_attention_dim=4096, num_tokens=128, num_blocks=57):
super().__init__()
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(num_attention_heads, attention_head_dim, cross_attention_dim) for _ in range(num_blocks)])
self.image_proj = MLPProjModel(cross_attention_dim=cross_attention_dim, id_embeddings_dim=1152, num_tokens=num_tokens)
self.set_adapter()
def set_adapter(self):
self.call_block_id = {i:i for i in range(len(self.ipadapter_modules))}
def forward(self, hidden_states, scale=1.0):
hidden_states = self.image_proj(hidden_states)
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
ip_kv_dict = {}
for block_id in self.call_block_id:
ipadapter_id = self.call_block_id[block_id]
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
ip_kv_dict[block_id] = {
"ip_k": ip_k,
"ip_v": ip_v,
"scale": scale
}
return ip_kv_dict
@staticmethod
def state_dict_converter():
return FluxIpAdapterStateDictConverter()
class FluxIpAdapterStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {}
for name in state_dict["ip_adapter"]:
name_ = 'ipadapter_modules.' + name
state_dict_[name_] = state_dict["ip_adapter"][name]
for name in state_dict["image_proj"]:
name_ = "image_proj." + name
state_dict_[name_] = state_dict["image_proj"][name]
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -0,0 +1,521 @@
import torch
from einops import rearrange
def low_version_attention(query, key, value, attn_bias=None):
scale = 1 / query.shape[-1] ** 0.5
query = query * scale
attn = torch.matmul(query, key.transpose(-2, -1))
if attn_bias is not None:
attn = attn + attn_bias
attn = attn.softmax(-1)
return attn @ value
class Attention(torch.nn.Module):
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
super().__init__()
dim_inner = head_dim * num_heads
kv_dim = kv_dim if kv_dim is not None else q_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
batch_size = q.shape[0]
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
hidden_states = hidden_states + scale * ip_hidden_states
return hidden_states
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
batch_size = encoder_hidden_states.shape[0]
q = self.to_q(hidden_states)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if qkv_preprocessor is not None:
q, k, v = qkv_preprocessor(q, k, v)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
if ipadapter_kwargs is not None:
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.to_out(hidden_states)
return hidden_states
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
q = self.to_q(hidden_states)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
if attn_mask is not None:
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
else:
import xformers.ops as xops
hidden_states = xops.memory_efficient_attention(q, k, v)
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.to_out(hidden_states)
return hidden_states
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
class CLIPEncoderLayer(torch.nn.Module):
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
super().__init__()
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
self.use_quick_gelu = use_quick_gelu
def quickGELU(self, x):
return x * torch.sigmoid(1.702 * x)
def forward(self, hidden_states, attn_mask=None):
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.fc1(hidden_states)
if self.use_quick_gelu:
hidden_states = self.quickGELU(hidden_states)
else:
hidden_states = torch.nn.functional.gelu(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class SDTextEncoder(torch.nn.Module):
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
super().__init__()
# token_embedding
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
# position_embeds (This is a fixed tensor)
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
# encoders
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
# attn_mask
self.attn_mask = self.attention_mask(max_position_embeddings)
# final_layer_norm
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
def attention_mask(self, length):
mask = torch.empty(length, length)
mask.fill_(float("-inf"))
mask.triu_(1)
return mask
def forward(self, input_ids, clip_skip=1):
embeds = self.token_embedding(input_ids) + self.position_embeds
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
for encoder_id, encoder in enumerate(self.encoders):
embeds = encoder(embeds, attn_mask=attn_mask)
if encoder_id + clip_skip == len(self.encoders):
break
embeds = self.final_layer_norm(embeds)
return embeds
@staticmethod
def state_dict_converter():
return SDTextEncoderStateDictConverter()
class SDTextEncoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
"text_model.embeddings.position_embedding.weight": "position_embeds",
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
}
attn_rename_dict = {
"self_attn.q_proj": "attn.to_q",
"self_attn.k_proj": "attn.to_k",
"self_attn.v_proj": "attn.to_v",
"self_attn.out_proj": "attn.to_out",
"layer_norm1": "layer_norm1",
"layer_norm2": "layer_norm2",
"mlp.fc1": "fc1",
"mlp.fc2": "fc2",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if name == "text_model.embeddings.position_embedding.weight":
param = param.reshape((1, param.shape[0], param.shape[1]))
state_dict_[rename_dict[name]] = param
elif name.startswith("text_model.encoder.layers."):
param = state_dict[name]
names = name.split(".")
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
state_dict_[name_] = param
return state_dict_
def from_civitai(self, state_dict):
rename_dict = {
"cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
"cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias",
"cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight",
"cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds"
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight":
param = param.reshape((1, param.shape[0], param.shape[1]))
state_dict_[rename_dict[name]] = param
return state_dict_
class LoRALayerBlock(torch.nn.Module):
def __init__(self, L, dim_in, dim_out):
super().__init__()
self.x = torch.nn.Parameter(torch.randn(1, L, dim_in))
self.layer_norm = torch.nn.LayerNorm(dim_out)
def forward(self, lora_A, lora_B):
x = self.x @ lora_A.T @ lora_B.T
x = self.layer_norm(x)
return x
class LoRAEmbedder(torch.nn.Module):
def __init__(self, lora_patterns=None, L=1, out_dim=2048):
super().__init__()
if lora_patterns is None:
lora_patterns = self.default_lora_patterns()
model_dict = {}
for lora_pattern in lora_patterns:
name, dim = lora_pattern["name"], lora_pattern["dim"]
model_dict[name.replace(".", "___")] = LoRALayerBlock(L, dim[0], dim[1])
self.model_dict = torch.nn.ModuleDict(model_dict)
proj_dict = {}
for lora_pattern in lora_patterns:
layer_type, dim = lora_pattern["type"], lora_pattern["dim"]
if layer_type not in proj_dict:
proj_dict[layer_type.replace(".", "___")] = torch.nn.Linear(dim[1], out_dim)
self.proj_dict = torch.nn.ModuleDict(proj_dict)
self.lora_patterns = lora_patterns
def default_lora_patterns(self):
lora_patterns = []
lora_dict = {
"attn.a_to_qkv": (3072, 9216), "attn.a_to_out": (3072, 3072), "ff_a.0": (3072, 12288), "ff_a.2": (12288, 3072), "norm1_a.linear": (3072, 18432),
"attn.b_to_qkv": (3072, 9216), "attn.b_to_out": (3072, 3072), "ff_b.0": (3072, 12288), "ff_b.2": (12288, 3072), "norm1_b.linear": (3072, 18432),
}
for i in range(19):
for suffix in lora_dict:
lora_patterns.append({
"name": f"blocks.{i}.{suffix}",
"dim": lora_dict[suffix],
"type": suffix,
})
lora_dict = {"to_qkv_mlp": (3072, 21504), "proj_out": (15360, 3072), "norm.linear": (3072, 9216)}
for i in range(38):
for suffix in lora_dict:
lora_patterns.append({
"name": f"single_blocks.{i}.{suffix}",
"dim": lora_dict[suffix],
"type": suffix,
})
return lora_patterns
def forward(self, lora):
lora_emb = []
for lora_pattern in self.lora_patterns:
name, layer_type = lora_pattern["name"], lora_pattern["type"]
lora_A = lora[name + ".lora_A.weight"]
lora_B = lora[name + ".lora_B.weight"]
lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B)
lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out)
lora_emb.append(lora_out)
lora_emb = torch.concat(lora_emb, dim=1)
return lora_emb
class FluxLoRAEncoder(torch.nn.Module):
def __init__(self, embed_dim=4096, encoder_intermediate_size=8192, num_encoder_layers=1, num_embeds_per_lora=16, num_special_embeds=1):
super().__init__()
self.num_embeds_per_lora = num_embeds_per_lora
# embedder
self.embedder = LoRAEmbedder(L=num_embeds_per_lora, out_dim=embed_dim)
# encoders
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=32, head_dim=128) for _ in range(num_encoder_layers)])
# special embedding
self.special_embeds = torch.nn.Parameter(torch.randn(1, num_special_embeds, embed_dim))
self.num_special_embeds = num_special_embeds
# final layer
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
self.final_linear = torch.nn.Linear(embed_dim, embed_dim)
def forward(self, lora):
lora_embeds = self.embedder(lora)
special_embeds = self.special_embeds.to(dtype=lora_embeds.dtype, device=lora_embeds.device)
embeds = torch.concat([special_embeds, lora_embeds], dim=1)
for encoder_id, encoder in enumerate(self.encoders):
embeds = encoder(embeds)
embeds = embeds[:, :self.num_special_embeds]
embeds = self.final_layer_norm(embeds)
embeds = self.final_linear(embeds)
return embeds
@staticmethod
def state_dict_converter():
return FluxLoRAEncoderStateDictConverter()
class FluxLoRAEncoderStateDictConverter:
def from_civitai(self, state_dict):
return state_dict

View File

@@ -0,0 +1,306 @@
import torch, math
from ..core.loader import load_state_dict
from typing import Union
class GeneralLoRALoader:
def __init__(self, device="cpu", torch_dtype=torch.float32):
self.device = device
self.torch_dtype = torch_dtype
def get_name_dict(self, lora_state_dict):
lora_name_dict = {}
for key in lora_state_dict:
if ".lora_B." not in key:
continue
keys = key.split(".")
if len(keys) > keys.index("lora_B") + 2:
keys.pop(keys.index("lora_B") + 1)
keys.pop(keys.index("lora_B"))
if keys[0] == "diffusion_model":
keys.pop(0)
keys.pop(-1)
target_name = ".".join(keys)
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
return lora_name_dict
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
updated_num = 0
lora_name_dict = self.get_name_dict(state_dict_lora)
for name, module in model.named_modules():
if name in lora_name_dict:
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype)
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2)
weight_down = weight_down.squeeze(3).squeeze(2)
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
weight_lora = alpha * torch.mm(weight_up, weight_down)
state_dict = module.state_dict()
state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora
module.load_state_dict(state_dict)
updated_num += 1
print(f"{updated_num} tensors are updated by LoRA.")
class FluxLoRALoader(GeneralLoRALoader):
def __init__(self, device="cpu", torch_dtype=torch.float32):
super().__init__(device=device, torch_dtype=torch_dtype)
self.diffusers_rename_dict = {
"transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.default.weight",
"transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.default.weight",
"transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.default.weight",
"transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.default.weight",
"transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.default.weight",
"transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.default.weight",
"transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.default.weight",
"transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.default.weight",
"transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.default.weight",
"transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.default.weight",
"transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.default.weight",
"transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.default.weight",
"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.default.weight",
"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.default.weight",
"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.default.weight",
"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.default.weight",
"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.default.weight",
"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.default.weight",
"transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.default.weight",
"transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.default.weight",
"transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.default.weight",
"transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.default.weight",
"transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.default.weight",
"transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.default.weight",
"transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.default.weight",
"transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.default.weight",
"transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.default.weight",
"transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.default.weight",
"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.default.weight",
"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.default.weight",
"transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.default.weight",
"transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.default.weight",
"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.default.weight",
"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.default.weight",
"transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.default.weight",
"transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.default.weight",
"transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.default.weight",
"transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.default.weight",
"transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.default.weight",
"transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.default.weight",
}
self.civitai_rename_dict = {
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
"lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
"lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
"lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
"lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
}
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
super().load(model, state_dict_lora, alpha)
def convert_state_dict(self,state_dict):
def guess_block_id(name,model_resource):
if model_resource == 'civitai':
names = name.split("_")
for i in names:
if i.isdigit():
return i, name.replace(f"_{i}_", "_blockid_")
if model_resource == 'diffusers':
names = name.split(".")
for i in names:
if i.isdigit():
return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.")
return None, None
def guess_resource(state_dict):
for k in state_dict:
if "lora_unet_" in k:
return 'civitai'
elif k.startswith("transformer."):
return 'diffusers'
else:
None
model_resource = guess_resource(state_dict)
if model_resource is None:
return state_dict
rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict
def guess_alpha(state_dict):
for name, param in state_dict.items():
if ".alpha" in name:
for suffix in [".lora_down.weight", ".lora_A.weight"]:
name_ = name.replace(".alpha", suffix)
if name_ in state_dict:
lora_alpha = param.item() / state_dict[name_].shape[0]
lora_alpha = math.sqrt(lora_alpha)
return lora_alpha
return 1
alpha = guess_alpha(state_dict)
state_dict_ = {}
for name, param in state_dict.items():
block_id, source_name = guess_block_id(name,model_resource)
if alpha != 1:
param *= alpha
if source_name in rename_dict:
target_name = rename_dict[source_name]
target_name = target_name.replace(".blockid.", f".{block_id}.")
state_dict_[target_name] = param
else:
state_dict_[name] = param
if model_resource == 'diffusers':
for name in list(state_dict_.keys()):
if "single_blocks." in name and ".a_to_q." in name:
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
if mlp is None:
dim = 4
if 'lora_A' in name:
dim = 1
mlp = torch.zeros(dim * state_dict_[name].shape[0],
*state_dict_[name].shape[1:],
dtype=state_dict_[name].dtype)
else:
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
if 'lora_A' in name:
param = torch.concat([
state_dict_.pop(name),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
mlp,
], dim=0)
elif 'lora_B' in name:
d, r = state_dict_[name].shape
param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device)
param[:d, :r] = state_dict_.pop(name)
param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k."))
param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v."))
param[3*d:, 3*r:] = mlp
else:
param = torch.concat([
state_dict_.pop(name),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
mlp,
], dim=0)
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
state_dict_[name_] = param
for name in list(state_dict_.keys()):
for component in ["a", "b"]:
if f".{component}_to_q." in name:
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
concat_dim = 0
if 'lora_A' in name:
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
elif 'lora_B' in name:
origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
d, r = origin.shape
# print(d, r)
param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device)
param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")]
param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")]
else:
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
return state_dict_
class LoraMerger(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.weight_base = torch.nn.Parameter(torch.randn((dim,)))
self.weight_lora = torch.nn.Parameter(torch.randn((dim,)))
self.weight_cross = torch.nn.Parameter(torch.randn((dim,)))
self.weight_out = torch.nn.Parameter(torch.ones((dim,)))
self.bias = torch.nn.Parameter(torch.randn((dim,)))
self.activation = torch.nn.Sigmoid()
self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5)
self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5)
def forward(self, base_output, lora_outputs):
norm_base_output = self.norm_base(base_output)
norm_lora_outputs = self.norm_lora(lora_outputs)
gate = self.activation(
norm_base_output * self.weight_base \
+ norm_lora_outputs * self.weight_lora \
+ norm_base_output * norm_lora_outputs * self.weight_cross + self.bias
)
output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0)
return output
class FluxLoraPatcher(torch.nn.Module):
def __init__(self, lora_patterns=None):
super().__init__()
if lora_patterns is None:
lora_patterns = self.default_lora_patterns()
model_dict = {}
for lora_pattern in lora_patterns:
name, dim = lora_pattern["name"], lora_pattern["dim"]
model_dict[name.replace(".", "___")] = LoraMerger(dim)
self.model_dict = torch.nn.ModuleDict(model_dict)
def default_lora_patterns(self):
lora_patterns = []
lora_dict = {
"attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432,
"attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432,
}
for i in range(19):
for suffix in lora_dict:
lora_patterns.append({
"name": f"blocks.{i}.{suffix}",
"dim": lora_dict[suffix]
})
lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216}
for i in range(38):
for suffix in lora_dict:
lora_patterns.append({
"name": f"single_blocks.{i}.{suffix}",
"dim": lora_dict[suffix]
})
return lora_patterns
def forward(self, base_output, lora_outputs, name):
return self.model_dict[name.replace(".", "___")](base_output, lora_outputs)

View File

@@ -0,0 +1,112 @@
import torch
class Attention(torch.nn.Module):
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
super().__init__()
dim_inner = head_dim * num_heads
kv_dim = kv_dim if kv_dim is not None else q_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
batch_size = encoder_hidden_states.shape[0]
q = self.to_q(hidden_states)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.to_out(hidden_states)
return hidden_states
class CLIPEncoderLayer(torch.nn.Module):
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
super().__init__()
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
self.use_quick_gelu = use_quick_gelu
def quickGELU(self, x):
return x * torch.sigmoid(1.702 * x)
def forward(self, hidden_states, attn_mask=None):
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.fc1(hidden_states)
if self.use_quick_gelu:
hidden_states = self.quickGELU(hidden_states)
else:
hidden_states = torch.nn.functional.gelu(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class FluxTextEncoderClip(torch.nn.Module):
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
super().__init__()
# token_embedding
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
# position_embeds (This is a fixed tensor)
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
# encoders
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
# attn_mask
self.attn_mask = self.attention_mask(max_position_embeddings)
# final_layer_norm
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
def attention_mask(self, length):
mask = torch.empty(length, length)
mask.fill_(float("-inf"))
mask.triu_(1)
return mask
def forward(self, input_ids, clip_skip=2, extra_mask=None):
embeds = self.token_embedding(input_ids)
embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device)
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
if extra_mask is not None:
attn_mask[:, extra_mask[0]==0] = float("-inf")
for encoder_id, encoder in enumerate(self.encoders):
embeds = encoder(embeds, attn_mask=attn_mask)
if encoder_id + clip_skip == len(self.encoders):
hidden_states = embeds
embeds = self.final_layer_norm(embeds)
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
return pooled_embeds, hidden_states

View File

@@ -0,0 +1,43 @@
import torch
from transformers import T5EncoderModel, T5Config
class FluxTextEncoderT5(T5EncoderModel):
def __init__(self):
config = T5Config(**{
"architectures": [
"T5EncoderModel"
],
"classifier_dropout": 0.0,
"d_ff": 10240,
"d_kv": 64,
"d_model": 4096,
"decoder_start_token_id": 0,
"dense_act_fn": "gelu_new",
"dropout_rate": 0.1,
"dtype": "bfloat16",
"eos_token_id": 1,
"feed_forward_proj": "gated-gelu",
"initializer_factor": 1.0,
"is_encoder_decoder": True,
"is_gated_act": True,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"num_decoder_layers": 24,
"num_heads": 64,
"num_layers": 24,
"output_past": True,
"pad_token_id": 0,
"relative_attention_max_distance": 128,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": False,
"transformers_version": "4.57.1",
"use_cache": True,
"vocab_size": 32128
})
super().__init__(config)
def forward(self, input_ids):
outputs = super().forward(input_ids=input_ids)
prompt_emb = outputs.last_hidden_state
return prompt_emb

View File

@@ -0,0 +1,451 @@
import torch
from einops import rearrange, repeat
class TileWorker:
def __init__(self):
pass
def mask(self, height, width, border_width):
# Create a mask with shape (height, width).
# The centre area is filled with 1, and the border line is filled with values in range (0, 1].
x = torch.arange(height).repeat(width, 1).T
y = torch.arange(width).repeat(height, 1)
mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
mask = (mask / border_width).clip(0, 1)
return mask
def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
# Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
batch_size, channel, _, _ = model_input.shape
model_input = model_input.to(device=tile_device, dtype=tile_dtype)
unfold_operator = torch.nn.Unfold(
kernel_size=(tile_size, tile_size),
stride=(tile_stride, tile_stride)
)
model_input = unfold_operator(model_input)
model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
return model_input
def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):
# Call y=forward_fn(x) for each tile
tile_num = model_input.shape[-1]
model_output_stack = []
for tile_id in range(0, tile_num, tile_batch_size):
# process input
tile_id_ = min(tile_id + tile_batch_size, tile_num)
x = model_input[:, :, :, :, tile_id: tile_id_]
x = x.to(device=inference_device, dtype=inference_dtype)
x = rearrange(x, "b c h w n -> (n b) c h w")
# process output
y = forward_fn(x)
y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id)
y = y.to(device=tile_device, dtype=tile_dtype)
model_output_stack.append(y)
model_output = torch.concat(model_output_stack, dim=-1)
return model_output
def io_scale(self, model_output, tile_size):
# Determine the size modification happened in forward_fn
# We only consider the same scale on height and width.
io_scale = model_output.shape[2] / tile_size
return io_scale
def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):
# The reversed function of tile
mask = self.mask(tile_size, tile_size, border_width)
mask = mask.to(device=tile_device, dtype=tile_dtype)
mask = rearrange(mask, "h w -> 1 1 h w 1")
model_output = model_output * mask
fold_operator = torch.nn.Fold(
output_size=(height, width),
kernel_size=(tile_size, tile_size),
stride=(tile_stride, tile_stride)
)
mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
model_output = fold_operator(model_output) / fold_operator(mask)
return model_output
def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
# Prepare
inference_device, inference_dtype = model_input.device, model_input.dtype
height, width = model_input.shape[2], model_input.shape[3]
border_width = int(tile_stride*0.5) if border_width is None else border_width
# tile
model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)
# inference
model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)
# resize
io_scale = self.io_scale(model_output, tile_size)
height, width = int(height*io_scale), int(width*io_scale)
tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)
border_width = int(border_width*io_scale)
# untile
model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)
# Done!
model_output = model_output.to(device=inference_device, dtype=inference_dtype)
return model_output
class ConvAttention(torch.nn.Module):
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
super().__init__()
dim_inner = head_dim * num_heads
kv_dim = kv_dim if kv_dim is not None else q_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.to_q = torch.nn.Conv2d(q_dim, dim_inner, kernel_size=(1, 1), bias=bias_q)
self.to_k = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv)
self.to_v = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv)
self.to_out = torch.nn.Conv2d(dim_inner, q_dim, kernel_size=(1, 1), bias=bias_out)
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
batch_size = encoder_hidden_states.shape[0]
conv_input = rearrange(hidden_states, "B L C -> B C L 1")
q = self.to_q(conv_input)
q = rearrange(q[:, :, :, 0], "B C L -> B L C")
conv_input = rearrange(encoder_hidden_states, "B L C -> B C L 1")
k = self.to_k(conv_input)
v = self.to_v(conv_input)
k = rearrange(k[:, :, :, 0], "B C L -> B L C")
v = rearrange(v[:, :, :, 0], "B C L -> B L C")
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
conv_input = rearrange(hidden_states, "B L C -> B C L 1")
hidden_states = self.to_out(conv_input)
hidden_states = rearrange(hidden_states[:, :, :, 0], "B C L -> B L C")
return hidden_states
class Attention(torch.nn.Module):
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
super().__init__()
dim_inner = head_dim * num_heads
kv_dim = kv_dim if kv_dim is not None else q_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
batch_size = encoder_hidden_states.shape[0]
q = self.to_q(hidden_states)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.to_out(hidden_states)
return hidden_states
class VAEAttentionBlock(torch.nn.Module):
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, use_conv_attention=True):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
if use_conv_attention:
self.transformer_blocks = torch.nn.ModuleList([
ConvAttention(
inner_dim,
num_attention_heads,
attention_head_dim,
bias_q=True,
bias_kv=True,
bias_out=True
)
for d in range(num_layers)
])
else:
self.transformer_blocks = torch.nn.ModuleList([
Attention(
inner_dim,
num_attention_heads,
attention_head_dim,
bias_q=True,
bias_kv=True,
bias_out=True
)
for d in range(num_layers)
])
def forward(self, hidden_states, time_emb, text_emb, res_stack):
batch, _, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
for block in self.transformer_blocks:
hidden_states = block(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = hidden_states + residual
return hidden_states, time_emb, text_emb, res_stack
class ResnetBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels, temb_channels=None, groups=32, eps=1e-5):
super().__init__()
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.nonlinearity = torch.nn.SiLU()
self.conv_shortcut = None
if in_channels != out_channels:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True)
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
x = hidden_states
x = self.norm1(x)
x = self.nonlinearity(x)
x = self.conv1(x)
if time_emb is not None:
emb = self.nonlinearity(time_emb)
emb = self.time_emb_proj(emb)[:, :, None, None]
x = x + emb
x = self.norm2(x)
x = self.nonlinearity(x)
x = self.conv2(x)
if self.conv_shortcut is not None:
hidden_states = self.conv_shortcut(hidden_states)
hidden_states = hidden_states + x
return hidden_states, time_emb, text_emb, res_stack
class UpSampler(torch.nn.Module):
def __init__(self, channels):
super().__init__()
self.conv = torch.nn.Conv2d(channels, channels, 3, padding=1)
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
hidden_states = self.conv(hidden_states)
return hidden_states, time_emb, text_emb, res_stack
class DownSampler(torch.nn.Module):
def __init__(self, channels, padding=1, extra_padding=False):
super().__init__()
self.conv = torch.nn.Conv2d(channels, channels, 3, stride=2, padding=padding)
self.extra_padding = extra_padding
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
if self.extra_padding:
hidden_states = torch.nn.functional.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
hidden_states = self.conv(hidden_states)
return hidden_states, time_emb, text_emb, res_stack
class FluxVAEDecoder(torch.nn.Module):
def __init__(self, use_conv_attention=True):
super().__init__()
self.scaling_factor = 0.3611
self.shift_factor = 0.1159
self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x
self.blocks = torch.nn.ModuleList([
# UNetMidBlock2D
ResnetBlock(512, 512, eps=1e-6),
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),
ResnetBlock(512, 512, eps=1e-6),
# UpDecoderBlock2D
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
UpSampler(512),
# UpDecoderBlock2D
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
UpSampler(512),
# UpDecoderBlock2D
ResnetBlock(512, 256, eps=1e-6),
ResnetBlock(256, 256, eps=1e-6),
ResnetBlock(256, 256, eps=1e-6),
UpSampler(256),
# UpDecoderBlock2D
ResnetBlock(256, 128, eps=1e-6),
ResnetBlock(128, 128, eps=1e-6),
ResnetBlock(128, 128, eps=1e-6),
])
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)
self.conv_act = torch.nn.SiLU()
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
hidden_states = TileWorker().tiled_forward(
lambda x: self.forward(x),
sample,
tile_size,
tile_stride,
tile_device=sample.device,
tile_dtype=sample.dtype
)
return hidden_states
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
# For VAE Decoder, we do not need to apply the tiler on each layer.
if tiled:
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
# 1. pre-process
hidden_states = sample / self.scaling_factor + self.shift_factor
hidden_states = self.conv_in(hidden_states)
time_emb = None
text_emb = None
res_stack = None
# 2. blocks
for i, block in enumerate(self.blocks):
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
# 3. output
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class FluxVAEEncoder(torch.nn.Module):
def __init__(self, use_conv_attention=True):
super().__init__()
self.scaling_factor = 0.3611
self.shift_factor = 0.1159
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
self.blocks = torch.nn.ModuleList([
# DownEncoderBlock2D
ResnetBlock(128, 128, eps=1e-6),
ResnetBlock(128, 128, eps=1e-6),
DownSampler(128, padding=0, extra_padding=True),
# DownEncoderBlock2D
ResnetBlock(128, 256, eps=1e-6),
ResnetBlock(256, 256, eps=1e-6),
DownSampler(256, padding=0, extra_padding=True),
# DownEncoderBlock2D
ResnetBlock(256, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
DownSampler(512, padding=0, extra_padding=True),
# DownEncoderBlock2D
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
# UNetMidBlock2D
ResnetBlock(512, 512, eps=1e-6),
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),
ResnetBlock(512, 512, eps=1e-6),
])
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
self.conv_act = torch.nn.SiLU()
self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
hidden_states = TileWorker().tiled_forward(
lambda x: self.forward(x),
sample,
tile_size,
tile_stride,
tile_device=sample.device,
tile_dtype=sample.dtype
)
return hidden_states
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
# For VAE Decoder, we do not need to apply the tiler on each layer.
if tiled:
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
# 1. pre-process
hidden_states = self.conv_in(sample)
time_emb = None
text_emb = None
res_stack = None
# 2. blocks
for i, block in enumerate(self.blocks):
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
# 3. output
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
hidden_states = hidden_states[:, :16]
hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
return hidden_states
def encode_video(self, sample, batch_size=8):
B = sample.shape[0]
hidden_states = []
for i in range(0, sample.shape[2], batch_size):
j = min(i + batch_size, sample.shape[2])
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
hidden_states_batch = self(sample_batch)
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
hidden_states.append(hidden_states_batch)
hidden_states = torch.concat(hidden_states, dim=2)
return hidden_states

View File

@@ -0,0 +1,56 @@
import torch
from .general_modules import TemporalTimesteps
class MultiValueEncoder(torch.nn.Module):
def __init__(self, encoders=()):
super().__init__()
if not isinstance(encoders, list):
encoders = [encoders]
self.encoders = torch.nn.ModuleList(encoders)
def __call__(self, values, dtype):
emb = []
for encoder, value in zip(self.encoders, values):
if value is not None:
value = value.unsqueeze(0)
emb.append(encoder(value, dtype))
emb = torch.concat(emb, dim=0)
return emb
class SingleValueEncoder(torch.nn.Module):
def __init__(self, dim_in=256, dim_out=4096, prefer_len=32, computation_device=None):
super().__init__()
self.prefer_len = prefer_len
self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
self.prefer_value_embedder = torch.nn.Sequential(
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
)
self.positional_embedding = torch.nn.Parameter(
torch.randn(self.prefer_len, dim_out)
)
def forward(self, value, dtype):
value = value * 1000
emb = self.prefer_proj(value).to(dtype)
emb = self.prefer_value_embedder(emb).squeeze(0)
base_embeddings = emb.expand(self.prefer_len, -1)
positional_embedding = self.positional_embedding.to(dtype=base_embeddings.dtype, device=base_embeddings.device)
learned_embeddings = base_embeddings + positional_embedding
return learned_embeddings
@staticmethod
def state_dict_converter():
return SingleValueEncoderStateDictConverter()
class SingleValueEncoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
return state_dict

View File

@@ -0,0 +1,146 @@
import torch, math
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
computation_device = None,
align_dtype_to_timestep = False,
):
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device if computation_device is None else computation_device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
if align_dtype_to_timestep:
emb = emb.to(timesteps.dtype)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class TemporalTimesteps(torch.nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None, scale=1, align_dtype_to_timestep=False):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.computation_device = computation_device
self.scale = scale
self.align_dtype_to_timestep = align_dtype_to_timestep
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
computation_device=self.computation_device,
scale=self.scale,
align_dtype_to_timestep=self.align_dtype_to_timestep,
)
return t_emb
class DiffusersCompatibleTimestepProj(torch.nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.linear_1 = torch.nn.Linear(dim_in, dim_out)
self.act = torch.nn.SiLU()
self.linear_2 = torch.nn.Linear(dim_out, dim_out)
def forward(self, x):
x = self.linear_1(x)
x = self.act(x)
x = self.linear_2(x)
return x
class TimestepEmbeddings(torch.nn.Module):
def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False, use_additional_t_cond=False):
super().__init__()
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep)
if diffusers_compatible_format:
self.timestep_embedder = DiffusersCompatibleTimestepProj(dim_in, dim_out)
else:
self.timestep_embedder = torch.nn.Sequential(
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
)
self.use_additional_t_cond = use_additional_t_cond
if use_additional_t_cond:
self.addition_t_embedding = torch.nn.Embedding(2, dim_out)
def forward(self, timestep, dtype, addition_t_cond=None):
time_emb = self.time_proj(timestep).to(dtype)
time_emb = self.timestep_embedder(time_emb)
if addition_t_cond is not None:
addition_t_emb = self.addition_t_embedding(addition_t_cond)
addition_t_emb = addition_t_emb.to(dtype=dtype)
time_emb = time_emb + addition_t_emb
return time_emb
class RMSNorm(torch.nn.Module):
def __init__(self, dim, eps, elementwise_affine=True):
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = torch.nn.Parameter(torch.ones((dim,)))
else:
self.weight = None
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
hidden_states = hidden_states.to(input_dtype)
if self.weight is not None:
hidden_states = hidden_states * self.weight
return hidden_states
class AdaLayerNorm(torch.nn.Module):
def __init__(self, dim, single=False, dual=False):
super().__init__()
self.single = single
self.dual = dual
self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual])
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(torch.nn.functional.silu(emb))
if self.single:
scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
x = self.norm(x) * (1 + scale) + shift
return x
elif self.dual:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2)
norm_x = self.norm(x)
x = norm_x * (1 + scale_msa) + shift_msa
norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2
else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
x = self.norm(x) * (1 + scale_msa) + shift_msa
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp

View File

@@ -1,451 +0,0 @@
from .attention import Attention
from .tiler import TileWorker
from einops import repeat, rearrange
import math
import torch
class HunyuanDiTRotaryEmbedding(torch.nn.Module):
def __init__(self, q_norm_shape=88, k_norm_shape=88, rotary_emb_on_k=True):
super().__init__()
self.q_norm = torch.nn.LayerNorm((q_norm_shape,), elementwise_affine=True, eps=1e-06)
self.k_norm = torch.nn.LayerNorm((k_norm_shape,), elementwise_affine=True, eps=1e-06)
self.rotary_emb_on_k = rotary_emb_on_k
self.k_cache, self.v_cache = [], []
def reshape_for_broadcast(self, freqs_cis, x):
ndim = x.ndim
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
def rotate_half(self, x):
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(self, xq, xk, freqs_cis):
xk_out = None
cos, sin = self.reshape_for_broadcast(freqs_cis, xq)
cos, sin = cos.to(xq.device), sin.to(xq.device)
xq_out = (xq.float() * cos + self.rotate_half(xq.float()) * sin).type_as(xq)
if xk is not None:
xk_out = (xk.float() * cos + self.rotate_half(xk.float()) * sin).type_as(xk)
return xq_out, xk_out
def forward(self, q, k, v, freqs_cis_img, to_cache=False):
# norm
q = self.q_norm(q)
k = self.k_norm(k)
# RoPE
if self.rotary_emb_on_k:
q, k = self.apply_rotary_emb(q, k, freqs_cis_img)
else:
q, _ = self.apply_rotary_emb(q, None, freqs_cis_img)
if to_cache:
self.k_cache.append(k)
self.v_cache.append(v)
elif len(self.k_cache) > 0 and len(self.v_cache) > 0:
k = torch.concat([k] + self.k_cache, dim=2)
v = torch.concat([v] + self.v_cache, dim=2)
self.k_cache, self.v_cache = [], []
return q, k, v
class FP32_Layernorm(torch.nn.LayerNorm):
def forward(self, inputs):
origin_dtype = inputs.dtype
return torch.nn.functional.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).to(origin_dtype)
class FP32_SiLU(torch.nn.SiLU):
def forward(self, inputs):
origin_dtype = inputs.dtype
return torch.nn.functional.silu(inputs.float(), inplace=False).to(origin_dtype)
class HunyuanDiTFinalLayer(torch.nn.Module):
def __init__(self, final_hidden_size=1408, condition_dim=1408, patch_size=2, out_channels=8):
super().__init__()
self.norm_final = torch.nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = torch.nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = torch.nn.Sequential(
FP32_SiLU(),
torch.nn.Linear(condition_dim, 2 * final_hidden_size, bias=True)
)
def modulate(self, x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def forward(self, hidden_states, condition_emb):
shift, scale = self.adaLN_modulation(condition_emb).chunk(2, dim=1)
hidden_states = self.modulate(self.norm_final(hidden_states), shift, scale)
hidden_states = self.linear(hidden_states)
return hidden_states
class HunyuanDiTBlock(torch.nn.Module):
def __init__(
self,
hidden_dim=1408,
condition_dim=1408,
num_heads=16,
mlp_ratio=4.3637,
text_dim=1024,
skip_connection=False
):
super().__init__()
self.norm1 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
self.rota1 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads)
self.attn1 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
self.norm2 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
self.rota2 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads, rotary_emb_on_k=False)
self.attn2 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, kv_dim=text_dim, bias_q=True, bias_kv=True, bias_out=True)
self.norm3 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
self.modulation = torch.nn.Sequential(FP32_SiLU(), torch.nn.Linear(condition_dim, hidden_dim, bias=True))
self.mlp = torch.nn.Sequential(
torch.nn.Linear(hidden_dim, int(hidden_dim*mlp_ratio), bias=True),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(int(hidden_dim*mlp_ratio), hidden_dim, bias=True)
)
if skip_connection:
self.skip_norm = FP32_Layernorm((hidden_dim * 2,), eps=1e-6, elementwise_affine=True)
self.skip_linear = torch.nn.Linear(hidden_dim * 2, hidden_dim, bias=True)
else:
self.skip_norm, self.skip_linear = None, None
def forward(self, hidden_states, condition_emb, text_emb, freq_cis_img, residual=None, to_cache=False):
# Long Skip Connection
if self.skip_norm is not None and self.skip_linear is not None:
hidden_states = torch.cat([hidden_states, residual], dim=-1)
hidden_states = self.skip_norm(hidden_states)
hidden_states = self.skip_linear(hidden_states)
# Self-Attention
shift_msa = self.modulation(condition_emb).unsqueeze(dim=1)
attn_input = self.norm1(hidden_states) + shift_msa
hidden_states = hidden_states + self.attn1(attn_input, qkv_preprocessor=lambda q, k, v: self.rota1(q, k, v, freq_cis_img, to_cache=to_cache))
# Cross-Attention
attn_input = self.norm3(hidden_states)
hidden_states = hidden_states + self.attn2(attn_input, text_emb, qkv_preprocessor=lambda q, k, v: self.rota2(q, k, v, freq_cis_img))
# FFN Layer
mlp_input = self.norm2(hidden_states)
hidden_states = hidden_states + self.mlp(mlp_input)
return hidden_states
class AttentionPool(torch.nn.Module):
def __init__(self, spacial_dim, embed_dim, num_heads, output_dim = None):
super().__init__()
self.positional_embedding = torch.nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = torch.nn.Linear(embed_dim, embed_dim)
self.q_proj = torch.nn.Linear(embed_dim, embed_dim)
self.v_proj = torch.nn.Linear(embed_dim, embed_dim)
self.c_proj = torch.nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.permute(1, 0, 2) # NLC -> LNC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
x, _ = torch.nn.functional.multi_head_attention_forward(
query=x[:1], key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x.squeeze(0)
class PatchEmbed(torch.nn.Module):
def __init__(
self,
patch_size=(2, 2),
in_chans=4,
embed_dim=1408,
bias=True,
):
super().__init__()
self.proj = torch.nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
def forward(self, x):
x = self.proj(x)
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
return x
def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
else:
embedding = repeat(t, "b -> b d", d=dim)
return embedding
class TimestepEmbedder(torch.nn.Module):
def __init__(self, hidden_size=1408, frequency_embedding_size=256):
super().__init__()
self.mlp = torch.nn.Sequential(
torch.nn.Linear(frequency_embedding_size, hidden_size, bias=True),
torch.nn.SiLU(),
torch.nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
def forward(self, t):
t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
t_emb = self.mlp(t_freq)
return t_emb
class HunyuanDiT(torch.nn.Module):
def __init__(self, num_layers_down=21, num_layers_up=19, in_channels=4, out_channels=8, hidden_dim=1408, text_dim=1024, t5_dim=2048, text_length=77, t5_length=256):
super().__init__()
# Embedders
self.text_emb_padding = torch.nn.Parameter(torch.randn(text_length + t5_length, text_dim, dtype=torch.float32))
self.t5_embedder = torch.nn.Sequential(
torch.nn.Linear(t5_dim, t5_dim * 4, bias=True),
FP32_SiLU(),
torch.nn.Linear(t5_dim * 4, text_dim, bias=True),
)
self.t5_pooler = AttentionPool(t5_length, t5_dim, num_heads=8, output_dim=1024)
self.style_embedder = torch.nn.Parameter(torch.randn(hidden_dim))
self.patch_embedder = PatchEmbed(in_chans=in_channels)
self.timestep_embedder = TimestepEmbedder()
self.extra_embedder = torch.nn.Sequential(
torch.nn.Linear(256 * 6 + 1024 + hidden_dim, hidden_dim * 4),
FP32_SiLU(),
torch.nn.Linear(hidden_dim * 4, hidden_dim),
)
# Transformer blocks
self.num_layers_down = num_layers_down
self.num_layers_up = num_layers_up
self.blocks = torch.nn.ModuleList(
[HunyuanDiTBlock(skip_connection=False) for _ in range(num_layers_down)] + \
[HunyuanDiTBlock(skip_connection=True) for _ in range(num_layers_up)]
)
# Output layers
self.final_layer = HunyuanDiTFinalLayer()
self.out_channels = out_channels
def prepare_text_emb(self, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5):
text_emb_mask = text_emb_mask.bool()
text_emb_mask_t5 = text_emb_mask_t5.bool()
text_emb_t5 = self.t5_embedder(text_emb_t5)
text_emb = torch.cat([text_emb, text_emb_t5], dim=1)
text_emb_mask = torch.cat([text_emb_mask, text_emb_mask_t5], dim=-1)
text_emb = torch.where(text_emb_mask.unsqueeze(2), text_emb, self.text_emb_padding.to(text_emb))
return text_emb
def prepare_extra_emb(self, text_emb_t5, timestep, size_emb, dtype, batch_size):
# Text embedding
pooled_text_emb_t5 = self.t5_pooler(text_emb_t5)
# Timestep embedding
timestep_emb = self.timestep_embedder(timestep)
# Size embedding
size_emb = timestep_embedding(size_emb.view(-1), 256).to(dtype)
size_emb = size_emb.view(-1, 6 * 256)
# Style embedding
style_emb = repeat(self.style_embedder, "D -> B D", B=batch_size)
# Concatenate all extra vectors
extra_emb = torch.cat([pooled_text_emb_t5, size_emb, style_emb], dim=1)
condition_emb = timestep_emb + self.extra_embedder(extra_emb)
return condition_emb
def unpatchify(self, x, h, w):
return rearrange(x, "B (H W) (P Q C) -> B C (H P) (W Q)", H=h, W=w, P=2, Q=2)
def build_mask(self, data, is_bound):
_, _, H, W = data.shape
h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
border_width = (H + W) // 4
pad = torch.ones_like(h) * border_width
mask = torch.stack([
pad if is_bound[0] else h + 1,
pad if is_bound[1] else H - h,
pad if is_bound[2] else w + 1,
pad if is_bound[3] else W - w
]).min(dim=0).values
mask = mask.clip(1, border_width)
mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
mask = rearrange(mask, "H W -> 1 H W")
return mask
def tiled_block_forward(self, block, hidden_states, condition_emb, text_emb, freq_cis_img, residual, torch_dtype, data_device, computation_device, tile_size, tile_stride):
B, C, H, W = hidden_states.shape
weight = torch.zeros((1, 1, H, W), dtype=torch_dtype, device=data_device)
values = torch.zeros((B, C, H, W), dtype=torch_dtype, device=data_device)
# Split tasks
tasks = []
for h in range(0, H, tile_stride):
for w in range(0, W, tile_stride):
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
continue
h_, w_ = h + tile_size, w + tile_size
if h_ > H: h, h_ = H - tile_size, H
if w_ > W: w, w_ = W - tile_size, W
tasks.append((h, h_, w, w_))
# Run
for hl, hr, wl, wr in tasks:
hidden_states_batch = hidden_states[:, :, hl:hr, wl:wr].to(computation_device)
hidden_states_batch = rearrange(hidden_states_batch, "B C H W -> B (H W) C")
if residual is not None:
residual_batch = residual[:, :, hl:hr, wl:wr].to(computation_device)
residual_batch = rearrange(residual_batch, "B C H W -> B (H W) C")
else:
residual_batch = None
# Forward
hidden_states_batch = block(hidden_states_batch, condition_emb, text_emb, freq_cis_img, residual_batch).to(data_device)
hidden_states_batch = rearrange(hidden_states_batch, "B (H W) C -> B C H W", H=hr-hl)
mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
weight[:, :, hl:hr, wl:wr] += mask
values /= weight
return values
def forward(
self, hidden_states, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5, timestep, size_emb, freq_cis_img,
tiled=False, tile_size=64, tile_stride=32,
to_cache=False,
use_gradient_checkpointing=False,
):
# Embeddings
text_emb = self.prepare_text_emb(text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5)
condition_emb = self.prepare_extra_emb(text_emb_t5, timestep, size_emb, hidden_states.dtype, hidden_states.shape[0])
# Input
height, width = hidden_states.shape[-2], hidden_states.shape[-1]
hidden_states = self.patch_embedder(hidden_states)
# Blocks
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if tiled:
hidden_states = rearrange(hidden_states, "B (H W) C -> B C H W", H=height//2)
residuals = []
for block_id, block in enumerate(self.blocks):
residual = residuals.pop() if block_id >= self.num_layers_down else None
hidden_states = self.tiled_block_forward(
block, hidden_states, condition_emb, text_emb, freq_cis_img, residual,
torch_dtype=hidden_states.dtype, data_device=hidden_states.device, computation_device=hidden_states.device,
tile_size=tile_size, tile_stride=tile_stride
)
if block_id < self.num_layers_down - 2:
residuals.append(hidden_states)
hidden_states = rearrange(hidden_states, "B C H W -> B (H W) C")
else:
residuals = []
for block_id, block in enumerate(self.blocks):
residual = residuals.pop() if block_id >= self.num_layers_down else None
if self.training and use_gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, condition_emb, text_emb, freq_cis_img, residual,
use_reentrant=False,
)
else:
hidden_states = block(hidden_states, condition_emb, text_emb, freq_cis_img, residual, to_cache=to_cache)
if block_id < self.num_layers_down - 2:
residuals.append(hidden_states)
# Output
hidden_states = self.final_layer(hidden_states, condition_emb)
hidden_states = self.unpatchify(hidden_states, height//2, width//2)
hidden_states, _ = hidden_states.chunk(2, dim=1)
return hidden_states
def state_dict_converter(self):
return HunyuanDiTStateDictConverter()
class HunyuanDiTStateDictConverter():
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {}
for name, param in state_dict.items():
name_ = name
name_ = name_.replace(".default_modulation.", ".modulation.")
name_ = name_.replace(".mlp.fc1.", ".mlp.0.")
name_ = name_.replace(".mlp.fc2.", ".mlp.2.")
name_ = name_.replace(".attn1.q_norm.", ".rota1.q_norm.")
name_ = name_.replace(".attn2.q_norm.", ".rota2.q_norm.")
name_ = name_.replace(".attn1.k_norm.", ".rota1.k_norm.")
name_ = name_.replace(".attn2.k_norm.", ".rota2.k_norm.")
name_ = name_.replace(".q_proj.", ".to_q.")
name_ = name_.replace(".out_proj.", ".to_out.")
name_ = name_.replace("text_embedding_padding", "text_emb_padding")
name_ = name_.replace("mlp_t5.0.", "t5_embedder.0.")
name_ = name_.replace("mlp_t5.2.", "t5_embedder.2.")
name_ = name_.replace("pooler.", "t5_pooler.")
name_ = name_.replace("x_embedder.", "patch_embedder.")
name_ = name_.replace("t_embedder.", "timestep_embedder.")
name_ = name_.replace("t5_pooler.to_q.", "t5_pooler.q_proj.")
name_ = name_.replace("style_embedder.weight", "style_embedder")
if ".kv_proj." in name_:
param_k = param[:param.shape[0]//2]
param_v = param[param.shape[0]//2:]
state_dict_[name_.replace(".kv_proj.", ".to_k.")] = param_k
state_dict_[name_.replace(".kv_proj.", ".to_v.")] = param_v
elif ".Wqkv." in name_:
param_q = param[:param.shape[0]//3]
param_k = param[param.shape[0]//3:param.shape[0]//3*2]
param_v = param[param.shape[0]//3*2:]
state_dict_[name_.replace(".Wqkv.", ".to_q.")] = param_q
state_dict_[name_.replace(".Wqkv.", ".to_k.")] = param_k
state_dict_[name_.replace(".Wqkv.", ".to_v.")] = param_v
elif "style_embedder" in name_:
state_dict_[name_] = param.squeeze()
else:
state_dict_[name_] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -1,161 +0,0 @@
from transformers import BertModel, BertConfig, T5EncoderModel, T5Config
import torch
class HunyuanDiTCLIPTextEncoder(BertModel):
def __init__(self):
config = BertConfig(
_name_or_path = "",
architectures = ["BertModel"],
attention_probs_dropout_prob = 0.1,
bos_token_id = 0,
classifier_dropout = None,
directionality = "bidi",
eos_token_id = 2,
hidden_act = "gelu",
hidden_dropout_prob = 0.1,
hidden_size = 1024,
initializer_range = 0.02,
intermediate_size = 4096,
layer_norm_eps = 1e-12,
max_position_embeddings = 512,
model_type = "bert",
num_attention_heads = 16,
num_hidden_layers = 24,
output_past = True,
pad_token_id = 0,
pooler_fc_size = 768,
pooler_num_attention_heads = 12,
pooler_num_fc_layers = 3,
pooler_size_per_head = 128,
pooler_type = "first_token_transform",
position_embedding_type = "absolute",
torch_dtype = "float32",
transformers_version = "4.37.2",
type_vocab_size = 2,
use_cache = True,
vocab_size = 47020
)
super().__init__(config, add_pooling_layer=False)
self.eval()
def forward(self, input_ids, attention_mask, clip_skip=1):
input_shape = input_ids.size()
batch_size, seq_length = input_shape
device = input_ids.device
past_key_values_length = 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=None,
token_type_ids=None,
inputs_embeds=None,
past_key_values_length=0,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=False,
output_attentions=False,
output_hidden_states=True,
return_dict=True,
)
all_hidden_states = encoder_outputs.hidden_states
prompt_emb = all_hidden_states[-clip_skip]
if clip_skip > 1:
mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std()
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
return prompt_emb
def state_dict_converter(self):
return HunyuanDiTCLIPTextEncoderStateDictConverter()
class HunyuanDiTT5TextEncoder(T5EncoderModel):
def __init__(self):
config = T5Config(
_name_or_path = "../HunyuanDiT/t2i/mt5",
architectures = ["MT5ForConditionalGeneration"],
classifier_dropout = 0.0,
d_ff = 5120,
d_kv = 64,
d_model = 2048,
decoder_start_token_id = 0,
dense_act_fn = "gelu_new",
dropout_rate = 0.1,
eos_token_id = 1,
feed_forward_proj = "gated-gelu",
initializer_factor = 1.0,
is_encoder_decoder = True,
is_gated_act = True,
layer_norm_epsilon = 1e-06,
model_type = "t5",
num_decoder_layers = 24,
num_heads = 32,
num_layers = 24,
output_past = True,
pad_token_id = 0,
relative_attention_max_distance = 128,
relative_attention_num_buckets = 32,
tie_word_embeddings = False,
tokenizer_class = "T5Tokenizer",
transformers_version = "4.37.2",
use_cache = True,
vocab_size = 250112
)
super().__init__(config)
self.eval()
def forward(self, input_ids, attention_mask, clip_skip=1):
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
)
prompt_emb = outputs.hidden_states[-clip_skip]
if clip_skip > 1:
mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std()
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
return prompt_emb
def state_dict_converter(self):
return HunyuanDiTT5TextEncoderStateDictConverter()
class HunyuanDiTCLIPTextEncoderStateDictConverter():
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")}
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)
class HunyuanDiTT5TextEncoderStateDictConverter():
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")}
state_dict_["shared.weight"] = state_dict["shared.weight"]
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -0,0 +1,902 @@
from typing import List, Optional, Tuple
import math
import torch
import torch.nn as nn
import torch.amp as amp
import numpy as np
import torch.nn.functional as F
from einops import rearrange, repeat
from .wan_video_dit import flash_attention
from ..core.device.npu_compatible_device import get_device_type
from ..core.gradient import gradient_checkpoint_forward
class RMSNorm_FP32(torch.nn.Module):
def __init__(self, dim: int, eps: float):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
def broadcat(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
), "invalid dimensions for broadcastable concatentation"
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
return torch.cat(tensors, dim=dim)
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
class RotaryPositionalEmbedding(nn.Module):
def __init__(self,
head_dim,
cp_split_hw=None
):
"""Rotary positional embedding for 3D
Reference : https://blog.eleuther.ai/rotary-embeddings/
Paper: https://arxiv.org/pdf/2104.09864.pdf
Args:
dim: Dimension of embedding
base: Base value for exponential
"""
super().__init__()
self.head_dim = head_dim
assert self.head_dim % 8 == 0, 'Dim must be a multiply of 8 for 3D RoPE.'
self.cp_split_hw = cp_split_hw
# We take the assumption that the longest side of grid will not larger than 512, i.e, 512 * 8 = 4098 input pixels
self.base = 10000
self.freqs_dict = {}
def register_grid_size(self, grid_size):
if grid_size not in self.freqs_dict:
self.freqs_dict.update({
grid_size: self.precompute_freqs_cis_3d(grid_size)
})
def precompute_freqs_cis_3d(self, grid_size):
num_frames, height, width = grid_size
dim_t = self.head_dim - 4 * (self.head_dim // 6)
dim_h = 2 * (self.head_dim // 6)
dim_w = 2 * (self.head_dim // 6)
freqs_t = 1.0 / (self.base ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
freqs_h = 1.0 / (self.base ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
freqs_w = 1.0 / (self.base ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
grid_t = np.linspace(0, num_frames, num_frames, endpoint=False, dtype=np.float32)
grid_h = np.linspace(0, height, height, endpoint=False, dtype=np.float32)
grid_w = np.linspace(0, width, width, endpoint=False, dtype=np.float32)
grid_t = torch.from_numpy(grid_t).float()
grid_h = torch.from_numpy(grid_h).float()
grid_w = torch.from_numpy(grid_w).float()
freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t)
freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h)
freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w)
freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2)
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
# (T H W D)
freqs = rearrange(freqs, "T H W D -> (T H W) D")
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
# with torch.no_grad():
# freqs = rearrange(freqs, "(T H W) D -> T H W D", T=num_frames, H=height, W=width)
# freqs = context_parallel_util.split_cp_2d(freqs, seq_dim_hw=(1, 2), split_hw=self.cp_split_hw)
# freqs = rearrange(freqs, "T H W D -> (T H W) D")
return freqs
def forward(self, q, k, grid_size):
"""3D RoPE.
Args:
query: [B, head, seq, head_dim]
key: [B, head, seq, head_dim]
Returns:
query and key with the same shape as input.
"""
if grid_size not in self.freqs_dict:
self.register_grid_size(grid_size)
freqs_cis = self.freqs_dict[grid_size].to(q.device)
q_, k_ = q.float(), k.float()
freqs_cis = freqs_cis.float().to(q.device)
cos, sin = freqs_cis.cos(), freqs_cis.sin()
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
q_ = (q_ * cos) + (rotate_half(q_) * sin)
k_ = (k_ * cos) + (rotate_half(k_) * sin)
return q_.type_as(q), k_.type_as(k)
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
enable_flashattn3: bool = False,
enable_flashattn2: bool = False,
enable_xformers: bool = False,
enable_bsa: bool = False,
bsa_params: dict = None,
cp_split_hw: Optional[List[int]] = None
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.enable_flashattn3 = enable_flashattn3
self.enable_flashattn2 = enable_flashattn2
self.enable_xformers = enable_xformers
self.enable_bsa = enable_bsa
self.bsa_params = bsa_params
self.cp_split_hw = cp_split_hw
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.proj = nn.Linear(dim, dim)
self.rope_3d = RotaryPositionalEmbedding(
self.head_dim,
cp_split_hw=cp_split_hw
)
def _process_attn(self, q, k, v, shape):
q = rearrange(q, "B H S D -> B S (H D)")
k = rearrange(k, "B H S D -> B S (H D)")
v = rearrange(v, "B H S D -> B S (H D)")
x = flash_attention(q, k, v, num_heads=self.num_heads)
x = rearrange(x, "B S (H D) -> B H S D", H=self.num_heads)
return x
def forward(self, x: torch.Tensor, shape=None, num_cond_latents=None, return_kv=False) -> torch.Tensor:
"""
"""
B, N, C = x.shape
qkv = self.qkv(x)
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if return_kv:
k_cache, v_cache = k.clone(), v.clone()
q, k = self.rope_3d(q, k, shape)
# cond mode
if num_cond_latents is not None and num_cond_latents > 0:
num_cond_latents_thw = num_cond_latents * (N // shape[0])
# process the condition tokens
q_cond = q[:, :, :num_cond_latents_thw].contiguous()
k_cond = k[:, :, :num_cond_latents_thw].contiguous()
v_cond = v[:, :, :num_cond_latents_thw].contiguous()
x_cond = self._process_attn(q_cond, k_cond, v_cond, shape)
# process the noise tokens
q_noise = q[:, :, num_cond_latents_thw:].contiguous()
x_noise = self._process_attn(q_noise, k, v, shape)
# merge x_cond and x_noise
x = torch.cat([x_cond, x_noise], dim=2).contiguous()
else:
x = self._process_attn(q, k, v, shape)
x_output_shape = (B, N, C)
x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]
x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]
x = self.proj(x)
if return_kv:
return x, (k_cache, v_cache)
else:
return x
def forward_with_kv_cache(self, x: torch.Tensor, shape=None, num_cond_latents=None, kv_cache=None) -> torch.Tensor:
"""
"""
B, N, C = x.shape
qkv = self.qkv(x)
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
T, H, W = shape
k_cache, v_cache = kv_cache
assert k_cache.shape[0] == v_cache.shape[0] and k_cache.shape[0] in [1, B]
if k_cache.shape[0] == 1:
k_cache = k_cache.repeat(B, 1, 1, 1)
v_cache = v_cache.repeat(B, 1, 1, 1)
if num_cond_latents is not None and num_cond_latents > 0:
k_full = torch.cat([k_cache, k], dim=2).contiguous()
v_full = torch.cat([v_cache, v], dim=2).contiguous()
q_padding = torch.cat([torch.empty_like(k_cache), q], dim=2).contiguous()
q_padding, k_full = self.rope_3d(q_padding, k_full, (T + num_cond_latents, H, W))
q = q_padding[:, :, -N:].contiguous()
x = self._process_attn(q, k_full, v_full, shape)
x_output_shape = (B, N, C)
x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]
x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]
x = self.proj(x)
return x
class MultiHeadCrossAttention(nn.Module):
def __init__(
self,
dim,
num_heads,
enable_flashattn3=False,
enable_flashattn2=False,
enable_xformers=False,
):
super(MultiHeadCrossAttention, self).__init__()
assert dim % num_heads == 0, "d_model must be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q_linear = nn.Linear(dim, dim)
self.kv_linear = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim)
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.enable_flashattn3 = enable_flashattn3
self.enable_flashattn2 = enable_flashattn2
self.enable_xformers = enable_xformers
def _process_cross_attn(self, x, cond, kv_seqlen):
B, N, C = x.shape
assert C == self.dim and cond.shape[2] == self.dim
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2)
q, k = self.q_norm(q), self.k_norm(k)
q = rearrange(q, "B S H D -> B S (H D)")
k = rearrange(k, "B S H D -> B S (H D)")
v = rearrange(v, "B S H D -> B S (H D)")
x = flash_attention(q, k, v, num_heads=self.num_heads)
x = x.view(B, -1, C)
x = self.proj(x)
return x
def forward(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None):
"""
x: [B, N, C]
cond: [B, M, C]
"""
if num_cond_latents is None or num_cond_latents == 0:
return self._process_cross_attn(x, cond, kv_seqlen)
else:
B, N, C = x.shape
if num_cond_latents is not None and num_cond_latents > 0:
assert shape is not None, "SHOULD pass in the shape"
num_cond_latents_thw = num_cond_latents * (N // shape[0])
x_noise = x[:, num_cond_latents_thw:] # [B, N_noise, C]
output_noise = self._process_cross_attn(x_noise, cond, kv_seqlen) # [B, N_noise, C]
output = torch.cat([
torch.zeros((B, num_cond_latents_thw, C), dtype=output_noise.dtype, device=output_noise.device),
output_noise
], dim=1).contiguous()
else:
raise NotImplementedError
return output
class LayerNorm_FP32(nn.LayerNorm):
def __init__(self, dim, eps, elementwise_affine):
super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
origin_dtype = inputs.dtype
out = F.layer_norm(
inputs.float(),
self.normalized_shape,
None if self.weight is None else self.weight.float(),
None if self.bias is None else self.bias.float() ,
self.eps
).to(origin_dtype)
return out
def modulate_fp32(norm_func, x, shift, scale):
# Suppose x is (B, N, D), shift is (B, -1, D), scale is (B, -1, D)
# ensure the modulation params be fp32
assert shift.dtype == torch.float32, scale.dtype == torch.float32
dtype = x.dtype
x = norm_func(x.to(torch.float32))
x = x * (scale + 1) + shift
x = x.to(dtype)
return x
class FinalLayer_FP32(nn.Module):
"""
The final layer of DiT.
"""
def __init__(self, hidden_size, num_patch, out_channels, adaln_tembed_dim):
super().__init__()
self.hidden_size = hidden_size
self.num_patch = num_patch
self.out_channels = out_channels
self.adaln_tembed_dim = adaln_tembed_dim
self.norm_final = LayerNorm_FP32(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(adaln_tembed_dim, 2 * hidden_size, bias=True))
def forward(self, x, t, latent_shape):
# timestep shape: [B, T, C]
assert t.dtype == torch.float32
B, N, C = x.shape
T, _, _ = latent_shape
with amp.autocast(get_device_type(), dtype=torch.float32):
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
x = self.linear(x)
return x
class FeedForwardSwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.dim = dim
self.hidden_dim = hidden_dim
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, t_embed_dim, frequency_embedding_size=256):
super().__init__()
self.t_embed_dim = t_embed_dim
self.frequency_embedding_size = frequency_embedding_size
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, t_embed_dim, bias=True),
nn.SiLU(),
nn.Linear(t_embed_dim, t_embed_dim, bias=True),
)
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
freqs = freqs.to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t, dtype):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
if t_freq.dtype != dtype:
t_freq = t_freq.to(dtype)
t_emb = self.mlp(t_freq)
return t_emb
class CaptionEmbedder(nn.Module):
"""
Embeds class labels into vector representations.
"""
def __init__(self, in_channels, hidden_size):
super().__init__()
self.in_channels = in_channels
self.hidden_size = hidden_size
self.y_proj = nn.Sequential(
nn.Linear(in_channels, hidden_size, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(hidden_size, hidden_size, bias=True),
)
def forward(self, caption):
B, _, N, C = caption.shape
caption = self.y_proj(caption)
return caption
class PatchEmbed3D(nn.Module):
"""Video to Patch Embedding.
Args:
patch_size (int): Patch token size. Default: (2,4,4).
in_chans (int): Number of input video channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(
self,
patch_size=(2, 4, 4),
in_chans=3,
embed_dim=96,
norm_layer=None,
flatten=True,
):
super().__init__()
self.patch_size = patch_size
self.flatten = flatten
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, D, H, W = x.size()
if W % self.patch_size[2] != 0:
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
if H % self.patch_size[1] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
if D % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
B, C, T, H, W = x.shape
x = self.proj(x) # (B C T H W)
if self.norm is not None:
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
return x
class LongCatSingleStreamBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: int,
adaln_tembed_dim: int,
enable_flashattn3: bool = False,
enable_flashattn2: bool = False,
enable_xformers: bool = False,
enable_bsa: bool = False,
bsa_params=None,
cp_split_hw=None
):
super().__init__()
self.hidden_size = hidden_size
# scale and gate modulation
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(adaln_tembed_dim, 6 * hidden_size, bias=True)
)
self.mod_norm_attn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)
self.mod_norm_ffn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)
self.pre_crs_attn_norm = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=True)
self.attn = Attention(
dim=hidden_size,
num_heads=num_heads,
enable_flashattn3=enable_flashattn3,
enable_flashattn2=enable_flashattn2,
enable_xformers=enable_xformers,
enable_bsa=enable_bsa,
bsa_params=bsa_params,
cp_split_hw=cp_split_hw
)
self.cross_attn = MultiHeadCrossAttention(
dim=hidden_size,
num_heads=num_heads,
enable_flashattn3=enable_flashattn3,
enable_flashattn2=enable_flashattn2,
enable_xformers=enable_xformers,
)
self.ffn = FeedForwardSwiGLU(dim=hidden_size, hidden_dim=int(hidden_size * mlp_ratio))
def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return_kv=False, kv_cache=None, skip_crs_attn=False):
"""
x: [B, N, C]
y: [1, N_valid_tokens, C]
t: [B, T, C_t]
y_seqlen: [B]; type of a list
latent_shape: latent shape of a single item
"""
x_dtype = x.dtype
B, N, C = x.shape
T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.
# compute modulation params in fp32
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
shift_msa, scale_msa, gate_msa, \
shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
# self attn with modulation
x_m = modulate_fp32(self.mod_norm_attn, x.view(B, T, -1, C), shift_msa, scale_msa).view(B, N, C)
if kv_cache is not None:
kv_cache = (kv_cache[0].to(x.device), kv_cache[1].to(x.device))
attn_outputs = self.attn.forward_with_kv_cache(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, kv_cache=kv_cache)
else:
attn_outputs = self.attn(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, return_kv=return_kv)
if return_kv:
x_s, kv_cache = attn_outputs
else:
x_s = attn_outputs
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype)
# cross attn
if not skip_crs_attn:
if kv_cache is not None:
num_cond_latents = None
x = x + self.cross_attn(self.pre_crs_attn_norm(x), y, y_seqlen, num_cond_latents=num_cond_latents, shape=latent_shape)
# ffn with modulation
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
x_s = self.ffn(x_m)
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype)
if return_kv:
return x, kv_cache
else:
return x
class LongCatVideoTransformer3DModel(torch.nn.Module):
def __init__(
self,
in_channels: int = 16,
out_channels: int = 16,
hidden_size: int = 4096,
depth: int = 48,
num_heads: int = 32,
caption_channels: int = 4096,
mlp_ratio: int = 4,
adaln_tembed_dim: int = 512,
frequency_embedding_size: int = 256,
# default params
patch_size: Tuple[int] = (1, 2, 2),
# attention config
enable_flashattn3: bool = False,
enable_flashattn2: bool = True,
enable_xformers: bool = False,
enable_bsa: bool = False,
bsa_params: dict = {'sparsity': 0.9375, 'chunk_3d_shape_q': [4, 4, 4], 'chunk_3d_shape_k': [4, 4, 4]},
cp_split_hw: Optional[List[int]] = [1, 1],
text_tokens_zero_pad: bool = True,
) -> None:
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels
self.cp_split_hw = cp_split_hw
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
self.t_embedder = TimestepEmbedder(t_embed_dim=adaln_tembed_dim, frequency_embedding_size=frequency_embedding_size)
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels,
hidden_size=hidden_size,
)
self.blocks = nn.ModuleList(
[
LongCatSingleStreamBlock(
hidden_size=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
adaln_tembed_dim=adaln_tembed_dim,
enable_flashattn3=enable_flashattn3,
enable_flashattn2=enable_flashattn2,
enable_xformers=enable_xformers,
enable_bsa=enable_bsa,
bsa_params=bsa_params,
cp_split_hw=cp_split_hw
)
for i in range(depth)
]
)
self.final_layer = FinalLayer_FP32(
hidden_size,
np.prod(self.patch_size),
out_channels,
adaln_tembed_dim,
)
self.gradient_checkpointing = False
self.text_tokens_zero_pad = text_tokens_zero_pad
self.lora_dict = {}
self.active_loras = []
def enable_loras(self, lora_key_list=[]):
self.disable_all_loras()
module_loras = {} # {module_name: [lora1, lora2, ...]}
model_device = next(self.parameters()).device
model_dtype = next(self.parameters()).dtype
for lora_key in lora_key_list:
if lora_key in self.lora_dict:
for lora in self.lora_dict[lora_key].loras:
lora.to(model_device, dtype=model_dtype, non_blocking=True)
module_name = lora.lora_name.replace("lora___lorahyphen___", "").replace("___lorahyphen___", ".")
if module_name not in module_loras:
module_loras[module_name] = []
module_loras[module_name].append(lora)
self.active_loras.append(lora_key)
for module_name, loras in module_loras.items():
module = self._get_module_by_name(module_name)
if not hasattr(module, 'org_forward'):
module.org_forward = module.forward
module.forward = self._create_multi_lora_forward(module, loras)
def _create_multi_lora_forward(self, module, loras):
def multi_lora_forward(x, *args, **kwargs):
weight_dtype = x.dtype
org_output = module.org_forward(x, *args, **kwargs)
total_lora_output = 0
for lora in loras:
if lora.use_lora:
lx = lora.lora_down(x.to(lora.lora_down.weight.dtype))
lx = lora.lora_up(lx)
lora_output = lx.to(weight_dtype) * lora.multiplier * lora.alpha_scale
total_lora_output += lora_output
return org_output + total_lora_output
return multi_lora_forward
def _get_module_by_name(self, module_name):
try:
module = self
for part in module_name.split('.'):
module = getattr(module, part)
return module
except AttributeError as e:
raise ValueError(f"Cannot find module: {module_name}, error: {e}")
def disable_all_loras(self):
for name, module in self.named_modules():
if hasattr(module, 'org_forward'):
module.forward = module.org_forward
delattr(module, 'org_forward')
for lora_key, lora_network in self.lora_dict.items():
for lora in lora_network.loras:
lora.to("cpu")
self.active_loras.clear()
def enable_bsa(self,):
for block in self.blocks:
block.attn.enable_bsa = True
def disable_bsa(self,):
for block in self.blocks:
block.attn.enable_bsa = False
def forward(
self,
hidden_states,
timestep,
encoder_hidden_states,
encoder_attention_mask=None,
num_cond_latents=0,
return_kv=False,
kv_cache_dict={},
skip_crs_attn=False,
offload_kv_cache=False,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
):
B, _, T, H, W = hidden_states.shape
N_t = T // self.patch_size[0]
N_h = H // self.patch_size[1]
N_w = W // self.patch_size[2]
assert self.patch_size[0]==1, "Currently, 3D x_embedder should not compress the temporal dimension."
# expand the shape of timestep from [B] to [B, T]
if len(timestep.shape) == 1:
timestep = timestep.unsqueeze(1).expand(-1, N_t).clone() # [B, T]
timestep[:, :num_cond_latents] = 0
dtype = hidden_states.dtype
hidden_states = hidden_states.to(dtype)
timestep = timestep.to(dtype)
encoder_hidden_states = encoder_hidden_states.to(dtype)
hidden_states = self.x_embedder(hidden_states) # [B, N, C]
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]
if self.text_tokens_zero_pad and encoder_attention_mask is not None:
encoder_hidden_states = encoder_hidden_states * encoder_attention_mask[:, None, :, None]
encoder_attention_mask = (encoder_attention_mask * 0 + 1).to(encoder_attention_mask.dtype)
if encoder_attention_mask is not None:
encoder_attention_mask = encoder_attention_mask.squeeze(1).squeeze(1)
encoder_hidden_states = encoder_hidden_states.squeeze(1).masked_select(encoder_attention_mask.unsqueeze(-1) != 0).view(1, -1, hidden_states.shape[-1]) # [1, N_valid_tokens, C]
y_seqlens = encoder_attention_mask.sum(dim=1).tolist() # [B]
else:
y_seqlens = [encoder_hidden_states.shape[2]] * encoder_hidden_states.shape[0]
encoder_hidden_states = encoder_hidden_states.squeeze(1).view(1, -1, hidden_states.shape[-1])
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
# hidden_states = rearrange(hidden_states, "B (T H W) C -> B T H W C", T=N_t, H=N_h, W=N_w)
# hidden_states = context_parallel_util.split_cp_2d(hidden_states, seq_dim_hw=(2, 3), split_hw=self.cp_split_hw)
# hidden_states = rearrange(hidden_states, "B T H W C -> B (T H W) C")
# blocks
kv_cache_dict_ret = {}
for i, block in enumerate(self.blocks):
block_outputs = gradient_checkpoint_forward(
block,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=hidden_states,
y=encoder_hidden_states,
t=t,
y_seqlen=y_seqlens,
latent_shape=(N_t, N_h, N_w),
num_cond_latents=num_cond_latents,
return_kv=return_kv,
kv_cache=kv_cache_dict.get(i, None),
skip_crs_attn=skip_crs_attn,
)
if return_kv:
hidden_states, kv_cache = block_outputs
if offload_kv_cache:
kv_cache_dict_ret[i] = (kv_cache[0].cpu(), kv_cache[1].cpu())
else:
kv_cache_dict_ret[i] = (kv_cache[0].contiguous(), kv_cache[1].contiguous())
else:
hidden_states = block_outputs
hidden_states = self.final_layer(hidden_states, t, (N_t, N_h, N_w)) # [B, N, C=T_p*H_p*W_p*C_out]
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
# hidden_states = context_parallel_util.gather_cp_2d(hidden_states, shape=(N_t, N_h, N_w), split_hw=self.cp_split_hw)
hidden_states = self.unpatchify(hidden_states, N_t, N_h, N_w) # [B, C_out, H, W]
# cast to float32 for better accuracy
hidden_states = hidden_states.to(torch.float32)
if return_kv:
return hidden_states, kv_cache_dict_ret
else:
return hidden_states
def unpatchify(self, x, N_t, N_h, N_w):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
T_p, H_p, W_p = self.patch_size
x = rearrange(
x,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
N_t=N_t,
N_h=N_h,
N_w=N_w,
T_p=T_p,
H_p=H_p,
W_p=W_p,
C_out=self.out_channels,
)
return x
@staticmethod
def state_dict_converter():
return LongCatVideoTransformer3DModelDictConverter()
class LongCatVideoTransformer3DModelDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
return state_dict

View File

@@ -0,0 +1,111 @@
from ..core.loader import load_model, hash_model_file
from ..core.vram import AutoWrappedModule
from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS
import importlib, json, torch
class ModelPool:
def __init__(self):
self.model = []
self.model_name = []
self.model_path = []
def import_model_class(self, model_class):
split = model_class.rfind(".")
model_resource, model_class = model_class[:split], model_class[split+1:]
model_class = importlib.import_module(model_resource).__getattribute__(model_class)
return model_class
def need_to_enable_vram_management(self, vram_config):
return vram_config["offload_dtype"] is not None and vram_config["offload_device"] is not None
def fetch_module_map(self, model_class, vram_config):
if self.need_to_enable_vram_management(vram_config):
if model_class in VRAM_MANAGEMENT_MODULE_MAPS:
module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in VRAM_MANAGEMENT_MODULE_MAPS[model_class].items()}
else:
module_map = {self.import_model_class(model_class): AutoWrappedModule}
else:
module_map = None
return module_map
def load_model_file(self, config, path, vram_config, vram_limit=None):
model_class = self.import_model_class(config["model_class"])
model_config = config.get("extra_kwargs", {})
if "state_dict_converter" in config:
state_dict_converter = self.import_model_class(config["state_dict_converter"])
else:
state_dict_converter = None
module_map = self.fetch_module_map(config["model_class"], vram_config)
model = load_model(
model_class, path, model_config,
vram_config["computation_dtype"], vram_config["computation_device"],
state_dict_converter,
use_disk_map=True,
vram_config=vram_config, module_map=module_map, vram_limit=vram_limit,
)
return model
def default_vram_config(self):
vram_config = {
"offload_dtype": None,
"offload_device": None,
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cpu",
"computation_dtype": torch.bfloat16,
"computation_device": "cpu",
}
return vram_config
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False):
print(f"Loading models from: {json.dumps(path, indent=4)}")
if vram_config is None:
vram_config = self.default_vram_config()
model_hash = hash_model_file(path)
loaded = False
for config in MODEL_CONFIGS:
if config["model_hash"] == model_hash:
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit)
if clear_parameters: self.clear_parameters(model)
self.model.append(model)
model_name = config["model_name"]
self.model_name.append(model_name)
self.model_path.append(path)
model_info = {"model_name": model_name, "model_class": config["model_class"], "extra_kwargs": config.get("extra_kwargs")}
print(f"Loaded model: {json.dumps(model_info, indent=4)}")
loaded = True
if not loaded:
raise ValueError(f"Cannot detect the model type. File: {path}. Model hash: {model_hash}")
def fetch_model(self, model_name, index=None):
fetched_models = []
fetched_model_paths = []
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
if model_name == model_name_:
fetched_models.append(model)
fetched_model_paths.append(model_path)
if len(fetched_models) == 0:
print(f"No {model_name} models available. This is not an error.")
model = None
elif len(fetched_models) == 1:
print(f"Using {model_name} from {json.dumps(fetched_model_paths[0], indent=4)}.")
model = fetched_models[0]
else:
if index is None:
model = fetched_models[0]
print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths[0], indent=4)}.")
elif isinstance(index, int):
model = fetched_models[:index]
print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths[:index], indent=4)}.")
else:
model = fetched_models
print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths, indent=4)}.")
return model
def clear_parameters(self, model: torch.nn.Module):
for name, module in model.named_children():
self.clear_parameters(module)
for name, param in model.named_parameters(recurse=False):
setattr(model, name, None)

View File

@@ -0,0 +1,161 @@
import torch
from PIL import Image
class NexusGenAutoregressiveModel(torch.nn.Module):
def __init__(self, max_length=1024, max_pixels=262640):
super(NexusGenAutoregressiveModel, self).__init__()
from .nexus_gen_ar_model import Qwen2_5_VLForConditionalGeneration
from transformers import Qwen2_5_VLConfig
self.max_length = max_length
self.max_pixels = max_pixels
model_config = Qwen2_5_VLConfig(**{
"_name_or_path": "DiffSynth-Studio/Nexus-GenV2",
"architectures": [
"Qwen2_5_VLForConditionalGeneration"
],
"attention_dropout": 0.0,
"auto_map": {
"AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig",
"AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel",
"AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration"
},
"bos_token_id": 151643,
"eos_token_id": 151645,
"hidden_act": "silu",
"hidden_size": 3584,
"image_token_id": 151655,
"initializer_range": 0.02,
"intermediate_size": 18944,
"max_position_embeddings": 128000,
"max_window_layers": 28,
"model_type": "qwen2_5_vl",
"num_attention_heads": 28,
"num_hidden_layers": 28,
"num_key_value_heads": 4,
"pad_token_id": 151643,
"rms_norm_eps": 1e-06,
"rope_scaling": {
"mrope_section": [
16,
24,
24
],
"rope_type": "default",
"type": "default"
},
"rope_theta": 1000000.0,
"sliding_window": 32768,
"tie_word_embeddings": False,
"torch_dtype": "bfloat16",
"transformers_version": "4.49.0",
"use_cache": False,
"use_sliding_window": False,
"video_token_id": 151656,
"vision_config": {
"hidden_size": 1280,
"in_chans": 3,
"model_type": "qwen2_5_vl",
"spatial_patch_size": 14,
"tokens_per_second": 2,
"torch_dtype": "bfloat16"
},
"vision_end_token_id": 151653,
"vision_start_token_id": 151652,
"vision_token_id": 151654,
"vocab_size": 152064
})
self.model = Qwen2_5_VLForConditionalGeneration(model_config)
self.processor = None
def load_processor(self, path):
from .nexus_gen_ar_model import Qwen2_5_VLProcessor
self.processor = Qwen2_5_VLProcessor.from_pretrained(path)
@staticmethod
def state_dict_converter():
return NexusGenAutoregressiveModelStateDictConverter()
def bound_image(self, image, max_pixels=262640):
from qwen_vl_utils import smart_resize
resized_height, resized_width = smart_resize(
image.height,
image.width,
max_pixels=max_pixels,
)
return image.resize((resized_width, resized_height))
def get_editing_msg(self, instruction):
if '<image>' not in instruction:
instruction = '<image> ' + instruction
messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is the image: <image>"}]
return messages
def get_generation_msg(self, instruction):
instruction = "Generate an image according to the following description: {}".format(instruction)
messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is an image based on the description: <image>"}]
return messages
def forward(self, instruction, ref_image=None, num_img_tokens=81):
"""
Generate target embeddings for the given instruction and reference image.
"""
if ref_image is not None:
messages = self.get_editing_msg(instruction)
images = [self.bound_image(ref_image)] + [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))]
output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens)
else:
messages = self.get_generation_msg(instruction)
images = [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))]
output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens)
return output_image_embeddings
def get_target_embeddings(self, images, messages, processor, model, num_img_tokens=81):
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
text = text.replace('<image>', '<|vision_start|><|image_pad|><|vision_end|>')
inputs = processor(
text=[text],
images=images,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(model.device)
input_embeds = model.model.embed_tokens(inputs['input_ids'])
image_embeds = model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw'])
ground_truth_image_embeds = image_embeds[-num_img_tokens:]
input_image_embeds = image_embeds[:-num_img_tokens]
image_mask = inputs['input_ids'] == model.config.image_token_id
indices = image_mask.cumsum(dim=1)
input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask)
gt_image_mask = torch.logical_and(image_mask, ~input_image_mask)
input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds)
input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds)
image_prefill_embeds = model.image_prefill_embeds(
torch.arange(81, device=model.device).long()
)
input_embeds = input_embeds.masked_scatter(gt_image_mask.unsqueeze(-1).expand_as(input_embeds), image_prefill_embeds)
position_ids, _ = model.get_rope_index(
inputs['input_ids'],
inputs['image_grid_thw'],
attention_mask=inputs['attention_mask'])
position_ids = position_ids.contiguous()
outputs = model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True)
output_image_embeddings = outputs.image_embeddings[:, :-1, :]
output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]]
return output_image_embeddings, input_image_embeds, inputs['image_grid_thw']
class NexusGenAutoregressiveModelStateDictConverter:
def __init__(self):
pass
def from_civitai(self, state_dict):
state_dict = {"model." + key: value for key, value in state_dict.items()}
return state_dict

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,417 @@
import math
import torch
import torch.nn as nn
from typing import Optional, Tuple
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
mrope_section = mrope_section * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim
)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim
)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class Qwen2_5_VLRotaryEmbedding(nn.Module):
def __init__(self, config, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
from transformers.modeling_rope_utils import _compute_default_rope_parameters
self.rope_init_fn = _compute_default_rope_parameters
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for the grids
# So we expand the inv_freq to shape (3, ...)
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class Qwen2_5_VLAttention(nn.Module):
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.is_causal = True
self.attention_dropout = config.attention_dropout
self.rope_scaling = config.rope_scaling
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
# Fix precision issues in Qwen2-VL float16 inference
# Replace inf values with zeros in attention weights to prevent NaN propagation
if query_states.dtype == torch.float16:
attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights)
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output
class Qwen2MLP(nn.Module):
def __init__(self, config):
super().__init__()
from transformers.activations import ACT2FN
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Qwen2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Qwen2_5_VLDecoderLayer(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Qwen2_5_VLAttention(config, layer_idx)
self.mlp = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class NexusGenImageEmbeddingMerger(nn.Module):
def __init__(self, num_layers=1, out_channel=4096, expand_ratio=4, device='cpu'):
super().__init__()
from transformers import Qwen2_5_VLConfig
from transformers.activations import ACT2FN
config = Qwen2_5_VLConfig(**{
"_name_or_path": "DiffSynth-Studio/Nexus-GenV2",
"architectures": [
"Qwen2_5_VLForConditionalGeneration"
],
"attention_dropout": 0.0,
"auto_map": {
"AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig",
"AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel",
"AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration"
},
"bos_token_id": 151643,
"eos_token_id": 151645,
"hidden_act": "silu",
"hidden_size": 3584,
"image_token_id": 151655,
"initializer_range": 0.02,
"intermediate_size": 18944,
"max_position_embeddings": 128000,
"max_window_layers": 28,
"model_type": "qwen2_5_vl",
"num_attention_heads": 28,
"num_hidden_layers": 28,
"num_key_value_heads": 4,
"pad_token_id": 151643,
"rms_norm_eps": 1e-06,
"rope_scaling": {
"mrope_section": [
16,
24,
24
],
"rope_type": "default",
"type": "default"
},
"rope_theta": 1000000.0,
"sliding_window": 32768,
"tie_word_embeddings": False,
"torch_dtype": "bfloat16",
"transformers_version": "4.49.0",
"use_cache": False,
"use_sliding_window": False,
"video_token_id": 151656,
"vision_config": {
"hidden_size": 1280,
"in_chans": 3,
"model_type": "qwen2_5_vl",
"spatial_patch_size": 14,
"tokens_per_second": 2,
"torch_dtype": "bfloat16"
},
"vision_end_token_id": 151653,
"vision_start_token_id": 151652,
"vision_token_id": 151654,
"vocab_size": 152064
})
self.config = config
self.num_layers = num_layers
self.layers = nn.ModuleList([Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(num_layers)])
self.projector = nn.Sequential(Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps),
nn.Linear(config.hidden_size, out_channel * expand_ratio),
Qwen2RMSNorm(out_channel * expand_ratio, eps=config.rms_norm_eps),
ACT2FN[config.hidden_act], nn.Linear(out_channel * expand_ratio, out_channel),
Qwen2RMSNorm(out_channel, eps=config.rms_norm_eps))
self.base_grid = torch.tensor([[1, 72, 72]], device=device)
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config, device=device)
def get_position_ids(self, image_grid_thw):
"""
Generates position ids for the input embeddings grid.
modified from the qwen2_vl mrope.
"""
batch_size = image_grid_thw.shape[0]
spatial_merge_size = self.config.vision_config.spatial_merge_size
t, h, w = (
image_grid_thw[0][0],
image_grid_thw[0][1],
image_grid_thw[0][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
scale_h = self.base_grid[0][1].item() / h.item()
scale_w = self.base_grid[0][2].item() / w.item()
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
time_tensor = expanded_range * self.config.vision_config.tokens_per_second
t_index = time_tensor.long().flatten().to(image_grid_thw.device)
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten().to(image_grid_thw.device) * scale_h
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten().to(image_grid_thw.device) * scale_w
# 3, B, L
position_ids = torch.stack([t_index, h_index, w_index]).unsqueeze(0).repeat(batch_size, 1, 1).permute(1, 0, 2)
return position_ids
def forward(self, embeds, embeds_grid, ref_embeds=None, ref_embeds_grid=None):
position_ids = self.get_position_ids(embeds_grid)
hidden_states = embeds
if ref_embeds is not None:
position_ids_ref_embeds = self.get_position_ids(ref_embeds_grid)
position_ids = torch.cat((position_ids, position_ids_ref_embeds), dim=-1)
hidden_states = torch.cat((embeds, ref_embeds), dim=1)
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for layer in self.layers:
hidden_states = layer(hidden_states, position_embeddings)
hidden_states = self.projector(hidden_states)
return hidden_states
@staticmethod
def state_dict_converter():
return NexusGenMergerStateDictConverter()
class NexusGenMergerStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
merger_state_dict = {key.replace("embedding_merger.", ""): value for key, value in state_dict.items() if key.startswith('embedding_merger.')}
return merger_state_dict
class NexusGenAdapter(nn.Module):
"""
Adapter for Nexus-Gen generation decoder.
"""
def __init__(self, input_dim=3584, output_dim=4096):
super(NexusGenAdapter, self).__init__()
self.adapter = nn.Sequential(nn.Linear(input_dim, output_dim),
nn.LayerNorm(output_dim), nn.ReLU(),
nn.Linear(output_dim, output_dim),
nn.LayerNorm(output_dim))
def forward(self, x):
return self.adapter(x)
@staticmethod
def state_dict_converter():
return NexusGenAdapterStateDictConverter()
class NexusGenAdapterStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
adapter_state_dict = {key: value for key, value in state_dict.items() if key.startswith('adapter.')}
return adapter_state_dict

View File

@@ -0,0 +1,56 @@
import torch
import torch.nn as nn
from .general_modules import RMSNorm
class BlockWiseControlBlock(torch.nn.Module):
# [linear, gelu, linear]
def __init__(self, dim: int = 3072):
super().__init__()
self.x_rms = RMSNorm(dim, eps=1e-6)
self.y_rms = RMSNorm(dim, eps=1e-6)
self.input_proj = nn.Linear(dim, dim)
self.act = nn.GELU()
self.output_proj = nn.Linear(dim, dim)
def forward(self, x, y):
x, y = self.x_rms(x), self.y_rms(y)
x = self.input_proj(x + y)
x = self.act(x)
x = self.output_proj(x)
return x
def init_weights(self):
# zero initialize output_proj
nn.init.zeros_(self.output_proj.weight)
nn.init.zeros_(self.output_proj.bias)
class QwenImageBlockWiseControlNet(torch.nn.Module):
def __init__(
self,
num_layers: int = 60,
in_dim: int = 64,
additional_in_dim: int = 0,
dim: int = 3072,
):
super().__init__()
self.img_in = nn.Linear(in_dim + additional_in_dim, dim)
self.controlnet_blocks = nn.ModuleList(
[
BlockWiseControlBlock(dim)
for _ in range(num_layers)
]
)
def init_weight(self):
nn.init.zeros_(self.img_in.weight)
nn.init.zeros_(self.img_in.bias)
for block in self.controlnet_blocks:
block.init_weights()
def process_controlnet_conditioning(self, controlnet_conditioning):
return self.img_in(controlnet_conditioning)
def blockwise_forward(self, img, controlnet_conditioning, block_id):
return self.controlnet_blocks[block_id](img, controlnet_conditioning)

View File

@@ -0,0 +1,685 @@
import torch, math, functools
import torch.nn as nn
from typing import Tuple, Optional, Union, List
from einops import rearrange
from .general_modules import TimestepEmbeddings, RMSNorm, AdaLayerNorm
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, attention_mask = None, enable_fp8_attention: bool = False):
if FLASH_ATTN_3_AVAILABLE and attention_mask is None:
if not enable_fp8_attention:
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v)
if isinstance(x, tuple):
x = x[0]
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
else:
origin_dtype = q.dtype
q_std, k_std, v_std = q.std(), k.std(), v.std()
q, k, v = (q / q_std).to(torch.float8_e4m3fn), (k / k_std).to(torch.float8_e4m3fn), (v / v_std).to(torch.float8_e4m3fn)
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=q_std * k_std / math.sqrt(q.size(-1)))
if isinstance(x, tuple):
x = x[0]
x = x.to(origin_dtype) * v_std
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
else:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
return x
class ApproximateGELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
return x * torch.sigmoid(1.702 * x)
def apply_rotary_emb_qwen(
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]
):
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
return x_out.type_as(x)
class QwenEmbedRope(nn.Module):
def __init__(self, theta: int, axes_dim: list[int], scale_rope=False):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
pos_index = torch.arange(4096)
neg_index = torch.arange(4096).flip(0) * -1 - 1
self.pos_freqs = torch.cat([
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
], dim=1)
self.neg_freqs = torch.cat([
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
], dim=1)
self.rope_cache = {}
self.scale_rope = scale_rope
def rope_params(self, index, dim, theta=10000):
"""
Args:
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
"""
assert dim % 2 == 0
freqs = torch.outer(
index,
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))
)
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens):
if isinstance(video_fhw, list):
video_fhw = tuple(max([i[j] for i in video_fhw]) for j in range(3))
_, height, width = video_fhw
if self.scale_rope:
max_vid_index = max(height // 2, width // 2)
else:
max_vid_index = max(height, width)
required_len = max_vid_index + max(txt_seq_lens)
cur_max_len = self.pos_freqs.shape[0]
if required_len <= cur_max_len:
return
new_max_len = math.ceil(required_len / 512) * 512
pos_index = torch.arange(new_max_len)
neg_index = torch.arange(new_max_len).flip(0) * -1 - 1
self.pos_freqs = torch.cat([
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
], dim=1)
self.neg_freqs = torch.cat([
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
], dim=1)
return
def forward(self, video_fhw, txt_seq_lens, device):
self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
vid_freqs = []
max_vid_index = 0
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
rope_key = f"{idx}_{height}_{width}"
if rope_key not in self.rope_cache:
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat(
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
self.rope_cache[rope_key] = freqs.clone().contiguous()
vid_freqs.append(self.rope_cache[rope_key])
if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else:
max_vid_index = max(height, width, max_vid_index)
max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
def forward_sampling(self, video_fhw, txt_seq_lens, device):
self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
vid_freqs = []
max_vid_index = 0
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
rope_key = f"{idx}_{height}_{width}"
if idx > 0 and f"{0}_{height}_{width}" not in self.rope_cache:
frame_0, height_0, width_0 = video_fhw[0]
rope_key_0 = f"0_{height_0}_{width_0}"
spatial_freqs_0 = self.rope_cache[rope_key_0].reshape(frame_0, height_0, width_0, -1)
h_indices = torch.linspace(0, height_0 - 1, height).long()
w_indices = torch.linspace(0, width_0 - 1, width).long()
h_grid, w_grid = torch.meshgrid(h_indices, w_indices, indexing='ij')
sampled_rope = spatial_freqs_0[:, h_grid, w_grid, :]
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
sampled_rope[:, :, :, :freqs_frame.shape[-1]] = freqs_frame
seq_lens = frame * height * width
self.rope_cache[rope_key] = sampled_rope.reshape(seq_lens, -1).clone()
if rope_key not in self.rope_cache:
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat(
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
self.rope_cache[rope_key] = freqs.clone()
vid_freqs.append(self.rope_cache[rope_key].contiguous())
if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else:
max_vid_index = max(height, width, max_vid_index)
max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
class QwenEmbedLayer3DRope(nn.Module):
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
pos_index = torch.arange(4096)
neg_index = torch.arange(4096).flip(0) * -1 - 1
self.pos_freqs = torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
)
self.neg_freqs = torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
],
dim=1,
)
self.scale_rope = scale_rope
def rope_params(self, index, dim, theta=10000):
"""
Args:
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
"""
assert dim % 2 == 0
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
def forward(self, video_fhw, txt_seq_lens, device):
"""
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
txt_length: [bs] a list of 1 integers representing the length of the text
"""
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
video_fhw = [video_fhw]
if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
if not isinstance(video_fhw, list):
video_fhw = [video_fhw]
vid_freqs = []
max_vid_index = 0
layer_num = len(video_fhw) - 1
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
if idx != layer_num:
video_freq = self._compute_video_freqs(frame, height, width, idx)
else:
### For the condition image, we set the layer index to -1
video_freq = self._compute_condition_freqs(frame, height, width)
video_freq = video_freq.to(device)
vid_freqs.append(video_freq)
if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else:
max_vid_index = max(height, width, max_vid_index)
max_vid_index = max(max_vid_index, layer_num)
max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
@functools.lru_cache(maxsize=None)
def _compute_condition_freqs(self, frame, height, width):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
class QwenFeedForward(nn.Module):
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
dropout: float = 0.0,
):
super().__init__()
inner_dim = int(dim * 4)
self.net = nn.ModuleList([])
self.net.append(ApproximateGELU(dim, inner_dim))
self.net.append(nn.Dropout(dropout))
self.net.append(nn.Linear(inner_dim, dim_out))
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class QwenDoubleStreamAttention(nn.Module):
def __init__(
self,
dim_a,
dim_b,
num_heads,
head_dim,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.to_q = nn.Linear(dim_a, dim_a)
self.to_k = nn.Linear(dim_a, dim_a)
self.to_v = nn.Linear(dim_a, dim_a)
self.norm_q = RMSNorm(head_dim, eps=1e-6)
self.norm_k = RMSNorm(head_dim, eps=1e-6)
self.add_q_proj = nn.Linear(dim_b, dim_b)
self.add_k_proj = nn.Linear(dim_b, dim_b)
self.add_v_proj = nn.Linear(dim_b, dim_b)
self.norm_added_q = RMSNorm(head_dim, eps=1e-6)
self.norm_added_k = RMSNorm(head_dim, eps=1e-6)
self.to_out = torch.nn.Sequential(nn.Linear(dim_a, dim_a))
self.to_add_out = nn.Linear(dim_b, dim_b)
def forward(
self,
image: torch.FloatTensor,
text: torch.FloatTensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
enable_fp8_attention: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
seq_txt = txt_q.shape[1]
img_q = rearrange(img_q, 'b s (h d) -> b h s d', h=self.num_heads)
img_k = rearrange(img_k, 'b s (h d) -> b h s d', h=self.num_heads)
img_v = rearrange(img_v, 'b s (h d) -> b h s d', h=self.num_heads)
txt_q = rearrange(txt_q, 'b s (h d) -> b h s d', h=self.num_heads)
txt_k = rearrange(txt_k, 'b s (h d) -> b h s d', h=self.num_heads)
txt_v = rearrange(txt_v, 'b s (h d) -> b h s d', h=self.num_heads)
img_q, img_k = self.norm_q(img_q), self.norm_k(img_k)
txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k)
if image_rotary_emb is not None:
img_freqs, txt_freqs = image_rotary_emb
img_q = apply_rotary_emb_qwen(img_q, img_freqs)
img_k = apply_rotary_emb_qwen(img_k, img_freqs)
txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs)
txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs)
joint_q = torch.cat([txt_q, img_q], dim=2)
joint_k = torch.cat([txt_k, img_k], dim=2)
joint_v = torch.cat([txt_v, img_v], dim=2)
joint_attn_out = qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], attention_mask=attention_mask, enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype)
txt_attn_output = joint_attn_out[:, :seq_txt, :]
img_attn_output = joint_attn_out[:, seq_txt:, :]
img_attn_output = self.to_out(img_attn_output)
txt_attn_output = self.to_add_out(txt_attn_output)
return img_attn_output, txt_attn_output
class QwenImageTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
eps: float = 1e-6,
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.img_mod = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 6 * dim),
)
self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
self.attn = QwenDoubleStreamAttention(
dim_a=dim,
dim_b=dim,
num_heads=num_attention_heads,
head_dim=attention_head_dim,
)
self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
self.img_mlp = QwenFeedForward(dim=dim, dim_out=dim)
self.txt_mod = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 6 * dim, bias=True),
)
self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim)
def _modulate(self, x, mod_params, index=None):
shift, scale, gate = mod_params.chunk(3, dim=-1)
if index is not None:
# Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts)
# So shift, scale, gate have shape [2*actual_batch, d]
actual_batch = shift.size(0) // 2
shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d]
scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]
gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]
# index: [b, l] where b is actual batch size
# Expand to [b, l, 1] to match feature dimension
index_expanded = index.unsqueeze(-1) # [b, l, 1]
# Expand chunks to [b, 1, d] then broadcast to [b, l, d]
shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d]
shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d]
scale_0_exp = scale_0.unsqueeze(1)
scale_1_exp = scale_1.unsqueeze(1)
gate_0_exp = gate_0.unsqueeze(1)
gate_1_exp = gate_1.unsqueeze(1)
# Use torch.where to select based on index
shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp)
scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp)
gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp)
else:
shift_result = shift.unsqueeze(1)
scale_result = scale.unsqueeze(1)
gate_result = gate.unsqueeze(1)
return x * (1 + scale_result) + shift_result, gate_result
def forward(
self,
image: torch.Tensor,
text: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
enable_fp8_attention = False,
modulate_index: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
if modulate_index is not None:
temb = torch.chunk(temb, 2, dim=0)[0]
txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
img_normed = self.img_norm1(image)
img_modulated, img_gate = self._modulate(img_normed, img_mod_attn, index=modulate_index)
txt_normed = self.txt_norm1(text)
txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn)
img_attn_out, txt_attn_out = self.attn(
image=img_modulated,
text=txt_modulated,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
enable_fp8_attention=enable_fp8_attention,
)
image = image + img_gate * img_attn_out
text = text + txt_gate * txt_attn_out
img_normed_2 = self.img_norm2(image)
img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp, index=modulate_index)
txt_normed_2 = self.txt_norm2(text)
txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp)
img_mlp_out = self.img_mlp(img_modulated_2)
txt_mlp_out = self.txt_mlp(txt_modulated_2)
image = image + img_gate_2 * img_mlp_out
text = text + txt_gate_2 * txt_mlp_out
return text, image
class QwenImageDiT(torch.nn.Module):
def __init__(
self,
num_layers: int = 60,
use_layer3d_rope: bool = False,
use_additional_t_cond: bool = False,
):
super().__init__()
if not use_layer3d_rope:
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
else:
self.pos_embed = QwenEmbedLayer3DRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=False, use_additional_t_cond=use_additional_t_cond)
self.txt_norm = RMSNorm(3584, eps=1e-6)
self.img_in = nn.Linear(64, 3072)
self.txt_in = nn.Linear(3584, 3072)
self.transformer_blocks = nn.ModuleList(
[
QwenImageTransformerBlock(
dim=3072,
num_attention_heads=24,
attention_head_dim=128,
)
for _ in range(num_layers)
]
)
self.norm_out = AdaLayerNorm(3072, single=True)
self.proj_out = nn.Linear(3072, 64)
def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, entity_masks, height, width, image, img_shapes):
# prompt_emb
all_prompt_emb = entity_prompt_emb + [prompt_emb]
all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb]
all_prompt_emb = torch.cat(all_prompt_emb, dim=1)
# image_rotary_emb
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
entity_seq_lens = [emb_mask.sum(dim=1).tolist() for emb_mask in entity_prompt_emb_mask]
entity_rotary_emb = [self.pos_embed(img_shapes, entity_seq_len, device=latents.device)[1] for entity_seq_len in entity_seq_lens]
txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)
image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)
# attention_mask
repeat_dim = latents.shape[1]
max_masks = entity_masks.shape[1]
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
global_mask = torch.ones_like(entity_masks[0]).to(device=latents.device, dtype=latents.dtype)
entity_masks = entity_masks + [global_mask]
N = len(entity_masks)
batch_size = entity_masks[0].shape[0]
seq_lens = [mask_.sum(dim=1).item() for mask_ in entity_prompt_emb_mask] + [prompt_emb_mask.sum(dim=1).item()]
total_seq_len = sum(seq_lens) + image.shape[1]
patched_masks = []
for i in range(N):
patched_mask = rearrange(entity_masks[i], "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
patched_masks.append(patched_mask)
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
# prompt-image attention mask
image_start = sum(seq_lens)
image_end = total_seq_len
cumsum = [0]
single_image_seq = image_end - image_start
for length in seq_lens:
cumsum.append(cumsum[-1] + length)
for i in range(N):
prompt_start = cumsum[i]
prompt_end = cumsum[i+1]
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1)
# repeat image mask to match the single image sequence length
repeat_time = single_image_seq // image_mask.shape[-1]
image_mask = image_mask.repeat(1, 1, repeat_time)
# prompt update with image
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
# image update with prompt
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
# prompt-prompt attention mask, let the prompt tokens not attend to each other
for i in range(N):
for j in range(N):
if i == j:
continue
start_i, end_i = cumsum[i], cumsum[i+1]
start_j, end_j = cumsum[j], cumsum[j+1]
attention_mask[:, start_i:end_i, start_j:end_j] = False
attention_mask = attention_mask.float()
attention_mask[attention_mask == 0] = float('-inf')
attention_mask[attention_mask == 1] = 0
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1)
return all_prompt_emb, image_rotary_emb, attention_mask
def forward(
self,
latents=None,
timestep=None,
prompt_emb=None,
prompt_emb_mask=None,
height=None,
width=None,
):
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
image = self.img_in(image)
text = self.txt_in(self.txt_norm(prompt_emb))
conditioning = self.time_text_embed(timestep, image.dtype)
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
for block in self.transformer_blocks:
text, image = block(
image=image,
text=text,
temb=conditioning,
image_rotary_emb=image_rotary_emb,
)
image = self.norm_out(image, conditioning)
image = self.proj_out(image)
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
return image

View File

@@ -0,0 +1,128 @@
import torch
class CompressedMLP(torch.nn.Module):
def __init__(self, in_dim, mid_dim, out_dim, bias=False):
super().__init__()
self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias)
self.proj_out = torch.nn.Linear(mid_dim, out_dim, bias=bias)
def forward(self, x, residual=None):
x = self.proj_in(x)
if residual is not None: x = x + residual
x = self.proj_out(x)
return x
class ImageEmbeddingToLoraMatrix(torch.nn.Module):
def __init__(self, in_dim, compress_dim, lora_a_dim, lora_b_dim, rank):
super().__init__()
self.proj_a = CompressedMLP(in_dim, compress_dim, lora_a_dim * rank)
self.proj_b = CompressedMLP(in_dim, compress_dim, lora_b_dim * rank)
self.lora_a_dim = lora_a_dim
self.lora_b_dim = lora_b_dim
self.rank = rank
def forward(self, x, residual=None):
lora_a = self.proj_a(x, residual).view(self.rank, self.lora_a_dim)
lora_b = self.proj_b(x, residual).view(self.lora_b_dim, self.rank)
return lora_a, lora_b
class SequencialMLP(torch.nn.Module):
def __init__(self, length, in_dim, mid_dim, out_dim, bias=False):
super().__init__()
self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias)
self.proj_out = torch.nn.Linear(length * mid_dim, out_dim, bias=bias)
self.length = length
self.in_dim = in_dim
self.mid_dim = mid_dim
def forward(self, x):
x = x.view(self.length, self.in_dim)
x = self.proj_in(x)
x = x.view(1, self.length * self.mid_dim)
x = self.proj_out(x)
return x
class LoRATrainerBlock(torch.nn.Module):
def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024):
super().__init__()
self.lora_patterns = lora_patterns
self.block_id = block_id
self.layers = []
for name, lora_a_dim, lora_b_dim in self.lora_patterns:
self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank))
self.layers = torch.nn.ModuleList(self.layers)
if use_residual:
self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim)
else:
self.proj_residual = None
def forward(self, x, residual=None):
lora = {}
if self.proj_residual is not None: residual = self.proj_residual(residual)
for lora_pattern, layer in zip(self.lora_patterns, self.layers):
name = lora_pattern[0]
lora_a, lora_b = layer(x, residual=residual)
lora[f"transformer_blocks.{self.block_id}.{name}.lora_A.default.weight"] = lora_a
lora[f"transformer_blocks.{self.block_id}.{name}.lora_B.default.weight"] = lora_b
return lora
class QwenImageImage2LoRAModel(torch.nn.Module):
def __init__(self, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024):
super().__init__()
self.lora_patterns = [
[
("attn.to_q", 3072, 3072),
("attn.to_k", 3072, 3072),
("attn.to_v", 3072, 3072),
("attn.to_out.0", 3072, 3072),
],
[
("img_mlp.net.2", 3072*4, 3072),
("img_mod.1", 3072, 3072*6),
],
[
("attn.add_q_proj", 3072, 3072),
("attn.add_k_proj", 3072, 3072),
("attn.add_v_proj", 3072, 3072),
("attn.to_add_out", 3072, 3072),
],
[
("txt_mlp.net.2", 3072*4, 3072),
("txt_mod.1", 3072, 3072*6),
],
]
self.num_blocks = num_blocks
self.blocks = []
for lora_patterns in self.lora_patterns:
for block_id in range(self.num_blocks):
self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim))
self.blocks = torch.nn.ModuleList(self.blocks)
self.residual_scale = 0.05
self.use_residual = use_residual
def forward(self, x, residual=None):
if residual is not None:
if self.use_residual:
residual = residual * self.residual_scale
else:
residual = None
lora = {}
for block in self.blocks:
lora.update(block(x, residual))
return lora
def initialize_weights(self):
state_dict = self.state_dict()
for name in state_dict:
if ".proj_a." in name:
state_dict[name] = state_dict[name] * 0.3
elif ".proj_b.proj_out." in name:
state_dict[name] = state_dict[name] * 0
elif ".proj_residual.proj_out." in name:
state_dict[name] = state_dict[name] * 0.3
self.load_state_dict(state_dict)

View File

@@ -0,0 +1,190 @@
import torch
from typing import Optional, Union
class QwenImageTextEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLModel
config = Qwen2_5_VLConfig(**{
"architectures": [
"Qwen2_5_VLForConditionalGeneration"
],
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"hidden_act": "silu",
"hidden_size": 3584,
"image_token_id": 151655,
"initializer_range": 0.02,
"intermediate_size": 18944,
"max_position_embeddings": 128000,
"max_window_layers": 28,
"model_type": "qwen2_5_vl",
"num_attention_heads": 28,
"num_hidden_layers": 28,
"num_key_value_heads": 4,
"rms_norm_eps": 1e-06,
"rope_scaling": {
"mrope_section": [
16,
24,
24
],
"rope_type": "default",
"type": "default"
},
"rope_theta": 1000000.0,
"sliding_window": 32768,
"text_config": {
"architectures": [
"Qwen2_5_VLForConditionalGeneration"
],
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"hidden_act": "silu",
"hidden_size": 3584,
"image_token_id": None,
"initializer_range": 0.02,
"intermediate_size": 18944,
"layer_types": [
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention"
],
"max_position_embeddings": 128000,
"max_window_layers": 28,
"model_type": "qwen2_5_vl_text",
"num_attention_heads": 28,
"num_hidden_layers": 28,
"num_key_value_heads": 4,
"rms_norm_eps": 1e-06,
"rope_scaling": {
"mrope_section": [
16,
24,
24
],
"rope_type": "default",
"type": "default"
},
"rope_theta": 1000000.0,
"sliding_window": None,
"torch_dtype": "float32",
"use_cache": True,
"use_sliding_window": False,
"video_token_id": None,
"vision_end_token_id": 151653,
"vision_start_token_id": 151652,
"vision_token_id": 151654,
"vocab_size": 152064
},
"tie_word_embeddings": False,
"torch_dtype": "float32",
"transformers_version": "4.54.0",
"use_cache": True,
"use_sliding_window": False,
"video_token_id": 151656,
"vision_config": {
"depth": 32,
"fullatt_block_indexes": [
7,
15,
23,
31
],
"hidden_act": "silu",
"hidden_size": 1280,
"in_channels": 3,
"in_chans": 3,
"initializer_range": 0.02,
"intermediate_size": 3420,
"model_type": "qwen2_5_vl",
"num_heads": 16,
"out_hidden_size": 3584,
"patch_size": 14,
"spatial_merge_size": 2,
"spatial_patch_size": 14,
"temporal_patch_size": 2,
"tokens_per_second": 2,
"torch_dtype": "float32",
"window_size": 112
},
"vision_end_token_id": 151653,
"vision_start_token_id": 151652,
"vision_token_id": 151654,
"vocab_size": 152064
})
self.model = Qwen2_5_VLModel(config)
self.lm_head = torch.nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.config = config
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
):
output_attentions = False
output_hidden_states = True
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
return outputs.hidden_states

View File

@@ -0,0 +1,726 @@
import torch
from typing import List, Optional, Tuple, Union
from torch import nn
CACHE_T = 2
class QwenImageCausalConv3d(torch.nn.Conv3d):
r"""
A custom 3D causal convolution layer with feature caching support.
This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
caching for efficient inference.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
) -> None:
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
# Set up causal padding
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = torch.nn.functional.pad(x, padding)
return super().forward(x)
class QwenImageRMS_norm(nn.Module):
r"""
A custom RMS normalization layer.
Args:
dim (int): The number of dimensions to normalize over.
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
Default is True.
images (bool, optional): Whether the input represents image data. Default is True.
bias (bool, optional): Whether to include a learnable bias term. Default is False.
"""
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return torch.nn.functional.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class QwenImageResidualBlock(nn.Module):
r"""
A custom residual block module.
Args:
in_dim (int): Number of input channels.
out_dim (int): Number of output channels.
dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
"""
def __init__(
self,
in_dim: int,
out_dim: int,
dropout: float = 0.0,
non_linearity: str = "silu",
) -> None:
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.nonlinearity = torch.nn.SiLU()
# layers
self.norm1 = QwenImageRMS_norm(in_dim, images=False)
self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1)
self.norm2 = QwenImageRMS_norm(out_dim, images=False)
self.dropout = nn.Dropout(dropout)
self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1)
self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
# Apply shortcut connection
h = self.conv_shortcut(x)
# First normalization and activation
x = self.norm1(x)
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
# Second normalization and activation
x = self.norm2(x)
x = self.nonlinearity(x)
# Dropout
x = self.dropout(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv2(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv2(x)
# Add residual connection
return x + h
class QwenImageAttentionBlock(nn.Module):
r"""
Causal self-attention with a single head.
Args:
dim (int): The number of channels in the input tensor.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# layers
self.norm = QwenImageRMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
def forward(self, x):
identity = x
batch_size, channels, time, height, width = x.size()
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
x = self.norm(x)
# compute query, key, value
qkv = self.to_qkv(x)
qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
qkv = qkv.permute(0, 1, 3, 2).contiguous()
q, k, v = qkv.chunk(3, dim=-1)
# apply attention
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
# output projection
x = self.proj(x)
# Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
x = x.view(batch_size, time, channels, height, width)
x = x.permute(0, 2, 1, 3, 4)
return x + identity
class QwenImageUpsample(nn.Upsample):
r"""
Perform upsampling while ensuring the output tensor has the same data type as the input.
Args:
x (torch.Tensor): Input tensor to be upsampled.
Returns:
torch.Tensor: Upsampled tensor with the same data type as the input.
"""
def forward(self, x):
return super().forward(x.float()).type_as(x)
class QwenImageResample(nn.Module):
r"""
A custom resampling module for 2D and 3D data.
Args:
dim (int): The number of input/output channels.
mode (str): The resampling mode. Must be one of:
- 'none': No resampling (identity operation).
- 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
- 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
- 'downsample2d': 2D downsampling with zero-padding and convolution.
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
"""
def __init__(self, dim: int, mode: str) -> None:
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == "upsample2d":
self.resample = nn.Sequential(
QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
)
self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d":
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == "downsample3d":
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == "upsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = "Rep"
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
# cache last frame of last two chunk
cache_x = torch.cat(
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
)
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
if feat_cache[idx] == "Rep":
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
x = self.resample(x)
x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
if self.mode == "downsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
class QwenImageMidBlock(nn.Module):
"""
Middle block for WanVAE encoder and decoder.
Args:
dim (int): Number of input/output channels.
dropout (float): Dropout rate.
non_linearity (str): Type of non-linearity to use.
"""
def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
super().__init__()
self.dim = dim
# Create the components
resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)]
attentions = []
for _ in range(num_layers):
attentions.append(QwenImageAttentionBlock(dim))
resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity))
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
# First residual block
x = self.resnets[0](x, feat_cache, feat_idx)
# Process through attention and residual blocks
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
x = attn(x)
x = resnet(x, feat_cache, feat_idx)
return x
class QwenImageEncoder3d(nn.Module):
r"""
A 3D encoder module.
Args:
dim (int): The base number of channels in the first layer.
z_dim (int): The dimensionality of the latent space.
dim_mult (list of int): Multipliers for the number of channels in each block.
num_res_blocks (int): Number of residual blocks in each block.
attn_scales (list of float): Scales at which to apply attention mechanisms.
temperal_downsample (list of bool): Whether to downsample temporally in each block.
dropout (float): Dropout rate for the dropout layers.
non_linearity (str): Type of non-linearity to use.
"""
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
non_linearity: str = "silu",
image_channels=3
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.nonlinearity = torch.nn.SiLU()
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv_in = QwenImageCausalConv3d(image_channels, dims[0], 3, padding=1)
# downsample blocks
self.down_blocks = torch.nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
self.down_blocks.append(QwenImageAttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
self.down_blocks.append(QwenImageResample(out_dim, mode=mode))
scale /= 2.0
# middle blocks
self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1)
# output blocks
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_in(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_in(x)
## downsamples
for layer in self.down_blocks:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
x = self.mid_block(x, feat_cache, feat_idx)
## head
x = self.norm_out(x)
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_out(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_out(x)
return x
class QwenImageUpBlock(nn.Module):
"""
A block that handles upsampling for the WanVAE decoder.
Args:
in_dim (int): Input dimension
out_dim (int): Output dimension
num_res_blocks (int): Number of residual blocks
dropout (float): Dropout rate
upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
non_linearity (str): Type of non-linearity to use
"""
def __init__(
self,
in_dim: int,
out_dim: int,
num_res_blocks: int,
dropout: float = 0.0,
upsample_mode: Optional[str] = None,
non_linearity: str = "silu",
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# Create layers list
resnets = []
# Add residual blocks and attention if needed
current_dim = in_dim
for _ in range(num_res_blocks + 1):
resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity))
current_dim = out_dim
self.resnets = nn.ModuleList(resnets)
# Add upsampling layer if needed
self.upsamplers = None
if upsample_mode is not None:
self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)])
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
"""
Forward pass through the upsampling block.
Args:
x (torch.Tensor): Input tensor
feat_cache (list, optional): Feature cache for causal convolutions
feat_idx (list, optional): Feature index for cache management
Returns:
torch.Tensor: Output tensor
"""
for resnet in self.resnets:
if feat_cache is not None:
x = resnet(x, feat_cache, feat_idx)
else:
x = resnet(x)
if self.upsamplers is not None:
if feat_cache is not None:
x = self.upsamplers[0](x, feat_cache, feat_idx)
else:
x = self.upsamplers[0](x)
return x
class QwenImageDecoder3d(nn.Module):
r"""
A 3D decoder module.
Args:
dim (int): The base number of channels in the first layer.
z_dim (int): The dimensionality of the latent space.
dim_mult (list of int): Multipliers for the number of channels in each block.
num_res_blocks (int): Number of residual blocks in each block.
attn_scales (list of float): Scales at which to apply attention mechanisms.
temperal_upsample (list of bool): Whether to upsample temporally in each block.
dropout (float): Dropout rate for the dropout layers.
non_linearity (str): Type of non-linearity to use.
"""
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0,
non_linearity: str = "silu",
image_channels=3,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
self.nonlinearity = torch.nn.SiLU()
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2 ** (len(dim_mult) - 2)
# init block
self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1)
# upsample blocks
self.up_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i > 0:
in_dim = in_dim // 2
# Determine if we need upsampling
upsample_mode = None
if i != len(dim_mult) - 1:
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
# Create and add the upsampling block
up_block = QwenImageUpBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks,
dropout=dropout,
upsample_mode=upsample_mode,
non_linearity=non_linearity,
)
self.up_blocks.append(up_block)
# Update scale for next iteration
if upsample_mode is not None:
scale *= 2.0
# output blocks
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
self.conv_out = QwenImageCausalConv3d(out_dim, image_channels, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_in(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_in(x)
## middle
x = self.mid_block(x, feat_cache, feat_idx)
## upsamples
for up_block in self.up_blocks:
x = up_block(x, feat_cache, feat_idx)
## head
x = self.norm_out(x)
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_out(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_out(x)
return x
class QwenImageVAE(torch.nn.Module):
def __init__(
self,
base_dim: int = 96,
z_dim: int = 16,
dim_mult: Tuple[int] = [1, 2, 4, 4],
num_res_blocks: int = 2,
attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True],
dropout: float = 0.0,
image_channels: int = 3,
) -> None:
super().__init__()
self.z_dim = z_dim
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
self.encoder = QwenImageEncoder3d(
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, image_channels=image_channels,
)
self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
self.decoder = QwenImageDecoder3d(
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, image_channels=image_channels,
)
mean = [
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921,
]
std = [
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.9160,
]
self.mean = torch.tensor(mean).view(1, 16, 1, 1, 1)
self.std = 1 / torch.tensor(std).view(1, 16, 1, 1, 1)
def encode(self, x, **kwargs):
x = x.unsqueeze(2)
x = self.encoder(x)
x = self.quant_conv(x)
x = x[:, :16]
mean, std = self.mean.to(dtype=x.dtype, device=x.device), self.std.to(dtype=x.dtype, device=x.device)
x = (x - mean) * std
x = x.squeeze(2)
return x
def decode(self, x, **kwargs):
x = x.unsqueeze(2)
mean, std = self.mean.to(dtype=x.dtype, device=x.device), self.std.to(dtype=x.dtype, device=x.device)
x = x / std + mean
x = self.post_quant_conv(x)
x = self.decoder(x)
x = x.squeeze(2)
return x

View File

@@ -1,587 +0,0 @@
import torch
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
from .tiler import TileWorker
class ControlNetConditioningLayer(torch.nn.Module):
def __init__(self, channels = (3, 16, 32, 96, 256, 320)):
super().__init__()
self.blocks = torch.nn.ModuleList([])
self.blocks.append(torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1))
self.blocks.append(torch.nn.SiLU())
for i in range(1, len(channels) - 2):
self.blocks.append(torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1))
self.blocks.append(torch.nn.SiLU())
self.blocks.append(torch.nn.Conv2d(channels[i], channels[i+1], kernel_size=3, padding=1, stride=2))
self.blocks.append(torch.nn.SiLU())
self.blocks.append(torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1))
def forward(self, conditioning):
for block in self.blocks:
conditioning = block(conditioning)
return conditioning
class SDControlNet(torch.nn.Module):
def __init__(self, global_pool=False):
super().__init__()
self.time_proj = Timesteps(320)
self.time_embedding = torch.nn.Sequential(
torch.nn.Linear(320, 1280),
torch.nn.SiLU(),
torch.nn.Linear(1280, 1280)
)
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
self.blocks = torch.nn.ModuleList([
# CrossAttnDownBlock2D
ResnetBlock(320, 320, 1280),
AttentionBlock(8, 40, 320, 1, 768),
PushBlock(),
ResnetBlock(320, 320, 1280),
AttentionBlock(8, 40, 320, 1, 768),
PushBlock(),
DownSampler(320),
PushBlock(),
# CrossAttnDownBlock2D
ResnetBlock(320, 640, 1280),
AttentionBlock(8, 80, 640, 1, 768),
PushBlock(),
ResnetBlock(640, 640, 1280),
AttentionBlock(8, 80, 640, 1, 768),
PushBlock(),
DownSampler(640),
PushBlock(),
# CrossAttnDownBlock2D
ResnetBlock(640, 1280, 1280),
AttentionBlock(8, 160, 1280, 1, 768),
PushBlock(),
ResnetBlock(1280, 1280, 1280),
AttentionBlock(8, 160, 1280, 1, 768),
PushBlock(),
DownSampler(1280),
PushBlock(),
# DownBlock2D
ResnetBlock(1280, 1280, 1280),
PushBlock(),
ResnetBlock(1280, 1280, 1280),
PushBlock(),
# UNetMidBlock2DCrossAttn
ResnetBlock(1280, 1280, 1280),
AttentionBlock(8, 160, 1280, 1, 768),
ResnetBlock(1280, 1280, 1280),
PushBlock()
])
self.controlnet_blocks = torch.nn.ModuleList([
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
])
self.global_pool = global_pool
def forward(
self,
sample, timestep, encoder_hidden_states, conditioning,
tiled=False, tile_size=64, tile_stride=32,
):
# 1. time
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
time_emb = self.time_embedding(time_emb)
time_emb = time_emb.repeat(sample.shape[0], 1)
# 2. pre-process
height, width = sample.shape[2], sample.shape[3]
hidden_states = self.conv_in(sample) + self.controlnet_conv_in(conditioning)
text_emb = encoder_hidden_states
res_stack = [hidden_states]
# 3. blocks
for i, block in enumerate(self.blocks):
if tiled and not isinstance(block, PushBlock):
_, _, inter_height, _ = hidden_states.shape
resize_scale = inter_height / height
hidden_states = TileWorker().tiled_forward(
lambda x: block(x, time_emb, text_emb, res_stack)[0],
hidden_states,
int(tile_size * resize_scale),
int(tile_stride * resize_scale),
tile_device=hidden_states.device,
tile_dtype=hidden_states.dtype
)
else:
hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
# 4. ControlNet blocks
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
# pool
if self.global_pool:
controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
return controlnet_res_stack
def state_dict_converter(self):
return SDControlNetStateDictConverter()
class SDControlNetStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
# architecture
block_types = [
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
'ResnetBlock', 'PushBlock', 'ResnetBlock', 'PushBlock',
'ResnetBlock', 'AttentionBlock', 'ResnetBlock',
'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'UpSampler',
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock'
]
# controlnet_rename_dict
controlnet_rename_dict = {
"controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
"controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
"controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
"controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
"controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
"controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
"controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
"controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
"controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
"controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
"controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
"controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
"controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
"controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
"controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
"controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
}
# Rename each parameter
name_list = sorted([name for name in state_dict])
rename_dict = {}
block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
for name in name_list:
names = name.split(".")
if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
pass
elif name in controlnet_rename_dict:
names = controlnet_rename_dict[name].split(".")
elif names[0] == "controlnet_down_blocks":
names[0] = "controlnet_blocks"
elif names[0] == "controlnet_mid_block":
names = ["controlnet_blocks", "12", names[-1]]
elif names[0] in ["time_embedding", "add_embedding"]:
if names[0] == "add_embedding":
names[0] = "add_time_embedding"
names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
if names[0] == "mid_block":
names.insert(1, "0")
block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
block_type_with_id = ".".join(names[:4])
if block_type_with_id != last_block_type_with_id[block_type]:
block_id[block_type] += 1
last_block_type_with_id[block_type] = block_type_with_id
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
block_id[block_type] += 1
block_type_with_id = ".".join(names[:4])
names = ["blocks", str(block_id[block_type])] + names[4:]
if "ff" in names:
ff_index = names.index("ff")
component = ".".join(names[ff_index:ff_index+3])
component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
names = names[:ff_index] + [component] + names[ff_index+3:]
if "to_out" in names:
names.pop(names.index("to_out") + 1)
else:
raise ValueError(f"Unknown parameters: {name}")
rename_dict[name] = ".".join(names)
# Convert state_dict
state_dict_ = {}
for name, param in state_dict.items():
if ".proj_in." in name or ".proj_out." in name:
param = param.squeeze()
if rename_dict[name] in [
"controlnet_blocks.1.bias", "controlnet_blocks.2.bias", "controlnet_blocks.3.bias", "controlnet_blocks.5.bias", "controlnet_blocks.6.bias",
"controlnet_blocks.8.bias", "controlnet_blocks.9.bias", "controlnet_blocks.10.bias", "controlnet_blocks.11.bias", "controlnet_blocks.12.bias"
]:
continue
state_dict_[rename_dict[name]] = param
return state_dict_
def from_civitai(self, state_dict):
if "mid_block.resnets.1.time_emb_proj.weight" in state_dict:
# For controlnets in diffusers format
return self.from_diffusers(state_dict)
rename_dict = {
"control_model.time_embed.0.weight": "time_embedding.0.weight",
"control_model.time_embed.0.bias": "time_embedding.0.bias",
"control_model.time_embed.2.weight": "time_embedding.2.weight",
"control_model.time_embed.2.bias": "time_embedding.2.bias",
"control_model.input_blocks.0.0.weight": "conv_in.weight",
"control_model.input_blocks.0.0.bias": "conv_in.bias",
"control_model.input_blocks.1.0.in_layers.0.weight": "blocks.0.norm1.weight",
"control_model.input_blocks.1.0.in_layers.0.bias": "blocks.0.norm1.bias",
"control_model.input_blocks.1.0.in_layers.2.weight": "blocks.0.conv1.weight",
"control_model.input_blocks.1.0.in_layers.2.bias": "blocks.0.conv1.bias",
"control_model.input_blocks.1.0.emb_layers.1.weight": "blocks.0.time_emb_proj.weight",
"control_model.input_blocks.1.0.emb_layers.1.bias": "blocks.0.time_emb_proj.bias",
"control_model.input_blocks.1.0.out_layers.0.weight": "blocks.0.norm2.weight",
"control_model.input_blocks.1.0.out_layers.0.bias": "blocks.0.norm2.bias",
"control_model.input_blocks.1.0.out_layers.3.weight": "blocks.0.conv2.weight",
"control_model.input_blocks.1.0.out_layers.3.bias": "blocks.0.conv2.bias",
"control_model.input_blocks.1.1.norm.weight": "blocks.1.norm.weight",
"control_model.input_blocks.1.1.norm.bias": "blocks.1.norm.bias",
"control_model.input_blocks.1.1.proj_in.weight": "blocks.1.proj_in.weight",
"control_model.input_blocks.1.1.proj_in.bias": "blocks.1.proj_in.bias",
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "blocks.1.transformer_blocks.0.attn1.to_q.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "blocks.1.transformer_blocks.0.attn1.to_k.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "blocks.1.transformer_blocks.0.attn1.to_v.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.1.transformer_blocks.0.attn1.to_out.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.1.transformer_blocks.0.attn1.to_out.bias",
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.1.transformer_blocks.0.act_fn.proj.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.1.transformer_blocks.0.act_fn.proj.bias",
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "blocks.1.transformer_blocks.0.ff.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "blocks.1.transformer_blocks.0.ff.bias",
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "blocks.1.transformer_blocks.0.attn2.to_q.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "blocks.1.transformer_blocks.0.attn2.to_k.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "blocks.1.transformer_blocks.0.attn2.to_v.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.1.transformer_blocks.0.attn2.to_out.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.1.transformer_blocks.0.attn2.to_out.bias",
"control_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "blocks.1.transformer_blocks.0.norm1.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "blocks.1.transformer_blocks.0.norm1.bias",
"control_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "blocks.1.transformer_blocks.0.norm2.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "blocks.1.transformer_blocks.0.norm2.bias",
"control_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "blocks.1.transformer_blocks.0.norm3.weight",
"control_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "blocks.1.transformer_blocks.0.norm3.bias",
"control_model.input_blocks.1.1.proj_out.weight": "blocks.1.proj_out.weight",
"control_model.input_blocks.1.1.proj_out.bias": "blocks.1.proj_out.bias",
"control_model.input_blocks.2.0.in_layers.0.weight": "blocks.3.norm1.weight",
"control_model.input_blocks.2.0.in_layers.0.bias": "blocks.3.norm1.bias",
"control_model.input_blocks.2.0.in_layers.2.weight": "blocks.3.conv1.weight",
"control_model.input_blocks.2.0.in_layers.2.bias": "blocks.3.conv1.bias",
"control_model.input_blocks.2.0.emb_layers.1.weight": "blocks.3.time_emb_proj.weight",
"control_model.input_blocks.2.0.emb_layers.1.bias": "blocks.3.time_emb_proj.bias",
"control_model.input_blocks.2.0.out_layers.0.weight": "blocks.3.norm2.weight",
"control_model.input_blocks.2.0.out_layers.0.bias": "blocks.3.norm2.bias",
"control_model.input_blocks.2.0.out_layers.3.weight": "blocks.3.conv2.weight",
"control_model.input_blocks.2.0.out_layers.3.bias": "blocks.3.conv2.bias",
"control_model.input_blocks.2.1.norm.weight": "blocks.4.norm.weight",
"control_model.input_blocks.2.1.norm.bias": "blocks.4.norm.bias",
"control_model.input_blocks.2.1.proj_in.weight": "blocks.4.proj_in.weight",
"control_model.input_blocks.2.1.proj_in.bias": "blocks.4.proj_in.bias",
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "blocks.4.transformer_blocks.0.attn1.to_q.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "blocks.4.transformer_blocks.0.attn1.to_k.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "blocks.4.transformer_blocks.0.attn1.to_v.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.4.transformer_blocks.0.attn1.to_out.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.4.transformer_blocks.0.attn1.to_out.bias",
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.4.transformer_blocks.0.act_fn.proj.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.4.transformer_blocks.0.act_fn.proj.bias",
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "blocks.4.transformer_blocks.0.ff.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "blocks.4.transformer_blocks.0.ff.bias",
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "blocks.4.transformer_blocks.0.attn2.to_q.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "blocks.4.transformer_blocks.0.attn2.to_k.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "blocks.4.transformer_blocks.0.attn2.to_v.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.4.transformer_blocks.0.attn2.to_out.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.4.transformer_blocks.0.attn2.to_out.bias",
"control_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "blocks.4.transformer_blocks.0.norm1.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "blocks.4.transformer_blocks.0.norm1.bias",
"control_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "blocks.4.transformer_blocks.0.norm2.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "blocks.4.transformer_blocks.0.norm2.bias",
"control_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "blocks.4.transformer_blocks.0.norm3.weight",
"control_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "blocks.4.transformer_blocks.0.norm3.bias",
"control_model.input_blocks.2.1.proj_out.weight": "blocks.4.proj_out.weight",
"control_model.input_blocks.2.1.proj_out.bias": "blocks.4.proj_out.bias",
"control_model.input_blocks.3.0.op.weight": "blocks.6.conv.weight",
"control_model.input_blocks.3.0.op.bias": "blocks.6.conv.bias",
"control_model.input_blocks.4.0.in_layers.0.weight": "blocks.8.norm1.weight",
"control_model.input_blocks.4.0.in_layers.0.bias": "blocks.8.norm1.bias",
"control_model.input_blocks.4.0.in_layers.2.weight": "blocks.8.conv1.weight",
"control_model.input_blocks.4.0.in_layers.2.bias": "blocks.8.conv1.bias",
"control_model.input_blocks.4.0.emb_layers.1.weight": "blocks.8.time_emb_proj.weight",
"control_model.input_blocks.4.0.emb_layers.1.bias": "blocks.8.time_emb_proj.bias",
"control_model.input_blocks.4.0.out_layers.0.weight": "blocks.8.norm2.weight",
"control_model.input_blocks.4.0.out_layers.0.bias": "blocks.8.norm2.bias",
"control_model.input_blocks.4.0.out_layers.3.weight": "blocks.8.conv2.weight",
"control_model.input_blocks.4.0.out_layers.3.bias": "blocks.8.conv2.bias",
"control_model.input_blocks.4.0.skip_connection.weight": "blocks.8.conv_shortcut.weight",
"control_model.input_blocks.4.0.skip_connection.bias": "blocks.8.conv_shortcut.bias",
"control_model.input_blocks.4.1.norm.weight": "blocks.9.norm.weight",
"control_model.input_blocks.4.1.norm.bias": "blocks.9.norm.bias",
"control_model.input_blocks.4.1.proj_in.weight": "blocks.9.proj_in.weight",
"control_model.input_blocks.4.1.proj_in.bias": "blocks.9.proj_in.bias",
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.9.transformer_blocks.0.attn1.to_q.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.9.transformer_blocks.0.attn1.to_k.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.9.transformer_blocks.0.attn1.to_v.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.9.transformer_blocks.0.attn1.to_out.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.9.transformer_blocks.0.attn1.to_out.bias",
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.9.transformer_blocks.0.act_fn.proj.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.9.transformer_blocks.0.act_fn.proj.bias",
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.9.transformer_blocks.0.ff.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.9.transformer_blocks.0.ff.bias",
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.9.transformer_blocks.0.attn2.to_q.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.9.transformer_blocks.0.attn2.to_k.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.9.transformer_blocks.0.attn2.to_v.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.9.transformer_blocks.0.attn2.to_out.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.9.transformer_blocks.0.attn2.to_out.bias",
"control_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.9.transformer_blocks.0.norm1.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.9.transformer_blocks.0.norm1.bias",
"control_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.9.transformer_blocks.0.norm2.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.9.transformer_blocks.0.norm2.bias",
"control_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.9.transformer_blocks.0.norm3.weight",
"control_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.9.transformer_blocks.0.norm3.bias",
"control_model.input_blocks.4.1.proj_out.weight": "blocks.9.proj_out.weight",
"control_model.input_blocks.4.1.proj_out.bias": "blocks.9.proj_out.bias",
"control_model.input_blocks.5.0.in_layers.0.weight": "blocks.11.norm1.weight",
"control_model.input_blocks.5.0.in_layers.0.bias": "blocks.11.norm1.bias",
"control_model.input_blocks.5.0.in_layers.2.weight": "blocks.11.conv1.weight",
"control_model.input_blocks.5.0.in_layers.2.bias": "blocks.11.conv1.bias",
"control_model.input_blocks.5.0.emb_layers.1.weight": "blocks.11.time_emb_proj.weight",
"control_model.input_blocks.5.0.emb_layers.1.bias": "blocks.11.time_emb_proj.bias",
"control_model.input_blocks.5.0.out_layers.0.weight": "blocks.11.norm2.weight",
"control_model.input_blocks.5.0.out_layers.0.bias": "blocks.11.norm2.bias",
"control_model.input_blocks.5.0.out_layers.3.weight": "blocks.11.conv2.weight",
"control_model.input_blocks.5.0.out_layers.3.bias": "blocks.11.conv2.bias",
"control_model.input_blocks.5.1.norm.weight": "blocks.12.norm.weight",
"control_model.input_blocks.5.1.norm.bias": "blocks.12.norm.bias",
"control_model.input_blocks.5.1.proj_in.weight": "blocks.12.proj_in.weight",
"control_model.input_blocks.5.1.proj_in.bias": "blocks.12.proj_in.bias",
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.12.transformer_blocks.0.attn1.to_q.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.12.transformer_blocks.0.attn1.to_k.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.12.transformer_blocks.0.attn1.to_v.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.12.transformer_blocks.0.attn1.to_out.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.12.transformer_blocks.0.attn1.to_out.bias",
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.12.transformer_blocks.0.act_fn.proj.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.12.transformer_blocks.0.act_fn.proj.bias",
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.12.transformer_blocks.0.ff.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.12.transformer_blocks.0.ff.bias",
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.12.transformer_blocks.0.attn2.to_q.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.12.transformer_blocks.0.attn2.to_k.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.12.transformer_blocks.0.attn2.to_v.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.12.transformer_blocks.0.attn2.to_out.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.12.transformer_blocks.0.attn2.to_out.bias",
"control_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.12.transformer_blocks.0.norm1.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.12.transformer_blocks.0.norm1.bias",
"control_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.12.transformer_blocks.0.norm2.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.12.transformer_blocks.0.norm2.bias",
"control_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.12.transformer_blocks.0.norm3.weight",
"control_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.12.transformer_blocks.0.norm3.bias",
"control_model.input_blocks.5.1.proj_out.weight": "blocks.12.proj_out.weight",
"control_model.input_blocks.5.1.proj_out.bias": "blocks.12.proj_out.bias",
"control_model.input_blocks.6.0.op.weight": "blocks.14.conv.weight",
"control_model.input_blocks.6.0.op.bias": "blocks.14.conv.bias",
"control_model.input_blocks.7.0.in_layers.0.weight": "blocks.16.norm1.weight",
"control_model.input_blocks.7.0.in_layers.0.bias": "blocks.16.norm1.bias",
"control_model.input_blocks.7.0.in_layers.2.weight": "blocks.16.conv1.weight",
"control_model.input_blocks.7.0.in_layers.2.bias": "blocks.16.conv1.bias",
"control_model.input_blocks.7.0.emb_layers.1.weight": "blocks.16.time_emb_proj.weight",
"control_model.input_blocks.7.0.emb_layers.1.bias": "blocks.16.time_emb_proj.bias",
"control_model.input_blocks.7.0.out_layers.0.weight": "blocks.16.norm2.weight",
"control_model.input_blocks.7.0.out_layers.0.bias": "blocks.16.norm2.bias",
"control_model.input_blocks.7.0.out_layers.3.weight": "blocks.16.conv2.weight",
"control_model.input_blocks.7.0.out_layers.3.bias": "blocks.16.conv2.bias",
"control_model.input_blocks.7.0.skip_connection.weight": "blocks.16.conv_shortcut.weight",
"control_model.input_blocks.7.0.skip_connection.bias": "blocks.16.conv_shortcut.bias",
"control_model.input_blocks.7.1.norm.weight": "blocks.17.norm.weight",
"control_model.input_blocks.7.1.norm.bias": "blocks.17.norm.bias",
"control_model.input_blocks.7.1.proj_in.weight": "blocks.17.proj_in.weight",
"control_model.input_blocks.7.1.proj_in.bias": "blocks.17.proj_in.bias",
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.17.transformer_blocks.0.attn1.to_q.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.17.transformer_blocks.0.attn1.to_k.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.17.transformer_blocks.0.attn1.to_v.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.17.transformer_blocks.0.attn1.to_out.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.17.transformer_blocks.0.attn1.to_out.bias",
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.17.transformer_blocks.0.act_fn.proj.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.17.transformer_blocks.0.act_fn.proj.bias",
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.17.transformer_blocks.0.ff.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.17.transformer_blocks.0.ff.bias",
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.17.transformer_blocks.0.attn2.to_q.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.17.transformer_blocks.0.attn2.to_k.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.17.transformer_blocks.0.attn2.to_v.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.17.transformer_blocks.0.attn2.to_out.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.17.transformer_blocks.0.attn2.to_out.bias",
"control_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.17.transformer_blocks.0.norm1.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.17.transformer_blocks.0.norm1.bias",
"control_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.17.transformer_blocks.0.norm2.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.17.transformer_blocks.0.norm2.bias",
"control_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.17.transformer_blocks.0.norm3.weight",
"control_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.17.transformer_blocks.0.norm3.bias",
"control_model.input_blocks.7.1.proj_out.weight": "blocks.17.proj_out.weight",
"control_model.input_blocks.7.1.proj_out.bias": "blocks.17.proj_out.bias",
"control_model.input_blocks.8.0.in_layers.0.weight": "blocks.19.norm1.weight",
"control_model.input_blocks.8.0.in_layers.0.bias": "blocks.19.norm1.bias",
"control_model.input_blocks.8.0.in_layers.2.weight": "blocks.19.conv1.weight",
"control_model.input_blocks.8.0.in_layers.2.bias": "blocks.19.conv1.bias",
"control_model.input_blocks.8.0.emb_layers.1.weight": "blocks.19.time_emb_proj.weight",
"control_model.input_blocks.8.0.emb_layers.1.bias": "blocks.19.time_emb_proj.bias",
"control_model.input_blocks.8.0.out_layers.0.weight": "blocks.19.norm2.weight",
"control_model.input_blocks.8.0.out_layers.0.bias": "blocks.19.norm2.bias",
"control_model.input_blocks.8.0.out_layers.3.weight": "blocks.19.conv2.weight",
"control_model.input_blocks.8.0.out_layers.3.bias": "blocks.19.conv2.bias",
"control_model.input_blocks.8.1.norm.weight": "blocks.20.norm.weight",
"control_model.input_blocks.8.1.norm.bias": "blocks.20.norm.bias",
"control_model.input_blocks.8.1.proj_in.weight": "blocks.20.proj_in.weight",
"control_model.input_blocks.8.1.proj_in.bias": "blocks.20.proj_in.bias",
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.20.transformer_blocks.0.attn1.to_q.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.20.transformer_blocks.0.attn1.to_k.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.20.transformer_blocks.0.attn1.to_v.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.20.transformer_blocks.0.attn1.to_out.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.20.transformer_blocks.0.attn1.to_out.bias",
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.20.transformer_blocks.0.act_fn.proj.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.20.transformer_blocks.0.act_fn.proj.bias",
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.20.transformer_blocks.0.ff.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.20.transformer_blocks.0.ff.bias",
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.20.transformer_blocks.0.attn2.to_q.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.20.transformer_blocks.0.attn2.to_k.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.20.transformer_blocks.0.attn2.to_v.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.20.transformer_blocks.0.attn2.to_out.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.20.transformer_blocks.0.attn2.to_out.bias",
"control_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.20.transformer_blocks.0.norm1.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.20.transformer_blocks.0.norm1.bias",
"control_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.20.transformer_blocks.0.norm2.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.20.transformer_blocks.0.norm2.bias",
"control_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.20.transformer_blocks.0.norm3.weight",
"control_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.20.transformer_blocks.0.norm3.bias",
"control_model.input_blocks.8.1.proj_out.weight": "blocks.20.proj_out.weight",
"control_model.input_blocks.8.1.proj_out.bias": "blocks.20.proj_out.bias",
"control_model.input_blocks.9.0.op.weight": "blocks.22.conv.weight",
"control_model.input_blocks.9.0.op.bias": "blocks.22.conv.bias",
"control_model.input_blocks.10.0.in_layers.0.weight": "blocks.24.norm1.weight",
"control_model.input_blocks.10.0.in_layers.0.bias": "blocks.24.norm1.bias",
"control_model.input_blocks.10.0.in_layers.2.weight": "blocks.24.conv1.weight",
"control_model.input_blocks.10.0.in_layers.2.bias": "blocks.24.conv1.bias",
"control_model.input_blocks.10.0.emb_layers.1.weight": "blocks.24.time_emb_proj.weight",
"control_model.input_blocks.10.0.emb_layers.1.bias": "blocks.24.time_emb_proj.bias",
"control_model.input_blocks.10.0.out_layers.0.weight": "blocks.24.norm2.weight",
"control_model.input_blocks.10.0.out_layers.0.bias": "blocks.24.norm2.bias",
"control_model.input_blocks.10.0.out_layers.3.weight": "blocks.24.conv2.weight",
"control_model.input_blocks.10.0.out_layers.3.bias": "blocks.24.conv2.bias",
"control_model.input_blocks.11.0.in_layers.0.weight": "blocks.26.norm1.weight",
"control_model.input_blocks.11.0.in_layers.0.bias": "blocks.26.norm1.bias",
"control_model.input_blocks.11.0.in_layers.2.weight": "blocks.26.conv1.weight",
"control_model.input_blocks.11.0.in_layers.2.bias": "blocks.26.conv1.bias",
"control_model.input_blocks.11.0.emb_layers.1.weight": "blocks.26.time_emb_proj.weight",
"control_model.input_blocks.11.0.emb_layers.1.bias": "blocks.26.time_emb_proj.bias",
"control_model.input_blocks.11.0.out_layers.0.weight": "blocks.26.norm2.weight",
"control_model.input_blocks.11.0.out_layers.0.bias": "blocks.26.norm2.bias",
"control_model.input_blocks.11.0.out_layers.3.weight": "blocks.26.conv2.weight",
"control_model.input_blocks.11.0.out_layers.3.bias": "blocks.26.conv2.bias",
"control_model.zero_convs.0.0.weight": "controlnet_blocks.0.weight",
"control_model.zero_convs.0.0.bias": "controlnet_blocks.0.bias",
"control_model.zero_convs.1.0.weight": "controlnet_blocks.1.weight",
"control_model.zero_convs.1.0.bias": "controlnet_blocks.0.bias",
"control_model.zero_convs.2.0.weight": "controlnet_blocks.2.weight",
"control_model.zero_convs.2.0.bias": "controlnet_blocks.0.bias",
"control_model.zero_convs.3.0.weight": "controlnet_blocks.3.weight",
"control_model.zero_convs.3.0.bias": "controlnet_blocks.0.bias",
"control_model.zero_convs.4.0.weight": "controlnet_blocks.4.weight",
"control_model.zero_convs.4.0.bias": "controlnet_blocks.4.bias",
"control_model.zero_convs.5.0.weight": "controlnet_blocks.5.weight",
"control_model.zero_convs.5.0.bias": "controlnet_blocks.4.bias",
"control_model.zero_convs.6.0.weight": "controlnet_blocks.6.weight",
"control_model.zero_convs.6.0.bias": "controlnet_blocks.4.bias",
"control_model.zero_convs.7.0.weight": "controlnet_blocks.7.weight",
"control_model.zero_convs.7.0.bias": "controlnet_blocks.7.bias",
"control_model.zero_convs.8.0.weight": "controlnet_blocks.8.weight",
"control_model.zero_convs.8.0.bias": "controlnet_blocks.7.bias",
"control_model.zero_convs.9.0.weight": "controlnet_blocks.9.weight",
"control_model.zero_convs.9.0.bias": "controlnet_blocks.7.bias",
"control_model.zero_convs.10.0.weight": "controlnet_blocks.10.weight",
"control_model.zero_convs.10.0.bias": "controlnet_blocks.7.bias",
"control_model.zero_convs.11.0.weight": "controlnet_blocks.11.weight",
"control_model.zero_convs.11.0.bias": "controlnet_blocks.7.bias",
"control_model.input_hint_block.0.weight": "controlnet_conv_in.blocks.0.weight",
"control_model.input_hint_block.0.bias": "controlnet_conv_in.blocks.0.bias",
"control_model.input_hint_block.2.weight": "controlnet_conv_in.blocks.2.weight",
"control_model.input_hint_block.2.bias": "controlnet_conv_in.blocks.2.bias",
"control_model.input_hint_block.4.weight": "controlnet_conv_in.blocks.4.weight",
"control_model.input_hint_block.4.bias": "controlnet_conv_in.blocks.4.bias",
"control_model.input_hint_block.6.weight": "controlnet_conv_in.blocks.6.weight",
"control_model.input_hint_block.6.bias": "controlnet_conv_in.blocks.6.bias",
"control_model.input_hint_block.8.weight": "controlnet_conv_in.blocks.8.weight",
"control_model.input_hint_block.8.bias": "controlnet_conv_in.blocks.8.bias",
"control_model.input_hint_block.10.weight": "controlnet_conv_in.blocks.10.weight",
"control_model.input_hint_block.10.bias": "controlnet_conv_in.blocks.10.bias",
"control_model.input_hint_block.12.weight": "controlnet_conv_in.blocks.12.weight",
"control_model.input_hint_block.12.bias": "controlnet_conv_in.blocks.12.bias",
"control_model.input_hint_block.14.weight": "controlnet_conv_in.blocks.14.weight",
"control_model.input_hint_block.14.bias": "controlnet_conv_in.blocks.14.bias",
"control_model.middle_block.0.in_layers.0.weight": "blocks.28.norm1.weight",
"control_model.middle_block.0.in_layers.0.bias": "blocks.28.norm1.bias",
"control_model.middle_block.0.in_layers.2.weight": "blocks.28.conv1.weight",
"control_model.middle_block.0.in_layers.2.bias": "blocks.28.conv1.bias",
"control_model.middle_block.0.emb_layers.1.weight": "blocks.28.time_emb_proj.weight",
"control_model.middle_block.0.emb_layers.1.bias": "blocks.28.time_emb_proj.bias",
"control_model.middle_block.0.out_layers.0.weight": "blocks.28.norm2.weight",
"control_model.middle_block.0.out_layers.0.bias": "blocks.28.norm2.bias",
"control_model.middle_block.0.out_layers.3.weight": "blocks.28.conv2.weight",
"control_model.middle_block.0.out_layers.3.bias": "blocks.28.conv2.bias",
"control_model.middle_block.1.norm.weight": "blocks.29.norm.weight",
"control_model.middle_block.1.norm.bias": "blocks.29.norm.bias",
"control_model.middle_block.1.proj_in.weight": "blocks.29.proj_in.weight",
"control_model.middle_block.1.proj_in.bias": "blocks.29.proj_in.bias",
"control_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "blocks.29.transformer_blocks.0.attn1.to_q.weight",
"control_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "blocks.29.transformer_blocks.0.attn1.to_k.weight",
"control_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "blocks.29.transformer_blocks.0.attn1.to_v.weight",
"control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.29.transformer_blocks.0.attn1.to_out.weight",
"control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.29.transformer_blocks.0.attn1.to_out.bias",
"control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.29.transformer_blocks.0.act_fn.proj.weight",
"control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.29.transformer_blocks.0.act_fn.proj.bias",
"control_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "blocks.29.transformer_blocks.0.ff.weight",
"control_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "blocks.29.transformer_blocks.0.ff.bias",
"control_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "blocks.29.transformer_blocks.0.attn2.to_q.weight",
"control_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "blocks.29.transformer_blocks.0.attn2.to_k.weight",
"control_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "blocks.29.transformer_blocks.0.attn2.to_v.weight",
"control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.29.transformer_blocks.0.attn2.to_out.weight",
"control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.29.transformer_blocks.0.attn2.to_out.bias",
"control_model.middle_block.1.transformer_blocks.0.norm1.weight": "blocks.29.transformer_blocks.0.norm1.weight",
"control_model.middle_block.1.transformer_blocks.0.norm1.bias": "blocks.29.transformer_blocks.0.norm1.bias",
"control_model.middle_block.1.transformer_blocks.0.norm2.weight": "blocks.29.transformer_blocks.0.norm2.weight",
"control_model.middle_block.1.transformer_blocks.0.norm2.bias": "blocks.29.transformer_blocks.0.norm2.bias",
"control_model.middle_block.1.transformer_blocks.0.norm3.weight": "blocks.29.transformer_blocks.0.norm3.weight",
"control_model.middle_block.1.transformer_blocks.0.norm3.bias": "blocks.29.transformer_blocks.0.norm3.bias",
"control_model.middle_block.1.proj_out.weight": "blocks.29.proj_out.weight",
"control_model.middle_block.1.proj_out.bias": "blocks.29.proj_out.bias",
"control_model.middle_block.2.in_layers.0.weight": "blocks.30.norm1.weight",
"control_model.middle_block.2.in_layers.0.bias": "blocks.30.norm1.bias",
"control_model.middle_block.2.in_layers.2.weight": "blocks.30.conv1.weight",
"control_model.middle_block.2.in_layers.2.bias": "blocks.30.conv1.bias",
"control_model.middle_block.2.emb_layers.1.weight": "blocks.30.time_emb_proj.weight",
"control_model.middle_block.2.emb_layers.1.bias": "blocks.30.time_emb_proj.bias",
"control_model.middle_block.2.out_layers.0.weight": "blocks.30.norm2.weight",
"control_model.middle_block.2.out_layers.0.bias": "blocks.30.norm2.bias",
"control_model.middle_block.2.out_layers.3.weight": "blocks.30.conv2.weight",
"control_model.middle_block.2.out_layers.3.bias": "blocks.30.conv2.bias",
"control_model.middle_block_out.0.weight": "controlnet_blocks.12.weight",
"control_model.middle_block_out.0.bias": "controlnet_blocks.7.bias",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if ".proj_in." in name or ".proj_out." in name:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
return state_dict_

View File

@@ -1,56 +0,0 @@
from .svd_image_encoder import SVDImageEncoder
from .sdxl_ipadapter import IpAdapterImageProjModel, IpAdapterModule, SDXLIpAdapterStateDictConverter
from transformers import CLIPImageProcessor
import torch
class IpAdapterCLIPImageEmbedder(SVDImageEncoder):
def __init__(self):
super().__init__()
self.image_processor = CLIPImageProcessor()
def forward(self, image):
pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype)
return super().forward(pixel_values)
class SDIpAdapter(torch.nn.Module):
def __init__(self):
super().__init__()
shape_list = [(768, 320)] * 2 + [(768, 640)] * 2 + [(768, 1280)] * 5 + [(768, 640)] * 3 + [(768, 320)] * 3 + [(768, 1280)] * 1
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list])
self.image_proj = IpAdapterImageProjModel(cross_attention_dim=768, clip_embeddings_dim=1024, clip_extra_context_tokens=4)
self.set_full_adapter()
def set_full_adapter(self):
block_ids = [1, 4, 9, 12, 17, 20, 40, 43, 46, 50, 53, 56, 60, 63, 66, 29]
self.call_block_id = {(i, 0): j for j, i in enumerate(block_ids)}
def set_less_adapter(self):
# IP-Adapter for SD v1.5 doesn't support this feature.
self.set_full_adapter(self)
def forward(self, hidden_states, scale=1.0):
hidden_states = self.image_proj(hidden_states)
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
ip_kv_dict = {}
for (block_id, transformer_id) in self.call_block_id:
ipadapter_id = self.call_block_id[(block_id, transformer_id)]
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
if block_id not in ip_kv_dict:
ip_kv_dict[block_id] = {}
ip_kv_dict[block_id][transformer_id] = {
"ip_k": ip_k,
"ip_v": ip_v,
"scale": scale
}
return ip_kv_dict
def state_dict_converter(self):
return SDIpAdapterStateDictConverter()
class SDIpAdapterStateDictConverter(SDXLIpAdapterStateDictConverter):
def __init__(self):
pass

View File

@@ -1,60 +0,0 @@
import torch
from .sd_unet import SDUNetStateDictConverter, SDUNet
from .sd_text_encoder import SDTextEncoderStateDictConverter, SDTextEncoder
class SDLoRA:
def __init__(self):
pass
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0, device="cuda"):
special_keys = {
"down.blocks": "down_blocks",
"up.blocks": "up_blocks",
"mid.block": "mid_block",
"proj.in": "proj_in",
"proj.out": "proj_out",
"transformer.blocks": "transformer_blocks",
"to.q": "to_q",
"to.k": "to_k",
"to.v": "to_v",
"to.out": "to_out",
}
state_dict_ = {}
for key in state_dict:
if ".lora_up" not in key:
continue
if not key.startswith(lora_prefix):
continue
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
lora_weight = alpha * torch.mm(weight_up, weight_down)
target_name = key.split(".")[0].replace("_", ".")[len(lora_prefix):] + ".weight"
for special_key in special_keys:
target_name = target_name.replace(special_key, special_keys[special_key])
state_dict_[target_name] = lora_weight.cpu()
return state_dict_
def add_lora_to_unet(self, unet: SDUNet, state_dict_lora, alpha=1.0, device="cuda"):
state_dict_unet = unet.state_dict()
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_unet_", alpha=alpha, device=device)
state_dict_lora = SDUNetStateDictConverter().from_diffusers(state_dict_lora)
if len(state_dict_lora) > 0:
for name in state_dict_lora:
state_dict_unet[name] += state_dict_lora[name].to(device=device)
unet.load_state_dict(state_dict_unet)
def add_lora_to_text_encoder(self, text_encoder: SDTextEncoder, state_dict_lora, alpha=1.0, device="cuda"):
state_dict_text_encoder = text_encoder.state_dict()
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_te_", alpha=alpha, device=device)
state_dict_lora = SDTextEncoderStateDictConverter().from_diffusers(state_dict_lora)
if len(state_dict_lora) > 0:
for name in state_dict_lora:
state_dict_text_encoder[name] += state_dict_lora[name].to(device=device)
text_encoder.load_state_dict(state_dict_text_encoder)

Some files were not shown because too many files have changed in this diff Show More