Compare commits

...

173 Commits

Author SHA1 Message Date
Artiprocher
add6f88324 bugfix 2026-03-03 15:33:42 +08:00
Zhongjie Duan
430b495100 Merge pull request #1321 from mi804/bugfix
fix qwen_text_encoder bug in transformers>=5.2.0
2026-03-03 13:02:45 +08:00
mi804
62ba8a3f2e fix qwen_text_encoder bug in transformers>=5.2.0 2026-03-03 12:44:36 +08:00
Zhongjie Duan
237d178733 Fix LoRA compatibility issues. (#1320) 2026-03-03 11:08:31 +08:00
Zhongjie Duan
b3ef224042 support Anima gradient checkpointing (#1319) 2026-03-02 19:06:55 +08:00
Zhongjie Duan
f43b18ec21 Update docs (#1318)
* update docs
2026-03-02 18:59:13 +08:00
Zhongjie Duan
6d671db5d2 Support Anima (#1317)
* support Anima

Co-authored-by: mi804 <1576993271@qq.com>
2026-03-02 18:49:02 +08:00
Zhongjie Duan
880231b4be Merge pull request #1315 from modelscope/docs2.0
update ltx-2 docs
2026-03-02 11:02:20 +08:00
mi804
b3f6c3275f update ltx-2 2026-03-02 10:58:02 +08:00
Zhongjie Duan
29cd5c7612 Merge pull request #1275 from Mr-Neutr0n/fix-dit-none-check
Fix AttributeError when pipe.dit is None during split training
2026-03-02 10:25:11 +08:00
Zhongjie Duan
ff4be1c7c7 Merge pull request #1293 from Mr-Neutr0n/fix/trajectory-loss-div-by-zero
fix: prevent division by zero in TrajectoryImitationLoss at final denoising step
2026-03-02 10:21:39 +08:00
Zhongjie Duan
6b0fb1601f Merge pull request #1296 from Explorer-Dong/fix/wan_vae
fix: WanVAE2.2 encode and decode error
2026-03-02 10:19:36 +08:00
Zhongjie Duan
4b400c07eb Merge pull request #1297 from Feng0w0/npu_fused
[doc][NPU]Documentation on modifications, NPU environment installation, and additional parameter
2026-03-02 10:16:01 +08:00
Zhongjie Duan
6a6ae6d791 Merge pull request #1312 from mi804/ltx2-iclora
Ltx2 iclora
2026-02-28 12:45:16 +08:00
mi804
1a380a6b62 minor fix 2026-02-28 11:09:10 +08:00
mi804
5ca74923e8 add readme 2026-02-28 10:56:08 +08:00
mi804
8b9a094c1b ltx iclora train 2026-02-27 18:43:53 +08:00
mi804
5996c2b068 support inference 2026-02-27 16:48:16 +08:00
Zhongjie Duan
8fc7e005a6 Merge pull request #1309 from mi804/ltx2-train
support ltx2 gradient_checkpointing
2026-02-26 19:31:04 +08:00
mi804
a18966c300 support ltx2 gradient_checkpointing 2026-02-26 19:19:59 +08:00
Zhongjie Duan
a87910bc65 Merge pull request #1307 from mi804/ltx2-train
Support LTX-2 training.
2026-02-26 11:39:09 +08:00
mi804
f48662e863 update docs 2026-02-26 11:10:00 +08:00
mi804
8d8bfc7f54 minor fix 2026-02-25 19:04:10 +08:00
mi804
8e15dcd289 support ltx2 train -2 2026-02-25 18:06:02 +08:00
mi804
586ac9d8a6 support ltx-2 training 2026-02-25 17:19:57 +08:00
Zhongjie Duan
288bbc7128 Merge pull request #1299 from modelscope/firered
support FireRed
2026-02-15 14:18:13 +08:00
Zhongjie Duan
5002ac74dc Update Qwen-Image.md 2026-02-15 14:15:44 +08:00
Zhongjie Duan
863a6ba597 Merge branch 'main' into firered 2026-02-15 14:12:44 +08:00
Artiprocher
b08bc1470d support firered 2026-02-15 14:02:50 +08:00
feng0w0
96143aa26b Merge branch 'npu_fused' of https://github.com/Feng0w0/DiffSynth-Studio into npu_fused 2026-02-13 10:06:39 +08:00
feng0w0
71cea4371c [doc][NPU]Documentation on modifications, NPU environment installation, and additional parameter 2026-02-13 09:58:27 +08:00
Mr_Dwj
fc11fd4297 chore: remove invalid comment code 2026-02-13 09:38:14 +08:00
Mr_Dwj
bd3c5822a1 fix: WanVAE2.2 decode error 2026-02-13 01:13:08 +08:00
Mr_Dwj
96fb0f3afe fix: unpack Resample38 output 2026-02-12 23:51:56 +08:00
Mr-Neutr0n
b68663426f fix: preserve sign of denominator in clamp to avoid inverting gradient direction
The previous .clamp(min=1e-6) on (sigma_ - sigma) flips the sign when
the denominator is negative (which is the typical case since sigmas
decrease monotonically). This would invert the target and cause
training divergence.

Use torch.sign(denom) * torch.clamp(denom.abs(), min=1e-6) instead,
which prevents division by zero while preserving the correct sign.
2026-02-11 21:04:55 +05:30
Mr-Neutr0n
0e6976a0ae fix: prevent division by zero in trajectory imitation loss at last step 2026-02-11 19:51:25 +05:30
Hong Zhang
94b57e9677 Fix readthedocs rendering (#1290)
* test latex

* test latex

* fix conf
2026-02-11 11:32:27 +08:00
Hong Zhang
3fb037d33a Correct hyperlinks for docs 2026-02-10 20:59:47 +08:00
Hong Zhang
b3b63fef3e Add readthedocs for diffsynth-studio
* add conf docs

* add conf docs

* add index

* add index

* update ref

* test root

* add en

* test relative

* redirect relative

* add document

* test_document

* test_document
2026-02-10 19:51:04 +08:00
Zhongjie Duan
f6d85f3c2e Merge pull request #1282 from mi804/ltx-2
add inference script for ltx-2 lora
2026-02-10 15:13:06 +08:00
mi804
2f22e598b7 fix load lora 2026-02-10 15:06:04 +08:00
Hong Zhang
888caf8b88 Update README_zh.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-02-10 14:48:19 +08:00
mi804
b6e39c97af add inference script for ltx-2 lora 2026-02-10 14:32:30 +08:00
Zhongjie Duan
02124c4034 Merge pull request #1280 from modelscope/issue-fix
fix mix-precision issues in low-version torch
2026-02-10 11:14:12 +08:00
Artiprocher
fddc98ff16 fix mix-precision issues in low-version torch 2026-02-10 11:12:50 +08:00
Zhongjie Duan
0dfcd25cf3 Merge pull request #1278 from modelscope/issue-fix
update lora loading in docs
2026-02-10 10:50:18 +08:00
Artiprocher
ff10fde47f update lora loading in docs 2026-02-10 10:48:44 +08:00
Zhongjie Duan
dc94614c80 Merge pull request #1256 from Feng0w0/npu_fused
[model][NPU]:Add NPU fusion operator patch to Zimage model to improve performance
2026-02-09 20:08:44 +08:00
feng0w0
e56a4d5730 [model][NPU]:Add NPU fusion operator patch to Zimage model to improve performance 2026-02-09 12:31:34 +08:00
feng0w0
3f8468893a [model][NPU]:Add NPU fusion operator patch to Zimage model to improve performance 2026-02-09 09:51:06 +08:00
Mr-Neutr0n
6383ec358c Fix AttributeError when pipe.dit is None
When using split training with 'sft:data_process' task, the DiT model
is not loaded but the attribute 'dit' exists with value None. The
existing hasattr check returns True but then accessing siglip_embedder
fails.

Add an explicit None check before accessing pipe.dit.siglip_embedder.

Fixes #1246
2026-02-07 05:23:11 +05:30
Zhongjie Duan
1b47e1dc22 Merge pull request #1272 from modelscope/zero3-fix
Support DeepSpeed ZeRO 3
2026-02-06 16:33:12 +08:00
Artiprocher
b0bf78e915 refine code & doc 2026-02-06 16:27:23 +08:00
Zhongjie Duan
abdf66d09e Merge pull request #1265 from lzws/main
fix wanS2V bug and update readme
2026-02-06 10:22:48 +08:00
lzws
27b1fe240b add examples 2026-02-05 17:17:10 +08:00
lzws
1635897516 update readme 2026-02-05 16:56:39 +08:00
lzws
8d172127cd fix wans2v bug and update readme 2026-02-05 16:52:38 +08:00
feng0w0
fccb1ecdd7 Initialize qwen-image on CPU 2026-02-05 11:54:36 +08:00
Zhongjie Duan
c0f7e1db7c Merge pull request #1261 from modelscope/examples-update
update examples
2026-02-05 11:11:35 +08:00
Artiprocher
53890bafa4 update examples 2026-02-05 11:10:55 +08:00
feng0w0
6886f7ba35 fix wan decoder bug 2026-02-05 10:31:41 +08:00
Zhongjie Duan
afd48cd706 Merge pull request #1259 from mi804/multi_controlnet
add example for multiple controlnet
2026-02-04 17:04:11 +08:00
mi804
24b68c2392 add example for multiple controlnet 2026-02-04 16:52:39 +08:00
Zhongjie Duan
280ff7cca6 Merge pull request #1229 from Feng0w0/wan_rope
[bugfix][NPU]:Fix bug that correctly obtains device type
2026-02-04 13:26:00 +08:00
Zhongjie Duan
b4b62e2f7c Merge pull request #1221 from Feng0w0/usp_npu
[NPU]:Support USP feature in NPU
2026-02-04 13:25:24 +08:00
feng0w0
051b957adb [model][NPU] Add NPU fusion operator patch to Zimage model to improve performance 2026-02-03 19:50:21 +08:00
feng0w0
ca9b5e64ea [feature]:Add adaptation of all models to zero3 2026-02-03 15:44:53 +08:00
Zhongjie Duan
6d1be405b9 Merge pull request #1242 from mi804/ltx-2
LTX-2
2026-02-03 13:07:41 +08:00
Zhongjie Duan
25c3a3d3e2 Merge branch 'main' into ltx-2 2026-02-03 13:06:44 +08:00
mi804
49bc84f78e add comment for tuple noise_pred 2026-02-03 10:43:25 +08:00
mi804
25a9e75030 final fix for ltx-2 2026-02-03 10:39:35 +08:00
mi804
2a7ac73eb5 minor fix 2026-02-02 20:07:08 +08:00
mi804
f4f991d409 support ltx-2 t2v and i2v 2026-02-02 19:53:07 +08:00
Zhongjie Duan
a781138413 Merge pull request #1245 from modelscope/docs-update
update docs
2026-02-02 17:00:04 +08:00
Artiprocher
91a5623976 update docs 2026-02-02 16:52:12 +08:00
Zhongjie Duan
28cd355aba Merge pull request #1232 from huarzone/main
fix wan i2v/ti2v train bug
2026-02-02 15:26:01 +08:00
Zhongjie Duan
005389fca7 Merge pull request #1244 from modelscope/qwen-image-edit-lightning
Qwen image edit lightning
2026-02-02 15:20:11 +08:00
Artiprocher
a6282056eb fix typo 2026-02-02 15:19:19 +08:00
Zhongjie Duan
21a6eb8e2f Merge pull request #1243 from modelscope/research_tutorial_1
add research tutorial sec 1
2026-02-02 14:29:39 +08:00
Artiprocher
98ab238340 add research tutorial sec 1 2026-02-02 14:28:26 +08:00
feng0w0
2070bbd925 [feature]:Add adaptation of all models to zero3 2026-01-31 16:50:18 +08:00
mi804
1c8a0f8317 refactor patchify 2026-01-31 13:55:52 +08:00
mi804
9f07d65ebb support ltx2 distilled pipeline 2026-01-30 17:40:30 +08:00
lzws
5f1d5adfce qwen-image-edit-2511-lightning 2026-01-30 17:26:26 +08:00
mi804
4f23caa55f support ltx2 two stage pipeline & vram 2026-01-30 16:55:40 +08:00
Zhongjie Duan
b4f6a4de6c Merge pull request #1240 from modelscope/loader-update
Loader update
2026-01-30 13:51:17 +08:00
Artiprocher
53fe42af1b update version 2026-01-30 13:49:27 +08:00
Artiprocher
ee9a3b4405 support loading models from state dict 2026-01-30 13:47:36 +08:00
mi804
b1a2782ad7 support ltx2 one-stage pipeline 2026-01-29 16:30:15 +08:00
mi804
8d303b47e9 add audio_vae, audio_vocoder, text_encoder, connector and upsampler for ltx2 2026-01-28 16:09:22 +08:00
mi804
00da4b6c4f add video_vae and dit for ltx-2 2026-01-27 19:34:09 +08:00
Zhongjie Duan
22695e9be0 Merge pull request #1233 from modelscope/z-image-release
Z-Image and Z-Image-i2L
2026-01-27 18:41:28 +08:00
feng0w0
3140199c96 [feature]:Add adaptation of all models to zero3 2026-01-27 15:33:42 +08:00
Artiprocher
98290190ec update z-image-i2L demo 2026-01-27 13:42:48 +08:00
Artiprocher
3f4de2cc7f update z-image-i2L examples 2026-01-27 12:16:48 +08:00
Kared
8d0df403ca fix wan i2v train bug 2026-01-27 03:55:36 +00:00
feng0w0
4e9db263b0 [feature]:Add adaptation of all models to zero3 2026-01-27 11:24:43 +08:00
Artiprocher
d12bf71bcc support z-image and z-image-i2L 2026-01-27 10:56:15 +08:00
feng0w0
35e0776022 [bugfix][NPU]:Fix bug that correctly obtains device type 2026-01-23 10:45:03 +08:00
Zhongjie Duan
ffb7a138f7 Merge pull request #1228 from modelscope/klein-bugfix
change klein image resize to crop
2026-01-22 10:34:17 +08:00
Artiprocher
548304667f change klein image resize to crop 2026-01-22 10:33:29 +08:00
Zhongjie Duan
273143136c Merge pull request #1227 from modelscope/modelscope-service-patch
update to 2.0.3
2026-01-21 20:23:13 +08:00
Artiprocher
030ebe649a update to 2.0.3 2026-01-21 20:22:43 +08:00
Zhongjie Duan
90921d2293 Merge pull request #1226 from modelscope/klein-train-fix
improve flux2 training performance
2026-01-21 15:44:52 +08:00
Artiprocher
b61131c693 improve flux2 training performance 2026-01-21 15:44:15 +08:00
Zhongjie Duan
37fbb3248a Merge pull request #1222 from modelscope/trainer-update
support auto detact lora target modules
2026-01-21 11:06:19 +08:00
Artiprocher
d13f533f42 support auto detact lora target modules 2026-01-21 11:05:05 +08:00
feng0w0
b3cc652dea [NPU]:Support USP feature in NPU 2026-01-21 10:38:27 +08:00
feng0w0
d879d66c62 [NPU]:Support USP feature in NPU 2026-01-21 10:34:09 +08:00
feng0w0
848bfd6993 [NPU]:Support USP feature in NPU 2026-01-21 10:25:31 +08:00
feng0w0
269da09f6e Merge branch 'main' of https://github.com/modelscope/DiffSynth-Studio into usp_npu 2026-01-21 10:00:08 +08:00
feng0w0
e30514a00c Merge branch 'main' of https://github.com/Feng0w0/DiffSynth-Studio into usp_npu 2026-01-21 09:59:18 +08:00
Zhongjie Duan
3743b1307c Merge pull request #1219 from modelscope/klein-edit
support klein edit
2026-01-20 12:59:12 +08:00
Artiprocher
a835df984c support klein edit 2026-01-20 12:58:18 +08:00
Zhongjie Duan
3e4b47e424 Merge pull request #1207 from Feng0w0/cuda_replace
[NPU]:Replace 'cuda' in the project with abstract interfaces
2026-01-20 10:13:04 +08:00
Zhongjie Duan
dd8d902624 Merge branch 'main' into cuda_replace 2026-01-20 10:12:31 +08:00
Zhongjie Duan
a8b340c098 Merge pull request #1191 from Feng0w0/wan_rope
[model][NPU]:Wan model rope use torch.complex64 in NPU
2026-01-20 10:05:22 +08:00
Zhongjie Duan
88497b5c13 Merge pull request #1217 from modelscope/klein-update
support klein base models
2026-01-19 21:14:47 +08:00
Artiprocher
1e90c72d94 support klein base models 2026-01-19 21:11:58 +08:00
Zhongjie Duan
3dd82a738e Merge pull request #1215 from lzws/main
updata learning rate in wan-vace training scripts
2026-01-19 17:48:42 +08:00
Artiprocher
8ad2d9884b update lr in wan-vace training scripts 2026-01-19 17:43:07 +08:00
Artiprocher
70f531b724 update wan-vace training scripts 2026-01-19 17:37:30 +08:00
Zhongjie Duan
37c2868b61 Merge pull request #1214 from modelscope/klein
Support FLUX.2-klein
2026-01-19 17:36:39 +08:00
Artiprocher
a18e6233b5 updata wan-vace training scripts 2026-01-19 17:35:08 +08:00
Artiprocher
2336d5f6b3 update doc 2026-01-19 17:27:32 +08:00
Artiprocher
b6ccb362b9 support flux.2 klein 2026-01-19 16:56:14 +08:00
Artiprocher
ae52d93694 support klein 4b models 2026-01-16 13:09:41 +08:00
feng0w0
ad91d41601 [NPU]:Replace 'cuda' in the project with abstract interfaces 2026-01-16 10:28:24 +08:00
feng0w0
dce77ec4d1 [NPU]:Replace 'cuda' in the project with abstract interfaces 2026-01-15 20:35:41 +08:00
feng0w0
5c0b07d939 [NPU]:Replace 'cuda' in the project with abstract interfaces 2026-01-15 20:34:52 +08:00
feng0w0
19e429d889 Merge remote-tracking branch 'origin/cuda_replace' into cuda_replace 2026-01-15 20:33:21 +08:00
feng0w0
209a350c0f [NPU]:Replace 'cuda' in the project with abstract interfaces 2026-01-15 20:33:01 +08:00
feng0w0
a3c2744a43 [NPU]:Replace 'cuda' in the project with abstract interfaces 2026-01-15 20:04:54 +08:00
Zhongjie Duan
55e8346da3 Blog link (#1202)
* update README
2026-01-15 12:31:55 +08:00
Zhongjie Duan
b7979b2633 Merge pull request #1200 from modelscope/flux-compatibility-fix
fix flux compatibility issues
2026-01-14 20:50:18 +08:00
Artiprocher
c90aaa2798 fix flux compatibility issues 2026-01-14 20:49:36 +08:00
Zhongjie Duan
0c617d5d9e Merge pull request #1194 from lzws/main
wan usp bug fix
2026-01-14 16:34:06 +08:00
lzws
fd87b72754 wan usp bug fix 2026-01-14 16:33:02 +08:00
Zhongjie Duan
db75508ba0 Merge pull request #1199 from modelscope/z-image-bugfix
fix RMSNorm precision
2026-01-14 16:32:33 +08:00
Artiprocher
acba342a63 fix RMSNorm precision 2026-01-14 16:29:43 +08:00
feng0w0
d16877e695 [model][NPU]:Wan model rope use torch.complex64 in NPU 2026-01-13 11:17:51 +08:00
lzws
e99cdcf3b8 wan usp bug fix 2026-01-12 22:08:48 +08:00
Zhongjie Duan
a236a17f17 Merge pull request #1193 from modelscope/qwen-image-layered-control
support qwen-image-layered-control
2026-01-12 17:24:06 +08:00
Artiprocher
03e530dc39 support qwen-image-layered-control 2026-01-12 17:20:01 +08:00
feng0w0
6be244233a [model][NPU]:Wan model rope use torch.complex64 in NPU 2026-01-12 11:34:41 +08:00
feng0w0
544c391936 [model][NPU]:Wan model rope use torch.complex64 in NPU 2026-01-12 11:24:11 +08:00
Feng
f4d06ce3fc Merge branch 'modelscope:main' into wan_rope 2026-01-12 11:21:09 +08:00
Zhongjie Duan
ffedb9eb52 Merge pull request #1187 from jiaqixuac/patch-1
Update package inclusion pattern in pyproject.toml
2026-01-12 10:12:20 +08:00
Zhongjie Duan
381067515c Merge pull request #1176 from Feng0w0/z-image-rope
[model][NPU]: Z-image model support NPU
2026-01-12 10:11:22 +08:00
Zhongjie Duan
00f2d1aa5d Merge pull request #1169 from Feng0w0/sample_add
Docs:Supplement NPU training script samples and documentation instruction
2026-01-12 10:08:38 +08:00
Zhongjie Duan
8cc3bece6d Merge pull request #1167 from Feng0w0/install_env
Docs:Supplement NPU environment installation document
2026-01-12 10:07:30 +08:00
Jiaqi Xu
f4bf592064 Update pyproject.toml
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-10 09:32:35 +08:00
Jiaqi Xu
3235393fb5 Update package inclusion pattern in pyproject.toml
Update to install all the sub-packages inside diffsynth. Otherwise, the installed packages only contain __init__.py
2026-01-10 09:28:45 +08:00
feng0w0
3b662da31e [model][NPU]:Wan model rope use torch.complex64 in NPU 2026-01-09 18:11:40 +08:00
feng0w0
19ce3048c1 [model][NPU]:Wan model rope use torch.complex64 in NPU 2026-01-09 18:06:41 +08:00
Zhongjie Duan
de0aa946f7 Merge pull request #1184 from modelscope/z-image-omni-base-dev
update package version
2026-01-08 17:27:33 +08:00
Zhongjie Duan
a13ecfc46b Merge pull request #1183 from modelscope/z-image-omni-base-dev
fix unused parameters in z-image-omni-base
2026-01-08 17:03:20 +08:00
Zhongjie Duan
0efab85674 Support Z-Image-Omni-Base and its related models
Support Z-Image-Omni-Base and its related models.
2026-01-08 13:43:59 +08:00
feng0w0
c1c9a4853b [model][NPU]:Z-image model support NPU 2026-01-07 11:42:19 +08:00
feng0w0
3ee5f53a36 [model][NPU]:Z-image model support NPU 2026-01-07 11:31:22 +08:00
Zhongjie Duan
a6884f6b3a Merge pull request #1171 from YZBPXX/main
Fix issue where LoRa loads on a device different from Dit
2026-01-05 16:39:02 +08:00
Zhongjie Duan
b078666640 Merge pull request #1173 from modelscope/flux-compatibility-patch
flux compatibility patch
2026-01-05 16:20:25 +08:00
Artiprocher
7604ca1e52 flux compatibility patch 2026-01-05 16:04:20 +08:00
feng0w0
62c3d406d9 Docs:Supplement NPU training script samples and documentation instruction 2026-01-05 15:42:55 +08:00
feng0w0
86829120c2 Docs:Supplement NPU training script samples and documentation instruction 2026-01-05 09:59:11 +08:00
yaozhengbing
60ac96525b Fix issue where LoRa loads on a device different from Dit 2025-12-31 21:31:01 +08:00
feng0w0
07b1f5702f Docs:Supplement NPU training script samples and documentation instruction 2025-12-31 10:01:21 +08:00
feng0w0
507e7e5d36 Docs:Supplement NPU training script samples and documentation instruction 2025-12-30 19:58:47 +08:00
feng0w0
9cc1697d4d Docs:Supplement NPU environment installation document 2025-12-30 15:57:13 +08:00
feng0w0
c758769a02 训练快速上手 2025-12-29 09:25:46 +08:00
feng0w0
a5935e973a 训练快速上手 2025-12-29 09:23:59 +08:00
feng0w0
9834d72e4d 文档环境安装上手 2025-12-27 16:11:27 +08:00
feng0w0
01234e59c0 文档环境安装上手 2025-12-27 15:01:10 +08:00
263 changed files with 19155 additions and 932 deletions

View File

@@ -22,7 +22,7 @@ jobs:
- name: Install wheel
run: pip install wheel==0.44.0 && pip install -r requirements.txt
- name: Build DiffSynth
run: python setup.py sdist bdist_wheel
run: python -m build
- name: Publish package to PyPI
run: |
pip install twine

262
README.md
View File

@@ -12,6 +12,8 @@
## Introduction
> DiffSynth-Studio Documentation: [中文版](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/)、[English version](https://diffsynth-studio-doc.readthedocs.io/en/latest/)
Welcome to the magical world of Diffusion models! DiffSynth-Studio is an open-source Diffusion model engine developed and maintained by the [ModelScope Community](https://www.modelscope.cn/). We hope to foster technological innovation through framework construction, aggregate the power of the open-source community, and explore the boundaries of generative model technology!
DiffSynth currently includes two open-source projects:
@@ -23,8 +25,6 @@ DiffSynth currently includes two open-source projects:
* ModelScope AIGC Zone (for Chinese users): https://modelscope.cn/aigc/home
* ModelScope Civision (for global users): https://modelscope.ai/civision/home
> DiffSynth-Studio Documentation: [中文版](/docs/zh/README.md)、[English version](/docs/en/README.md)
We believe that a well-developed open-source code framework can lower the threshold for technical exploration. We have achieved many [interesting technologies](#innovative-achievements) based on this codebase. Perhaps you also have many wild ideas, and with DiffSynth-Studio, you can quickly realize these ideas. For this reason, we have prepared detailed documentation for developers. We hope that through these documents, developers can understand the principles of Diffusion models, and we look forward to expanding the boundaries of technology together with you.
## Update History
@@ -33,6 +33,22 @@ We believe that a well-developed open-source code framework can lower the thresh
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
- **March 2, 2026** Added support for [Anima](https://modelscope.cn/models/circlestone-labs/Anima). For details, please refer to the [documentation](docs/en/Model_Details/Anima.md). This is an interesting anime-style image generation model. We look forward to its future updates.
- **February 26, 2026** Added full and lora training support for the LTX-2 audio-video generation model. See the [documentation](/docs/en/Model_Details/LTX-2.md) for details.
- **February 10, 2026** Added inference support for the LTX-2 audio-video generation model. See the [documentation](/docs/en/Model_Details/LTX-2.md) for details. Support for model training will be implemented in the future.
- **February 2, 2026** The first document of the Research Tutorial series is now available, guiding you through training a small 0.1B text-to-image model from scratch. For details, see the [documentation](/docs/en/Research_Tutorial/train_from_scratch.md) and [model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel). We hope DiffSynth-Studio can evolve into a more powerful training framework for Diffusion models.
- **January 27, 2026**: [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) is released, and our [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) model is released concurrently. You can use it in [ModelScope Studios](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L). For details, see the [documentation](/docs/zh/Model_Details/Z-Image.md).
- **January 19, 2026**: Added support for [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) and [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/FLUX2.md) and [example code](/examples/flux2/) are now available.
- **January 12, 2026**: We trained and open-sourced a text-guided image layer separation model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)). Given an input image and a textual description, the model isolates the image layer corresponding to the described content. For more details, please refer to our blog post ([Chinese version](https://modelscope.cn/learn/4938), [English version](https://huggingface.co/blog/kelseye/qwen-image-layered-control)).
- **December 24, 2025**: Based on Qwen-Image-Edit-2511, we trained an In-Context Editing LoRA model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)). This model takes three images as input (Image A, Image B, and Image C), and automatically analyzes the transformation from Image A to Image B, then applies the same transformation to Image C to generate Image D. For more details, please refer to our blog post ([Chinese version](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g), [English version](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora)).
- **December 9, 2025** We release a wild model based on DiffSynth-Studio 2.0: [Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L) (Image-to-LoRA). This model takes an image as input and outputs a LoRA. Although this version still has significant room for improvement in terms of generalization, detail preservation, and other aspects, we are open-sourcing these models to inspire more innovative research. For more details, please refer to our [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l).
- **December 4, 2025** DiffSynth-Studio 2.0 released! Many new features online
@@ -263,9 +279,14 @@ image.save("image.jpg")
Example code for Z-Image is available at: [/examples/z_image/](/examples/z_image/)
| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|Model ID|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
|-|-|-|-|-|-|-|
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
</details>
@@ -315,9 +336,67 @@ image.save("image.jpg")
Example code for FLUX.2 is available at: [/examples/flux2/](/examples/flux2/)
| Model ID | Inference | Low-VRAM Inference | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
</details>
#### Anima: [/docs/en/Model_Details/Anima.md](/docs/en/Model_Details/Anima.md)
<details>
<summary>Quick Start</summary>
Run the following code to quickly load the [circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 8GB VRAM.
```python
from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": "disk",
"onload_device": "disk",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AnimaImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors", **vram_config),
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors", **vram_config),
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait."
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
image = pipe(prompt, seed=0, num_inference_steps=50)
image.save("image.jpg")
```
</details>
<details>
<summary>Examples</summary>
Example code for Anima is located at: [/examples/anima/](/examples/anima/)
| Model ID | Inference | Low VRAM Inference | Full Training | Validation after Full Training | LoRA Training | Validation after LoRA Training |
|-|-|-|-|-|-|-|
|[circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference_low_vram/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/full/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_full/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/lora/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_lora/anima-preview.py)|
</details>
@@ -400,7 +479,10 @@ Example code for Qwen-Image is available at: [/examples/qwen_image/](/examples/q
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[FireRedTeam/FireRed-Image-Edit-1.0](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.0)|[code](/examples/qwen_image/model_inference/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.0.sh)|[code](/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.0.sh)|[code](/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.0.py)|
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
@@ -511,6 +593,132 @@ Example code for FLUX.1 is available at: [/examples/flux/](/examples/flux/)
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
#### LTX-2: [/docs/en/Model_Details/LTX-2.md](/docs/en/Model_Details/LTX-2.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8GB of VRAM.
```python
import torch
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
vram_config = {
"offload_dtype": torch.float8_e5m2,
"offload_device": "cpu",
"onload_dtype": torch.float8_e5m2,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e5m2,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
"""
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
and avoid redundant memory usage when users only want to use part of the model.
"""
# use the repackaged modelconfig from "DiffSynth-Studio/LTX-2-Repackage" to avoid redundant model loading
pipe = LTX2AudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
# use the following modelconfig if you want to initialize model from offical checkpoints from "Lightricks/LTX-2"
# pipe = LTX2AudioVideoPipeline.from_pretrained(
# torch_dtype=torch.bfloat16,
# device="cuda",
# model_configs=[
# ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
# ],
# tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
# stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
# vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
# )
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
negative_prompt = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)
height, width, num_frames = 512 * 2, 768 * 2, 121
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
seed=43,
height=height,
width=width,
num_frames=num_frames,
tiled=True,
use_two_stage_pipeline=True,
)
write_video_audio_ltx2(
video=video,
audio=audio,
output_path='ltx2_twostage.mp4',
fps=24,
audio_sample_rate=24000,
)
```
</details>
<details>
<summary>Examples</summary>
Example code for LTX-2 is available at: [/examples/ltx2/](/examples/ltx2/)
| Model ID | Extra Args | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|-|
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)|
|[Lightricks/LTX-2-19b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|
|[Lightricks/LTX-2-19b-IC-LoRA-Detailer](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Detailer)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Detailer.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Static](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Static)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py)|-|-|-|-|
</details>
#### Wan: [/docs/en/Model_Details/Wan.md](/docs/en/Model_Details/Wan.md)
<details>
@@ -650,6 +858,37 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements.
<details>
<summary>Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation</summary>
- Paper: [Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation
](https://arxiv.org/abs/2602.03208)
- Sample Code: coming soon
|FLUX.1-dev|FLUX.1-dev + SES|Qwen-Image|Qwen-Image + SES|
|-|-|-|-|
|![Image](https://github.com/user-attachments/assets/5be15dc6-2805-4822-b04c-2573fc0f45f0)|![Image](https://github.com/user-attachments/assets/e71b8c20-1629-41d9-b0ff-185805c1da4e)|![Image](https://github.com/user-attachments/assets/7a73c968-133a-4545-9aa2-205533861cd4)|![Image](https://github.com/user-attachments/assets/c8390b22-14fe-48a0-a6e6-d6556d31235e)|
</details>
<details>
<summary>VIRAL: Visual In-Context Reasoning via Analogy in Diffusion Transformers</summary>
- Paper: [VIRAL: Visual In-Context Reasoning via Analogy in Diffusion Transformers
](https://arxiv.org/abs/2602.03210)
- Sample code: [/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py)
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)
|Example 1|Example 2|Query|Output|
|-|-|-|-|
|![Image](https://github.com/user-attachments/assets/380d2670-47bf-41cd-b5c9-37110cc4a943)|![Image](https://github.com/user-attachments/assets/7ceaf345-0992-46e6-b38f-394c2065b165)|![Image](https://github.com/user-attachments/assets/f7c26c21-6894-4d9e-b570-f1d44ca7c1de)|![Image](https://github.com/user-attachments/assets/c2bebe3b-5984-41ba-94bf-9509f6a8a990)|
</details>
<details>
<summary>AttriCtrl: Attribute Intensity Control for Image Generation Models</summary>
@@ -660,7 +899,7 @@ DiffSynth-Studio is not just an engineered model framework, but also an incubato
|brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9|
|-|-|-|-|-|
|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.5.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.7.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.9.jpg)|
|![Image](https://github.com/user-attachments/assets/e74b32a5-5b2e-4c87-9df8-487c0f8366b7)|![Image](https://github.com/user-attachments/assets/bfe8bec2-9e55-493d-9a26-7e9cce28e03d)|![Image](https://github.com/user-attachments/assets/b099dfe3-ff1f-4b96-894c-d48bbe92db7a)|![Image](https://github.com/user-attachments/assets/0a6b2982-deab-4b0d-91ad-888782de01c9)|![Image](https://github.com/user-attachments/assets/fcecb755-7d03-4020-b83a-13ad2b38705c)|
</details>
@@ -675,10 +914,10 @@ DiffSynth-Studio is not just an engineered model framework, but also an incubato
||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)|
|-|-|-|-|-|
|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_0.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|
|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|
|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|
|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_3_3.jpg)|
|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |![Image](https://github.com/user-attachments/assets/01c54d5a-4f00-4c2e-982a-4ec0a4c6a6e3)|![Image](https://github.com/user-attachments/assets/e6621457-b9f1-437c-bcc8-3e12e41646de)|![Image](https://github.com/user-attachments/assets/4b7f721f-a2e5-416c-af2c-b53ef236c321)|![Image](https://github.com/user-attachments/assets/802d554e-0402-482c-9f28-87605f8fe318)|
|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |![Image](https://github.com/user-attachments/assets/e6621457-b9f1-437c-bcc8-3e12e41646de)|![Image](https://github.com/user-attachments/assets/43720a9f-aa27-4918-947d-545389375d46)|![Image](https://github.com/user-attachments/assets/418c725b-6d35-41f4-b18f-c7e3867cc142)|![Image](https://github.com/user-attachments/assets/8c8f22fa-9643-4019-b6d7-396d8b7fed9a)|
|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |![Image](https://github.com/user-attachments/assets/4b7f721f-a2e5-416c-af2c-b53ef236c321)|![Image](https://github.com/user-attachments/assets/418c725b-6d35-41f4-b18f-c7e3867cc142)|![Image](https://github.com/user-attachments/assets/041a3f9a-c7b4-4311-8582-cb71a7226d80)|![Image](https://github.com/user-attachments/assets/b54ebaa4-31a7-4536-a2c1-496adba0c013)|
|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |![Image](https://github.com/user-attachments/assets/802d554e-0402-482c-9f28-87605f8fe318)|![Image](https://github.com/user-attachments/assets/8c8f22fa-9643-4019-b6d7-396d8b7fed9a)|![Image](https://github.com/user-attachments/assets/b54ebaa4-31a7-4536-a2c1-496adba0c013)|![Image](https://github.com/user-attachments/assets/a640fd54-3192-49a0-9281-b43d9ba64f09)|
</details>
@@ -769,4 +1008,3 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-47
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
</details>

View File

@@ -12,6 +12,8 @@
## 简介
> DiffSynth-Studio 文档:[中文版](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/)、[English version](https://diffsynth-studio-doc.readthedocs.io/en/latest/)
欢迎来到 Diffusion 模型的魔法世界DiffSynth-Studio 是由[魔搭社区](https://www.modelscope.cn/)团队开发和维护的开源 Diffusion 模型引擎。我们期望以框架建设孵化技术创新,凝聚开源社区的力量,探索生成式模型技术的边界!
DiffSynth 目前包括两个开源项目:
@@ -23,8 +25,6 @@ DiffSynth 目前包括两个开源项目:
* 魔搭社区 AIGC 专区 (面向中国用户): https://modelscope.cn/aigc/home
* ModelScope Civision (for global users): https://modelscope.ai/civision/home
> DiffSynth-Studio 文档:[中文版](/docs/zh/README.md)、[English version](/docs/en/README.md)
我们相信,一个完善的开源代码框架能够降低技术探索的门槛,我们基于这个代码库搞出了不少[有意思的技术](#创新成果)。或许你也有许多天马行空的构想,借助 DiffSynth-Studio你可以快速实现这些想法。为此我们为开发者准备了详细的文档我们希望通过这些文档帮助开发者理解 Diffusion 模型的原理,更期待与你一同拓展技术的边界。
## 更新历史
@@ -33,6 +33,22 @@ DiffSynth 目前包括两个开源项目:
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责因此新功能的开发进展会比较缓慢issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
- **2026年3月2日** 新增对[Anima](https://modelscope.cn/models/circlestone-labs/Anima)的支持,详见[文档](docs/zh/Model_Details/Anima.md)。这是一个有趣的动漫风格图像生成模型,我们期待其后续的模型更新。
- **2026年2月26日** 新增对[LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2)音视频生成模型全量微调与LoRA训练支持详见[文档](docs/zh/Model_Details/LTX-2.md)。
- **2026年2月10日** 新增对[LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2)音视频生成模型的推理支持,详见[文档](docs/zh/Model_Details/LTX-2.md),后续将推进模型训练的支持。
- **2026年2月2日** Research Tutorial 的第一篇文档上线,带你从零开始训练一个 0.1B 的小型文生图模型,详见[文档](/docs/zh/Research_Tutorial/train_from_scratch.md)、[模型](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel),我们希望 DiffSynth-Studio 能够成为一个更强大的 Diffusion 模型训练框架。
- **2026年1月27日** [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) 发布,我们的 [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) 模型同步发布,在[魔搭创空间](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L)可直接体验,详见[文档](/docs/zh/Model_Details/Z-Image.md)。
- **2026年1月19日** 新增对 [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 和 [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/FLUX2.md)和[示例代码](/examples/flux2/)现已可用。
- **2026年1月12日** 我们训练并开源了一个文本引导的图层拆分模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)),这一模型输入一张图与一段文本描述,模型会将图像中与文本描述相关的图层拆分出来。更多细节请阅读我们的 blog[中文版](https://modelscope.cn/learn/4938)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-layered-control))。
- **2025年12月24日** 我们基于 Qwen-Image-Edit-2511 训练了一个 In-Context Editing LoRA 模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)这个模型可以输入三张图图A、图B、图C模型会自行分析图A到图B的变化并将这样的变化应用到图C生成图D。更多细节请阅读我们的 blog[中文版](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora))。
- **2025年12月9日** 我们基于 DiffSynth-Studio 2.0 训练了一个疯狂的模型:[Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)Image to LoRA。这一模型以图像为输入以 LoRA 为输出。尽管这个版本的模型在泛化能力、细节保持能力等方面还有很大改进空间,我们将这些模型开源,以启发更多创新性的研究工作。更多细节,请参考我们的 [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l)。
- **2025年12月4日** DiffSynth-Studio 2.0 发布!众多新功能上线
@@ -265,7 +281,12 @@ Z-Image 的示例代码位于:[/examples/z_image/](/examples/z_image/)
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
</details>
@@ -315,9 +336,67 @@ image.save("image.jpg")
FLUX.2 的示例代码位于:[/examples/flux2/](/examples/flux2/)
|模型 ID|推理|低显存推理|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
</details>
#### Anima: [/docs/zh/Model_Details/Anima.md](/docs/zh/Model_Details/Anima.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
```python
from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": "disk",
"onload_device": "disk",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AnimaImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors", **vram_config),
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors", **vram_config),
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait."
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
image = pipe(prompt, seed=0, num_inference_steps=50)
image.save("image.jpg")
```
</details>
<details>
<summary>示例代码</summary>
Anima 的示例代码位于:[/examples/anima/](/examples/anima/)
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference_low_vram/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/full/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_full/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/lora/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_lora/anima-preview.py)|
</details>
@@ -400,7 +479,10 @@ Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[FireRedTeam/FireRed-Image-Edit-1.0](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.0)|[code](/examples/qwen_image/model_inference/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.0.sh)|[code](/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.0.sh)|[code](/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.0.py)|
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
@@ -511,6 +593,132 @@ FLUX.1 的示例代码位于:[/examples/flux/](/examples/flux/)
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
#### LTX-2: [/docs/zh/Model_Details/LTX-2.md](/docs/zh/Model_Details/LTX-2.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8GB 显存即可运行。
```python
import torch
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
vram_config = {
"offload_dtype": torch.float8_e5m2,
"offload_device": "cpu",
"onload_dtype": torch.float8_e5m2,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e5m2,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
"""
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
and avoid redundant memory usage when users only want to use part of the model.
"""
# use the repackaged modelconfig from "DiffSynth-Studio/LTX-2-Repackage" to avoid redundant model loading
pipe = LTX2AudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
# use the following modelconfig if you want to initialize model from offical checkpoints from "Lightricks/LTX-2"
# pipe = LTX2AudioVideoPipeline.from_pretrained(
# torch_dtype=torch.bfloat16,
# device="cuda",
# model_configs=[
# ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
# ],
# tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
# stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
# vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
# )
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
negative_prompt = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)
height, width, num_frames = 512 * 2, 768 * 2, 121
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
seed=43,
height=height,
width=width,
num_frames=num_frames,
tiled=True,
use_two_stage_pipeline=True,
)
write_video_audio_ltx2(
video=video,
audio=audio,
output_path='ltx2_twostage.mp4',
fps=24,
audio_sample_rate=24000,
)
```
</details>
<details>
<summary>示例代码</summary>
LTX-2 的示例代码位于:[/examples/ltx2/](/examples/ltx2/)
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|-|
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)|
|[Lightricks/LTX-2-19b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|
|[Lightricks/LTX-2-19b-IC-LoRA-Detailer](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Detailer)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Detailer.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Static](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Static)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py)|-|-|-|-|
</details>
#### Wan: [/docs/zh/Model_Details/Wan.md](/docs/zh/Model_Details/Wan.md)
<details>
@@ -650,6 +858,37 @@ Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/)
DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。
<details>
<summary>Spectral Evolution Search: 用于奖励对齐图像生成的高效推理阶段缩放</summary>
- 论文:[Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation
](https://arxiv.org/abs/2602.03208)
- 代码样例coming soon
|FLUX.1-dev|FLUX.1-dev + SES|Qwen-Image|Qwen-Image + SES|
|-|-|-|-|
|![Image](https://github.com/user-attachments/assets/5be15dc6-2805-4822-b04c-2573fc0f45f0)|![Image](https://github.com/user-attachments/assets/e71b8c20-1629-41d9-b0ff-185805c1da4e)|![Image](https://github.com/user-attachments/assets/7a73c968-133a-4545-9aa2-205533861cd4)|![Image](https://github.com/user-attachments/assets/c8390b22-14fe-48a0-a6e6-d6556d31235e)|
</details>
<details>
<summary>VIRAL基于DiT模型的类比视觉上下文推理</summary>
- 论文:[VIRAL: Visual In-Context Reasoning via Analogy in Diffusion Transformers
](https://arxiv.org/abs/2602.03210)
- 代码样例:[/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py)
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)
|Example 1|Example 2|Query|Output|
|-|-|-|-|
|![Image](https://github.com/user-attachments/assets/380d2670-47bf-41cd-b5c9-37110cc4a943)|![Image](https://github.com/user-attachments/assets/7ceaf345-0992-46e6-b38f-394c2065b165)|![Image](https://github.com/user-attachments/assets/f7c26c21-6894-4d9e-b570-f1d44ca7c1de)|![Image](https://github.com/user-attachments/assets/c2bebe3b-5984-41ba-94bf-9509f6a8a990)|
</details>
<details>
<summary>AttriCtrl: 图像生成模型的属性强度控制</summary>
@@ -661,7 +900,7 @@ DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果
|brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9|
|-|-|-|-|-|
|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.5.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.7.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.9.jpg)|
|![Image](https://github.com/user-attachments/assets/e74b32a5-5b2e-4c87-9df8-487c0f8366b7)|![Image](https://github.com/user-attachments/assets/bfe8bec2-9e55-493d-9a26-7e9cce28e03d)|![Image](https://github.com/user-attachments/assets/b099dfe3-ff1f-4b96-894c-d48bbe92db7a)|![Image](https://github.com/user-attachments/assets/0a6b2982-deab-4b0d-91ad-888782de01c9)|![Image](https://github.com/user-attachments/assets/fcecb755-7d03-4020-b83a-13ad2b38705c)|
</details>
@@ -677,10 +916,10 @@ DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果
||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)|
|-|-|-|-|-|
|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_0.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|
|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|
|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|
|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_3_3.jpg)|
|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |![Image](https://github.com/user-attachments/assets/01c54d5a-4f00-4c2e-982a-4ec0a4c6a6e3)|![Image](https://github.com/user-attachments/assets/e6621457-b9f1-437c-bcc8-3e12e41646de)|![Image](https://github.com/user-attachments/assets/4b7f721f-a2e5-416c-af2c-b53ef236c321)|![Image](https://github.com/user-attachments/assets/802d554e-0402-482c-9f28-87605f8fe318)|
|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |![Image](https://github.com/user-attachments/assets/e6621457-b9f1-437c-bcc8-3e12e41646de)|![Image](https://github.com/user-attachments/assets/43720a9f-aa27-4918-947d-545389375d46)|![Image](https://github.com/user-attachments/assets/418c725b-6d35-41f4-b18f-c7e3867cc142)|![Image](https://github.com/user-attachments/assets/8c8f22fa-9643-4019-b6d7-396d8b7fed9a)|
|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |![Image](https://github.com/user-attachments/assets/4b7f721f-a2e5-416c-af2c-b53ef236c321)|![Image](https://github.com/user-attachments/assets/418c725b-6d35-41f4-b18f-c7e3867cc142)|![Image](https://github.com/user-attachments/assets/041a3f9a-c7b4-4311-8582-cb71a7226d80)|![Image](https://github.com/user-attachments/assets/b54ebaa4-31a7-4536-a2c1-496adba0c013)|
|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |![Image](https://github.com/user-attachments/assets/802d554e-0402-482c-9f28-87605f8fe318)|![Image](https://github.com/user-attachments/assets/8c8f22fa-9643-4019-b6d7-396d8b7fed9a)|![Image](https://github.com/user-attachments/assets/b54ebaa4-31a7-4536-a2c1-496adba0c013)|![Image](https://github.com/user-attachments/assets/a640fd54-3192-49a0-9281-b43d9ba64f09)|
</details>

View File

@@ -1,2 +1,2 @@
from .model_configs import MODEL_CONFIGS
from .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS
from .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS, VERSION_CHECKER_MAPS

View File

@@ -317,6 +317,13 @@ flux_series = [
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
},
{
# Supported due to historical reasons.
"model_hash": "605c56eab23e9e2af863ad8f0813a25d",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverterFromDiffusers",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors")
"model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
@@ -474,6 +481,13 @@ flux_series = [
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
"extra_kwargs": {"disable_guidance_embedder": True},
},
{
# Example: ModelConfig(model_id="MAILAND/majicflus_v1", origin_file_pattern="majicflus_v134.safetensors")
"model_hash": "3394f306c4cbf04334b712bf5aaed95f",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
},
]
flux2_series = [
@@ -496,6 +510,28 @@ flux2_series = [
"model_name": "flux2_vae",
"model_class": "diffsynth.models.flux2_vae.Flux2VAE",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors")
"model_hash": "3bde7b817fec8143028b6825a63180df",
"model_name": "flux2_dit",
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 7680, "num_attention_heads": 24, "num_layers": 5, "num_single_layers": 20}
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors")
"model_hash": "9195f3ea256fcd0ae6d929c203470754",
"model_name": "z_image_text_encoder",
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
"extra_kwargs": {"model_size": "8B"},
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors")
"model_hash": "39c6fc48f07bebecedbbaa971ff466c8",
"model_name": "flux2_dit",
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 12288, "num_attention_heads": 32, "num_layers": 8, "num_single_layers": 24}
},
]
z_image_series = [
@@ -553,6 +589,150 @@ z_image_series = [
"model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
"extra_kwargs": {"compress_dim": 128},
},
{
# Example: ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors")
"model_hash": "1392adecee344136041e70553f875f31",
"model_name": "z_image_text_encoder",
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
"extra_kwargs": {"model_size": "0.6B"},
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
},
]
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series
"""
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
and avoid redundant memory usage when users only want to use part of the model.
"""
ltx2_series = [
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_dit",
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors")
"model_hash": "c567aaa37d5ed7454c73aa6024458661",
"model_name": "ltx2_dit",
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_video_vae_encoder",
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors")
"model_hash": "7f7e904a53260ec0351b05f32153754b",
"model_name": "ltx2_video_vae_encoder",
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_video_vae_decoder",
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors")
"model_hash": "dc6029ca2825147872b45e35a2dc3a97",
"model_name": "ltx2_video_vae_decoder",
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_audio_vae_decoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors")
"model_hash": "7d7823dde8f1ea0b50fb07ac329dd4cb",
"model_name": "ltx2_audio_vae_decoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_audio_vocoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors")
"model_hash": "f471360f6b24bef702ab73133d9f8bb9",
"model_name": "ltx2_audio_vocoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_audio_vae_encoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_encoder.safetensors")
"model_hash": "29338f3b95e7e312a3460a482e4f4554",
"model_name": "ltx2_audio_vae_encoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_text_encoder_post_modules",
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors")
"model_hash": "981629689c8be92a712ab3c5eb4fc3f6",
"model_name": "ltx2_text_encoder_post_modules",
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
},
{
# Example: ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors")
"model_hash": "33917f31c4a79196171154cca39f165e",
"model_name": "ltx2_text_encoder",
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "c79c458c6e99e0e14d47e676761732d2",
"model_name": "ltx2_latent_upsampler",
"model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
},
]
anima_series = [
{
# Example: ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors")
"model_hash": "a9995952c2d8e63cf82e115005eb61b9",
"model_name": "z_image_text_encoder",
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
"extra_kwargs": {"model_size": "0.6B"},
},
{
# Example: ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors")
"model_hash": "417673936471e79e31ed4d186d7a3f4a",
"model_name": "anima_dit",
"model_class": "diffsynth.models.anima_dit.AnimaDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.anima_dit.AnimaDiTStateDictConverter",
}
]
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series

View File

@@ -210,4 +210,57 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.ltx2_dit.LTXModel": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler": {
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_video_vae.LTX2VideoEncoder": {
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_video_vae.LTX2VideoDecoder": {
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder": {
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_audio_vae.LTX2Vocoder": {
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.ltx2_text_encoder.Embeddings1DConnector": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_text_encoder.LTX2TextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"transformers.models.gemma3.modeling_gemma3.Gemma3MultiModalProjector": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.anima_dit.AnimaDiT": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
}
def QwenImageTextEncoder_Module_Map_Updater():
current = VRAM_MANAGEMENT_MODULE_MAPS["diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder"]
from packaging import version
import transformers
if version.parse(transformers.__version__) >= version.parse("5.2.0"):
# The Qwen2RMSNorm in transformers 5.2.0+ has been renamed to Qwen2_5_VLRMSNorm, so we need to update the module map accordingly
current.pop("transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm", None)
current["transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRMSNorm"] = "diffsynth.core.vram.layers.AutoWrappedModule"
return current
VERSION_CHECKER_MAPS = {
"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": QwenImageTextEncoder_Module_Map_Updater,
}

View File

@@ -52,7 +52,7 @@ def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="
if k_pattern != required_in_pattern:
k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims)
if v_pattern != required_in_pattern:
v = rearrange(v, f"{q_pattern} -> {required_in_pattern}", **dims)
v = rearrange(v, f"{v_pattern} -> {required_in_pattern}", **dims)
return q, k, v

View File

@@ -218,3 +218,20 @@ class LoadAudio(DataProcessingOperator):
import librosa
input_audio, sample_rate = librosa.load(data, sr=self.sr)
return input_audio
class LoadAudioWithTorchaudio(DataProcessingOperator):
def __init__(self, duration=5):
self.duration = duration
def __call__(self, data: str):
import torchaudio
waveform, sample_rate = torchaudio.load(data)
target_samples = int(self.duration * sample_rate)
current_samples = waveform.shape[-1]
if current_samples > target_samples:
waveform = waveform[..., :target_samples]
elif current_samples < target_samples:
padding = target_samples - current_samples
waveform = torch.nn.functional.pad(waveform, (0, padding))
return waveform, sample_rate

View File

@@ -10,6 +10,7 @@ class UnifiedDataset(torch.utils.data.Dataset):
data_file_keys=tuple(),
main_data_operator=lambda x: x,
special_operator_map=None,
max_data_items=None,
):
self.base_path = base_path
self.metadata_path = metadata_path
@@ -18,6 +19,7 @@ class UnifiedDataset(torch.utils.data.Dataset):
self.main_data_operator = main_data_operator
self.cached_data_operator = LoadTorchPickle()
self.special_operator_map = {} if special_operator_map is None else special_operator_map
self.max_data_items = max_data_items
self.data = []
self.cached_data = []
self.load_from_cache = metadata_path is None
@@ -97,7 +99,9 @@ class UnifiedDataset(torch.utils.data.Dataset):
return data
def __len__(self):
if self.load_from_cache:
if self.max_data_items is not None:
return self.max_data_items
elif self.load_from_cache:
return len(self.cached_data) * self.repeat
else:
return len(self.data) * self.repeat

View File

@@ -1 +1,2 @@
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name
from .npu_compatible_device import IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE

View File

@@ -1,5 +1,5 @@
import torch, glob, os
from typing import Optional, Union
from typing import Optional, Union, Dict
from dataclasses import dataclass
from modelscope import snapshot_download
from huggingface_hub import snapshot_download as hf_snapshot_download
@@ -23,13 +23,14 @@ class ModelConfig:
computation_device: Optional[Union[str, torch.device]] = None
computation_dtype: Optional[torch.dtype] = None
clear_parameters: bool = False
state_dict: Dict[str, torch.Tensor] = None
def check_input(self):
if self.path is None and self.model_id is None:
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
def parse_original_file_pattern(self):
if self.origin_file_pattern is None or self.origin_file_pattern == "":
if self.origin_file_pattern in [None, "", "./"]:
return "*"
elif self.origin_file_pattern.endswith("/"):
return self.origin_file_pattern + "*"
@@ -98,7 +99,7 @@ class ModelConfig:
if self.require_downloading():
self.download()
if self.path is None:
if self.origin_file_pattern is None or self.origin_file_pattern == "":
if self.origin_file_pattern in [None, "", "./"]:
self.path = os.path.join(self.local_model_path, self.model_id)
else:
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))

View File

@@ -2,16 +2,25 @@ from safetensors import safe_open
import torch, hashlib
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
def load_state_dict(file_path, torch_dtype=None, device="cpu", pin_memory=False, verbose=0):
if isinstance(file_path, list):
state_dict = {}
for file_path_ in file_path:
state_dict.update(load_state_dict(file_path_, torch_dtype, device))
return state_dict
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose))
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
if verbose >= 1:
print(f"Loading file [started]: {file_path}")
if file_path.endswith(".safetensors"):
state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
else:
state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
# If load state dict in CPU memory, `pin_memory=True` will make `model.to("cuda")` faster.
if pin_memory:
for i in state_dict:
state_dict[i] = state_dict[i].pin_memory()
if verbose >= 1:
print(f"Loading file [done]: {file_path}")
return state_dict
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):

View File

@@ -3,14 +3,14 @@ from ..vram.disk_map import DiskMap
from ..vram.layers import enable_vram_management
from .file import load_state_dict
import torch
from contextlib import contextmanager
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils import ContextManagers
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None):
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
config = {} if config is None else config
# Why do we use `skip_model_initialization`?
# It skips the random initialization of model parameters,
# thereby speeding up model loading and avoiding excessive memory usage.
with skip_model_initialization():
with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)):
model = model_class(**config)
# What is `module_map`?
# This is a module mapping table for VRAM management.
@@ -20,7 +20,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
dtype = [d for d in dtypes if d != "disk"][0]
if vram_config["offload_device"] != "disk":
state_dict = DiskMap(path, device, torch_dtype=dtype)
if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype)
if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict)
else:
@@ -35,7 +35,9 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
# Sometimes a model file contains multiple models,
# and DiskMap can load only the parameters of a single model,
# avoiding the need to load all parameters in the file.
if use_disk_map:
if state_dict is not None:
pass
elif use_disk_map:
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
else:
state_dict = load_state_dict(path, torch_dtype, device)
@@ -46,7 +48,14 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
state_dict = state_dict_converter(state_dict)
else:
state_dict = {i: state_dict[i] for i in state_dict}
model.load_state_dict(state_dict, assign=True)
# Why does DeepSpeed ZeRO Stage 3 need to be handled separately?
# Because at this stage, model parameters are partitioned across multiple GPUs.
# Loading them directly could lead to excessive GPU memory consumption.
if is_deepspeed_zero3_enabled():
from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model
_load_state_dict_into_zero3_model(model, state_dict)
else:
model.load_state_dict(state_dict, assign=True)
# Why do we call `to()`?
# Because some models override the behavior of `to()`,
# especially those from libraries like Transformers.
@@ -77,3 +86,20 @@ def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=tor
}
enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
return model
def get_init_context(torch_dtype, device):
if is_deepspeed_zero3_enabled():
from transformers.modeling_utils import set_zero3_state
import deepspeed
# Why do we use "deepspeed.zero.Init"?
# Weight segmentation of the model can be performed on the CPU side
# and loading the segmented weights onto the computing card
init_contexts = [deepspeed.zero.Init(remote_device=device, dtype=torch_dtype), set_zero3_state()]
else:
# Why do we use `skip_model_initialization`?
# It skips the random initialization of model parameters,
# thereby speeding up model loading and avoiding excessive memory usage.
init_contexts = [skip_model_initialization()]
return init_contexts

View File

@@ -0,0 +1,30 @@
import torch
from ..device.npu_compatible_device import get_device_type
try:
import torch_npu
except:
pass
def rms_norm_forward_npu(self, hidden_states):
"npu rms fused operator for RMSNorm.forward from diffsynth\models\general_modules.py"
if hidden_states.dtype != self.weight.dtype:
hidden_states = hidden_states.to(self.weight.dtype)
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.eps)[0]
def rms_norm_forward_transformers_npu(self, hidden_states):
"npu rms fused operator for transformers"
if hidden_states.dtype != self.weight.dtype:
hidden_states = hidden_states.to(self.weight.dtype)
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
def rotary_emb_Zimage_npu(self, x_in: torch.Tensor, freqs_cis: torch.Tensor):
"npu rope fused operator for Zimage"
with torch.amp.autocast(get_device_type(), enabled=False):
freqs_cis = freqs_cis.unsqueeze(2)
cos, sin = torch.chunk(torch.view_as_real(freqs_cis), 2, dim=-1)
cos = cos.expand(-1, -1, -1, -1, 2).flatten(-2)
sin = sin.expand(-1, -1, -1, -1, 2).flatten(-2)
return torch_npu.npu_rotary_mul(x_in, cos, sin, rotary_mode="interleave").to(x_in)

View File

@@ -2,7 +2,7 @@ import torch, copy
from typing import Union
from .initialization import skip_model_initialization
from .disk_map import DiskMap
from ..device import parse_device_type
from ..device import parse_device_type, get_device_name, IS_NPU_AVAILABLE
class AutoTorchModule(torch.nn.Module):
@@ -63,7 +63,7 @@ class AutoTorchModule(torch.nn.Module):
return r
def check_free_vram(self):
device = self.computation_device if self.computation_device != "npu" else "npu:0"
device = self.computation_device if not IS_NPU_AVAILABLE else get_device_name()
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
return used_memory < self.vram_limit

View File

@@ -4,9 +4,11 @@ import numpy as np
from einops import repeat, reduce
from typing import Union
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
from ..core.device.npu_compatible_device import get_device_type
from ..utils.lora import GeneralLoRALoader
from ..models.model_loader import ModelPool
from ..utils.controlnet import ControlNetInput
from ..core.device import get_device_name, IS_NPU_AVAILABLE
class PipelineUnit:
@@ -60,7 +62,7 @@ class BasePipeline(torch.nn.Module):
def __init__(
self,
device="cuda", torch_dtype=torch.float16,
device=get_device_type(), torch_dtype=torch.float16,
height_division_factor=64, width_division_factor=64,
time_division_factor=None, time_division_remainder=None,
):
@@ -92,20 +94,23 @@ class BasePipeline(torch.nn.Module):
return self
def check_resize_height_width(self, height, width, num_frames=None):
def check_resize_height_width(self, height, width, num_frames=None, verbose=1):
# Shape check
if height % self.height_division_factor != 0:
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
if verbose > 0:
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
if width % self.width_division_factor != 0:
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
if verbose > 0:
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
if num_frames is None:
return height, width
else:
if num_frames % self.time_division_factor != self.time_division_remainder:
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
if verbose > 0:
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
return height, width, num_frames
@@ -177,7 +182,7 @@ class BasePipeline(torch.nn.Module):
def get_vram(self):
device = self.device if self.device != "npu" else "npu:0"
device = self.device if not IS_NPU_AVAILABLE else get_device_name()
return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)
def get_module(self, model, name):
@@ -294,6 +299,7 @@ class BasePipeline(torch.nn.Module):
vram_config=vram_config,
vram_limit=vram_limit,
clear_parameters=model_config.clear_parameters,
state_dict=model_config.state_dict,
)
return model_pool
@@ -315,7 +321,14 @@ class BasePipeline(torch.nn.Module):
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
if isinstance(noise_pred_posi, tuple):
# Separately handling different output types of latents, eg. video and audio latents.
noise_pred = tuple(
n_nega + cfg_scale * (n_posi - n_nega)
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
)
else:
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
return noise_pred

View File

@@ -4,13 +4,15 @@ from typing_extensions import Literal
class FlowMatchScheduler():
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image"] = "FLUX.1"):
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning"] = "FLUX.1"):
self.set_timesteps_fn = {
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
"Wan": FlowMatchScheduler.set_timesteps_wan,
"Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image,
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
"Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
}.get(template, FlowMatchScheduler.set_timesteps_flux)
self.num_train_timesteps = 1000
@@ -70,6 +72,28 @@ class FlowMatchScheduler():
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def set_timesteps_qwen_image_lightning(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
sigma_min = 0.0
sigma_max = 1.0
num_train_timesteps = 1000
base_shift = math.log(3)
max_shift = math.log(3)
# Sigmas
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
# Mu
if exponential_shift_mu is not None:
mu = exponential_shift_mu
elif dynamic_shift_len is not None:
mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len, base_shift=base_shift, max_shift=max_shift)
else:
mu = 0.8
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
# Timesteps
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def compute_empirical_mu(image_seq_len, num_steps):
a1, b1 = 8.73809524e-05, 1.89833333
@@ -89,13 +113,18 @@ class FlowMatchScheduler():
return float(mu)
@staticmethod
def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=1024//16*1024//16):
def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None):
sigma_min = 1 / num_inference_steps
sigma_max = 1.0
num_train_timesteps = 1000
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps)
if dynamic_shift_len is None:
# If you ask me why I set mu=0.8,
# I can only say that it yields better training results.
mu = 0.8
else:
mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps)
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@@ -116,7 +145,35 @@ class FlowMatchScheduler():
timestep_id = torch.argmin((timesteps - timestep).abs())
timesteps[timestep_id] = timestep
return sigmas, timesteps
@staticmethod
def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None):
num_train_timesteps = 1000
if special_case == "stage2":
sigmas = torch.Tensor([0.909375, 0.725, 0.421875])
elif special_case == "ditilled_stage1":
sigmas = torch.Tensor([1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875])
else:
dynamic_shift_len = dynamic_shift_len or 4096
sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image(
image_seq_len=dynamic_shift_len,
base_seq_len=1024,
max_seq_len=4096,
base_shift=0.95,
max_shift=2.05,
)
sigma_min = 0.0
sigma_max = 1.0
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
sigmas = math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1))
# Shift terminal
one_minus_z = 1.0 - sigmas
scale_factor = one_minus_z[-1] / (1 - terminal)
sigmas = 1.0 - (one_minus_z / scale_factor)
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
def set_training_weight(self):
steps = 1000
x = self.timesteps

View File

@@ -10,7 +10,7 @@ class ModelLogger:
self.num_steps = 0
def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None):
def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None, **kwargs):
self.num_steps += 1
if save_steps is not None and self.num_steps % save_steps == 0:
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
@@ -18,8 +18,8 @@ class ModelLogger:
def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id):
accelerator.wait_for_everyone()
state_dict = accelerator.get_state_dict(model)
if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True)
@@ -34,8 +34,8 @@ class ModelLogger:
def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name):
accelerator.wait_for_everyone()
state_dict = accelerator.get_state_dict(model)
if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True)

View File

@@ -13,14 +13,51 @@ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
if "first_frame_latents" in inputs:
inputs["latents"][:, :, 0:1] = inputs["first_frame_latents"]
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
if "first_frame_latents" in inputs:
noise_pred = noise_pred[:, :, 1:]
training_target = training_target[:, :, 1:]
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * pipe.scheduler.training_weight(timestep)
return loss
def FlowMatchSFTAudioVideoLoss(pipe: BasePipeline, **inputs):
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
# video
noise = torch.randn_like(inputs["input_latents"])
inputs["video_latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
# audio
if inputs.get("audio_input_latents") is not None:
audio_noise = torch.randn_like(inputs["audio_input_latents"])
inputs["audio_latents"] = pipe.scheduler.add_noise(inputs["audio_input_latents"], audio_noise, timestep)
training_target_audio = pipe.scheduler.training_target(inputs["audio_input_latents"], audio_noise, timestep)
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
noise_pred, noise_pred_audio = pipe.model_fn(**models, **inputs, timestep=timestep)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * pipe.scheduler.training_weight(timestep)
if inputs.get("audio_input_latents") is not None:
loss_audio = torch.nn.functional.mse_loss(noise_pred_audio.float(), training_target_audio.float())
loss_audio = loss_audio * pipe.scheduler.training_weight(timestep)
loss = loss + loss_audio
return loss
def DirectDistillLoss(pipe: BasePipeline, **inputs):
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
pipe.scheduler.training = True
@@ -84,7 +121,9 @@ class TrajectoryImitationLoss(torch.nn.Module):
progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs())
latents_ = trajectory_teacher[progress_id_teacher]
target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma)
denom = sigma_ - sigma
denom = torch.sign(denom) * torch.clamp(denom.abs(), min=1e-6)
target = (latents_ - inputs_shared["latents"]) / denom
loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep)
return loss

View File

@@ -27,7 +27,7 @@ def launch_training_task(
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
model.to(device=accelerator.device)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
for epoch_id in range(num_epochs):
@@ -40,7 +40,7 @@ def launch_training_task(
loss = model(data)
accelerator.backward(loss)
optimizer.step()
model_logger.on_step_end(accelerator, model, save_steps)
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
scheduler.step()
if save_steps is None:
model_logger.on_epoch_end(accelerator, model, epoch_id)
@@ -59,6 +59,7 @@ def launch_data_process_task(
num_workers = args.dataset_num_workers
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
model.to(device=accelerator.device)
model, dataloader = accelerator.prepare(model, dataloader)
for data_id, data in enumerate(tqdm(dataloader)):

View File

@@ -1,4 +1,4 @@
import torch, json
import torch, json, os
from ..core import ModelConfig, load_state_dict
from ..utils.controlnet import ControlNetInput
from peft import LoraConfig, inject_adapter_in_model
@@ -127,16 +127,67 @@ class DiffusionTrainingModule(torch.nn.Module):
if model_id_with_origin_paths is not None:
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
for model_id_with_origin_path in model_id_with_origin_paths:
model_id, origin_file_pattern = model_id_with_origin_path.split(":")
vram_config = self.parse_vram_config(
fp8=model_id_with_origin_path in fp8_models,
offload=model_id_with_origin_path in offload_models,
device=device
)
model_configs.append(ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern, **vram_config))
config = self.parse_path_or_model_id(model_id_with_origin_path)
model_configs.append(ModelConfig(model_id=config.model_id, origin_file_pattern=config.origin_file_pattern, **vram_config))
return model_configs
def parse_path_or_model_id(self, model_id_with_origin_path, default_value=None):
if model_id_with_origin_path is None:
return default_value
elif os.path.exists(model_id_with_origin_path):
return ModelConfig(path=model_id_with_origin_path)
else:
if ":" not in model_id_with_origin_path:
raise ValueError(f"Failed to parse model config: {model_id_with_origin_path}. This is neither a valid path nor in the format of `model_id/origin_file_pattern`.")
split_id = model_id_with_origin_path.rfind(":")
model_id = model_id_with_origin_path[:split_id]
origin_file_pattern = model_id_with_origin_path[split_id + 1:]
return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern)
def auto_detect_lora_target_modules(
self,
model: torch.nn.Module,
search_for_linear=False,
linear_detector=lambda x: min(x.weight.shape) >= 512,
block_list_detector=lambda x: isinstance(x, torch.nn.ModuleList) and len(x) > 1,
name_prefix="",
):
lora_target_modules = []
if search_for_linear:
for name, module in model.named_modules():
module_name = name_prefix + ["", "."][name_prefix != ""] + name
if isinstance(module, torch.nn.Linear) and linear_detector(module):
lora_target_modules.append(module_name)
else:
for name, module in model.named_children():
module_name = name_prefix + ["", "."][name_prefix != ""] + name
lora_target_modules += self.auto_detect_lora_target_modules(
module,
search_for_linear=block_list_detector(module),
linear_detector=linear_detector,
block_list_detector=block_list_detector,
name_prefix=module_name,
)
return lora_target_modules
def parse_lora_target_modules(self, model, lora_target_modules):
if lora_target_modules == "":
print("No LoRA target modules specified. The framework will automatically search for them.")
lora_target_modules = self.auto_detect_lora_target_modules(model)
print(f"LoRA will be patched at {lora_target_modules}.")
else:
lora_target_modules = lora_target_modules.split(",")
return lora_target_modules
def switch_pipe_to_training_mode(
self,
pipe,
@@ -166,7 +217,7 @@ class DiffusionTrainingModule(torch.nn.Module):
return
model = self.add_lora_to_model(
getattr(pipe, lora_base_model),
target_modules=lora_target_modules.split(","),
target_modules=self.parse_lora_target_modules(getattr(pipe, lora_base_model), lora_target_modules),
lora_rank=lora_rank,
upcast_dtype=pipe.torch_dtype,
)

File diff suppressed because it is too large Load Diff

View File

@@ -2,6 +2,8 @@ from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
import torch
from ..core.device.npu_compatible_device import get_device_type
class DINOv3ImageEncoder(DINOv3ViTModel):
def __init__(self):
@@ -70,7 +72,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel):
}
)
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
inputs = self.processor(images=image, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device)
bool_masked_pos = None

View File

@@ -407,6 +407,7 @@ class Flux2AttnProcessor:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype)
hidden_states = attention_forward(
query,
key,
@@ -536,6 +537,7 @@ class Flux2ParallelSelfAttnProcessor:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype)
hidden_states = attention_forward(
query,
key,
@@ -823,7 +825,13 @@ class Flux2PosEmbed(nn.Module):
class Flux2TimestepGuidanceEmbeddings(nn.Module):
def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False):
def __init__(
self,
in_channels: int = 256,
embedding_dim: int = 6144,
bias: bool = False,
guidance_embeds: bool = True,
):
super().__init__()
self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
@@ -831,20 +839,24 @@ class Flux2TimestepGuidanceEmbeddings(nn.Module):
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
)
self.guidance_embedder = TimestepEmbedding(
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
)
if guidance_embeds:
self.guidance_embedder = TimestepEmbedding(
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
)
else:
self.guidance_embedder = None
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
time_guidance_emb = timesteps_emb + guidance_emb
return time_guidance_emb
if guidance is not None and self.guidance_embedder is not None:
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
time_guidance_emb = timesteps_emb + guidance_emb
return time_guidance_emb
else:
return timesteps_emb
class Flux2Modulation(nn.Module):
@@ -882,6 +894,7 @@ class Flux2DiT(torch.nn.Module):
axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
rope_theta: int = 2000,
eps: float = 1e-6,
guidance_embeds: bool = True,
):
super().__init__()
self.out_channels = out_channels or in_channels
@@ -892,7 +905,10 @@ class Flux2DiT(torch.nn.Module):
# 2. Combined timestep + guidance embedding
self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False
in_channels=timestep_guidance_channels,
embedding_dim=self.inner_dim,
bias=False,
guidance_embeds=guidance_embeds,
)
# 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
@@ -953,34 +969,9 @@ class Flux2DiT(torch.nn.Module):
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
) -> Union[torch.Tensor]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
Input `hidden_states`.
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
):
# 0. Handle input arguments
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
@@ -992,7 +983,9 @@ class Flux2DiT(torch.nn.Module):
# 1. Calculate timestep embedding and modulation parameters
timestep = timestep.to(hidden_states.dtype) * 1000
guidance = guidance.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
temb = self.time_guidance_embed(timestep, guidance)

View File

@@ -9,6 +9,7 @@ import numpy as np
import torch.nn.functional as F
from einops import rearrange, repeat
from .wan_video_dit import flash_attention
from ..core.device.npu_compatible_device import get_device_type
from ..core.gradient import gradient_checkpoint_forward
@@ -373,7 +374,7 @@ class FinalLayer_FP32(nn.Module):
B, N, C = x.shape
T, _, _ = latent_shape
with amp.autocast('cuda', dtype=torch.float32):
with amp.autocast(get_device_type(), dtype=torch.float32):
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
x = self.linear(x)
@@ -583,7 +584,7 @@ class LongCatSingleStreamBlock(nn.Module):
T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.
# compute modulation params in fp32
with amp.autocast(device_type='cuda', dtype=torch.float32):
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
shift_msa, scale_msa, gate_msa, \
shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
@@ -602,7 +603,7 @@ class LongCatSingleStreamBlock(nn.Module):
else:
x_s = attn_outputs
with amp.autocast(device_type='cuda', dtype=torch.float32):
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype)
@@ -615,7 +616,7 @@ class LongCatSingleStreamBlock(nn.Module):
# ffn with modulation
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
x_s = self.ffn(x_m)
with amp.autocast(device_type='cuda', dtype=torch.float32):
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype)
@@ -797,7 +798,7 @@ class LongCatVideoTransformer3DModel(torch.nn.Module):
hidden_states = self.x_embedder(hidden_states) # [B, N, C]
with amp.autocast(device_type='cuda', dtype=torch.float32):
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,371 @@
from dataclasses import dataclass
from typing import NamedTuple, Protocol, Tuple
import torch
from torch import nn
from enum import Enum
class VideoPixelShape(NamedTuple):
"""
Shape of the tensor representing the video pixel array. Assumes BGR channel format.
"""
batch: int
frames: int
height: int
width: int
fps: float
class SpatioTemporalScaleFactors(NamedTuple):
"""
Describes the spatiotemporal downscaling between decoded video space and
the corresponding VAE latent grid.
"""
time: int
width: int
height: int
@classmethod
def default(cls) -> "SpatioTemporalScaleFactors":
return cls(time=8, width=32, height=32)
VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default()
class VideoLatentShape(NamedTuple):
"""
Shape of the tensor representing video in VAE latent space.
The latent representation is a 5D tensor with dimensions ordered as
(batch, channels, frames, height, width). Spatial and temporal dimensions
are downscaled relative to pixel space according to the VAE's scale factors.
"""
batch: int
channels: int
frames: int
height: int
width: int
def to_torch_shape(self) -> torch.Size:
return torch.Size([self.batch, self.channels, self.frames, self.height, self.width])
@staticmethod
def from_torch_shape(shape: torch.Size) -> "VideoLatentShape":
return VideoLatentShape(
batch=shape[0],
channels=shape[1],
frames=shape[2],
height=shape[3],
width=shape[4],
)
def mask_shape(self) -> "VideoLatentShape":
return self._replace(channels=1)
@staticmethod
def from_pixel_shape(
shape: VideoPixelShape,
latent_channels: int = 128,
scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS,
) -> "VideoLatentShape":
frames = (shape.frames - 1) // scale_factors[0] + 1
height = shape.height // scale_factors[1]
width = shape.width // scale_factors[2]
return VideoLatentShape(
batch=shape.batch,
channels=latent_channels,
frames=frames,
height=height,
width=width,
)
def upscale(self, scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS) -> "VideoLatentShape":
return self._replace(
channels=3,
frames=(self.frames - 1) * scale_factors.time + 1,
height=self.height * scale_factors.height,
width=self.width * scale_factors.width,
)
class AudioLatentShape(NamedTuple):
"""
Shape of audio in VAE latent space: (batch, channels, frames, mel_bins).
mel_bins is the number of frequency bins from the mel-spectrogram encoding.
"""
batch: int
channels: int
frames: int
mel_bins: int
def to_torch_shape(self) -> torch.Size:
return torch.Size([self.batch, self.channels, self.frames, self.mel_bins])
def mask_shape(self) -> "AudioLatentShape":
return self._replace(channels=1, mel_bins=1)
@staticmethod
def from_torch_shape(shape: torch.Size) -> "AudioLatentShape":
return AudioLatentShape(
batch=shape[0],
channels=shape[1],
frames=shape[2],
mel_bins=shape[3],
)
@staticmethod
def from_duration(
batch: int,
duration: float,
channels: int = 8,
mel_bins: int = 16,
sample_rate: int = 16000,
hop_length: int = 160,
audio_latent_downsample_factor: int = 4,
) -> "AudioLatentShape":
latents_per_second = float(sample_rate) / float(hop_length) / float(audio_latent_downsample_factor)
return AudioLatentShape(
batch=batch,
channels=channels,
frames=round(duration * latents_per_second),
mel_bins=mel_bins,
)
@staticmethod
def from_video_pixel_shape(
shape: VideoPixelShape,
channels: int = 8,
mel_bins: int = 16,
sample_rate: int = 16000,
hop_length: int = 160,
audio_latent_downsample_factor: int = 4,
) -> "AudioLatentShape":
return AudioLatentShape.from_duration(
batch=shape.batch,
duration=float(shape.frames) / float(shape.fps),
channels=channels,
mel_bins=mel_bins,
sample_rate=sample_rate,
hop_length=hop_length,
audio_latent_downsample_factor=audio_latent_downsample_factor,
)
@dataclass(frozen=True)
class LatentState:
"""
State of latents during the diffusion denoising process.
Attributes:
latent: The current noisy latent tensor being denoised.
denoise_mask: Mask encoding the denoising strength for each token (1 = full denoising, 0 = no denoising).
positions: Positional indices for each latent element, used for positional embeddings.
clean_latent: Initial state of the latent before denoising, may include conditioning latents.
"""
latent: torch.Tensor
denoise_mask: torch.Tensor
positions: torch.Tensor
clean_latent: torch.Tensor
def clone(self) -> "LatentState":
return LatentState(
latent=self.latent.clone(),
denoise_mask=self.denoise_mask.clone(),
positions=self.positions.clone(),
clean_latent=self.clean_latent.clone(),
)
class NormType(Enum):
"""Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm)."""
GROUP = "group"
PIXEL = "pixel"
class PixelNorm(nn.Module):
"""
Per-pixel (per-location) RMS normalization layer.
For each element along the chosen dimension, this layer normalizes the tensor
by the root-mean-square of its values across that dimension:
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
"""
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
"""
Args:
dim: Dimension along which to compute the RMS (typically channels).
eps: Small constant added for numerical stability.
"""
super().__init__()
self.dim = dim
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply RMS normalization along the configured dimension.
"""
# Compute mean of squared values along `dim`, keep dimensions for broadcasting.
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
# Normalize by the root-mean-square (RMS).
rms = torch.sqrt(mean_sq + self.eps)
return x / rms
def build_normalization_layer(
in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP
) -> nn.Module:
"""
Create a normalization layer based on the normalization type.
Args:
in_channels: Number of input channels
num_groups: Number of groups for group normalization
normtype: Type of normalization: "group" or "pixel"
Returns:
A normalization layer
"""
if normtype == NormType.GROUP:
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if normtype == NormType.PIXEL:
return PixelNorm(dim=1, eps=1e-6)
raise ValueError(f"Invalid normalization type: {normtype}")
def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor:
"""Root-mean-square (RMS) normalize `x` over its last dimension.
Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized
shape and forwards `weight` and `eps`.
"""
return torch.nn.functional.rms_norm(x, (x.shape[-1],), weight=weight, eps=eps)
@dataclass(frozen=True)
class Modality:
"""
Input data for a single modality (video or audio) in the transformer.
Bundles the latent tokens, timestep embeddings, positional information,
and text conditioning context for processing by the diffusion transformer.
"""
latent: (
torch.Tensor
) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps
positions: (
torch.Tensor
) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens
context: torch.Tensor
enabled: bool = True
context_mask: torch.Tensor | None = None
def to_denoised(
sample: torch.Tensor,
velocity: torch.Tensor,
sigma: float | torch.Tensor,
calc_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Convert the sample and its denoising velocity to denoised sample.
Returns:
Denoised sample
"""
if isinstance(sigma, torch.Tensor):
sigma = sigma.to(calc_dtype)
return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype)
class Patchifier(Protocol):
"""
Protocol for patchifiers that convert latent tensors into patches and assemble them back.
"""
def patchify(
self,
latents: torch.Tensor,
) -> torch.Tensor:
...
"""
Convert latent tensors into flattened patch tokens.
Args:
latents: Latent tensor to patchify.
Returns:
Flattened patch tokens tensor.
"""
def unpatchify(
self,
latents: torch.Tensor,
output_shape: AudioLatentShape | VideoLatentShape,
) -> torch.Tensor:
"""
Converts latent tensors between spatio-temporal formats and flattened sequence representations.
Args:
latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`.
output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or
VideoLatentShape.
Returns:
Dense latent tensor restored from the flattened representation.
"""
@property
def patch_size(self) -> Tuple[int, int, int]:
...
"""
Returns the patch size as a tuple of (temporal, height, width) dimensions
"""
def get_patch_grid_bounds(
self,
output_shape: AudioLatentShape | VideoLatentShape,
device: torch.device | None = None,
) -> torch.Tensor:
...
"""
Compute metadata describing where each latent patch resides within the
grid specified by `output_shape`.
Args:
output_shape: Target grid layout for the patches.
device: Target device for the returned tensor.
Returns:
Tensor containing patch coordinate metadata such as spatial or temporal intervals.
"""
def get_pixel_coords(
latent_coords: torch.Tensor,
scale_factors: SpatioTemporalScaleFactors,
causal_fix: bool = False,
) -> torch.Tensor:
"""
Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
Args:
latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
per axis.
causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
that treat frame zero differently still yield non-negative timestamps.
"""
# Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
broadcast_shape = [1] * latent_coords.ndim
broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
# Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
pixel_coords = latent_coords * scale_tensor
if causal_fix:
# VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
# Shift and clamp to keep the first-frame timestamps causal and non-negative.
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
return pixel_coords

1452
diffsynth/models/ltx2_dit.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,366 @@
import torch
from transformers import Gemma3ForConditionalGeneration, Gemma3Config, AutoTokenizer
from .ltx2_dit import (LTXRopeType, generate_freq_grid_np, generate_freq_grid_pytorch, precompute_freqs_cis, Attention,
FeedForward)
from .ltx2_common import rms_norm
class LTX2TextEncoder(Gemma3ForConditionalGeneration):
def __init__(self):
config = Gemma3Config(
**{
"architectures": ["Gemma3ForConditionalGeneration"],
"boi_token_index": 255999,
"dtype": "bfloat16",
"eoi_token_index": 256000,
"eos_token_id": [1, 106],
"image_token_index": 262144,
"initializer_range": 0.02,
"mm_tokens_per_image": 256,
"model_type": "gemma3",
"text_config": {
"_sliding_window_pattern": 6,
"attention_bias": False,
"attention_dropout": 0.0,
"attn_logit_softcapping": None,
"cache_implementation": "hybrid",
"dtype": "bfloat16",
"final_logit_softcapping": None,
"head_dim": 256,
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 3840,
"initializer_range": 0.02,
"intermediate_size": 15360,
"layer_types": [
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention"
],
"max_position_embeddings": 131072,
"model_type": "gemma3_text",
"num_attention_heads": 16,
"num_hidden_layers": 48,
"num_key_value_heads": 8,
"query_pre_attn_scalar": 256,
"rms_norm_eps": 1e-06,
"rope_local_base_freq": 10000,
"rope_scaling": {
"factor": 8.0,
"rope_type": "linear"
},
"rope_theta": 1000000,
"sliding_window": 1024,
"sliding_window_pattern": 6,
"use_bidirectional_attention": False,
"use_cache": True,
"vocab_size": 262208
},
"transformers_version": "4.57.3",
"vision_config": {
"attention_dropout": 0.0,
"dtype": "bfloat16",
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": 896,
"intermediate_size": 4304,
"layer_norm_eps": 1e-06,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 27,
"patch_size": 14,
"vision_use_head": False
}
})
super().__init__(config)
class LTXVGemmaTokenizer:
"""
Tokenizer wrapper for Gemma models compatible with LTXV processes.
This class wraps HuggingFace's `AutoTokenizer` for use with Gemma text encoders,
ensuring correct settings and output formatting for downstream consumption.
"""
def __init__(self, tokenizer_path: str, max_length: int = 1024):
"""
Initialize the tokenizer.
Args:
tokenizer_path (str): Path to the pretrained tokenizer files or model directory.
max_length (int, optional): Max sequence length for encoding. Defaults to 256.
"""
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, local_files_only=True, model_max_length=max_length
)
# Gemma expects left padding for chat-style prompts; for plain text it doesn't matter much.
self.tokenizer.padding_side = "left"
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.max_length = max_length
def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> dict[str, list[tuple[int, int]]]:
"""
Tokenize the given text and return token IDs and attention weights.
Args:
text (str): The input string to tokenize.
return_word_ids (bool, optional): If True, includes the token's position (index) in the output tuples.
If False (default), omits the indices.
Returns:
dict[str, list[tuple[int, int]]] OR dict[str, list[tuple[int, int, int]]]:
A dictionary with a "gemma" key mapping to:
- a list of (token_id, attention_mask) tuples if return_word_ids is False;
- a list of (token_id, attention_mask, index) tuples if return_word_ids is True.
Example:
>>> tokenizer = LTXVGemmaTokenizer("path/to/tokenizer", max_length=8)
>>> tokenizer.tokenize_with_weights("hello world")
{'gemma': [(1234, 1), (5678, 1), (2, 0), ...]}
"""
text = text.strip()
encoded = self.tokenizer(
text,
padding="max_length",
max_length=self.max_length,
truncation=True,
return_tensors="pt",
)
input_ids = encoded.input_ids
attention_mask = encoded.attention_mask
tuples = [
(token_id, attn, i) for i, (token_id, attn) in enumerate(zip(input_ids[0], attention_mask[0], strict=True))
]
out = {"gemma": tuples}
if not return_word_ids:
# Return only (token_id, attention_mask) pairs, omitting token position
out = {k: [(t, w) for t, w, _ in v] for k, v in out.items()}
return out
class GemmaFeaturesExtractorProjLinear(torch.nn.Module):
"""
Feature extractor module for Gemma models.
This module applies a single linear projection to the input tensor.
It expects a flattened feature tensor of shape (batch_size, 3840*49).
The linear layer maps this to a (batch_size, 3840) embedding.
Attributes:
aggregate_embed (torch.nn.Linear): Linear projection layer.
"""
def __init__(self) -> None:
"""
Initialize the GemmaFeaturesExtractorProjLinear module.
The input dimension is expected to be 3840 * 49, and the output is 3840.
"""
super().__init__()
self.aggregate_embed = torch.nn.Linear(3840 * 49, 3840, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the feature extractor.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, 3840 * 49).
Returns:
torch.Tensor: Output tensor of shape (batch_size, 3840).
"""
return self.aggregate_embed(x)
class _BasicTransformerBlock1D(torch.nn.Module):
def __init__(
self,
dim: int,
heads: int,
dim_head: int,
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
):
super().__init__()
self.attn1 = Attention(
query_dim=dim,
heads=heads,
dim_head=dim_head,
rope_type=rope_type,
)
self.ff = FeedForward(
dim,
dim_out=dim,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
pe: torch.Tensor | None = None,
) -> torch.Tensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Normalization Before Self-Attention
norm_hidden_states = rms_norm(hidden_states)
norm_hidden_states = norm_hidden_states.squeeze(1)
# 2. Self-Attention
attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe)
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 3. Normalization before Feed-Forward
norm_hidden_states = rms_norm(hidden_states)
# 4. Feed-forward
ff_output = self.ff(norm_hidden_states)
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class Embeddings1DConnector(torch.nn.Module):
"""
Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or
other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can
substitute padded positions with learnable registers. The module is highly configurable for head size, number of
layers, and register usage.
Args:
attention_head_dim (int): Dimension of each attention head (default=128).
num_attention_heads (int): Number of attention heads (default=30).
num_layers (int): Number of transformer layers (default=2).
positional_embedding_theta (float): Scaling factor for position embedding (default=10000.0).
positional_embedding_max_pos (list[int] | None): Max positions for positional embeddings (default=[1]).
causal_temporal_positioning (bool): If True, uses causal attention (default=False).
num_learnable_registers (int | None): Number of learnable registers to replace padded tokens. If None, disables
register replacement. (default=128)
rope_type (LTXRopeType): The RoPE variant to use (default=DEFAULT_ROPE_TYPE).
double_precision_rope (bool): Use double precision rope calculation (default=False).
"""
_supports_gradient_checkpointing = True
def __init__(
self,
attention_head_dim: int = 128,
num_attention_heads: int = 30,
num_layers: int = 2,
positional_embedding_theta: float = 10000.0,
positional_embedding_max_pos: list[int] | None = [4096],
causal_temporal_positioning: bool = False,
num_learnable_registers: int | None = 128,
rope_type: LTXRopeType = LTXRopeType.SPLIT,
double_precision_rope: bool = True,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.inner_dim = num_attention_heads * attention_head_dim
self.causal_temporal_positioning = causal_temporal_positioning
self.positional_embedding_theta = positional_embedding_theta
self.positional_embedding_max_pos = (
positional_embedding_max_pos if positional_embedding_max_pos is not None else [1]
)
self.rope_type = rope_type
self.double_precision_rope = double_precision_rope
self.transformer_1d_blocks = torch.nn.ModuleList(
[
_BasicTransformerBlock1D(
dim=self.inner_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
rope_type=rope_type,
)
for _ in range(num_layers)
]
)
self.num_learnable_registers = num_learnable_registers
if self.num_learnable_registers:
self.learnable_registers = torch.nn.Parameter(
torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0
)
def _replace_padded_with_learnable_registers(
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.shape[1] % self.num_learnable_registers == 0, (
f"Hidden states sequence length {hidden_states.shape[1]} must be divisible by num_learnable_registers "
f"{self.num_learnable_registers}."
)
num_registers_duplications = hidden_states.shape[1] // self.num_learnable_registers
learnable_registers = torch.tile(self.learnable_registers, (num_registers_duplications, 1))
attention_mask_binary = (attention_mask.squeeze(1).squeeze(1).unsqueeze(-1) >= -9000.0).int()
non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :]
non_zero_nums = non_zero_hidden_states.shape[1]
pad_length = hidden_states.shape[1] - non_zero_nums
adjusted_hidden_states = torch.nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0)
flipped_mask = torch.flip(attention_mask_binary, dims=[1])
hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers
attention_mask = torch.full_like(
attention_mask,
0.0,
dtype=attention_mask.dtype,
device=attention_mask.device,
)
return hidden_states, attention_mask
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass of Embeddings1DConnector.
Args:
hidden_states (torch.Tensor): Input tensor of embeddings (shape [batch, seq_len, feature_dim]).
attention_mask (torch.Tensor|None): Optional mask for valid tokens (shape compatible with hidden_states).
Returns:
tuple[torch.Tensor, torch.Tensor]: Processed features and the corresponding (possibly modified) mask.
"""
if self.num_learnable_registers:
hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask)
indices_grid = torch.arange(hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device)
indices_grid = indices_grid[None, None, :]
freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch
freqs_cis = precompute_freqs_cis(
indices_grid=indices_grid,
dim=self.inner_dim,
out_dtype=hidden_states.dtype,
theta=self.positional_embedding_theta,
max_pos=self.positional_embedding_max_pos,
num_attention_heads=self.num_attention_heads,
rope_type=self.rope_type,
freq_grid_generator=freq_grid_generator,
)
for block in self.transformer_1d_blocks:
hidden_states = block(hidden_states, attention_mask=attention_mask, pe=freqs_cis)
hidden_states = rms_norm(hidden_states)
return hidden_states, attention_mask
class LTX2TextEncoderPostModules(torch.nn.Module):
def __init__(self,):
super().__init__()
self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear()
self.embeddings_connector = Embeddings1DConnector()
self.audio_embeddings_connector = Embeddings1DConnector()

View File

@@ -0,0 +1,313 @@
import math
from typing import Optional, Tuple
import torch
from einops import rearrange
import torch.nn.functional as F
from .ltx2_video_vae import LTX2VideoEncoder
class PixelShuffleND(torch.nn.Module):
"""
N-dimensional pixel shuffle operation for upsampling tensors.
Args:
dims (int): Number of dimensions to apply pixel shuffle to.
- 1: Temporal (e.g., frames)
- 2: Spatial (e.g., height and width)
- 3: Spatiotemporal (e.g., depth, height, width)
upscale_factors (tuple[int, int, int], optional): Upscaling factors for each dimension.
For dims=1, only the first value is used.
For dims=2, the first two values are used.
For dims=3, all three values are used.
The input tensor is rearranged so that the channel dimension is split into
smaller channels and upscaling factors, and the upscaling factors are moved
into the corresponding spatial/temporal dimensions.
Note:
This operation is equivalent to the patchifier operation in for the models. Consider
using this class instead.
"""
def __init__(self, dims: int, upscale_factors: tuple[int, int, int] = (2, 2, 2)):
super().__init__()
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
self.dims = dims
self.upscale_factors = upscale_factors
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.dims == 3:
return rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.upscale_factors[0],
p2=self.upscale_factors[1],
p3=self.upscale_factors[2],
)
elif self.dims == 2:
return rearrange(
x,
"b (c p1 p2) h w -> b c (h p1) (w p2)",
p1=self.upscale_factors[0],
p2=self.upscale_factors[1],
)
elif self.dims == 1:
return rearrange(
x,
"b (c p1) f h w -> b c (f p1) h w",
p1=self.upscale_factors[0],
)
else:
raise ValueError(f"Unsupported dims: {self.dims}")
class ResBlock(torch.nn.Module):
"""
Residual block with two convolutional layers, group normalization, and SiLU activation.
Args:
channels (int): Number of input and output channels.
mid_channels (Optional[int]): Number of channels in the intermediate convolution layer. Defaults to `channels`
if not specified.
dims (int): Dimensionality of the convolution (2 for Conv2d, 3 for Conv3d). Defaults to 3.
"""
def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3):
super().__init__()
if mid_channels is None:
mid_channels = channels
conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
self.conv1 = conv(channels, mid_channels, kernel_size=3, padding=1)
self.norm1 = torch.nn.GroupNorm(32, mid_channels)
self.conv2 = conv(mid_channels, channels, kernel_size=3, padding=1)
self.norm2 = torch.nn.GroupNorm(32, channels)
self.activation = torch.nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.conv1(x)
x = self.norm1(x)
x = self.activation(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.activation(x + residual)
return x
class BlurDownsample(torch.nn.Module):
"""
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel.
Applies only on H,W. Works for dims=2 or dims=3 (per-frame).
"""
def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None:
super().__init__()
assert dims in (2, 3)
assert isinstance(stride, int)
assert stride >= 1
assert kernel_size >= 3
assert kernel_size % 2 == 1
self.dims = dims
self.stride = stride
self.kernel_size = kernel_size
# 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from
# the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and
# provides a smooth approximation of a Gaussian filter (often called a "binomial filter").
# The 2D kernel is constructed as the outer product and normalized.
k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)])
k2d = k[:, None] @ k[None, :]
k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size)
self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.stride == 1:
return x
if self.dims == 2:
return self._apply_2d(x)
else:
# dims == 3: apply per-frame on H,W
b, _, f, _, _ = x.shape
x = rearrange(x, "b c f h w -> (b f) c h w")
x = self._apply_2d(x)
h2, w2 = x.shape[-2:]
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2)
return x
def _apply_2d(self, x2d: torch.Tensor) -> torch.Tensor:
c = x2d.shape[1]
weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
x2d = F.conv2d(x2d, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
return x2d
def _rational_for_scale(scale: float) -> Tuple[int, int]:
mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)}
if float(scale) not in mapping:
raise ValueError(f"Unsupported scale {scale}. Choose from {list(mapping.keys())}")
return mapping[float(scale)]
class SpatialRationalResampler(torch.nn.Module):
"""
Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
downsample by 'den' using fixed blur + stride. Operates on H,W only.
For dims==3, work per-frame for spatial scaling (temporal axis untouched).
Args:
mid_channels (`int`): Number of intermediate channels for the convolution layer
scale (`float`): Spatial scaling factor. Supported values are:
- 0.75: Downsample by 3/4 (reduce spatial size)
- 1.5: Upsample by 3/2 (increase spatial size)
- 2.0: Upsample by 2x (double spatial size)
- 4.0: Upsample by 4x (quadruple spatial size)
Any other value will raise a ValueError.
"""
def __init__(self, mid_channels: int, scale: float):
super().__init__()
self.scale = float(scale)
self.num, self.den = _rational_for_scale(self.scale)
self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1)
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
self.blur_down = BlurDownsample(dims=2, stride=self.den)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, _, f, _, _ = x.shape
x = rearrange(x, "b c f h w -> (b f) c h w")
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.blur_down(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
return x
class LTX2LatentUpsampler(torch.nn.Module):
"""
Model to upsample VAE latents spatially and/or temporally.
Args:
in_channels (`int`): Number of channels in the input latent
mid_channels (`int`): Number of channels in the middle layers
num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
dims (`int`): Number of dimensions for convolutions (2 or 3)
spatial_upsample (`bool`): Whether to spatially upsample the latent
temporal_upsample (`bool`): Whether to temporally upsample the latent
spatial_scale (`float`): Scale factor for spatial upsampling
rational_resampler (`bool`): Whether to use a rational resampler for spatial upsampling
"""
def __init__(
self,
in_channels: int = 128,
mid_channels: int = 1024,
num_blocks_per_stage: int = 4,
dims: int = 3,
spatial_upsample: bool = True,
temporal_upsample: bool = False,
spatial_scale: float = 2.0,
rational_resampler: bool = True,
):
super().__init__()
self.in_channels = in_channels
self.mid_channels = mid_channels
self.num_blocks_per_stage = num_blocks_per_stage
self.dims = dims
self.spatial_upsample = spatial_upsample
self.temporal_upsample = temporal_upsample
self.spatial_scale = float(spatial_scale)
self.rational_resampler = rational_resampler
conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
self.initial_conv = conv(in_channels, mid_channels, kernel_size=3, padding=1)
self.initial_norm = torch.nn.GroupNorm(32, mid_channels)
self.initial_activation = torch.nn.SiLU()
self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])
if spatial_upsample and temporal_upsample:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(3),
)
elif spatial_upsample:
if rational_resampler:
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=self.spatial_scale)
else:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(2),
)
elif temporal_upsample:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(1),
)
else:
raise ValueError("Either spatial_upsample or temporal_upsample must be True")
self.post_upsample_res_blocks = torch.nn.ModuleList(
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
)
self.final_conv = conv(mid_channels, in_channels, kernel_size=3, padding=1)
def forward(self, latent: torch.Tensor) -> torch.Tensor:
b, _, f, _, _ = latent.shape
if self.dims == 2:
x = rearrange(latent, "b c f h w -> (b f) c h w")
x = self.initial_conv(x)
x = self.initial_norm(x)
x = self.initial_activation(x)
for block in self.res_blocks:
x = block(x)
x = self.upsampler(x)
for block in self.post_upsample_res_blocks:
x = block(x)
x = self.final_conv(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
else:
x = self.initial_conv(latent)
x = self.initial_norm(x)
x = self.initial_activation(x)
for block in self.res_blocks:
x = block(x)
if self.temporal_upsample:
x = self.upsampler(x)
# remove the first frame after upsampling.
# This is done because the first frame encodes one pixel frame.
x = x[:, :, 1:, :, :]
elif isinstance(self.upsampler, SpatialRationalResampler):
x = self.upsampler(x)
else:
x = rearrange(x, "b c f h w -> (b f) c h w")
x = self.upsampler(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
for block in self.post_upsample_res_blocks:
x = block(x)
x = self.final_conv(x)
return x
def upsample_video(latent: torch.Tensor, video_encoder: LTX2VideoEncoder, upsampler: "LTX2LatentUpsampler") -> torch.Tensor:
"""
Apply upsampling to the latent representation using the provided upsampler,
with normalization and un-normalization based on the video encoder's per-channel statistics.
Args:
latent: Input latent tensor of shape [B, C, F, H, W].
video_encoder: VideoEncoder with per_channel_statistics for normalization.
upsampler: LTX2LatentUpsampler module to perform upsampling.
Returns:
torch.Tensor: Upsampled and re-normalized latent tensor.
"""
latent = video_encoder.per_channel_statistics.un_normalize(latent)
latent = upsampler(latent)
latent = video_encoder.per_channel_statistics.normalize(latent)
return latent

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
from ..core.loader import load_model, hash_model_file
from ..core.vram import AutoWrappedModule
from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS
from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS, VERSION_CHECKER_MAPS
import importlib, json, torch
@@ -22,14 +22,15 @@ class ModelPool:
def fetch_module_map(self, model_class, vram_config):
if self.need_to_enable_vram_management(vram_config):
if model_class in VRAM_MANAGEMENT_MODULE_MAPS:
module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in VRAM_MANAGEMENT_MODULE_MAPS[model_class].items()}
vram_module_map = VRAM_MANAGEMENT_MODULE_MAPS[model_class] if model_class not in VERSION_CHECKER_MAPS else VERSION_CHECKER_MAPS[model_class]()
module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in vram_module_map.items()}
else:
module_map = {self.import_model_class(model_class): AutoWrappedModule}
else:
module_map = None
return module_map
def load_model_file(self, config, path, vram_config, vram_limit=None):
def load_model_file(self, config, path, vram_config, vram_limit=None, state_dict=None):
model_class = self.import_model_class(config["model_class"])
model_config = config.get("extra_kwargs", {})
if "state_dict_converter" in config:
@@ -43,6 +44,7 @@ class ModelPool:
state_dict_converter,
use_disk_map=True,
vram_config=vram_config, module_map=module_map, vram_limit=vram_limit,
state_dict=state_dict,
)
return model
@@ -59,7 +61,7 @@ class ModelPool:
}
return vram_config
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False):
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False, state_dict=None):
print(f"Loading models from: {json.dumps(path, indent=4)}")
if vram_config is None:
vram_config = self.default_vram_config()
@@ -67,7 +69,7 @@ class ModelPool:
loaded = False
for config in MODEL_CONFIGS:
if config["model_hash"] == model_hash:
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit)
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit, state_dict=state_dict)
if clear_parameters: self.clear_parameters(model)
self.model.append(model)
model_name = config["model_name"]

View File

@@ -583,7 +583,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache
is_compileable = is_compileable and not self.generation_config.disable_compile
if is_compileable and (
self.device.type == "cuda" or generation_config.compile_config._compile_all_devices
self.device.type in ["cuda", "npu"] or generation_config.compile_config._compile_all_devices
):
os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = self.get_compiled_call(generation_config.compile_config)

View File

@@ -2,6 +2,8 @@ from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer,
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
import torch
from diffsynth.core.device.npu_compatible_device import get_device_type
class Siglip2ImageEncoder(SiglipVisionTransformer):
def __init__(self):
@@ -47,7 +49,7 @@ class Siglip2ImageEncoder(SiglipVisionTransformer):
}
)
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"]
pixel_values = pixel_values.to(device=device, dtype=torch_dtype)
output_attentions = False

View File

@@ -1,10 +1,11 @@
import torch
from typing import Optional, Union
from .qwen_image_text_encoder import QwenImageTextEncoder
from ..core.device.npu_compatible_device import get_device_type, get_torch_device
class Step1xEditEmbedder(torch.nn.Module):
def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device="cuda"):
def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device=get_device_type()):
super().__init__()
self.max_length = max_length
self.dtype = dtype
@@ -77,13 +78,13 @@ User Prompt:'''
self.max_length,
self.model.config.hidden_size,
dtype=torch.bfloat16,
device=torch.cuda.current_device(),
device=get_torch_device().current_device(),
)
masks = torch.zeros(
len(text_list),
self.max_length,
dtype=torch.long,
device=torch.cuda.current_device(),
device=get_torch_device().current_device(),
)
def split_string(s):
@@ -158,7 +159,7 @@ User Prompt:'''
else:
token_list.append(token_each)
new_txt_ids = torch.cat(token_list, dim=1).to("cuda")
new_txt_ids = torch.cat(token_list, dim=1).to(get_device_type())
new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
@@ -167,15 +168,15 @@ User Prompt:'''
inputs.input_ids = (
torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
.unsqueeze(0)
.to("cuda")
.to(get_device_type())
)
inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
inputs.attention_mask = (inputs.input_ids > 0).long().to(get_device_type())
outputs = self.model_forward(
self.model,
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
pixel_values=inputs.pixel_values.to("cuda"),
image_grid_thw=inputs.image_grid_thw.to("cuda"),
pixel_values=inputs.pixel_values.to(get_device_type()),
image_grid_thw=inputs.image_grid_thw.to(get_device_type()),
output_hidden_states=True,
)
@@ -188,7 +189,7 @@ User Prompt:'''
masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
(min(self.max_length, emb.shape[1] - 217)),
dtype=torch.long,
device=torch.cuda.current_device(),
device=get_torch_device().current_device(),
)
return embs, masks

View File

@@ -5,6 +5,8 @@ import math
from typing import Tuple, Optional
from einops import rearrange
from .wan_video_camera_controller import SimpleAdapter
from ..core.gradient import gradient_checkpoint_forward
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
@@ -92,6 +94,7 @@ def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2))
freqs = freqs.to(torch.complex64) if freqs.device.type == "npu" else freqs
x_out = torch.view_as_real(x_out * freqs).flatten(2)
return x_out.to(x.dtype)
@@ -377,27 +380,15 @@ class WanModel(torch.nn.Module):
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.blocks:
if self.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
if self.training:
x = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x, context, t_mod, freqs
)
else:
x = block(x, context, t_mod, freqs)

View File

@@ -4,6 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d
from ..core.gradient import gradient_checkpoint_forward
def torch_dfs(model: nn.Module, parent_name='root'):
@@ -545,46 +546,19 @@ class WanS2VModel(torch.nn.Module):
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block_id, block in enumerate(self.blocks):
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x,
context,
t_mod,
seq_len_x,
pre_compute_freqs[0],
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
x,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x,
context,
t_mod,
seq_len_x,
pre_compute_freqs[0],
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
x,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)
x = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x, context, t_mod, seq_len_x, pre_compute_freqs[0]
)
x = gradient_checkpoint_forward(
lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x),
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x
)
x = x[:, :seq_len_x]
x = self.head(x, t[:-1])

View File

@@ -1,6 +1,6 @@
import torch
from .wan_video_dit import DiTBlock
from ..core.gradient import gradient_checkpoint_forward
class VaceWanAttentionBlock(DiTBlock):
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
@@ -62,26 +62,13 @@ class VaceWanModel(torch.nn.Module):
dim=1) for u in c
])
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.vace_blocks:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
c = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
c, x, context, t_mod, freqs,
use_reentrant=False,
)
elif use_gradient_checkpointing:
c = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
c, x, context, t_mod, freqs,
use_reentrant=False,
)
else:
c = block(c, x, context, t_mod, freqs)
c = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
c, x, context, t_mod, freqs
)
hints = torch.unbind(c)[:-1]
return hints

View File

@@ -171,7 +171,7 @@ class Resample(nn.Module):
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
return x, feat_cache, feat_idx
def init_weight(self, conv):
conv_weight = conv.weight
@@ -298,7 +298,7 @@ class ResidualBlock(nn.Module):
feat_idx[0] += 1
else:
x = layer(x)
return x + h
return x + h, feat_cache, feat_idx
class AttentionBlock(nn.Module):
@@ -469,9 +469,9 @@ class Down_ResidualBlock(nn.Module):
def forward(self, x, feat_cache=None, feat_idx=[0]):
x_copy = x.clone()
for module in self.downsamples:
x = module(x, feat_cache, feat_idx)
x, feat_cache, feat_idx = module(x, feat_cache, feat_idx)
return x + self.avg_shortcut(x_copy)
return x + self.avg_shortcut(x_copy), feat_cache, feat_idx
class Up_ResidualBlock(nn.Module):
@@ -506,12 +506,12 @@ class Up_ResidualBlock(nn.Module):
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
x_main = x.clone()
for module in self.upsamples:
x_main = module(x_main, feat_cache, feat_idx)
x_main, feat_cache, feat_idx = module(x_main, feat_cache, feat_idx)
if self.avg_shortcut is not None:
x_shortcut = self.avg_shortcut(x, first_chunk)
return x_main + x_shortcut
return x_main + x_shortcut, feat_cache, feat_idx
else:
return x_main
return x_main, feat_cache, feat_idx
class Encoder3d(nn.Module):
@@ -586,14 +586,14 @@ class Encoder3d(nn.Module):
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle:
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
@@ -614,7 +614,7 @@ class Encoder3d(nn.Module):
feat_idx[0] += 1
else:
x = layer(x)
return x
return x, feat_cache, feat_idx
class Encoder3d_38(nn.Module):
@@ -698,14 +698,14 @@ class Encoder3d_38(nn.Module):
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
@@ -730,7 +730,7 @@ class Encoder3d_38(nn.Module):
else:
x = layer(x)
return x
return x, feat_cache, feat_idx
class Decoder3d(nn.Module):
@@ -807,14 +807,14 @@ class Decoder3d(nn.Module):
## middle
for layer in self.middle:
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
@@ -835,7 +835,7 @@ class Decoder3d(nn.Module):
feat_idx[0] += 1
else:
x = layer(x)
return x
return x, feat_cache, feat_idx
@@ -906,14 +906,14 @@ class Decoder3d_38(nn.Module):
for layer in self.middle:
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx, first_chunk)
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx, first_chunk)
else:
x = layer(x)
@@ -937,7 +937,7 @@ class Decoder3d_38(nn.Module):
feat_idx[0] += 1
else:
x = layer(x)
return x
return x, feat_cache, feat_idx
def count_conv3d(model):
@@ -990,11 +990,11 @@ class VideoVAE_(nn.Module):
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(x[:, :, :1, :, :],
out, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
out_, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = torch.cat([out, out_], 2)
@@ -1023,11 +1023,11 @@ class VideoVAE_(nn.Module):
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i:i + 1, :, :],
out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else:
out_ = self.decoder(x[:, :, i:i + 1, :, :],
out_, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2) # may add tensor offload
@@ -1303,11 +1303,11 @@ class VideoVAE38_(VideoVAE_):
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(x[:, :, :1, :, :],
out, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
out_, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = torch.cat([out, out_], 2)
@@ -1337,12 +1337,12 @@ class VideoVAE38_(VideoVAE_):
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i:i + 1, :, :],
out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
first_chunk=True)
else:
out_ = self.decoder(x[:, :, i:i + 1, :, :],
out_, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)

View File

@@ -6,8 +6,9 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.nn import RMSNorm
from .general_modules import RMSNorm
from ..core.attention import attention_forward
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type
from ..core.gradient import gradient_checkpoint_forward
@@ -39,7 +40,7 @@ class TimestepEmbedder(nn.Module):
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
with torch.amp.autocast("cuda", enabled=False):
with torch.amp.autocast(get_device_type(), enabled=False):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
@@ -87,6 +88,14 @@ class Attention(torch.nn.Module):
self.norm_q = RMSNorm(head_dim, eps=1e-5)
self.norm_k = RMSNorm(head_dim, eps=1e-5)
# Apply RoPE
def apply_rotary_emb(self, x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast(get_device_type(), enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in) # todo
def forward(self, hidden_states, freqs_cis, attention_mask):
query = self.to_q(hidden_states)
key = self.to_k(hidden_states)
@@ -102,17 +111,9 @@ class Attention(torch.nn.Module):
if self.norm_k is not None:
key = self.norm_k(key)
# Apply RoPE
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast("cuda", enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in) # todo
if freqs_cis is not None:
query = apply_rotary_emb(query, freqs_cis)
key = apply_rotary_emb(key, freqs_cis)
query = self.apply_rotary_emb(query, freqs_cis)
key = self.apply_rotary_emb(key, freqs_cis)
# Cast to correct dtype
dtype = query.dtype
@@ -315,7 +316,10 @@ class RopeEmbedder:
result = []
for i in range(len(self.axes_dims)):
index = ids[:, i]
result.append(self.freqs_cis[i][index])
if IS_NPU_AVAILABLE:
result.append(torch.index_select(self.freqs_cis[i], 0, index))
else:
result.append(self.freqs_cis[i][index])
return torch.cat(result, dim=-1)

View File

@@ -3,38 +3,101 @@ import torch
class ZImageTextEncoder(torch.nn.Module):
def __init__(self):
def __init__(self, model_size="4B"):
super().__init__()
config = Qwen3Config(**{
"architectures": [
"Qwen3ForCausalLM"
],
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 2560,
"initializer_range": 0.02,
"intermediate_size": 9728,
"max_position_embeddings": 40960,
"max_window_layers": 36,
"model_type": "qwen3",
"num_attention_heads": 32,
"num_hidden_layers": 36,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_scaling": None,
"rope_theta": 1000000,
"sliding_window": None,
"tie_word_embeddings": True,
"torch_dtype": "bfloat16",
"transformers_version": "4.51.0",
"use_cache": True,
"use_sliding_window": False,
"vocab_size": 151936
})
config_dict = {
"0.6B": Qwen3Config(**{
"architectures": [
"Qwen3ForCausalLM"
],
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 3072,
"max_position_embeddings": 40960,
"max_window_layers": 28,
"model_type": "qwen3",
"num_attention_heads": 16,
"num_hidden_layers": 28,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_scaling": None,
"rope_theta": 1000000,
"sliding_window": None,
"tie_word_embeddings": True,
"torch_dtype": "bfloat16",
"transformers_version": "4.51.0",
"use_cache": True,
"use_sliding_window": False,
"vocab_size": 151936
}),
"4B": Qwen3Config(**{
"architectures": [
"Qwen3ForCausalLM"
],
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 2560,
"initializer_range": 0.02,
"intermediate_size": 9728,
"max_position_embeddings": 40960,
"max_window_layers": 36,
"model_type": "qwen3",
"num_attention_heads": 32,
"num_hidden_layers": 36,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_scaling": None,
"rope_theta": 1000000,
"sliding_window": None,
"tie_word_embeddings": True,
"torch_dtype": "bfloat16",
"transformers_version": "4.51.0",
"use_cache": True,
"use_sliding_window": False,
"vocab_size": 151936
}),
"8B": Qwen3Config(**{
"architectures": [
"Qwen3ForCausalLM"
],
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"dtype": "bfloat16",
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 12288,
"max_position_embeddings": 40960,
"max_window_layers": 36,
"model_type": "qwen3",
"num_attention_heads": 32,
"num_hidden_layers": 36,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_scaling": None,
"rope_theta": 1000000,
"sliding_window": None,
"tie_word_embeddings": False,
"transformers_version": "4.56.1",
"use_cache": True,
"use_sliding_window": False,
"vocab_size": 151936
})
}
config = config_dict[model_size]
self.model = Qwen3Model(config)
def forward(self, *args, **kwargs):

View File

@@ -0,0 +1,263 @@
import torch, math
from PIL import Image
from typing import Union
from tqdm import tqdm
from einops import rearrange
import numpy as np
from math import prod
from transformers import AutoTokenizer
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
from ..utils.lora.merge import merge_lora
from ..models.anima_dit import AnimaDiT
from ..models.z_image_text_encoder import ZImageTextEncoder
from ..models.wan_video_vae import WanVideoVAE
class AnimaImagePipeline(BasePipeline):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
)
self.scheduler = FlowMatchScheduler("Z-Image")
self.text_encoder: ZImageTextEncoder = None
self.dit: AnimaDiT = None
self.vae: WanVideoVAE = None
self.tokenizer: AutoTokenizer = None
self.tokenizer_t5xxl: AutoTokenizer = None
self.in_iteration_models = ("dit",)
self.units = [
AnimaUnit_ShapeChecker(),
AnimaUnit_NoiseInitializer(),
AnimaUnit_InputImageEmbedder(),
AnimaUnit_PromptEmbedder(),
]
self.model_fn = model_fn_anima
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
tokenizer_t5xxl_config: ModelConfig = ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"),
vram_limit: float = None,
):
# Initialize pipeline
pipe = AnimaImagePipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
pipe.dit = model_pool.fetch_model("anima_dit")
pipe.vae = model_pool.fetch_model("wan_video_vae")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
if tokenizer_t5xxl_config is not None:
tokenizer_t5xxl_config.download_if_necessary()
pipe.tokenizer_t5xxl = AutoTokenizer.from_pretrained(tokenizer_t5xxl_config.path)
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 4.0,
# Image
input_image: Image.Image = None,
denoising_strength: float = 1.0,
# Shape
height: int = 1024,
width: int = 1024,
# Randomness
seed: int = None,
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 30,
sigma_shift: float = None,
# Progress bar
progress_bar_cmd = tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
# Parameters
inputs_posi = {
"prompt": prompt,
}
inputs_nega = {
"negative_prompt": negative_prompt,
}
inputs_shared = {
"cfg_scale": cfg_scale,
"input_image": input_image, "denoising_strength": denoising_strength,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"num_inference_steps": num_inference_steps,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
# Decode
self.load_models_to_device(['vae'])
image = self.vae.decode(inputs_shared["latents"].unsqueeze(2), device=self.device).squeeze(2)
image = self.vae_output_to_image(image)
self.load_models_to_device([])
return image
class AnimaUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width"),
output_params=("height", "width"),
)
def process(self, pipe: AnimaImagePipeline, height, width):
height, width = pipe.check_resize_height_width(height, width)
return {"height": height, "width": width}
class AnimaUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe: AnimaImagePipeline, height, width, seed, rand_device):
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
class AnimaUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",)
)
def process(self, pipe: AnimaImagePipeline, input_image, noise):
if input_image is None:
return {"latents": noise, "input_latents": None}
pipe.load_models_to_device(['vae'])
if isinstance(input_image, list):
input_latents = []
for image in input_image:
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
input_latents.append(pipe.vae.encode(image))
input_latents = torch.concat(input_latents, dim=0)
else:
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
input_latents = pipe.vae.encode(image.unsqueeze(2), device=pipe.device).squeeze(2)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents, "input_latents": input_latents}
class AnimaUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_emb",),
onload_model_names=("text_encoder",)
)
def encode_prompt(
self,
pipe: AnimaImagePipeline,
prompt,
device = None,
max_sequence_length: int = 512,
):
if isinstance(prompt, str):
prompt = [prompt]
text_inputs = pipe.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device).bool()
prompt_embeds = pipe.text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
).hidden_states[-1]
t5xxl_text_inputs = pipe.tokenizer_t5xxl(
prompt,
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
t5xxl_ids = t5xxl_text_inputs.input_ids.to(device)
return prompt_embeds.to(pipe.torch_dtype), t5xxl_ids
def process(self, pipe: AnimaImagePipeline, prompt):
pipe.load_models_to_device(self.onload_model_names)
prompt_embeds, t5xxl_ids = self.encode_prompt(pipe, prompt, pipe.device)
return {"prompt_emb": prompt_embeds, "t5xxl_ids": t5xxl_ids}
def model_fn_anima(
dit: AnimaDiT = None,
latents=None,
timestep=None,
prompt_emb=None,
t5xxl_ids=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs
):
latents = latents.unsqueeze(2)
timestep = timestep / 1000
model_output = dit(
x=latents,
timesteps=timestep,
context=prompt_emb,
t5xxl_ids=t5xxl_ids,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
model_output = model_output.squeeze(2)
return model_output

View File

@@ -1,4 +1,4 @@
import torch, math
import torch, math, torchvision
from PIL import Image
from typing import Union
from tqdm import tqdm
@@ -6,25 +6,28 @@ from einops import rearrange
import numpy as np
from typing import Union, List, Optional, Tuple
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
from transformers import AutoProcessor
from transformers import AutoProcessor, AutoTokenizer
from ..models.flux2_text_encoder import Flux2TextEncoder
from ..models.flux2_dit import Flux2DiT
from ..models.flux2_vae import Flux2VAE
from ..models.z_image_text_encoder import ZImageTextEncoder
class Flux2ImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
)
self.scheduler = FlowMatchScheduler("FLUX.2")
self.text_encoder: Flux2TextEncoder = None
self.text_encoder_qwen3: ZImageTextEncoder = None
self.dit: Flux2DiT = None
self.vae: Flux2VAE = None
self.tokenizer: AutoProcessor = None
@@ -32,8 +35,10 @@ class Flux2ImagePipeline(BasePipeline):
self.units = [
Flux2Unit_ShapeChecker(),
Flux2Unit_PromptEmbedder(),
Flux2Unit_Qwen3PromptEmbedder(),
Flux2Unit_NoiseInitializer(),
Flux2Unit_InputImageEmbedder(),
Flux2Unit_EditImageEmbedder(),
Flux2Unit_ImageIDs(),
]
self.model_fn = model_fn_flux2
@@ -42,7 +47,7 @@ class Flux2ImagePipeline(BasePipeline):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
vram_limit: float = None,
@@ -53,11 +58,12 @@ class Flux2ImagePipeline(BasePipeline):
# Fetch models
pipe.text_encoder = model_pool.fetch_model("flux2_text_encoder")
pipe.text_encoder_qwen3 = model_pool.fetch_model("z_image_text_encoder")
pipe.dit = model_pool.fetch_model("flux2_dit")
pipe.vae = model_pool.fetch_model("flux2_vae")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoProcessor.from_pretrained(tokenizer_config.path)
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
@@ -75,6 +81,9 @@ class Flux2ImagePipeline(BasePipeline):
# Image
input_image: Image.Image = None,
denoising_strength: float = 1.0,
# Edit
edit_image: Union[Image.Image, List[Image.Image]] = None,
edit_image_auto_resize: bool = True,
# Shape
height: int = 1024,
width: int = 1024,
@@ -98,6 +107,7 @@ class Flux2ImagePipeline(BasePipeline):
inputs_shared = {
"cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance,
"input_image": input_image, "denoising_strength": denoising_strength,
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"num_inference_steps": num_inference_steps,
@@ -275,6 +285,10 @@ class Flux2Unit_PromptEmbedder(PipelineUnit):
return prompt_embeds, text_ids
def process(self, pipe: Flux2ImagePipeline, prompt):
# Skip if Qwen3 text encoder is available (handled by Qwen3PromptEmbedder)
if pipe.text_encoder_qwen3 is not None:
return {}
pipe.load_models_to_device(self.onload_model_names)
prompt_embeds, text_ids = self.encode_prompt(
pipe.text_encoder, pipe.tokenizer, prompt,
@@ -283,6 +297,135 @@ class Flux2Unit_PromptEmbedder(PipelineUnit):
return {"prompt_embeds": prompt_embeds, "text_ids": text_ids}
class Flux2Unit_Qwen3PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_emb", "prompt_emb_mask"),
onload_model_names=("text_encoder_qwen3",)
)
self.hidden_states_layers = (9, 18, 27) # Qwen3 layers
def get_qwen3_prompt_embeds(
self,
text_encoder: ZImageTextEncoder,
tokenizer: AutoTokenizer,
prompt: Union[str, List[str]],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
):
dtype = text_encoder.dtype if dtype is None else dtype
device = text_encoder.device if device is None else device
prompt = [prompt] if isinstance(prompt, str) else prompt
all_input_ids = []
all_attention_masks = []
for single_prompt in prompt:
messages = [{"role": "user", "content": single_prompt}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_sequence_length,
)
all_input_ids.append(inputs["input_ids"])
all_attention_masks.append(inputs["attention_mask"])
input_ids = torch.cat(all_input_ids, dim=0).to(device)
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
# Forward pass through the model
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
# Only use outputs from intermediate layers and stack them
out = torch.stack([output.hidden_states[k] for k in self.hidden_states_layers], dim=1)
out = out.to(dtype=dtype, device=device)
batch_size, num_channels, seq_len, hidden_dim = out.shape
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
return prompt_embeds
def prepare_text_ids(
self,
x: torch.Tensor, # (B, L, D) or (L, D)
t_coord: Optional[torch.Tensor] = None,
):
B, L, _ = x.shape
out_ids = []
for i in range(B):
t = torch.arange(1) if t_coord is None else t_coord[i]
h = torch.arange(1)
w = torch.arange(1)
l = torch.arange(L)
coords = torch.cartesian_prod(t, h, w, l)
out_ids.append(coords)
return torch.stack(out_ids)
def encode_prompt(
self,
text_encoder: ZImageTextEncoder,
tokenizer: AutoTokenizer,
prompt: Union[str, List[str]],
dtype = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 512,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_embeds = self.get_qwen3_prompt_embeds(
text_encoder=text_encoder,
tokenizer=tokenizer,
prompt=prompt,
dtype=dtype,
device=device,
max_sequence_length=max_sequence_length,
)
batch_size, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
text_ids = self.prepare_text_ids(prompt_embeds)
text_ids = text_ids.to(device)
return prompt_embeds, text_ids
def process(self, pipe: Flux2ImagePipeline, prompt):
# Check if Qwen3 text encoder is available
if pipe.text_encoder_qwen3 is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
prompt_embeds, text_ids = self.encode_prompt(
pipe.text_encoder_qwen3, pipe.tokenizer, prompt,
dtype=pipe.torch_dtype, device=pipe.device,
)
return {"prompt_embeds": prompt_embeds, "text_ids": text_ids}
class Flux2Unit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
@@ -318,6 +461,75 @@ class Flux2Unit_InputImageEmbedder(PipelineUnit):
return {"latents": latents, "input_latents": input_latents}
class Flux2Unit_EditImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image", "edit_image_auto_resize"),
output_params=("edit_latents", "edit_image_ids"),
onload_model_names=("vae",)
)
def calculate_dimensions(self, target_area, ratio):
import math
width = math.sqrt(target_area * ratio)
height = width / ratio
width = round(width / 32) * 32
height = round(height / 32) * 32
return width, height
def crop_and_resize(self, image, target_height, target_width):
width, height = image.size
scale = max(target_width / width, target_height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
return image
def edit_image_auto_resize(self, edit_image):
calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1])
return self.crop_and_resize(edit_image, calculated_height, calculated_width)
def process_image_ids(self, image_latents, scale=10):
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
t_coords = [t.view(-1) for t in t_coords]
image_latent_ids = []
for x, t in zip(image_latents, t_coords):
x = x.squeeze(0)
_, height, width = x.shape
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
image_latent_ids.append(x_ids)
image_latent_ids = torch.cat(image_latent_ids, dim=0)
image_latent_ids = image_latent_ids.unsqueeze(0)
return image_latent_ids
def process(self, pipe: Flux2ImagePipeline, edit_image, edit_image_auto_resize):
if edit_image is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
if isinstance(edit_image, Image.Image):
edit_image = [edit_image]
resized_edit_image, edit_latents = [], []
for image in edit_image:
# Preprocess
if edit_image_auto_resize is None or edit_image_auto_resize:
image = self.edit_image_auto_resize(image)
resized_edit_image.append(image)
# Encode
image = pipe.preprocess_image(image)
latents = pipe.vae.encode(image)
edit_latents.append(latents)
edit_image_ids = self.process_image_ids(edit_latents).to(pipe.device)
edit_latents = torch.concat([rearrange(latents, "B C H W -> B (H W) C") for latents in edit_latents], dim=1)
return {"edit_latents": edit_latents, "edit_image_ids": edit_image_ids}
class Flux2Unit_ImageIDs(PipelineUnit):
def __init__(self):
super().__init__(
@@ -352,10 +564,17 @@ def model_fn_flux2(
prompt_embeds=None,
text_ids=None,
image_ids=None,
edit_latents=None,
edit_image_ids=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
image_seq_len = latents.shape[1]
if edit_latents is not None:
image_seq_len = latents.shape[1]
latents = torch.concat([latents, edit_latents], dim=1)
image_ids = torch.concat([image_ids, edit_image_ids], dim=1)
embedded_guidance = torch.tensor([embedded_guidance], device=latents.device)
model_output = dit(
hidden_states=latents,
@@ -367,4 +586,5 @@ def model_fn_flux2(
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
model_output = model_output[:, :image_seq_len]
return model_output

View File

@@ -6,6 +6,7 @@ from einops import rearrange, repeat
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward, load_state_dict
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
@@ -55,7 +56,7 @@ class MultiControlNet(torch.nn.Module):
class FluxImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
@@ -117,7 +118,7 @@ class FluxImagePipeline(BasePipeline):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_1_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/"),
tokenizer_2_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/"),
@@ -377,7 +378,7 @@ class FluxImageUnit_PromptEmbedder(PipelineUnit):
text_encoder_2,
prompt,
positive=True,
device="cuda",
device=get_device_type(),
t5_sequence_length=512,
):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)
@@ -558,7 +559,7 @@ class FluxImageUnit_EntityControl(PipelineUnit):
text_encoder_2,
prompt,
positive=True,
device="cuda",
device=get_device_type(),
t5_sequence_length=512,
):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)
@@ -793,7 +794,7 @@ class FluxImageUnit_ValueControl(PipelineUnit):
class InfinitYou(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__()
from facexlib.recognition import init_recognition_model
from insightface.app import FaceAnalysis

View File

@@ -0,0 +1,660 @@
import torch, types
import numpy as np
from PIL import Image
from einops import repeat
from typing import Optional, Union
from einops import rearrange
import numpy as np
from PIL import Image
from tqdm import tqdm
from typing import Optional
from transformers import AutoImageProcessor, Gemma3Processor
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer
from ..models.ltx2_dit import LTXModel
from ..models.ltx2_video_vae import LTX2VideoEncoder, LTX2VideoDecoder, VideoLatentPatchifier
from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier, AudioProcessor
from ..models.ltx2_upsampler import LTX2LatentUpsampler
from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS
from ..utils.data.media_io_ltx2 import ltx2_preprocess
class LTX2AudioVideoPipeline(BasePipeline):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device,
torch_dtype=torch_dtype,
height_division_factor=32,
width_division_factor=32,
time_division_factor=8,
time_division_remainder=1,
)
self.scheduler = FlowMatchScheduler("LTX-2")
self.text_encoder: LTX2TextEncoder = None
self.tokenizer: LTXVGemmaTokenizer = None
self.processor: Gemma3Processor = None
self.text_encoder_post_modules: LTX2TextEncoderPostModules = None
self.dit: LTXModel = None
self.video_vae_encoder: LTX2VideoEncoder = None
self.video_vae_decoder: LTX2VideoDecoder = None
self.audio_vae_encoder: LTX2AudioEncoder = None
self.audio_vae_decoder: LTX2AudioDecoder = None
self.audio_vocoder: LTX2Vocoder = None
self.upsampler: LTX2LatentUpsampler = None
self.video_patchifier: VideoLatentPatchifier = VideoLatentPatchifier(patch_size=1)
self.audio_patchifier: AudioPatchifier = AudioPatchifier(patch_size=1)
self.audio_processor: AudioProcessor = AudioProcessor()
self.in_iteration_models = ("dit",)
self.units = [
LTX2AudioVideoUnit_PipelineChecker(),
LTX2AudioVideoUnit_ShapeChecker(),
LTX2AudioVideoUnit_PromptEmbedder(),
LTX2AudioVideoUnit_NoiseInitializer(),
LTX2AudioVideoUnit_InputAudioEmbedder(),
LTX2AudioVideoUnit_InputVideoEmbedder(),
LTX2AudioVideoUnit_InputImagesEmbedder(),
LTX2AudioVideoUnit_InContextVideoEmbedder(),
]
self.model_fn = model_fn_ltx2
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
stage2_lora_config: Optional[ModelConfig] = None,
vram_limit: float = None,
):
# Initialize pipeline
pipe = LTX2AudioVideoPipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
pipe.text_encoder = model_pool.fetch_model("ltx2_text_encoder")
tokenizer_config.download_if_necessary()
pipe.tokenizer = LTXVGemmaTokenizer(tokenizer_path=tokenizer_config.path)
image_processor = AutoImageProcessor.from_pretrained(tokenizer_config.path, local_files_only=True)
pipe.processor = Gemma3Processor(image_processor=image_processor, tokenizer=pipe.tokenizer.tokenizer)
pipe.text_encoder_post_modules = model_pool.fetch_model("ltx2_text_encoder_post_modules")
pipe.dit = model_pool.fetch_model("ltx2_dit")
pipe.video_vae_encoder = model_pool.fetch_model("ltx2_video_vae_encoder")
pipe.video_vae_decoder = model_pool.fetch_model("ltx2_video_vae_decoder")
pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder")
pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder")
pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler")
# Stage 2
if stage2_lora_config is not None:
stage2_lora_config.download_if_necessary()
pipe.stage2_lora_path = stage2_lora_config.path
# Optional, currently not used
pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd=tqdm):
if inputs_shared["use_two_stage_pipeline"]:
if inputs_shared.get("clear_lora_before_state_two", False):
self.clear_lora()
latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
self.load_models_to_device('upsampler',)
latent = self.upsampler(latent)
latent = self.video_vae_encoder.per_channel_statistics.normalize(latent)
self.scheduler.set_timesteps(special_case="stage2")
inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")})
denoise_mask_video = 1.0
# input image
if inputs_shared.get("input_images", None) is not None:
latent, denoise_mask_video, initial_latents = self.apply_input_images_to_latents(
latent, inputs_shared.pop("input_latents"), inputs_shared["input_images_indexes"],
inputs_shared["input_images_strength"], latent.clone())
inputs_shared.update({"input_latents_video": initial_latents, "denoise_mask_video": denoise_mask_video})
# remove in-context video control in stage 2
inputs_shared.pop("in_context_video_latents", None)
inputs_shared.pop("in_context_video_positions", None)
# initialize latents for stage 2
inputs_shared["video_latents"] = self.scheduler.sigmas[0] * denoise_mask_video * inputs_shared[
"video_noise"] + (1 - self.scheduler.sigmas[0] * denoise_mask_video) * latent
inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (
1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"]
self.load_models_to_device(self.in_iteration_models)
if not inputs_shared["use_distilled_pipeline"]:
self.load_lora(self.dit, self.stage2_lora_path, alpha=0.8)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
self.model_fn, 1.0, inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id,
noise_pred=noise_pred_video, inpaint_mask=inputs_shared.get("denoise_mask_video", None),
input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
noise_pred=noise_pred_audio, **inputs_shared)
return inputs_shared
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: Optional[str] = "",
denoising_strength: float = 1.0,
# Image-to-video
input_images: Optional[list[Image.Image]] = None,
input_images_indexes: Optional[list[int]] = None,
input_images_strength: Optional[float] = 1.0,
# In-Context Video Control
in_context_videos: Optional[list[list[Image.Image]]] = None,
in_context_downsample_factor: Optional[int] = 2,
# Randomness
seed: Optional[int] = None,
rand_device: Optional[str] = "cpu",
# Shape
height: Optional[int] = 512,
width: Optional[int] = 768,
num_frames=121,
frame_rate=24,
# Classifier-free guidance
cfg_scale: Optional[float] = 3.0,
# Scheduler
num_inference_steps: Optional[int] = 40,
# VAE tiling
tiled: Optional[bool] = True,
tile_size_in_pixels: Optional[int] = 512,
tile_overlap_in_pixels: Optional[int] = 128,
tile_size_in_frames: Optional[int] = 128,
tile_overlap_in_frames: Optional[int] = 24,
# Special Pipelines
use_two_stage_pipeline: Optional[bool] = False,
clear_lora_before_state_two: Optional[bool] = False,
use_distilled_pipeline: Optional[bool] = False,
# progress_bar
progress_bar_cmd=tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength,
special_case="ditilled_stage1" if use_distilled_pipeline else None)
# Inputs
inputs_posi = {
"prompt": prompt,
}
inputs_nega = {
"negative_prompt": negative_prompt,
}
inputs_shared = {
"input_images": input_images, "input_images_indexes": input_images_indexes, "input_images_strength": input_images_strength,
"in_context_videos": in_context_videos, "in_context_downsample_factor": in_context_downsample_factor,
"seed": seed, "rand_device": rand_device,
"height": height, "width": width, "num_frames": num_frames, "frame_rate": frame_rate,
"cfg_scale": cfg_scale,
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline, "clear_lora_before_state_two": clear_lora_before_state_two,
"video_patchifier": self.video_patchifier, "audio_patchifier": self.audio_patchifier,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Denoise Stage 1
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video,
inpaint_mask=inputs_shared.get("denoise_mask_video", None), input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
noise_pred=noise_pred_audio, **inputs_shared)
# Denoise Stage 2
inputs_shared = self.stage2_denoise(inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd)
# Decode
self.load_models_to_device(['video_vae_decoder'])
video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels,
tile_overlap_in_pixels, tile_size_in_frames, tile_overlap_in_frames)
video = self.vae_output_to_video(video)
self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder'])
decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"])
decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float()
return video, decoded_audio
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength, initial_latents=None, num_frames=121):
b, _, f, h, w = latents.shape
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device)
initial_latents = torch.zeros_like(latents) if initial_latents is None else initial_latents
for idx, input_latent in zip(input_indexes, input_latents):
idx = min(max(1 + (idx-1) // 8, 0), f - 1)
input_latent = input_latent.to(dtype=latents.dtype, device=latents.device)
initial_latents[:, :, idx:idx + input_latent.shape[2], :, :] = input_latent
denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength
latents = latents * denoise_mask + initial_latents * (1.0 - denoise_mask)
return latents, denoise_mask, initial_latents
class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
input_params=("use_distilled_pipeline", "use_two_stage_pipeline"),
output_params=("use_two_stage_pipeline", "cfg_scale")
)
def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
if inputs_shared.get("use_distilled_pipeline", False):
inputs_shared["use_two_stage_pipeline"] = True
inputs_shared["cfg_scale"] = 1.0
print(f"Distilled pipeline requested, setting use_two_stage_pipeline to True, disable CFG by setting cfg_scale to 1.0.")
if inputs_shared.get("use_two_stage_pipeline", False):
# distill pipeline also uses two-stage, but it does not needs lora
if not inputs_shared.get("use_distilled_pipeline", False):
if not (hasattr(pipe, "stage2_lora_path") and pipe.stage2_lora_path is not None):
raise ValueError("Two-stage pipeline requested, but stage2_lora_path is not set in the pipeline.")
if not (hasattr(pipe, "upsampler") and pipe.upsampler is not None):
raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.")
return inputs_shared, inputs_posi, inputs_nega
class LTX2AudioVideoUnit_ShapeChecker(PipelineUnit):
"""
For two-stage pipelines, the resolution must be divisible by 64.
For one-stage pipelines, the resolution must be divisible by 32.
"""
def __init__(self):
super().__init__(
input_params=("height", "width", "num_frames"),
output_params=("height", "width", "num_frames"),
)
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, use_two_stage_pipeline=False):
if use_two_stage_pipeline:
self.width_division_factor = 64
self.height_division_factor = 64
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
if use_two_stage_pipeline:
self.width_division_factor = 32
self.height_division_factor = 32
return {"height": height, "width": width, "num_frames": num_frames}
class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("video_context", "audio_context"),
onload_model_names=("text_encoder", "text_encoder_post_modules"),
)
def _convert_to_additive_mask(self, attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return (attention_mask - 1).to(dtype).reshape(
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(dtype).max
def _run_connectors(self, pipe, encoded_input: torch.Tensor,
attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
connector_attention_mask = self._convert_to_additive_mask(attention_mask, encoded_input.dtype)
encoded, encoded_connector_attention_mask = pipe.text_encoder_post_modules.embeddings_connector(
encoded_input,
connector_attention_mask,
)
# restore the mask values to int64
attention_mask = (encoded_connector_attention_mask < 0.000001).to(torch.int64)
attention_mask = attention_mask.reshape([encoded.shape[0], encoded.shape[1], 1])
encoded = encoded * attention_mask
encoded_for_audio, _ = pipe.text_encoder_post_modules.audio_embeddings_connector(
encoded_input, connector_attention_mask)
return encoded, encoded_for_audio, attention_mask.squeeze(-1)
def _norm_and_concat_padded_batch(
self,
encoded_text: torch.Tensor,
sequence_lengths: torch.Tensor,
padding_side: str = "right",
) -> torch.Tensor:
"""Normalize and flatten multi-layer hidden states, respecting padding.
Performs per-batch, per-layer normalization using masked mean and range,
then concatenates across the layer dimension.
Args:
encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].
sequence_lengths: Number of valid (non-padded) tokens per batch item.
padding_side: Whether padding is on "left" or "right".
Returns:
Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],
with padded positions zeroed out.
"""
b, t, d, l = encoded_text.shape # noqa: E741
device = encoded_text.device
# Build mask: [B, T, 1, 1]
token_indices = torch.arange(t, device=device)[None, :] # [1, T]
if padding_side == "right":
# For right padding, valid tokens are from 0 to sequence_length-1
mask = token_indices < sequence_lengths[:, None] # [B, T]
elif padding_side == "left":
# For left padding, valid tokens are from (T - sequence_length) to T-1
start_indices = t - sequence_lengths[:, None] # [B, 1]
mask = token_indices >= start_indices # [B, T]
else:
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
mask = rearrange(mask, "b t -> b t 1 1")
eps = 1e-6
# Compute masked mean: [B, 1, 1, L]
masked = encoded_text.masked_fill(~mask, 0.0)
denom = (sequence_lengths * d).view(b, 1, 1, 1)
mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps)
# Compute masked min/max: [B, 1, 1, L]
x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
range_ = x_max - x_min
# Normalize only the valid tokens
normed = 8 * (encoded_text - mean) / (range_ + eps)
# concat to be [Batch, T, D * L] - this preserves the original structure
normed = normed.reshape(b, t, -1) # [B, T, D * L]
# Apply mask to preserve original padding (set padded positions to 0)
mask_flattened = rearrange(mask, "b t 1 1 -> b t 1").expand(-1, -1, d * l)
normed = normed.masked_fill(~mask_flattened, 0.0)
return normed
def _run_feature_extractor(self,
pipe,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
padding_side: str = "right") -> torch.Tensor:
encoded_text_features = torch.stack(hidden_states, dim=-1)
encoded_text_features_dtype = encoded_text_features.dtype
sequence_lengths = attention_mask.sum(dim=-1)
normed_concated_encoded_text_features = self._norm_and_concat_padded_batch(encoded_text_features,
sequence_lengths,
padding_side=padding_side)
return pipe.text_encoder_post_modules.feature_extractor_linear(
normed_concated_encoded_text_features.to(encoded_text_features_dtype))
def _preprocess_text(
self,
pipe,
text: str,
padding_side: str = "left",
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""
Encode a given string into feature tensors suitable for downstream tasks.
Args:
text (str): Input string to encode.
Returns:
tuple[torch.Tensor, dict[str, torch.Tensor]]: Encoded features and a dictionary with attention mask.
"""
token_pairs = pipe.tokenizer.tokenize_with_weights(text)["gemma"]
input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.device)
attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.device)
outputs = pipe.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
projected = self._run_feature_extractor(pipe,
hidden_states=outputs.hidden_states,
attention_mask=attention_mask,
padding_side=padding_side)
return projected, attention_mask
def encode_prompt(self, pipe, text, padding_side="left"):
encoded_inputs, attention_mask = self._preprocess_text(pipe, text, padding_side)
video_encoding, audio_encoding, attention_mask = self._run_connectors(pipe, encoded_inputs, attention_mask)
return video_encoding, audio_encoding, attention_mask
def process(self, pipe: LTX2AudioVideoPipeline, prompt: str):
pipe.load_models_to_device(self.onload_model_names)
video_context, audio_context, _ = self.encode_prompt(pipe, prompt)
return {"video_context": video_context, "audio_context": audio_context}
class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "num_frames", "seed", "rand_device", "frame_rate", "use_two_stage_pipeline"),
output_params=("video_noise", "audio_noise", "video_positions", "audio_positions", "video_latent_shape", "audio_latent_shape")
)
def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=128)
video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device)
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float()
video_positions[:, 0, ...] = video_positions[:, 0, ...] / frame_rate
video_positions = video_positions.to(pipe.torch_dtype)
audio_latent_shape = AudioLatentShape.from_video_pixel_shape(video_pixel_shape)
audio_noise = pipe.generate_noise(audio_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
return {
"video_noise": video_noise,
"audio_noise": audio_noise,
"video_positions": video_positions,
"audio_positions": audio_positions,
"video_latent_shape": video_latent_shape,
"audio_latent_shape": audio_latent_shape
}
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0, use_two_stage_pipeline=False):
if use_two_stage_pipeline:
stage1_dict = self.process_stage(pipe, height // 2, width // 2, num_frames, seed, rand_device, frame_rate)
stage2_dict = self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
initial_dict = stage1_dict
initial_dict.update({"stage2_" + k: v for k, v in stage2_dict.items()})
return initial_dict
else:
return self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_video", "video_noise", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels"),
output_params=("video_latents", "input_latents"),
onload_model_names=("video_vae_encoder")
)
def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, tiled, tile_size_in_pixels, tile_overlap_in_pixels):
if input_video is None:
return {"video_latents": video_noise}
else:
pipe.load_models_to_device(self.onload_model_names)
input_video = pipe.preprocess_video(input_video)
input_latents = pipe.video_vae_encoder.encode(input_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device)
if pipe.scheduler.training:
return {"video_latents": input_latents, "input_latents": input_latents}
else:
raise NotImplementedError("Video-to-video not implemented yet.")
class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_audio", "audio_noise"),
output_params=("audio_latents", "audio_input_latents", "audio_positions", "audio_latent_shape"),
onload_model_names=("audio_vae_encoder",)
)
def process(self, pipe: LTX2AudioVideoPipeline, input_audio, audio_noise):
if input_audio is None:
return {"audio_latents": audio_noise}
else:
input_audio, sample_rate = input_audio
pipe.load_models_to_device(self.onload_model_names)
input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype)
audio_input_latents = pipe.audio_vae_encoder(input_audio)
audio_latent_shape = AudioLatentShape.from_torch_shape(audio_input_latents.shape)
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
if pipe.scheduler.training:
return {"audio_latents": audio_input_latents, "audio_input_latents": audio_input_latents, "audio_positions": audio_positions, "audio_latent_shape": audio_latent_shape}
else:
raise NotImplementedError("Audio-to-video not supported.")
class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "num_frames", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"),
output_params=("video_latents", "denoise_mask_video", "input_latents_video", "stage2_input_latents"),
onload_model_names=("video_vae_encoder")
)
def get_image_latent(self, pipe, input_image, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels):
image = ltx2_preprocess(np.array(input_image.resize((width, height))))
image = torch.Tensor(np.array(image, dtype=np.float32)).to(dtype=pipe.torch_dtype, device=pipe.device)
image = image / 127.5 - 1.0
image = repeat(image, f"H W C -> B C F H W", B=1, F=1)
latent = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device)
return latent
def process(self, pipe: LTX2AudioVideoPipeline, input_images, input_images_indexes, input_images_strength, video_latents, height, width, num_frames, tiled, tile_size_in_pixels, tile_overlap_in_pixels, use_two_stage_pipeline=False):
if input_images is None or len(input_images) == 0:
return {"video_latents": video_latents}
else:
pipe.load_models_to_device(self.onload_model_names)
output_dicts = {}
stage1_height = height // 2 if use_two_stage_pipeline else height
stage1_width = width // 2 if use_two_stage_pipeline else width
stage1_latents = [
self.get_image_latent(pipe, img, stage1_height, stage1_width, tiled, tile_size_in_pixels,
tile_overlap_in_pixels) for img in input_images
]
video_latents, denoise_mask_video, initial_latents = pipe.apply_input_images_to_latents(video_latents, stage1_latents, input_images_indexes, input_images_strength, num_frames=num_frames)
output_dicts.update({"video_latents": video_latents, "denoise_mask_video": denoise_mask_video, "input_latents_video": initial_latents})
if use_two_stage_pipeline:
stage2_latents = [
self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels,
tile_overlap_in_pixels) for img in input_images
]
output_dicts.update({"stage2_input_latents": stage2_latents})
return output_dicts
class LTX2AudioVideoUnit_InContextVideoEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("in_context_videos", "height", "width", "num_frames", "frame_rate", "in_context_downsample_factor", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"),
output_params=("in_context_video_latents", "in_context_video_positions"),
onload_model_names=("video_vae_encoder")
)
def check_in_context_video(self, pipe, in_context_video, height, width, num_frames, in_context_downsample_factor, use_two_stage_pipeline=True):
if in_context_video is None or len(in_context_video) == 0:
raise ValueError("In-context video is None or empty.")
in_context_video = in_context_video[:num_frames]
expected_height = height // in_context_downsample_factor // 2 if use_two_stage_pipeline else height // in_context_downsample_factor
expected_width = width // in_context_downsample_factor // 2 if use_two_stage_pipeline else width // in_context_downsample_factor
current_h, current_w, current_f = in_context_video[0].size[1], in_context_video[0].size[0], len(in_context_video)
h, w, f = pipe.check_resize_height_width(expected_height, expected_width, current_f, verbose=0)
if current_h != h or current_w != w:
in_context_video = [img.resize((w, h)) for img in in_context_video]
if current_f != f:
# pad black frames at the end
in_context_video = in_context_video + [Image.new("RGB", (w, h), (0, 0, 0))] * (f - current_f)
return in_context_video
def process(self, pipe: LTX2AudioVideoPipeline, in_context_videos, height, width, num_frames, frame_rate, in_context_downsample_factor, tiled, tile_size_in_pixels, tile_overlap_in_pixels, use_two_stage_pipeline=True):
if in_context_videos is None or len(in_context_videos) == 0:
return {}
else:
pipe.load_models_to_device(self.onload_model_names)
latents, positions = [], []
for in_context_video in in_context_videos:
in_context_video = self.check_in_context_video(pipe, in_context_video, height, width, num_frames, in_context_downsample_factor, use_two_stage_pipeline)
in_context_video = pipe.preprocess_video(in_context_video)
in_context_latents = pipe.video_vae_encoder.encode(in_context_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device)
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=VideoLatentShape.from_torch_shape(in_context_latents.shape), device=pipe.device)
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float()
video_positions[:, 0, ...] = video_positions[:, 0, ...] / frame_rate
video_positions[:, 1, ...] *= in_context_downsample_factor # height axis
video_positions[:, 2, ...] *= in_context_downsample_factor # width axis
video_positions = video_positions.to(pipe.torch_dtype)
latents.append(in_context_latents)
positions.append(video_positions)
latents = torch.cat(latents, dim=1)
positions = torch.cat(positions, dim=1)
return {"in_context_video_latents": latents, "in_context_video_positions": positions}
def model_fn_ltx2(
dit: LTXModel,
video_latents=None,
video_context=None,
video_positions=None,
video_patchifier=None,
audio_latents=None,
audio_context=None,
audio_positions=None,
audio_patchifier=None,
timestep=None,
denoise_mask_video=None,
in_context_video_latents=None,
in_context_video_positions=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
timestep = timestep.float() / 1000.
# patchify
b, c_v, f, h, w = video_latents.shape
video_latents = video_patchifier.patchify(video_latents)
seq_len_video = video_latents.shape[1]
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
if denoise_mask_video is not None:
video_timesteps = video_patchifier.patchify(denoise_mask_video) * video_timesteps
if in_context_video_latents is not None:
in_context_video_latents = video_patchifier.patchify(in_context_video_latents)
in_context_video_timesteps = timestep.repeat(1, in_context_video_latents.shape[1], 1) * 0.
video_latents = torch.cat([video_latents, in_context_video_latents], dim=1)
video_positions = torch.cat([video_positions, in_context_video_positions], dim=2)
video_timesteps = torch.cat([video_timesteps, in_context_video_timesteps], dim=1)
if audio_latents is not None:
_, c_a, _, mel_bins = audio_latents.shape
audio_latents = audio_patchifier.patchify(audio_latents)
audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1)
else:
audio_timesteps = None
vx, ax = dit(
video_latents=video_latents,
video_positions=video_positions,
video_context=video_context,
video_timesteps=video_timesteps,
audio_latents=audio_latents,
audio_positions=audio_positions,
audio_context=audio_context,
audio_timesteps=audio_timesteps,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
vx = vx[:, :seq_len_video, ...]
# unpatchify
vx = video_patchifier.unpatchify_video(vx, f, h, w)
ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins) if ax is not None else None
return vx, ax

View File

@@ -6,6 +6,7 @@ from einops import rearrange
import numpy as np
from math import prod
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
@@ -22,7 +23,7 @@ from ..models.qwen_image_image2lora import QwenImageImage2LoRAModel
class QwenImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
@@ -60,7 +61,7 @@ class QwenImagePipeline(BasePipeline):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
processor_config: ModelConfig = None,

View File

@@ -11,6 +11,7 @@ from typing import Optional
from typing_extensions import Literal
from transformers import Wav2Vec2Processor
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
@@ -30,7 +31,7 @@ from ..models.longcat_video_dit import LongCatVideoTransformer3DModel
class WanVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
@@ -98,7 +99,7 @@ class WanVideoPipeline(BasePipeline):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
audio_processor_config: ModelConfig = None,
@@ -122,11 +123,15 @@ class WanVideoPipeline(BasePipeline):
model_config.model_id = redirect_dict[model_config.origin_file_pattern][0]
model_config.origin_file_pattern = redirect_dict[model_config.origin_file_pattern][1]
# Initialize pipeline
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
if use_usp:
from ..utils.xfuser import initialize_usp
initialize_usp(device)
import torch.distributed as dist
from ..core.device.npu_compatible_device import get_device_name
if dist.is_available() and dist.is_initialized():
device = get_device_name()
# Initialize pipeline
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
@@ -960,7 +965,7 @@ class WanVideoUnit_AnimateInpaint(PipelineUnit):
onload_model_names=("vae",)
)
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"):
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device=get_device_type()):
if mask_pixel_values is None:
msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
else:
@@ -1316,11 +1321,6 @@ def model_fn_wan_video(
if tea_cache_update:
x = tea_cache.update(x)
else:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
def create_custom_forward_vap(block, vap):
def custom_forward(*inputs):
return vap(block, *inputs)
@@ -1334,32 +1334,24 @@ def model_fn_wan_video(
x, x_vap = torch.utils.checkpoint.checkpoint(
create_custom_forward_vap(block, vap),
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
use_reentrant=False,
use_reentrant=False
)
elif use_gradient_checkpointing:
x, x_vap = torch.utils.checkpoint.checkpoint(
create_custom_forward_vap(block, vap),
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
use_reentrant=False,
use_reentrant=False
)
else:
x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id)
else:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs)
x = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x, context, t_mod, freqs
)
# VACE
if vace_context is not None and block_id in vace.vace_layers_mapping:
@@ -1482,32 +1474,18 @@ def model_fn_wans2v(
return custom_forward
for block_id, block in enumerate(dit.blocks):
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, seq_len_x, pre_compute_freqs[0],
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
x,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, seq_len_x, pre_compute_freqs[0],
use_reentrant=False,
x = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x, context, t_mod, seq_len_x, pre_compute_freqs[0]
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
x,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel)
x = gradient_checkpoint_forward(
lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x),
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x
)
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)

View File

@@ -1,4 +1,4 @@
import torch, math
import torch, math, warnings
from PIL import Image
from typing import Union
from tqdm import tqdm
@@ -6,6 +6,7 @@ from einops import rearrange
import numpy as np
from typing import Union, List, Optional, Tuple, Iterable, Dict
from ..core.device.npu_compatible_device import get_device_type, IS_NPU_AVAILABLE
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..core.data.operators import ImageCropAndResize
@@ -25,7 +26,7 @@ from ..models.z_image_image2lora import ZImageImage2LoRAModel
class ZImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
@@ -58,10 +59,11 @@ class ZImagePipeline(BasePipeline):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
vram_limit: float = None,
enable_npu_patch: bool = True,
):
# Initialize pipeline
pipe = ZImagePipeline(device=device, torch_dtype=torch_dtype)
@@ -83,6 +85,8 @@ class ZImagePipeline(BasePipeline):
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
# NPU patch
apply_npu_patch(enable_npu_patch)
return pipe
@@ -295,7 +299,7 @@ class ZImageUnit_PromptEmbedder(PipelineUnit):
def process(self, pipe: ZImagePipeline, prompt, edit_image):
pipe.load_models_to_device(self.onload_model_names)
if hasattr(pipe, "dit") and pipe.dit.siglip_embedder is not None:
if hasattr(pipe, "dit") and pipe.dit is not None and pipe.dit.siglip_embedder is not None:
# Z-Image-Turbo and Z-Image-Omni-Base use different prompt encoding methods.
# We determine which encoding method to use based on the model architecture.
# If you are using two-stage split training,
@@ -666,3 +670,19 @@ def model_fn_z_image_turbo(
x = rearrange(x, "C B H W -> B C H W")
x = -x
return x
def apply_npu_patch(enable_npu_patch: bool=True):
if IS_NPU_AVAILABLE and enable_npu_patch:
from ..models.general_modules import RMSNorm
from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm
from ..models.z_image_dit import Attention
from ..core.npu_patch.npu_fused_operator import (
rms_norm_forward_npu,
rms_norm_forward_transformers_npu,
rotary_emb_Zimage_npu
)
warnings.warn("Replacing RMSNorm and Rope with NPU fusion operators to improve the performance of the model on NPU.Set enable_npu_patch=False to disable this feature.")
RMSNorm.forward = rms_norm_forward_npu
Qwen3RMSNorm.forward = rms_norm_forward_transformers_npu
Attention.apply_rotary_emb = rotary_emb_Zimage_npu

View File

@@ -1,12 +1,13 @@
from typing_extensions import Literal, TypeAlias
from diffsynth.core.device.npu_compatible_device import get_device_type
Processor_id: TypeAlias = Literal[
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
]
class Annotator:
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device=get_device_type(), skip_processor=False):
if not skip_processor:
if processor_id == "canny":
from controlnet_aux.processor import CannyDetector

View File

@@ -116,7 +116,7 @@ class VideoData:
if self.height is not None and self.width is not None:
return self.height, self.width
else:
height, width, _ = self.__getitem__(0).shape
width, height = self.__getitem__(0).size
return height, width
def __getitem__(self, item):

View File

@@ -0,0 +1,149 @@
from fractions import Fraction
import torch
import av
from tqdm import tqdm
from PIL import Image
import numpy as np
from io import BytesIO
from collections.abc import Generator, Iterator
def _resample_audio(
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
) -> None:
cc = audio_stream.codec_context
# Use the encoder's format/layout/rate as the *target*
target_format = cc.format or "fltp" # AAC → usually fltp
target_layout = cc.layout or "stereo"
target_rate = cc.sample_rate or frame_in.sample_rate
audio_resampler = av.audio.resampler.AudioResampler(
format=target_format,
layout=target_layout,
rate=target_rate,
)
audio_next_pts = 0
for rframe in audio_resampler.resample(frame_in):
if rframe.pts is None:
rframe.pts = audio_next_pts
audio_next_pts += rframe.samples
rframe.sample_rate = frame_in.sample_rate
container.mux(audio_stream.encode(rframe))
# flush audio encoder
for packet in audio_stream.encode():
container.mux(packet)
def _write_audio(
container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int
) -> None:
if samples.ndim == 1:
samples = samples[:, None]
if samples.shape[1] != 2 and samples.shape[0] == 2:
samples = samples.T
if samples.shape[1] != 2:
raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
if samples.dtype != torch.int16:
samples = torch.clip(samples, -1.0, 1.0)
samples = (samples * 32767.0).to(torch.int16)
frame_in = av.AudioFrame.from_ndarray(
samples.contiguous().reshape(1, -1).cpu().numpy(),
format="s16",
layout="stereo",
)
frame_in.sample_rate = audio_sample_rate
_resample_audio(container, audio_stream, frame_in)
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
"""
Prepare the audio stream for writing.
"""
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
audio_stream.codec_context.sample_rate = audio_sample_rate
audio_stream.codec_context.layout = "stereo"
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
return audio_stream
def write_video_audio_ltx2(
video: list[Image.Image],
audio: torch.Tensor | None,
output_path: str,
fps: int = 24,
audio_sample_rate: int | None = 24000,
) -> None:
width, height = video[0].size
container = av.open(output_path, mode="w")
stream = container.add_stream("libx264", rate=int(fps))
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
if audio is not None:
if audio_sample_rate is None:
raise ValueError("audio_sample_rate is required when audio is provided")
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
for frame in tqdm(video, total=len(video)):
frame = av.VideoFrame.from_image(frame)
for packet in stream.encode(frame):
container.mux(packet)
# Flush encoder
for packet in stream.encode():
container.mux(packet)
if audio is not None:
_write_audio(container, audio_stream, audio, audio_sample_rate)
container.close()
def encode_single_frame(output_file: str, image_array: np.ndarray, crf: float) -> None:
container = av.open(output_file, "w", format="mp4")
try:
stream = container.add_stream("libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"})
# Round to nearest multiple of 2 for compatibility with video codecs
height = image_array.shape[0] // 2 * 2
width = image_array.shape[1] // 2 * 2
image_array = image_array[:height, :width]
stream.height = height
stream.width = width
av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(format="yuv420p")
container.mux(stream.encode(av_frame))
container.mux(stream.encode())
finally:
container.close()
def decode_single_frame(video_file: str) -> np.array:
container = av.open(video_file)
try:
stream = next(s for s in container.streams if s.type == "video")
frame = next(container.decode(stream))
finally:
container.close()
return frame.to_ndarray(format="rgb24")
def ltx2_preprocess(image: np.array, crf: float = 33) -> np.array:
if crf == 0:
return image
with BytesIO() as output_file:
encode_single_frame(output_file, image, crf)
video_bytes = output_file.getvalue()
with BytesIO(video_bytes) as video_file:
image_array = decode_single_frame(video_file)
return image_array

View File

@@ -149,6 +149,8 @@ class FluxLoRALoader(GeneralLoRALoader):
dtype=state_dict_[name].dtype)
else:
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
mlp = mlp.to(device=state_dict_[name].device)
if 'lora_A' in name:
param = torch.concat([
state_dict_.pop(name),

View File

@@ -1,4 +1,4 @@
import torch
import torch, warnings
class GeneralLoRALoader:
@@ -26,7 +26,11 @@ class GeneralLoRALoader:
keys.pop(0)
keys.pop(-1)
target_name = ".".join(keys)
lora_name_dict[target_name] = (key, key.replace(lora_B_key, lora_A_key))
# Alpha: Deprecated but retained for compatibility.
key_alpha = key.replace(lora_B_key + ".weight", "alpha").replace(lora_B_key + ".default.weight", "alpha")
if key_alpha == key or key_alpha not in lora_state_dict:
key_alpha = None
lora_name_dict[target_name] = (key, key.replace(lora_B_key, lora_A_key), key_alpha)
return lora_name_dict
@@ -36,6 +40,10 @@ class GeneralLoRALoader:
for name in name_dict:
weight_up = state_dict[name_dict[name][0]]
weight_down = state_dict[name_dict[name][1]]
if name_dict[name][2] is not None:
warnings.warn("Alpha detected in the LoRA file. This may be a LoRA model not trained by DiffSynth-Studio. To ensure compatibility, the LoRA weights will be converted to weight * alpha / rank.")
alpha = state_dict[name_dict[name][2]] / weight_down.shape[0]
weight_down = weight_down * alpha
state_dict_[name + f".lora_B{suffix}"] = weight_up
state_dict_[name + f".lora_A{suffix}"] = weight_down
return state_dict_

View File

@@ -0,0 +1,6 @@
def AnimaDiTStateDictConverter(state_dict):
new_state_dict = {}
for key in state_dict:
value = state_dict[key]
new_state_dict[key.replace("net.", "")] = value
return new_state_dict

View File

@@ -89,4 +89,109 @@ def FluxDiTStateDictConverter(state_dict):
state_dict_[rename] = state_dict[original_name]
else:
pass
return state_dict_
def FluxDiTStateDictConverterFromDiffusers(state_dict):
global_rename_dict = {
"context_embedder": "context_embedder",
"x_embedder": "x_embedder",
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
"norm_out.linear": "final_norm_out.linear",
"proj_out": "final_proj_out",
}
rename_dict = {
"proj_out": "proj_out",
"norm1.linear": "norm1_a.linear",
"norm1_context.linear": "norm1_b.linear",
"attn.to_q": "attn.a_to_q",
"attn.to_k": "attn.a_to_k",
"attn.to_v": "attn.a_to_v",
"attn.to_out.0": "attn.a_to_out",
"attn.add_q_proj": "attn.b_to_q",
"attn.add_k_proj": "attn.b_to_k",
"attn.add_v_proj": "attn.b_to_v",
"attn.to_add_out": "attn.b_to_out",
"ff.net.0.proj": "ff_a.0",
"ff.net.2": "ff_a.2",
"ff_context.net.0.proj": "ff_b.0",
"ff_context.net.2": "ff_b.2",
"attn.norm_q": "attn.norm_q_a",
"attn.norm_k": "attn.norm_k_a",
"attn.norm_added_q": "attn.norm_q_b",
"attn.norm_added_k": "attn.norm_k_b",
}
rename_dict_single = {
"attn.to_q": "a_to_q",
"attn.to_k": "a_to_k",
"attn.to_v": "a_to_v",
"attn.norm_q": "norm_q_a",
"attn.norm_k": "norm_k_a",
"norm.linear": "norm.linear",
"proj_mlp": "proj_in_besides_attn",
"proj_out": "proj_out",
}
state_dict_ = {}
for name in state_dict:
param = state_dict[name]
if name.endswith(".weight") or name.endswith(".bias"):
suffix = ".weight" if name.endswith(".weight") else ".bias"
prefix = name[:-len(suffix)]
if prefix in global_rename_dict:
if global_rename_dict[prefix] == "final_norm_out.linear":
param = torch.concat([param[3072:], param[:3072]], dim=0)
state_dict_[global_rename_dict[prefix] + suffix] = param
elif prefix.startswith("transformer_blocks."):
names = prefix.split(".")
names[0] = "blocks"
middle = ".".join(names[2:])
if middle in rename_dict:
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
state_dict_[name_] = param
elif prefix.startswith("single_transformer_blocks."):
names = prefix.split(".")
names[0] = "single_blocks"
middle = ".".join(names[2:])
if middle in rename_dict_single:
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
state_dict_[name_] = param
else:
pass
else:
pass
for name in list(state_dict_.keys()):
if "single_blocks." in name and ".a_to_q." in name:
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
if mlp is None:
mlp = torch.zeros(4 * state_dict_[name].shape[0],
*state_dict_[name].shape[1:],
dtype=state_dict_[name].dtype)
else:
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
param = torch.concat([
state_dict_.pop(name),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
mlp,
], dim=0)
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
state_dict_[name_] = param
for name in list(state_dict_.keys()):
for component in ["a", "b"]:
if f".{component}_to_q." in name:
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
return state_dict_

View File

@@ -0,0 +1,32 @@
def LTX2AudioEncoderStateDictConverter(state_dict):
# Not used
state_dict_ = {}
for name in state_dict:
if name.startswith("audio_vae.encoder."):
new_name = name.replace("audio_vae.encoder.", "")
state_dict_[new_name] = state_dict[name]
elif name.startswith("audio_vae.per_channel_statistics."):
new_name = name.replace("audio_vae.per_channel_statistics.", "per_channel_statistics.")
state_dict_[new_name] = state_dict[name]
return state_dict_
def LTX2AudioDecoderStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith("audio_vae.decoder."):
new_name = name.replace("audio_vae.decoder.", "")
state_dict_[new_name] = state_dict[name]
elif name.startswith("audio_vae.per_channel_statistics."):
new_name = name.replace("audio_vae.per_channel_statistics.", "per_channel_statistics.")
state_dict_[new_name] = state_dict[name]
return state_dict_
def LTX2VocoderStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith("vocoder."):
new_name = name.replace("vocoder.", "")
state_dict_[new_name] = state_dict[name]
return state_dict_

View File

@@ -0,0 +1,9 @@
def LTXModelStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith("model.diffusion_model."):
new_name = name.replace("model.diffusion_model.", "")
if new_name.startswith("audio_embeddings_connector.") or new_name.startswith("video_embeddings_connector."):
continue
state_dict_[new_name] = state_dict[name]
return state_dict_

View File

@@ -0,0 +1,31 @@
def LTX2TextEncoderStateDictConverter(state_dict):
state_dict_ = {}
for key in state_dict:
if key.startswith("language_model.model."):
new_key = key.replace("language_model.model.", "model.language_model.")
elif key.startswith("vision_tower."):
new_key = key.replace("vision_tower.", "model.vision_tower.")
elif key.startswith("multi_modal_projector."):
new_key = key.replace("multi_modal_projector.", "model.multi_modal_projector.")
elif key.startswith("language_model.lm_head."):
new_key = key.replace("language_model.lm_head.", "lm_head.")
else:
continue
state_dict_[new_key] = state_dict[key]
state_dict_["lm_head.weight"] = state_dict_.get("model.language_model.embed_tokens.weight")
return state_dict_
def LTX2TextEncoderPostModulesStateDictConverter(state_dict):
state_dict_ = {}
for key in state_dict:
if key.startswith("text_embedding_projection."):
new_key = key.replace("text_embedding_projection.", "feature_extractor_linear.")
elif key.startswith("model.diffusion_model.video_embeddings_connector."):
new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "embeddings_connector.")
elif key.startswith("model.diffusion_model.audio_embeddings_connector."):
new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "audio_embeddings_connector.")
else:
continue
state_dict_[new_key] = state_dict[key]
return state_dict_

View File

@@ -0,0 +1,22 @@
def LTX2VideoEncoderStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith("vae.encoder."):
new_name = name.replace("vae.encoder.", "")
state_dict_[new_name] = state_dict[name]
elif name.startswith("vae.per_channel_statistics."):
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
state_dict_[new_name] = state_dict[name]
return state_dict_
def LTX2VideoDecoderStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith("vae.decoder."):
new_name = name.replace("vae.decoder.", "")
state_dict_[new_name] = state_dict[name]
elif name.startswith("vae.per_channel_statistics."):
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
state_dict_[new_name] = state_dict[name]
return state_dict_

View File

@@ -0,0 +1,6 @@
def ZImageTextEncoderStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name != "lm_head.weight":
state_dict_[name] = state_dict[name]
return state_dict_

View File

@@ -1,11 +1,15 @@
import torch
from typing import Optional
from einops import rearrange
from yunchang.kernels import AttnType
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from ... import IS_NPU_AVAILABLE
from ...core.device import parse_nccl_backend, parse_device_type
from ...core.gradient import gradient_checkpoint_forward
def initialize_usp(device_type):
@@ -30,13 +34,16 @@ def sinusoidal_embedding_1d(dim, position):
def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len
original_tensor_device = original_tensor.device
if original_tensor.device == "npu":
original_tensor = original_tensor.cpu()
padding_tensor = torch.ones(
pad_size,
s1,
s2,
dtype=original_tensor.dtype,
device=original_tensor.device)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0).to(device=original_tensor_device)
return padded_tensor
def rope_apply(x, freqs, num_heads):
@@ -50,7 +57,7 @@ def rope_apply(x, freqs, num_heads):
sp_rank = get_sequence_parallel_rank()
freqs = pad_freqs(freqs, s_per_rank * sp_size)
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device.type == "npu" else freqs_rank
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
return x_out.to(x.dtype)
@@ -81,11 +88,6 @@ def usp_dit_forward(self,
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# Context Parallel
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
@@ -94,20 +96,13 @@ def usp_dit_forward(self,
x = chunks[get_sequence_parallel_rank()]
for block in self.blocks:
if self.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
if self.training:
x = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x, context, t_mod, freqs
)
else:
x = block(x, context, t_mod, freqs)
@@ -133,7 +128,12 @@ def usp_attn_forward(self, x, freqs):
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
x = xFuserLongContextAttention()(
attn_type = AttnType.FA
ring_impl_type = "basic"
if IS_NPU_AVAILABLE:
attn_type = AttnType.NPU
ring_impl_type = "basic_npu"
x = xFuserLongContextAttention(attn_type=attn_type, ring_impl_type=ring_impl_type)(
None,
query=q,
key=k,

5
diffsynth/version.py Normal file
View File

@@ -0,0 +1,5 @@
# Make sure to modify __release_datetime__ to release time when making official release.
__version__ = '2.0.0'
# default release datetime for branches under active development is set
# to be a time far-far-away-into-the-future
__release_datetime__ = '2099-10-13 08:56:12'

28
docs/en/.readthedocs.yaml Normal file
View File

@@ -0,0 +1,28 @@
# .readthedocs.yaml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
# Required
version: 2
# Set the OS, Python version and other tools you might need
build:
os: ubuntu-22.04
tools:
python: "3.10"
# Build documentation in the "docs/" directory with Sphinx
sphinx:
configuration: docs/en/conf.py
# Optionally build your docs in additional formats such as PDF and ePub
# formats:
# - pdf
# - epub
# Optional but recommended, declare the Python requirements required
# to build your documentation
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
python:
install:
- requirements: docs/requirements.txt

View File

@@ -1,6 +1,6 @@
# `diffsynth.core.attention`: Attention Mechanism Implementation
`diffsynth.core.attention` provides routing mechanisms for attention mechanism implementations, automatically selecting efficient attention implementations based on available packages in the `Python` environment and [environment variables](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation).
`diffsynth.core.attention` provides routing mechanisms for attention mechanism implementations, automatically selecting efficient attention implementations based on available packages in the `Python` environment and [environment variables](../../Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation).
## Attention Mechanism
@@ -46,7 +46,7 @@ Note that the dimension of the Attention Score in the attention mechanism ( $\te
* xFormers: [GitHub](https://github.com/facebookresearch/xformers), [Documentation](https://facebookresearch.github.io/xformers/components/ops.html#module-xformers.ops)
* PyTorch: [GitHub](https://github.com/pytorch/pytorch), [Documentation](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
To call attention implementations other than `PyTorch`, please follow the instructions on their GitHub pages to install the corresponding packages. `DiffSynth-Studio` will automatically route to the corresponding implementation based on available packages in the Python environment, or can be controlled through [environment variables](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation).
To call attention implementations other than `PyTorch`, please follow the instructions on their GitHub pages to install the corresponding packages. `DiffSynth-Studio` will automatically route to the corresponding implementation based on available packages in the Python environment, or can be controlled through [environment variables](../../Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation).
```python
from diffsynth.core.attention import attention_forward

View File

@@ -8,9 +8,9 @@ This document introduces the model download and loading functionalities in `diff
### Downloading and Loading Models from Remote Sources
Taking the model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) as an example, after filling in `model_id` and `origin_file_pattern` in `ModelConfig`, the model can be automatically downloaded. By default, it downloads to the `./models` path, which can be modified through the [environment variable DIFFSYNTH_MODEL_BASE_PATH](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path).
Taking the model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) as an example, after filling in `model_id` and `origin_file_pattern` in `ModelConfig`, the model can be automatically downloaded. By default, it downloads to the `./models` path, which can be modified through the [environment variable DIFFSYNTH_MODEL_BASE_PATH](../../Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path).
By default, even if the model has already been downloaded, the program will still query the remote for any missing files. To completely disable remote requests, set the [environment variable DIFFSYNTH_SKIP_DOWNLOAD](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) to `True`.
By default, even if the model has already been downloaded, the program will still query the remote for any missing files. To completely disable remote requests, set the [environment variable DIFFSYNTH_SKIP_DOWNLOAD](../../Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) to `True`.
```python
from diffsynth.core import ModelConfig
@@ -51,7 +51,7 @@ config = ModelConfig(path=[
### VRAM Management Configuration
`ModelConfig` also contains VRAM management configuration information. See [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md#more-usage-methods) for details.
`ModelConfig` also contains VRAM management configuration information. See [VRAM Management](../../Pipeline_Usage/VRAM_management.md#more-usage-methods) for details.
## Model File Loading
@@ -103,11 +103,11 @@ print(hash_model_file([
The model hash value is only related to the keys and tensor shapes in the state dict of the model file, and is unrelated to the numerical values of the model parameters, file saving time, and other information. When calculating the model hash value of `.safetensors` format files, `hash_model_file` is almost instantly completed without reading the model parameters. However, when calculating the model hash value of `.bin`, `.pth`, `.ckpt`, and other binary files, all model parameters need to be read, so **we do not recommend developers to continue using these formats of files.**
By [writing model Config](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-3-writing-model-config) and filling in model hash value and other information into `diffsynth/configs/model_configs.py`, developers can let `DiffSynth-Studio` automatically identify the model type and load it.
By [writing model Config](../../Developer_Guide/Integrating_Your_Model.md#step-3-writing-model-config) and filling in model hash value and other information into `diffsynth/configs/model_configs.py`, developers can let `DiffSynth-Studio` automatically identify the model type and load it.
## Model Loading
`load_model` is the external entry for loading models in `diffsynth.core.loader`. It will call [skip_model_initialization](/docs/en/API_Reference/core/vram.md#skipping-model-parameter-initialization) to skip model parameter initialization. If [Disk Offload](/docs/en/Pipeline_Usage/VRAM_management.md#disk-offload) is enabled, it calls [DiskMap](/docs/en/API_Reference/core/vram.md#state-dict-disk-mapping) for lazy loading. If Disk Offload is not enabled, it calls [load_state_dict](#model-file-loading) to load model parameters. If necessary, it will also call [state dict converter](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) for model format conversion. Finally, it calls `model.eval()` to switch to inference mode.
`load_model` is the external entry for loading models in `diffsynth.core.loader`. It will call [skip_model_initialization](../../API_Reference/core/vram.md#skipping-model-parameter-initialization) to skip model parameter initialization. If [Disk Offload](../../Pipeline_Usage/VRAM_management.md#disk-offload) is enabled, it calls [DiskMap](../../API_Reference/core/vram.md#state-dict-disk-mapping) for lazy loading. If Disk Offload is not enabled, it calls [load_state_dict](#model-file-loading) to load model parameters. If necessary, it will also call [state dict converter](../../Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) for model format conversion. Finally, it calls `model.eval()` to switch to inference mode.
Here is a usage example with Disk Offload enabled:

View File

@@ -31,7 +31,7 @@ state_dict = load_state_dict(path, device="cpu")
model.load_state_dict(state_dict, assign=True)
```
In `DiffSynth-Studio`, all pretrained models follow this loading logic. After developers [integrate models](/docs/en/Developer_Guide/Integrating_Your_Model.md), they can directly load models quickly using this approach.
In `DiffSynth-Studio`, all pretrained models follow this loading logic. After developers [integrate models](../../Developer_Guide/Integrating_Your_Model.md), they can directly load models quickly using this approach.
## State Dict Disk Mapping
@@ -57,10 +57,10 @@ state_dict = DiskMap(path, device="cpu") # Fast
print(state_dict["img_in.weight"])
```
`DiskMap` is the basic component of Disk Offload in `DiffSynth-Studio`. After developers [configure fine-grained VRAM management schemes](/docs/en/Developer_Guide/Enabling_VRAM_management.md), they can directly enable Disk Offload.
`DiskMap` is the basic component of Disk Offload in `DiffSynth-Studio`. After developers [configure fine-grained VRAM management schemes](../../Developer_Guide/Enabling_VRAM_management.md), they can directly enable Disk Offload.
`DiskMap` is a functionality implemented using the characteristics of `.safetensors` files. Therefore, when using `.bin`, `.pth`, `.ckpt`, and other binary files, model parameters are fully loaded, which causes Disk Offload to not support these formats of files. **We do not recommend developers to continue using these formats of files.**
## Replacable Modules for VRAM Management
When `DiffSynth-Studio`'s VRAM management is enabled, the modules inside the model will be replaced with replacable modules in `diffsynth.core.vram.layers`. For usage, see [Fine-grained VRAM Management Scheme](/docs/en/Developer_Guide/Enabling_VRAM_management.md#writing-fine-grained-vram-management-schemes).
When `DiffSynth-Studio`'s VRAM management is enabled, the modules inside the model will be replaced with replacable modules in `diffsynth.core.vram.layers`. For usage, see [Fine-grained VRAM Management Scheme](../../Developer_Guide/Enabling_VRAM_management.md#writing-fine-grained-vram-management-schemes).

View File

@@ -1,6 +1,6 @@
# Building a Pipeline
After [integrating the required models for the Pipeline](/docs/en/Developer_Guide/Integrating_Your_Model.md), you also need to build a `Pipeline` for model inference. This document provides a standardized process for building a `Pipeline`. Developers can also refer to existing `Pipeline` implementations for construction.
After [integrating the required models for the Pipeline](../Developer_Guide/Integrating_Your_Model.md), you also need to build a `Pipeline` for model inference. This document provides a standardized process for building a `Pipeline`. Developers can also refer to existing `Pipeline` implementations for construction.
The `Pipeline` implementation is located in `diffsynth/pipelines`. Each `Pipeline` contains the following essential key components:
@@ -79,7 +79,7 @@ This includes the following parts:
return pipe
```
Developers need to implement the logic for fetching models. The corresponding model names are the `"model_name"` in the [model Config filled in during model integration](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-3-writing-model-config).
Developers need to implement the logic for fetching models. The corresponding model names are the `"model_name"` in the [model Config filled in during model integration](../Developer_Guide/Integrating_Your_Model.md#step-3-writing-model-config).
Some models also need to load `tokenizer`. Extra `tokenizer_config` parameters can be added to `from_pretrained` as needed, and this part can be implemented after fetching the models.

View File

@@ -1,6 +1,6 @@
# Fine-Grained VRAM Management Scheme
This document introduces how to write reasonable fine-grained VRAM management schemes for models, and how to use the VRAM management functions in `DiffSynth-Studio` for other external code libraries. Before reading this document, please read the document [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md).
This document introduces how to write reasonable fine-grained VRAM management schemes for models, and how to use the VRAM management functions in `DiffSynth-Studio` for other external code libraries. Before reading this document, please read the document [VRAM Management](../Pipeline_Usage/VRAM_management.md).
## How Much VRAM Does a 20B Model Need?
@@ -124,7 +124,7 @@ module_map={
}
```
In addition, `vram_config` and `vram_limit` are also required, which have been introduced in [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md#more-usage-methods).
In addition, `vram_config` and `vram_limit` are also required, which have been introduced in [VRAM Management](../Pipeline_Usage/VRAM_management.md#more-usage-methods).
Call `enable_vram_management` to enable VRAM management. Note that the `device` when loading the model is `cpu`, consistent with `offload_device`:
@@ -171,7 +171,7 @@ The above code only requires 2G VRAM to run the `forward` of a 20B model.
## Disk Offload
[Disk Offload](/docs/en/Pipeline_Usage/VRAM_management.md#disk-offload) is a special VRAM management scheme that needs to be enabled during the model loading process, not after the model is loaded. Usually, when the above code can run smoothly, Disk Offload can be directly enabled:
[Disk Offload](../Pipeline_Usage/VRAM_management.md#disk-offload) is a special VRAM management scheme that needs to be enabled during the model loading process, not after the model is loaded. Usually, when the above code can run smoothly, Disk Offload can be directly enabled:
```python
from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
@@ -212,7 +212,7 @@ with torch.no_grad():
output = model(**inputs)
```
Disk Offload is an extremely special VRAM management scheme. It only supports `.safetensors` format files, not binary files such as `.bin`, `.pth`, `.ckpt`, and does not support [state dict converter](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape.
Disk Offload is an extremely special VRAM management scheme. It only supports `.safetensors` format files, not binary files such as `.bin`, `.pth`, `.ckpt`, and does not support [state dict converter](../Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape.
If there are situations where Disk Offload cannot run normally but non-Disk Offload can run normally, please submit an issue to us on GitHub.
@@ -227,7 +227,7 @@ To make it easier for users to use the VRAM management function, we write the fi
}
```# Fine-Grained VRAM Management Scheme
This document introduces how to write reasonable fine-grained VRAM management schemes for models, and how to use the VRAM management functions in `DiffSynth-Studio` for other external code libraries. Before reading this document, please read the document [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md).
This document introduces how to write reasonable fine-grained VRAM management schemes for models, and how to use the VRAM management functions in `DiffSynth-Studio` for other external code libraries. Before reading this document, please read the document [VRAM Management](../Pipeline_Usage/VRAM_management.md).
## How Much VRAM Does a 20B Model Need?
@@ -351,7 +351,7 @@ module_map={
}
```
In addition, `vram_config` and `vram_limit` are also required, which have been introduced in [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md#more-usage-methods).
In addition, `vram_config` and `vram_limit` are also required, which have been introduced in [VRAM Management](../Pipeline_Usage/VRAM_management.md#more-usage-methods).
Call `enable_vram_management` to enable VRAM management. Note that the `device` when loading the model is `cpu`, consistent with `offload_device`:
@@ -398,7 +398,7 @@ The above code only requires 2G VRAM to run the `forward` of a 20B model.
## Disk Offload
[Disk Offload](/docs/en/Pipeline_Usage/VRAM_management.md#disk-offload) is a special VRAM management scheme that needs to be enabled during the model loading process, not after the model is loaded. Usually, when the above code can run smoothly, Disk Offload can be directly enabled:
[Disk Offload](../Pipeline_Usage/VRAM_management.md#disk-offload) is a special VRAM management scheme that needs to be enabled during the model loading process, not after the model is loaded. Usually, when the above code can run smoothly, Disk Offload can be directly enabled:
```python
from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
@@ -439,7 +439,7 @@ with torch.no_grad():
output = model(**inputs)
```
Disk Offload is an extremely special VRAM management scheme. It only supports `.safetensors` format files, not binary files such as `.bin`, `.pth`, `.ckpt`, and does not support [state dict converter](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape.
Disk Offload is an extremely special VRAM management scheme. It only supports `.safetensors` format files, not binary files such as `.bin`, `.pth`, `.ckpt`, and does not support [state dict converter](../Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape.
If there are situations where Disk Offload cannot run normally but non-Disk Offload can run normally, please submit an issue to us on GitHub.

View File

@@ -183,4 +183,4 @@ Loaded model: {
## Step 5: Writing Model VRAM Management Scheme
`DiffSynth-Studio` supports complex VRAM management. See [Enabling VRAM Management](/docs/en/Developer_Guide/Enabling_VRAM_management.md) for details.
`DiffSynth-Studio` supports complex VRAM management. See [Enabling VRAM Management](../Developer_Guide/Enabling_VRAM_management.md) for details.

View File

@@ -1,6 +1,6 @@
# Integrating Model Training
After [integrating models](/docs/en/Developer_Guide/Integrating_Your_Model.md) and [implementing Pipeline](/docs/en/Developer_Guide/Building_a_Pipeline.md), the next step is to integrate model training functionality.
After [integrating models](../Developer_Guide/Integrating_Your_Model.md) and [implementing Pipeline](../Developer_Guide/Building_a_Pipeline.md), the next step is to integrate model training functionality.
## Training-Inference Consistent Pipeline Modification

20
docs/en/Makefile Normal file
View File

@@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

View File

@@ -0,0 +1,139 @@
# Anima
Anima is an image generation model trained and open-sourced by CircleStone Labs and Comfy Org.
## Installation
Before using this project for model inference and training, please install DiffSynth-Studio first.
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
For more installation information, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md).
## Quick Start
The following code demonstrates how to quickly load the [circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima) model for inference. VRAM management is enabled by default, allowing the framework to automatically control model parameter loading based on available VRAM. Minimum 8GB VRAM required.
```python
from diffsynth.pipelines.anima_image import AnimaImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": "disk",
"onload_device": "disk",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AnimaImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors", **vram_config),
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/text_encoders/qwen_3_06b_base.safetensors", **vram_config),
ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
tokenizer_t5xxl_config=ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "Masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait."
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
image = pipe(prompt, seed=0, num_inference_steps=50)
image.save("image.jpg")
```
## Model Overview
|Model ID|Inference|Low VRAM Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|
|-|-|-|-|-|-|-|
|[circlestone-labs/Anima](https://www.modelscope.cn/models/circlestone-labs/Anima)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_inference_low_vram/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/full/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_full/anima-preview.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/lora/anima-preview.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/validate_lora/anima-preview.py)|
Special training scripts:
* Differential LoRA Training: [doc](../Training/Differential_LoRA.md)
* FP8 Precision Training: [doc](../Training/FP8_Precision.md)
* Two-Stage Split Training: [doc](../Training/Split_Training.md)
* End-to-End Direct Distillation: [doc](../Training/Direct_Distill.md)
## Model Inference
Models are loaded through `AnimaImagePipeline.from_pretrained`, see [Model Inference](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
Input parameters for `AnimaImagePipeline` inference include:
* `prompt`: Text description of the desired image content.
* `negative_prompt`: Content to exclude from the generated image (default: `""`).
* `cfg_scale`: Classifier-free guidance parameter (default: 4.0).
* `input_image`: Input image for image-to-image generation (default: `None`).
* `denoising_strength`: Controls similarity to input image (default: 1.0).
* `height`: Image height (must be multiple of 16, default: 1024).
* `width`: Image width (must be multiple of 16, default: 1024).
* `seed`: Random seed (default: `None`).
* `rand_device`: Device for random noise generation (default: `"cpu"`).
* `num_inference_steps`: Inference steps (default: 30).
* `sigma_shift`: Scheduler sigma offset (default: `None`).
* `progress_bar_cmd`: Progress bar implementation (default: `tqdm.tqdm`).
For VRAM constraints, enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). Recommended low-VRAM configurations are provided in the "Model Overview" table above.
## Model Training
Anima models are trained through [`examples/anima/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/anima/model_training/train.py) with parameters including:
* General Training Parameters
* Dataset Configuration
* `--dataset_base_path`: Dataset root directory.
* `--dataset_metadata_path`: Metadata file path.
* `--dataset_repeat`: Dataset repetition per epoch.
* `--dataset_num_workers`: Dataloader worker count.
* `--data_file_keys`: Metadata fields to load (comma-separated).
* Model Loading
* `--model_paths`: Model paths (JSON format).
* `--model_id_with_origin_paths`: Model IDs with origin paths (e.g., `"anima-team/anima-1B:text_encoder/*.safetensors"`).
* `--extra_inputs`: Additional pipeline inputs (e.g., `controlnet_inputs` for ControlNet).
* `--fp8_models`: FP8-formatted models (same format as `--model_paths`).
* Training Configuration
* `--learning_rate`: Learning rate.
* `--num_epochs`: Training epochs.
* `--trainable_models`: Trainable components (e.g., `dit`, `vae`, `text_encoder`).
* `--find_unused_parameters`: Handle unused parameters in DDP training.
* `--weight_decay`: Weight decay value.
* `--task`: Training task (default: `sft`).
* Output Configuration
* `--output_path`: Model output directory.
* `--remove_prefix_in_ckpt`: Remove state dict prefixes.
* `--save_steps`: Model saving interval.
* LoRA Configuration
* `--lora_base_model`: Target model for LoRA.
* `--lora_target_modules`: Target modules for LoRA.
* `--lora_rank`: LoRA rank.
* `--lora_checkpoint`: LoRA checkpoint path.
* `--preset_lora_path`: Preloaded LoRA checkpoint path.
* `--preset_lora_model`: Model to merge LoRA with (e.g., `dit`).
* Gradient Configuration
* `--use_gradient_checkpointing`: Enable gradient checkpointing.
* `--use_gradient_checkpointing_offload`: Offload checkpointing to CPU.
* `--gradient_accumulation_steps`: Gradient accumulation steps.
* Image Resolution
* `--height`: Image height (empty for dynamic resolution).
* `--width`: Image width (empty for dynamic resolution).
* `--max_pixels`: Maximum pixel area for dynamic resolution.
* Anima-Specific Parameters
* `--tokenizer_path`: Tokenizer path for text-to-image models.
* `--tokenizer_t5xxl_path`: T5-XXL tokenizer path.
We provide a sample image dataset for testing:
```shell
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
```
For training script details, refer to [Model Training](../Pipeline_Usage/Model_Training.md). For advanced training techniques, see [Training Framework Documentation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/).

View File

@@ -14,7 +14,7 @@ cd DiffSynth-Studio
pip install -e .
```
For more information about installation, please refer to [Install Dependencies](/docs/en/Pipeline_Usage/Setup.md).
For more information about installation, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md).
## Quick Start
@@ -81,31 +81,31 @@ graph LR;
| Model ID | Extra Parameters | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
| - | - | - | - | - | - | - | - |
| [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) | | [code](/examples/flux/model_inference/FLUX.1-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py) |
| [black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) | | [code](/examples/flux/model_inference/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py) |
| [black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) | `kontext_images` | [code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py) |
| [alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py) |
| [InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py) |
| [jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py) |
| [InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter) | `ipadapter_images`, `ipadapter_scale` | [code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py) |
| [ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou) | `infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py) |
| [DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) | `eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint` | [code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py) | - | - | [code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py) |
| [DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev) | `lora_encoder_inputs`, `lora_encoder_scale` | [code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py) | - | - |
| [DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev) | | [code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py) | - | - | - | - | - |
| [stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit) | `step1x_reference_image` | [code](/examples/flux/model_inference/Step1X-Edit.py) | [code](/examples/flux/model_inference_low_vram/Step1X-Edit.py) | [code](/examples/flux/model_training/full/Step1X-Edit.sh) | [code](/examples/flux/model_training/validate_full/Step1X-Edit.py) | [code](/examples/flux/model_training/lora/Step1X-Edit.sh) | [code](/examples/flux/model_training/validate_lora/Step1X-Edit.py) |
| [ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview) | `flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop` | [code](/examples/flux/model_inference/FLEX.2-preview.py) | [code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py) | [code](/examples/flux/model_training/full/FLEX.2-preview.sh) | [code](/examples/flux/model_training/validate_full/FLEX.2-preview.py) | [code](/examples/flux/model_training/lora/FLEX.2-preview.sh) | [code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py) |
| [DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2) | `nexus_gen_reference_image` | [code](/examples/flux/model_inference/Nexus-Gen-Editing.py) | [code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py) | [code](/examples/flux/model_training/full/Nexus-Gen.sh) | [code](/examples/flux/model_training/validate_full/Nexus-Gen.py) | [code](/examples/flux/model_training/lora/Nexus-Gen.sh) | [code](/examples/flux/model_training/validate_lora/Nexus-Gen.py) |
| [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev.py) |
| [black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Krea-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py) |
| [black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) | `kontext_images` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py) |
| [alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py) |
| [InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py) |
| [jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py) |
| [InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter) | `ipadapter_images`, `ipadapter_scale` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py) |
| [ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou) | `infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py) |
| [DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) | `eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-EliGen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py) |
| [DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev) | `lora_encoder_inputs`, `lora_encoder_scale` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py) | - | - |
| [DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py) | - | - | - | - | - |
| [stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit) | `step1x_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Step1X-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Step1X-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Step1X-Edit.py) |
| [ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview) | `flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLEX.2-preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLEX.2-preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLEX.2-preview.py) |
| [DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2) | `nexus_gen_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Nexus-Gen-Editing.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Nexus-Gen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Nexus-Gen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Nexus-Gen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Nexus-Gen.py) |
Special Training Scripts:
* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md), [code](/examples/flux/model_training/special/differential_training/)
* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/flux/model_training/special/fp8_training/)
* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/flux/model_training/special/split_training/)
* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh)
* Differential LoRA Training: [doc](../Training/Differential_LoRA.md)
* FP8 Precision Training: [doc](../Training/FP8_Precision.md)
* Two-stage Split Training: [doc](../Training/Split_Training.md)
* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md)
## Model Inference
Models are loaded via `FluxImagePipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models).
Models are loaded via `FluxImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models).
Input parameters for `FluxImagePipeline` inference include:
@@ -143,11 +143,11 @@ Input parameters for `FluxImagePipeline` inference include:
* `flex_control_stop`: Flex model control stop timestep.
* `nexus_gen_reference_image`: Nexus-Gen model reference image.
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
## Model Training
FLUX series models are uniformly trained through [`examples/flux/model_training/train.py`](/examples/flux/model_training/train.py), and the script parameters include:
FLUX series models are uniformly trained through [`examples/flux/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/train.py), and the script parameters include:
* General Training Parameters
* Dataset Basic Configuration
@@ -198,4 +198,4 @@ We have built a sample image dataset for your testing. You can download this dat
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
```
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/).
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).

View File

@@ -2,6 +2,15 @@
FLUX.2 is an image generation model trained and open-sourced by Black Forest Labs.
## Model Lineage
```mermaid
graph LR;
FLUX.2-Series-->black-forest-labs/FLUX.2-dev;
FLUX.2-Series-->black-forest-labs/FLUX.2-klein-4B;
FLUX.2-Series-->black-forest-labs/FLUX.2-klein-9B;
```
## Installation
Before using this project for model inference and training, please install DiffSynth-Studio first.
@@ -12,7 +21,7 @@ cd DiffSynth-Studio
pip install -e .
```
For more information about installation, please refer to [Install Dependencies](/docs/en/Pipeline_Usage/Setup.md).
For more information about installation, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md).
## Quick Start
@@ -50,20 +59,24 @@ image.save("image.jpg")
## Model Overview
| Model ID | Inference | Low VRAM Inference | LoRA Training | Validation After LoRA Training |
| - | - | - | - | - |
| [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) | [code](/examples/flux2/model_inference/FLUX.2-dev.py) | [code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py) | [code](/examples/flux2/model_training/lora/FLUX.2-dev.sh) | [code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py) |
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
| - | - | - | - | - | - | - |
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-dev.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
Special Training Scripts:
* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md), [code](/examples/flux/model_training/special/differential_training/)
* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/flux/model_training/special/fp8_training/)
* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/flux/model_training/special/split_training/)
* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh)
* Differential LoRA Training: [doc](../Training/Differential_LoRA.md)
* FP8 Precision Training: [doc](../Training/FP8_Precision.md)
* Two-stage Split Training: [doc](../Training/Split_Training.md)
* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md)
## Model Inference
Models are loaded via `Flux2ImagePipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models).
Models are loaded via `Flux2ImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models).
Input parameters for `Flux2ImagePipeline` inference include:
@@ -82,11 +95,11 @@ Input parameters for `Flux2ImagePipeline` inference include:
* `tile_stride`: Tile stride during VAE encoding/decoding stages, default is 64, only effective when `tiled=True`, must be less than or equal to `tile_size`.
* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be disabled by setting to `lambda x:x`.
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
## Model Training
FLUX.2 series models are uniformly trained through [`examples/flux2/model_training/train.py`](/examples/flux2/model_training/train.py), and the script parameters include:
FLUX.2 series models are uniformly trained through [`examples/flux2/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/train.py), and the script parameters include:
* General Training Parameters
* Dataset Basic Configuration
@@ -135,4 +148,4 @@ We have built a sample image dataset for your testing. You can download this dat
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
```
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/).
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).

View File

@@ -0,0 +1,213 @@
# LTX-2
LTX-2 is a series of audio-video generation models developed by Lightricks.
## Installation
Before using this project for model inference and training, please install DiffSynth-Studio first.
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
For more information about installation, please refer to [Installation Dependencies](../Pipeline_Usage/Setup.md).
## Quick Start
Run the following code to quickly load the [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) model and perform inference. VRAM management has been enabled, and the framework will automatically control model parameter loading based on remaining VRAM. It can run with a minimum of 8GB VRAM.
```python
import torch
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
vram_config = {
"offload_dtype": torch.float8_e5m2,
"offload_device": "cpu",
"onload_dtype": torch.float8_e5m2,
"onload_device": "cpu",
"preparing_dtype": torch.float8_e5m2,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
"""
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
and avoid redundant memory usage when users only want to use part of the model.
"""
# use the repackaged modelconfig from "DiffSynth-Studio/LTX-2-Repackage" to avoid redundant model loading
pipe = LTX2AudioVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
# use the following modelconfig if you want to initialize model from offical checkpoints from "Lightricks/LTX-2"
# pipe = LTX2AudioVideoPipeline.from_pretrained(
# torch_dtype=torch.bfloat16,
# device="cuda",
# model_configs=[
# ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
# ],
# tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
# stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
# vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
# )
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
negative_prompt = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)
height, width, num_frames = 512 * 2, 768 * 2, 121
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
seed=43,
height=height,
width=width,
num_frames=num_frames,
tiled=True,
use_two_stage_pipeline=True,
)
write_video_audio_ltx2(
video=video,
audio=audio,
output_path='ltx2_twostage.mp4',
fps=24,
audio_sample_rate=24000,
)
```
## Model Overview
|Model ID|Additional Parameters|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
|-|-|-|-|-|-|-|-|
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)|
|[Lightricks/LTX-2-19b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|
|[Lightricks/LTX-2-19b-IC-LoRA-Detailer](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-IC-LoRA-Detailer)|`in_context_videos`,`in_context_downsample_factor`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-IC-LoRA-Detailer.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2-T2AV-IC-LoRA-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2-T2AV-IC-LoRA.py)|
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py)|-|-|-|-|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Static](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Static)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py)|-|-|-|-|
## Model Inference
Models are loaded through `LTX2AudioVideoPipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
Input parameters for `LTX2AudioVideoPipeline` inference include:
* `prompt`: Prompt describing the content appearing in the video.
* `negative_prompt`: Negative prompt describing content that should not appear in the video, default value is `""`.
* `cfg_scale`: Classifier-free guidance parameter, default value is 3.0.
* `input_images`: List of input images for image-to-video generation.
* `input_images_indexes`: Frame index list of input images in the video.
* `input_images_strength`: Strength of input images, default value is 1.0.
* `denoising_strength`: Denoising strength, range is 01, default value is 1.0.
* `seed`: Random seed. Default is `None`, which means completely random.
* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different results will be generated on different GPUs.
* `height`: Video height, must be a multiple of 32 (single-stage) or 64 (two-stage).
* `width`: Video width, must be a multiple of 32 (single-stage) or 64 (two-stage).
* `num_frames`: Number of video frames, default value is 121, must be a multiple of 8 + 1.
* `num_inference_steps`: Number of inference steps, default value is 40.
* `tiled`: Whether to enable VAE tiling inference, default is `True`. When set to `True`, it can significantly reduce VRAM usage during VAE encoding/decoding stages, with slight errors and minor inference time extension.
* `tile_size_in_pixels`: Pixel tiling size during VAE encoding/decoding stages, default is 512.
* `tile_overlap_in_pixels`: Pixel tiling overlap size during VAE encoding/decoding stages, default is 128.
* `tile_size_in_frames`: Frame tiling size during VAE encoding/decoding stages, default is 128.
* `tile_overlap_in_frames`: Frame tiling overlap size during VAE encoding/decoding stages, default is 24.
* `use_two_stage_pipeline`: Whether to use two-stage pipeline, default is `False`.
* `use_distilled_pipeline`: Whether to use distilled pipeline, default is `False`.
* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be set to `lambda x:x` to hide the progress bar.
If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the previous "Supported Inference Scripts" section.
## Model Training
LTX-2 series models are uniformly trained through [`examples/ltx2/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/train.py), and the script parameters include:
* General Training Parameters
* Dataset Basic Configuration
* `--dataset_base_path`: Root directory of the dataset.
* `--dataset_metadata_path`: Metadata file path of the dataset.
* `--dataset_repeat`: Number of times the dataset is repeated in each epoch.
* `--dataset_num_workers`: Number of processes for each DataLoader.
* `--data_file_keys`: Field names to be loaded from metadata, usually image or video file paths, separated by `,`.
* Model Loading Configuration
* `--model_paths`: Paths of models to be loaded. JSON format.
* `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `"Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors"`. Separated by commas.
* `--extra_inputs`: Extra input parameters required by the model Pipeline, e.g., extra parameters when training image editing models, separated by `,`.
* `--fp8_models`: Models loaded in FP8 format, consistent with `--model_paths` or `--model_id_with_origin_paths` format. Currently only supports models whose parameters are not updated by gradients (no gradient backpropagation, or gradients only update their LoRA).
* Training Basic Configuration
* `--learning_rate`: Learning rate.
* `--num_epochs`: Number of epochs.
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
* `--find_unused_parameters`: Whether there are unused parameters in DDP training. Some models contain redundant parameters that do not participate in gradient calculation, and this setting needs to be enabled to avoid errors in multi-GPU training.
* `--weight_decay`: Weight decay size, see [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html).
* `--task`: Training task, default is `sft`. Some models support more training modes, please refer to the documentation of each specific model.
* Output Configuration
* `--output_path`: Model saving path.
* `--remove_prefix_in_ckpt`: Remove prefix in the state dict of the model file.
* `--save_steps`: Interval of training steps to save the model. If this parameter is left blank, the model is saved once per epoch.
* LoRA Configuration
* `--lora_base_model`: Which model to add LoRA to.
* `--lora_target_modules`: Which layers to add LoRA to.
* `--lora_rank`: Rank of LoRA.
* `--lora_checkpoint`: Path of the LoRA checkpoint. If this path is provided, LoRA will be loaded from this checkpoint.
* `--preset_lora_path`: Preset LoRA checkpoint path. If this path is provided, this LoRA will be loaded in the form of being merged into the base model. This parameter is used for LoRA differential training.
* `--preset_lora_model`: Model that the preset LoRA is merged into, e.g., `dit`.
* Gradient Configuration
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to memory.
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
* Video Width/Height Configuration
* `--height`: Height of the video. Leave `height` and `width` blank to enable dynamic resolution.
* `--width`: Width of the video. Leave `height` and `width` blank to enable dynamic resolution.
* `--max_pixels`: Maximum pixel area of video frames. When dynamic resolution is enabled, video frames with resolution larger than this value will be downscaled, and video frames with resolution smaller than this value will remain unchanged.
* `--num_frames`: Number of frames in the video.
* LTX-2 Series Specific Parameters
* `--tokenizer_path`: Path of the tokenizer, applicable to text-to-video models, leave blank to automatically download from remote.
* `--frame_rate`: frame rate of the training videos.
We have built a sample video dataset for your testing. You can download this dataset with the following command:
```shell
modelscope download --dataset DiffSynth-Studio/example_video_dataset --local_dir ./data/example_video_dataset
```
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).

View File

@@ -2,7 +2,7 @@
## Qwen-Image
Documentation: [./Qwen-Image.md](/docs/en/Model_Details/Qwen-Image.md)
Documentation: [./Qwen-Image.md](../Model_Details/Qwen-Image.md)
<details>
@@ -69,23 +69,23 @@ graph LR;
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
| - | - | - | - | - | - | - |
| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](/examples/qwen_image/model_inference/Qwen-Image.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) |
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) |
| [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) | [code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py) |
| [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) | [code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py) |
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py) |
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py) |
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) |
| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) |
| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - |
| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) |
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) |
| [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py) |
| [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py) |
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py) |
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py) |
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) |
| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) |
| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - |
## FLUX Series
Documentation: [./FLUX.md](/docs/en/Model_Details/FLUX.md)
Documentation: [./FLUX.md](../Model_Details/FLUX.md)
<details>
@@ -149,24 +149,24 @@ graph LR;
| Model ID | Extra Parameters | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
| - | - | - | - | - | - | - | - |
| [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) | | [code](/examples/flux/model_inference/FLUX.1-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py) |
| [black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) | | [code](/examples/flux/model_inference/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py) |
| [black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) | `kontext_images` | [code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py) |
| [alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py) |
| [InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py) |
| [jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py) |
| [InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter) | `ipadapter_images`, `ipadapter_scale` | [code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py) |
| [ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou) | `infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py) |
| [DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) | `eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint` | [code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py) | - | - | [code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py) |
| [DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev) | `lora_encoder_inputs`, `lora_encoder_scale` | [code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py) | - | - |
| [DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev) | | [code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py) | - | - | - | - | - |
| [stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit) | `step1x_reference_image` | [code](/examples/flux/model_inference/Step1X-Edit.py) | [code](/examples/flux/model_inference_low_vram/Step1X-Edit.py) | [code](/examples/flux/model_training/full/Step1X-Edit.sh) | [code](/examples/flux/model_training/validate_full/Step1X-Edit.py) | [code](/examples/flux/model_training/lora/Step1X-Edit.sh) | [code](/examples/flux/model_training/validate_lora/Step1X-Edit.py) |
| [ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview) | `flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop` | [code](/examples/flux/model_inference/FLEX.2-preview.py) | [code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py) | [code](/examples/flux/model_training/full/FLEX.2-preview.sh) | [code](/examples/flux/model_training/validate_full/FLEX.2-preview.py) | [code](/examples/flux/model_training/lora/FLEX.2-preview.sh) | [code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py) |
| [DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2) | `nexus_gen_reference_image` | [code](/examples/flux/model_inference/Nexus-Gen-Editing.py) | [code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py) | [code](/examples/flux/model_training/full/Nexus-Gen.sh) | [code](/examples/flux/model_training/validate_full/Nexus-Gen.py) | [code](/examples/flux/model_training/lora/Nexus-Gen.sh) | [code](/examples/flux/model_training/validate_lora/Nexus-Gen.py) |
| [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev.py) |
| [black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Krea-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py) |
| [black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) | `kontext_images` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py) |
| [alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py) |
| [InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py) |
| [jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler) | `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py) |
| [InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter) | `ipadapter_images`, `ipadapter_scale` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py) |
| [ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou) | `infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py) |
| [DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) | `eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-EliGen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py) |
| [DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev) | `lora_encoder_inputs`, `lora_encoder_scale` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py) | - | - |
| [DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py) | - | - | - | - | - |
| [stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit) | `step1x_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Step1X-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Step1X-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Step1X-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Step1X-Edit.py) |
| [ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview) | `flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/FLEX.2-preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/FLEX.2-preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/FLEX.2-preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/FLEX.2-preview.py) |
| [DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2) | `nexus_gen_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference/Nexus-Gen-Editing.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/full/Nexus-Gen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_full/Nexus-Gen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/lora/Nexus-Gen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux/model_training/validate_lora/Nexus-Gen.py) |
## Wan Series
Documentation: [./Wan.md](/docs/en/Model_Details/Wan.md)
Documentation: [./Wan.md](../Model_Details/Wan.md)
<details>
@@ -254,38 +254,38 @@ graph LR;
| Model ID | Extra Parameters | Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
| - | - | - | - | - | - | - |
| [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | | [code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py) |
| [Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | | [code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py) |
| [Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py) |
| [Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py) |
| [Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py) |
| [iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py) |
| [Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py) |
| [Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py) |
| [PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py) |
| [PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control) | `control_video` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py) |
| [PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py) |
| [PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control) | `control_video` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py) |
| [PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py) |
| [PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py) |
| [PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py) |
| [PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py) |
| [PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) |
| [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py) |
| [DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1) | `motion_bucket_id` | [code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py) |
| [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) | | [code](/examples/wanvideo/model_inference/krea-realtime-video.py) | [code](/examples/wanvideo/model_training/full/krea-realtime-video.sh) | [code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py) | [code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh) | [code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py) |
| [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) | `longcat_video` | [code](/examples/wanvideo/model_inference/LongCat-Video.py) | [code](/examples/wanvideo/model_training/full/LongCat-Video.sh) | [code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py) | [code](/examples/wanvideo/model_training/lora/LongCat-Video.sh) | [code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py) |
| [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) | `vap_video`, `vap_prompt` | [code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py) | [code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py) | [code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py) |
| [Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B) | | [code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py) |
| [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py) |
| [Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py) |
| [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B) | `input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video` | [code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py) |
| [Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B) | `input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video` | [code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py) |
| [PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py) |
| [PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py) |
| [PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py) |
| [PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py) |
| [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py) |
| [Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py) |
| [Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py) |
| [Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py) |
| [Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py) |
| [iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py) |
| [Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py) |
| [Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py) |
| [PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py) |
| [PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control) | `control_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py) |
| [PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py) |
| [PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control) | `control_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py) |
| [PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control) | `control_video`, `reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py) |
| [PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control) | `control_video`, `reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py) |
| [PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py) |
| [PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py) |
| [PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) |
| [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py) |
| [DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1) | `motion_bucket_id` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py) |
| [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/krea-realtime-video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/krea-realtime-video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/krea-realtime-video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/krea-realtime-video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py) |
| [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) | `longcat_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/LongCat-Video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/LongCat-Video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/LongCat-Video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/LongCat-Video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/LongCat-Video.py) |
| [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) | `vap_video`, `vap_prompt` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py) |
| [Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py) |
| [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py) |
| [Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py) |
| [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B) | `input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py) |
| [Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B) | `input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py) |
| [PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py) |
| [PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py) |
| [PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control) | `control_video`, `reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py) |
| [PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py) |
* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/wanvideo/model_training/special/fp8_training/)
* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/wanvideo/model_training/special/split_training/)
* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/wanvideo/model_training/special/direct_distill/)
* FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/fp8_training/)
* Two-stage Split Training: [doc](../Training/Split_Training.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/split_training/)
* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/direct_distill/)

View File

@@ -14,7 +14,7 @@ cd DiffSynth-Studio
pip install -e .
```
For more information about installation, please refer to [Install Dependencies](/docs/en/Pipeline_Usage/Setup.md).
For more information about installation, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md).
## Quick Start
@@ -80,34 +80,42 @@ graph LR;
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
| - | - | - | - | - | - | - |
| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](/examples/qwen_image/model_inference/Qwen-Image.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) |
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) |
| [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) | [code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py) |
| [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) | [code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py) |
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py) |
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py) |
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) |
| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) |
| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - |
|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) |
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[FireRedTeam/FireRed-Image-Edit-1.0](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.0)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/FireRed-Image-Edit-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.0.py)|
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) |
| [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py) |
| [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py) |
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py) |
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py) |
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) |
| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) |
| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - |
|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
Special Training Scripts:
* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md), [code](/examples/qwen_image/model_training/special/differential_training/)
* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/qwen_image/model_training/special/fp8_training/)
* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/qwen_image/model_training/special/split_training/)
* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)
* Differential LoRA Training: [doc](../Training/Differential_LoRA.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/qwen_image/model_training/special/differential_training/)
* FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/qwen_image/model_training/special/fp8_training/)
* Two-stage Split Training: [doc](../Training/Split_Training.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/qwen_image/model_training/special/split_training/)
* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md), [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)
DeepSpeed ZeRO Stage 3 Training: The Qwen-Image series models support DeepSpeed ZeRO Stage 3 training, which partitions the model across multiple GPUs. Taking full parameter training of the Qwen-Image model as an example, the following modifications are required:
* `--config_file examples/qwen_image/model_training/full/accelerate_config_zero3.yaml`
* `--initialize_model_on_cpu`
## Model Inference
Models are loaded via `QwenImagePipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models).
Models are loaded via `QwenImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models).
Input parameters for `QwenImagePipeline` inference include:
@@ -138,11 +146,11 @@ Input parameters for `QwenImagePipeline` inference include:
* `tile_stride`: Tile stride during VAE encoding/decoding stages, default is 64, only effective when `tiled=True`, must be less than or equal to `tile_size`.
* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be disabled by setting to `lambda x:x`.
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
## Model Training
Qwen-Image series models are uniformly trained through [`examples/qwen_image/model_training/train.py`](/examples/qwen_image/model_training/train.py), and the script parameters include:
Qwen-Image series models are uniformly trained through [`examples/qwen_image/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/train.py), and the script parameters include:
* General Training Parameters
* Dataset Basic Configuration
@@ -192,4 +200,4 @@ We have built a sample image dataset for your testing. You can download this dat
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
```
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/).
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).

View File

@@ -14,7 +14,7 @@ cd DiffSynth-Studio
pip install -e .
```
For more information about installation, please refer to [Install Dependencies](/docs/en/Pipeline_Usage/Setup.md).
For more information about installation, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md).
## Quick Start
@@ -106,45 +106,50 @@ graph LR;
| Model ID | Extra Parameters | Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
| - | - | - | - | - | - | - |
| [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | | [code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py) |
| [Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | | [code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py) |
| [Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py) |
| [Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py) |
| [Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py) |
| [iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py) |
| [Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py) |
| [Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py) |
| [PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py) |
| [PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control) | `control_video` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py) |
| [PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py) |
| [PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control) | `control_video` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py) |
| [PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py) |
| [PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py) |
| [PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py) |
| [PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py) |
| [PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) |
| [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py) |
| [DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1) | `motion_bucket_id` | [code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py) |
| [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) | | [code](/examples/wanvideo/model_inference/krea-realtime-video.py) | [code](/examples/wanvideo/model_training/full/krea-realtime-video.sh) | [code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py) | [code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh) | [code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py) |
| [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) | `longcat_video` | [code](/examples/wanvideo/model_inference/LongCat-Video.py) | [code](/examples/wanvideo/model_training/full/LongCat-Video.sh) | [code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py) | [code](/examples/wanvideo/model_training/lora/LongCat-Video.sh) | [code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py) |
| [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) | `vap_video`, `vap_prompt` | [code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py) | [code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py) | [code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py) |
| [Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B) | | [code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py) |
| [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py) |
| [Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py) |
| [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B) | `input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video` | [code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py) |
| [Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B) | `input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video` | [code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py) |
| [PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py) |
| [PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py) |
| [PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py) |
| [PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py) |
| [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py) |
| [Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py) |
| [Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py) |
| [Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py) |
| [Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py) |
| [iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py) |
| [Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py) |
| [Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py) |
| [PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py) |
| [PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control) | `control_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py) |
| [PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py) |
| [PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control) | `control_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py) |
| [PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control) | `control_video`, `reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py) |
| [PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control) | `control_video`, `reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py) |
| [PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py) |
| [PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py) |
| [PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) |
| [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py) |
| [DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1) | `motion_bucket_id` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py) |
| [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/krea-realtime-video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/krea-realtime-video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/krea-realtime-video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/krea-realtime-video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py) |
| [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) | `longcat_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/LongCat-Video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/LongCat-Video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/LongCat-Video.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/LongCat-Video.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/LongCat-Video.py) |
| [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) | `vap_video`, `vap_prompt` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py) |
| [Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B) | | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py) |
| [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py) |
| [Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py) |
| [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B) | `input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py) |
| [Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B) | `input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py) |
| [PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B) | `vace_control_video`, `vace_reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py) |
| [PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py) |
| [PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control) | `control_video`, `reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py) |
| [PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py) |
* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/wanvideo/model_training/special/fp8_training/)
* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/wanvideo/model_training/special/split_training/)
* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/wanvideo/model_training/special/direct_distill/)
* FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/fp8_training/)
* Two-stage Split Training: [doc](../Training/Split_Training.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/split_training/)
* End-to-end Direct Distillation: [doc](../Training/Direct_Distill.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/direct_distill/)
DeepSpeed ZeRO Stage 3 Training: The Wan series models support DeepSpeed ZeRO Stage 3 training, which partitions the model across multiple GPUs. Taking full parameter training of the Wan2.1-T2V-14B model as an example, the following modifications are required:
* `--config_file examples/wanvideo/model_training/full/accelerate_config_zero3.yaml`
* `--initialize_model_on_cpu`
## Model Inference
Models are loaded via `WanVideoPipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models).
Models are loaded via `WanVideoPipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models).
Input parameters for `WanVideoPipeline` inference include:
@@ -194,11 +199,11 @@ Input parameters for `WanVideoPipeline` inference include:
* `tea_cache_model_id`: Model ID used by TeaCache.
* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be disabled by setting to `lambda x:x`.
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
## Model Training
Wan series models are uniformly trained through [`examples/wanvideo/model_training/train.py`](/examples/wanvideo/model_training/train.py), and the script parameters include:
Wan series models are uniformly trained through [`examples/wanvideo/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/train.py), and the script parameters include:
* General Training Parameters
* Dataset Basic Configuration
@@ -249,4 +254,4 @@ We have built a sample video dataset for your testing. You can download this dat
modelscope download --dataset DiffSynth-Studio/example_video_dataset --local_dir ./data/example_video_dataset
```
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/).
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).

View File

@@ -12,7 +12,7 @@ cd DiffSynth-Studio
pip install -e .
```
For more information about installation, please refer to [Install Dependencies](/docs/en/Pipeline_Usage/Setup.md).
For more information about installation, please refer to [Install Dependencies](../Pipeline_Usage/Setup.md).
## Quick Start
@@ -50,18 +50,23 @@ image.save("image.jpg")
## Model Overview
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
| - | - | - | - | - | - | - |
| [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) | [code](/examples/z_image/model_inference/Z-Image-Turbo.py) | [code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py) | [code](/examples/z_image/model_training/full/Z-Image-Turbo.sh) | [code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py) | [code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh) | [code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py) |
|Model ID|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
|-|-|-|-|-|-|-|
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image.py)|
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-i2L.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
Special Training Scripts:
* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md), [code](/examples/z_image/model_training/special/differential_training/)
* Trajectory Imitation Distillation Training (Experimental Feature): [code](/examples/z_image/model_training/special/trajectory_imitation/)
* Differential LoRA Training: [doc](../Training/Differential_LoRA.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/z_image/model_training/special/differential_training/)
* Trajectory Imitation Distillation Training (Experimental Feature): [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/z_image/model_training/special/trajectory_imitation/)
## Model Inference
Models are loaded via `ZImagePipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models).
Models are loaded via `ZImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models).
Input parameters for `ZImagePipeline` inference include:
@@ -75,12 +80,15 @@ Input parameters for `ZImagePipeline` inference include:
* `seed`: Random seed. Default is `None`, meaning completely random.
* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different GPUs will produce different generation results.
* `num_inference_steps`: Number of inference steps, default value is 8.
* `controlnet_inputs`: Inputs for ControlNet models.
* `edit_image`: Edit images for image editing models, supporting multiple images.
* `positive_only_lora`: LoRA weights used only in positive prompts.
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
## Model Training
Z-Image series models are uniformly trained through [`examples/z_image/model_training/train.py`](/examples/z_image/model_training/train.py), and the script parameters include:
Z-Image series models are uniformly trained through [`examples/z_image/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/train.py), and the script parameters include:
* General Training Parameters
* Dataset Basic Configuration
@@ -129,13 +137,13 @@ We have built a sample image dataset for your testing. You can download this dat
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
```
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/).
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).
Training Tips:
* [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) is a distilled acceleration model. Therefore, direct training will quickly cause the model to lose its acceleration capability. The effect of inference with "acceleration configuration" (`num_inference_steps=8`, `cfg_scale=1`) becomes worse, while the effect of inference with "no acceleration configuration" (`num_inference_steps=30`, `cfg_scale=2`) becomes better. The following training and inference schemes can be adopted:
* Standard SFT Training ([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + No Acceleration Configuration Inference
* Differential LoRA Training ([code](/examples/z_image/model_training/special/differential_training/)) + Acceleration Configuration Inference
* Standard SFT Training ([code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + No Acceleration Configuration Inference
* Differential LoRA Training ([code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/z_image/model_training/special/differential_training/)) + Acceleration Configuration Inference
* An additional LoRA needs to be loaded in differential LoRA training, e.g., [ostris/zimage_turbo_training_adapter](https://www.modelscope.cn/models/ostris/zimage_turbo_training_adapter)
* Standard SFT Training ([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Trajectory Imitation Distillation Training ([code](/examples/z_image/model_training/special/trajectory_imitation/)) + Acceleration Configuration Inference
* Standard SFT Training ([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Load Distillation Acceleration LoRA During Inference ([model](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-Turbo-DistillPatch)) + Acceleration Configuration Inference
* Standard SFT Training ([code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Trajectory Imitation Distillation Training ([code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/z_image/model_training/special/trajectory_imitation/)) + Acceleration Configuration Inference
* Standard SFT Training ([code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Load Distillation Acceleration LoRA During Inference ([model](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-Turbo-DistillPatch)) + Acceleration Configuration Inference

View File

@@ -28,7 +28,7 @@ Model download root directory. Can be set to any local path. If `local_model_pat
## `DIFFSYNTH_ATTENTION_IMPLEMENTATION`
Attention mechanism implementation method. Can be set to `flash_attention_3`, `flash_attention_2`, `sage_attention`, `xformers`, or `torch`. See [`./core/attention.md`](/docs/en/API_Reference/core/attention.md) for details.
Attention mechanism implementation method. Can be set to `flash_attention_3`, `flash_attention_2`, `sage_attention`, `xformers`, or `torch`. See [`./core/attention.md`](../API_Reference/core/attention.md) for details.
## `DIFFSYNTH_DISK_MAP_BUFFER_SIZE`

View File

@@ -2,7 +2,7 @@
`DiffSynth-Studio` supports various GPUs and NPUs. This document explains how to run model inference and training on these devices.
Before you begin, please follow the [Installation Guide](/docs/en/Pipeline_Usage/Setup.md) to install the required GPU/NPU dependencies.
Before you begin, please follow the [Installation Guide](../Pipeline_Usage/Setup.md) to install the required GPU/NPU dependencies.
## NVIDIA GPU
@@ -13,7 +13,7 @@ All sample code provided by this project supports NVIDIA GPUs by default, requir
AMD provides PyTorch packages based on ROCm, so most models can run without code changes. A small number of models may not be compatible due to their reliance on CUDA-specific instructions.
## Ascend NPU
### Inference
When using Ascend NPU, you need to replace `"cuda"` with `"npu"` in your code.
For example, here is the inference code for **Wan2.1-T2V-1.3B**, modified for Ascend NPU:
@@ -22,6 +22,7 @@ For example, here is the inference code for **Wan2.1-T2V-1.3B**, modified for As
import torch
from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
from diffsynth.core.device.npu_compatible_device import get_device_name
vram_config = {
"offload_dtype": "disk",
@@ -46,7 +47,7 @@ pipe = WanVideoPipeline.from_pretrained(
],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
- vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
+ vram_limit=torch.npu.mem_get_info("npu:0")[1] / (1024 ** 3) - 2,
+ vram_limit=torch.npu.mem_get_info(get_device_name())[1] / (1024 ** 3) - 2,
)
video = pipe(
@@ -56,3 +57,38 @@ video = pipe(
)
save_video(video, "video.mp4", fps=15, quality=5)
```
#### USP(Unified Sequence Parallel)
If you want to use this feature on NPU, please install additional third-party libraries as follows:
```shell
pip install git+https://github.com/feifeibear/long-context-attention.git
pip install git+https://github.com/xdit-project/xDiT.git
```
### Training
NPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_training`, for example `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`.
In the NPU training scripts, NPU specific environment variables that can optimize performance have been added, and relevant parameters have been enabled for specific models.
#### Environment variables
```shell
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
```
`expandable_segments:<value>`: Enable the memory pool expansion segment function, which is the virtual memory feature.
```shell
export CPU_AFFINITY_CONF=1
```
Set 0 or not set: indicates not enabling the binding function
1: Indicates enabling coarse-grained kernel binding
2: Indicates enabling fine-grained kernel binding
#### Parameters for specific models
| Model | Parameter | Note |
|----------------|---------------------------|-------------------|
| Wan 14B series | --initialize_model_on_cpu | The 14B model needs to be initialized on the CPU |
| Qwen-Image series | --initialize_model_on_cpu | The model needs to be initialized on the CPU |
| Z-Image series | --enable_npu_patch | Using NPU fusion operator to replace the corresponding operator in Z-image model to improve the performance of the model on NPU |

View File

@@ -22,7 +22,7 @@ pipe = QwenImagePipeline.from_pretrained(
)
```
Where `torch_dtype` and `device` are computation precision and computation device (not model precision and device). `model_configs` can be configured in multiple ways for model paths. For how models are loaded internally in this project, please refer to [`diffsynth.core.loader`](/docs/en/API_Reference/core/loader.md).
Where `torch_dtype` and `device` are computation precision and computation device (not model precision and device). `model_configs` can be configured in multiple ways for model paths. For how models are loaded internally in this project, please refer to [`diffsynth.core.loader`](../API_Reference/core/loader.md).
<details>
@@ -34,7 +34,7 @@ Where `torch_dtype` and `device` are computation precision and computation devic
> ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
> ```
>
> Model files are downloaded to the `./models` path by default, which can be modified through [environment variable DIFFSYNTH_MODEL_BASE_PATH](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path).
> Model files are downloaded to the `./models` path by default, which can be modified through [environment variable DIFFSYNTH_MODEL_BASE_PATH](../Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path).
</details>
@@ -61,7 +61,7 @@ Where `torch_dtype` and `device` are computation precision and computation devic
</details>
By default, even after models have been downloaded, the program will still query remotely for missing files. To completely disable remote requests, set [environment variable DIFFSYNTH_SKIP_DOWNLOAD](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) to `True`.
By default, even after models have been downloaded, the program will still query remotely for missing files. To completely disable remote requests, set [environment variable DIFFSYNTH_SKIP_DOWNLOAD](../Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) to `True`.
```shell
import os
@@ -69,7 +69,7 @@ os.environ["DIFFSYNTH_SKIP_DOWNLOAD"] = "True"
import diffsynth
```
To download models from [HuggingFace](https://huggingface.co/), set [environment variable DIFFSYNTH_DOWNLOAD_SOURCE](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_download_source) to `huggingface`.
To download models from [HuggingFace](https://huggingface.co/), set [environment variable DIFFSYNTH_DOWNLOAD_SOURCE](../Pipeline_Usage/Environment_Variables.md#diffsynth_download_source) to `huggingface`.
```shell
import os
@@ -102,4 +102,65 @@ image.save("image.jpg")
Each model `Pipeline` has different input parameters. Please refer to the documentation for each model.
If the model parameters are too large, causing insufficient VRAM, please enable [VRAM management](/docs/en/Pipeline_Usage/VRAM_management.md).
If the model parameters are too large, causing insufficient VRAM, please enable [VRAM management](../Pipeline_Usage/VRAM_management.md).
## Loading LoRA
LoRA is a lightweight model training method that produces a small number of parameters to extend model capabilities. DiffSynth-Studio supports two ways to load LoRA: cold loading and hot loading.
* Cold loading: When the base model does not have [VRAM management](../Pipeline_Usage/VRAM_management.md) enabled, LoRA will be fused into the base model weights. In this case, inference speed remains unchanged, but LoRA cannot be unloaded after loading.
```python
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
lora = ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1", origin_file_pattern="model.safetensors")
pipe.load_lora(pipe.dit, lora, alpha=1)
prompt = "Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal."
image = pipe(prompt, seed=0, num_inference_steps=40)
image.save("image.jpg")
```
* Hot loading: When the base model has [VRAM management](../Pipeline_Usage/VRAM_management.md) enabled, LoRA will not be fused into the base model weights. In this case, inference speed will be slower, but LoRA can be unloaded through `pipe.clear_lora()` after loading.
```python
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cuda",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
lora = ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1", origin_file_pattern="model.safetensors")
pipe.load_lora(pipe.dit, lora, alpha=1)
prompt = "Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal."
image = pipe(prompt, seed=0, num_inference_steps=40)
image.save("image.jpg")
pipe.clear_lora()
```

View File

@@ -65,7 +65,7 @@ image_1.jpg,"a dog"
image_2.jpg,"a cat"
```
We have built sample datasets for your testing. To understand how the universal dataset architecture is implemented, please refer to [`diffsynth.core.data`](/docs/en/API_Reference/core/data.md).
We have built sample datasets for your testing. To understand how the universal dataset architecture is implemented, please refer to [`diffsynth.core.data`](../API_Reference/core/data.md).
<details>
@@ -93,7 +93,7 @@ We have built sample datasets for your testing. To understand how the universal
## Loading Models
Similar to [model loading during inference](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models), we support multiple ways to configure model paths, and the two methods can be mixed.
Similar to [model loading during inference](../Pipeline_Usage/Model_Inference.md#loading-models), we support multiple ways to configure model paths, and the two methods can be mixed.
<details>
@@ -115,9 +115,9 @@ Similar to [model loading during inference](/docs/en/Pipeline_Usage/Model_Infere
> --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors"
> ```
>
> Model files are downloaded to the `./models` path by default, which can be modified through [environment variable DIFFSYNTH_MODEL_BASE_PATH](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path).
> Model files are downloaded to the `./models` path by default, which can be modified through [environment variable DIFFSYNTH_MODEL_BASE_PATH](../Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path).
>
> By default, even after models have been downloaded, the program will still query remotely for missing files. To completely disable remote requests, set [environment variable DIFFSYNTH_SKIP_DOWNLOAD](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) to `True`.
> By default, even after models have been downloaded, the program will still query remotely for missing files. To completely disable remote requests, set [environment variable DIFFSYNTH_SKIP_DOWNLOAD](../Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) to `True`.
</details>
@@ -237,11 +237,11 @@ accelerate launch --config_file examples/qwen_image/model_training/full/accelera
## Training Considerations
* In addition to the `csv` format, dataset metadata also supports `json` and `jsonl` formats. For how to choose the best metadata format, please refer to [/docs/en/API_Reference/core/data.md#metadata](/docs/en/API_Reference/core/data.md#metadata)
* In addition to the `csv` format, dataset metadata also supports `json` and `jsonl` formats. For how to choose the best metadata format, please refer to [../API_Reference/core/data.md#metadata](../API_Reference/core/data.md#metadata)
* Training effectiveness is usually strongly correlated with training steps and weakly correlated with epoch count. Therefore, we recommend using the `--save_steps` parameter to save model files at training step intervals.
* When data volume * `dataset_repeat` exceeds $10^9$, we observed that the dataset speed becomes significantly slower, which seems to be a `PyTorch` bug. We are not sure if newer versions of `PyTorch` have fixed this issue.
* For learning rate `--learning_rate`, it is recommended to set to `1e-4` in LoRA training and `1e-5` in full training.
* The training framework does not support batch size > 1. The reasons are complex. See [Q&A: Why doesn't the training framework support batch size > 1?](/docs/en/QA.md#why-doesnt-the-training-framework-support-batch-size--1)
* The training framework does not support batch size > 1. The reasons are complex. See [Q&A: Why doesn't the training framework support batch size > 1?](../QA.md#why-doesnt-the-training-framework-support-batch-size--1)
* Some models contain redundant parameters. For example, the text encoding part of the last layer of Qwen-Image's DiT part. When training these models, `--find_unused_parameters` needs to be set to avoid errors in multi-GPU training. For compatibility with community models, we do not intend to remove these redundant parameters.
* The loss function value of Diffusion models has little relationship with actual effects. Therefore, we do not record loss function values during training. We recommend setting `--num_epochs` to a sufficiently large value, testing while training, and manually closing the training program after the effect converges.
* `--use_gradient_checkpointing` is usually enabled unless GPU VRAM is sufficient; `--use_gradient_checkpointing_offload` is enabled as needed. See [`diffsynth.core.gradient`](/docs/en/API_Reference/core/gradient.md) for details.
* `--use_gradient_checkpointing` is usually enabled unless GPU VRAM is sufficient; `--use_gradient_checkpointing_offload` is enabled as needed. See [`diffsynth.core.gradient`](../API_Reference/core/gradient.md) for details.

View File

@@ -30,13 +30,18 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm6
* **Ascend NPU**
Ascend NPU support is provided via the `torch-npu` package. Taking version `2.1.0.post17` (as of the article update date: December 15, 2025) as an example, run the following command:
1. Install [CANN](https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/softwareinst/instg/instg_quick.html?Mode=PmIns&InstallType=local&OS=openEuler&Software=cannToolKit) through official documentation.
```shell
pip install torch-npu==2.1.0.post17
```
2. Install from source
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
# aarch64/ARM
pip install -e .[npu_aarch64]
# x86
pip install -e .[npu] --extra-index-url "https://download.pytorch.org/whl/cpu"
When using Ascend NPU, please replace `"cuda"` with `"npu"` in your Python code. For details, see [NPU Support](/docs/en/Pipeline_Usage/GPU_support.md#ascend-npu).
When using Ascend NPU, please replace `"cuda"` with `"npu"` in your Python code. For details, see [NPU Support](../Pipeline_Usage/GPU_support.md#ascend-npu).
## Other Installation Issues

View File

@@ -140,7 +140,7 @@ image.save("image.jpg")
In more extreme cases, when memory is also insufficient to store the entire model, the Disk Offload feature allows lazy loading of model parameters, meaning each Layer of the model only reads the corresponding parameters from disk when the forward function is called. When enabling this feature, we recommend using high-speed SSD drives.
Disk Offload is a very special VRAM management solution that only supports `.safetensors` format files, not `.bin`, `.pth`, `.ckpt`, or other binary files, and does not support [state dict converter](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape.
Disk Offload is a very special VRAM management solution that only supports `.safetensors` format files, not `.bin`, `.pth`, `.ckpt`, or other binary files, and does not support [state dict converter](../Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape.
```python
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
@@ -196,7 +196,7 @@ Specifically, the VRAM management module divides model Layers into the following
* Preparing: Intermediate state between Onload and Computation. A temporary storage state when VRAM allows. This state is controlled by the VRAM management mechanism and enters this state if and only if [vram_limit is set to unlimited] or [vram_limit is set and there is spare VRAM]
* Computation: The model is being computed. This state is controlled by the VRAM management mechanism and is temporarily entered only during `forward`
If you are a model developer and want to control the VRAM management granularity of a specific model, please refer to [../Developer_Guide/Enabling_VRAM_management.md](/docs/en/Developer_Guide/Enabling_VRAM_management.md).
If you are a model developer and want to control the VRAM management granularity of a specific model, please refer to [../Developer_Guide/Enabling_VRAM_management.md](../Developer_Guide/Enabling_VRAM_management.md).
## Best Practices

View File

@@ -25,4 +25,11 @@ Even with suitable hardware conditions, we currently have no plans to support na
* The main challenge of native FP8 precision training is precision overflow caused by gradient explosion. To ensure training stability, the model structure needs to be redesigned accordingly. However, no model developers are willing to do so at present.
* Additionally, models trained with native FP8 precision can only be computed with BF16 precision during inference without Hopper architecture GPUs, theoretically resulting in generation quality inferior to FP8.
Therefore, native FP8 precision training technology is extremely immature. We will observe the technological developments in the open-source community.
Therefore, native FP8 precision training technology is extremely immature. We will observe the technological developments in the open-source community.
## How to dynamically load LoRA models during inference?
We support two loading methods for LoRA models. See [LoRA Loading](./Pipeline_Usage/Model_Inference.md#loading-lora) for details:
* Cold Loading: When [VRAM Management](./Pipeline_Usage/VRAM_management.md) is not enabled for the base model, LoRA will be fused into the base model weights. In this case, inference speed remains unchanged, and LoRA cannot be unloaded after loading.
* Hot Loading: When [VRAM Management](./Pipeline_Usage/VRAM_management.md) is enabled for the base model, LoRA will not be fused into the base model weights. In this case, inference speed will slow down, and LoRA can be unloaded after loading via `pipe.clear_lora()`.

View File

@@ -26,58 +26,60 @@ graph LR;
This section introduces the basic usage of `DiffSynth-Studio`, including how to enable VRAM management for inference on GPUs with extremely low VRAM, and how to train various base models, LoRAs, ControlNets, and other models.
* [Installation Dependencies](/docs/en/Pipeline_Usage/Setup.md)
* [Model Inference](/docs/en/Pipeline_Usage/Model_Inference.md)
* [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md)
* [Model Training](/docs/en/Pipeline_Usage/Model_Training.md)
* [Environment Variables](/docs/en/Pipeline_Usage/Environment_Variables.md)
* [GPU/NPU Support](/docs/en/Pipeline_Usage/GPU_support.md)
* [Installation Dependencies](./Pipeline_Usage/Setup.md)
* [Model Inference](./Pipeline_Usage/Model_Inference.md)
* [VRAM Management](./Pipeline_Usage/VRAM_management.md)
* [Model Training](./Pipeline_Usage/Model_Training.md)
* [Environment Variables](./Pipeline_Usage/Environment_Variables.md)
* [GPU/NPU Support](./Pipeline_Usage/GPU_support.md)
## Section 2: Model Details
This section introduces the Diffusion models supported by `DiffSynth-Studio`. Some model pipelines feature special functionalities such as controllable generation and parallel acceleration.
* [FLUX.1](/docs/en/Model_Details/FLUX.md)
* [Wan](/docs/en/Model_Details/Wan.md)
* [Qwen-Image](/docs/en/Model_Details/Qwen-Image.md)
* [FLUX.2](/docs/en/Model_Details/FLUX2.md)
* [Z-Image](/docs/en/Model_Details/Z-Image.md)
* [FLUX.1](./Model_Details/FLUX.md)
* [Wan](./Model_Details/Wan.md)
* [Qwen-Image](./Model_Details/Qwen-Image.md)
* [FLUX.2](./Model_Details/FLUX2.md)
* [Z-Image](./Model_Details/Z-Image.md)
* [Anima](./Model_Details/Anima.md)
* [LTX-2](./Model_Details/LTX-2.md)
## Section 3: Training Framework
This section introduces the design philosophy of the training framework in `DiffSynth-Studio`, helping developers understand the principles of Diffusion model training algorithms.
* [Basic Principles of Diffusion Models](/docs/en/Training/Understanding_Diffusion_models.md)
* [Standard Supervised Training](/docs/en/Training/Supervised_Fine_Tuning.md)
* [Enabling FP8 Precision in Training](/docs/en/Training/FP8_Precision.md)
* [End-to-End Distillation Accelerated Training](/docs/en/Training/Direct_Distill.md)
* [Two-Stage Split Training](/docs/en/Training/Split_Training.md)
* [Differential LoRA Training](/docs/en/Training/Differential_LoRA.md)
* [Basic Principles of Diffusion Models](./Training/Understanding_Diffusion_models.md)
* [Standard Supervised Training](./Training/Supervised_Fine_Tuning.md)
* [Enabling FP8 Precision in Training](./Training/FP8_Precision.md)
* [End-to-End Distillation Accelerated Training](./Training/Direct_Distill.md)
* [Two-Stage Split Training](./Training/Split_Training.md)
* [Differential LoRA Training](./Training/Differential_LoRA.md)
## Section 4: Model Integration
This section introduces how to integrate models into `DiffSynth-Studio` to utilize the framework's basic functions, helping developers provide support for new models in this project or perform inference and training of private models.
* [Integrating Model Architecture](/docs/en/Developer_Guide/Integrating_Your_Model.md)
* [Building a Pipeline](/docs/en/Developer_Guide/Building_a_Pipeline.md)
* [Enabling Fine-Grained VRAM Management](/docs/en/Developer_Guide/Enabling_VRAM_management.md)
* [Model Training Integration](/docs/en/Developer_Guide/Training_Diffusion_Models.md)
* [Integrating Model Architecture](./Developer_Guide/Integrating_Your_Model.md)
* [Building a Pipeline](./Developer_Guide/Building_a_Pipeline.md)
* [Enabling Fine-Grained VRAM Management](./Developer_Guide/Enabling_VRAM_management.md)
* [Model Training Integration](./Developer_Guide/Training_Diffusion_Models.md)
## Section 5: API Reference
This section introduces the independent core module `diffsynth.core` in `DiffSynth-Studio`, explaining how internal functions are designed and operate. Developers can use these functional modules in other codebase developments if needed.
* [`diffsynth.core.attention`](/docs/en/API_Reference/core/attention.md): Attention mechanism implementation
* [`diffsynth.core.data`](/docs/en/API_Reference/core/data.md): Data processing operators and general datasets
* [`diffsynth.core.gradient`](/docs/en/API_Reference/core/gradient.md): Gradient checkpointing
* [`diffsynth.core.loader`](/docs/en/API_Reference/core/loader.md): Model download and loading
* [`diffsynth.core.vram`](/docs/en/API_Reference/core/vram.md): VRAM management
* [`diffsynth.core.attention`](./API_Reference/core/attention.md): Attention mechanism implementation
* [`diffsynth.core.data`](./API_Reference/core/data.md): Data processing operators and general datasets
* [`diffsynth.core.gradient`](./API_Reference/core/gradient.md): Gradient checkpointing
* [`diffsynth.core.loader`](./API_Reference/core/loader.md): Model download and loading
* [`diffsynth.core.vram`](./API_Reference/core/vram.md): VRAM management
## Section 6: Academic Guide
This section introduces how to use `DiffSynth-Studio` to train new models, helping researchers explore new model technologies.
* Training models from scratch 【coming soon】
* [Training models from scratch](./Research_Tutorial/train_from_scratch.md)
* Inference improvement techniques 【coming soon】
* Designing controllable generation models 【coming soon】
* Creating new training paradigms 【coming soon】
@@ -86,4 +88,4 @@ This section introduces how to use `DiffSynth-Studio` to train new models, helpi
This section summarizes common developer questions. If you encounter issues during usage or development, please refer to this section. If you still cannot resolve the problem, please submit an issue on GitHub.
* [Frequently Asked Questions](/docs/en/QA.md)
* [Frequently Asked Questions](./QA.md)

View File

@@ -0,0 +1,476 @@
# Training Models from Scratch
DiffSynth-Studio's training engine supports training foundation models from scratch. This article introduces how to train a small text-to-image model with only 0.1B parameters from scratch.
## 1. Building Model Architecture
### 1.1 Diffusion Model
From UNet [[1]](https://arxiv.org/abs/1505.04597) [[2]](https://arxiv.org/abs/2112.10752) to DiT [[3]](https://arxiv.org/abs/2212.09748) [[4]](https://arxiv.org/abs/2403.03206), the mainstream model architectures of Diffusion have undergone multiple evolutions. Typically, a Diffusion model's inputs include:
* Image tensor (`latents`): The encoding of images, generated by the VAE model, containing partial noise
* Text tensor (`prompt_embeds`): The encoding of text, generated by the text encoder
* Timestep (`timestep`): A scalar used to mark which stage of the Diffusion process we are currently at
The model's output is a tensor with the same shape as the image tensor, representing the denoising direction predicted by the model. For details about Diffusion model theory, please refer to [Basic Principles of Diffusion Models](../Training/Understanding_Diffusion_models.md). In this article, we build a DiT model with only 0.1B parameters: `AAADiT`.
<details>
<summary>Model Architecture Code</summary>
```python
import torch, accelerate
from PIL import Image
from typing import Union
from tqdm import tqdm
from einops import rearrange, repeat
from transformers import AutoProcessor, AutoTokenizer
from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
from diffsynth.models.general_modules import TimestepEmbeddings
from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
from diffsynth.models.flux2_vae import Flux2VAE
class AAAPositionalEmbedding(torch.nn.Module):
def __init__(self, height=16, width=16, dim=1024):
super().__init__()
self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
def forward(self, image, text):
height, width = image.shape[-2:]
image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
emb = torch.concat([image_emb, text_emb], dim=1)
return emb
class AAABlock(torch.nn.Module):
def __init__(self, dim=1024, num_heads=32):
super().__init__()
self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
self.to_q = torch.nn.Linear(dim, dim)
self.to_k = torch.nn.Linear(dim, dim)
self.to_v = torch.nn.Linear(dim, dim)
self.to_out = torch.nn.Linear(dim, dim)
self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
self.ff = torch.nn.Sequential(
torch.nn.Linear(dim, dim*3),
torch.nn.SiLU(),
torch.nn.Linear(dim*3, dim),
)
self.to_gate = torch.nn.Linear(dim, dim * 2)
self.num_heads = num_heads
def attention(self, emb, pos_emb):
emb = self.norm_attn(emb + pos_emb)
q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
emb = attention_forward(
q, k, v,
q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
dims={"n": self.num_heads},
)
emb = self.to_out(emb)
return emb
def feed_forward(self, emb, pos_emb):
emb = self.norm_mlp(emb + pos_emb)
emb = self.ff(emb)
return emb
def forward(self, emb, pos_emb, t_emb):
gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
return emb
class AAADiT(torch.nn.Module):
def __init__(self, dim=1024):
super().__init__()
self.pos_embedder = AAAPositionalEmbedding(dim=dim)
self.timestep_embedder = TimestepEmbeddings(256, dim)
self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
self.proj_out = torch.nn.Linear(dim, 128)
def forward(
self,
latents,
prompt_embeds,
timestep,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
):
pos_emb = self.pos_embedder(latents, prompt_embeds)
t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
text = self.text_embedder(prompt_embeds)
emb = torch.concat([image, text], dim=1)
for block_id, block in enumerate(self.blocks):
emb = gradient_checkpoint_forward(
block,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
emb=emb,
pos_emb=pos_emb,
t_emb=t_emb,
)
emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
emb = self.proj_out(emb)
emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
return emb
```
</details>
### 1.2 Encoder-Decoder Models
Besides the Diffusion model used for denoising, we also need two other models:
* Text Encoder: Used to encode text into tensors. We adopt the [Qwen/Qwen3-0.6B](https://modelscope.cn/models/Qwen/Qwen3-0.6B) model.
* VAE Encoder-Decoder: The encoder part is used to encode images into tensors, and the decoder part is used to decode image tensors into images. We adopt the VAE model from [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B).
The architectures of these two models are already integrated in DiffSynth-Studio, located at [/diffsynth/models/z_image_text_encoder.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/models/z_image_text_encoder.py) and [/diffsynth/models/flux2_vae.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/models/flux2_vae.py), so we don't need to modify any code.
## 2. Building Pipeline
We introduced how to build a model Pipeline in the document [Integrating Pipeline](../Developer_Guide/Building_a_Pipeline.md). For the model in this article, we also need to build a Pipeline to connect the text encoder, Diffusion model, and VAE encoder-decoder.
<details>
<summary>Pipeline Code</summary>
```python
class AAAImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
)
self.scheduler = FlowMatchScheduler("FLUX.2")
self.text_encoder: ZImageTextEncoder = None
self.dit: AAADiT = None
self.vae: Flux2VAE = None
self.tokenizer: AutoProcessor = None
self.in_iteration_models = ("dit",)
self.units = [
AAAUnit_PromptEmbedder(),
AAAUnit_NoiseInitializer(),
AAAUnit_InputImageEmbedder(),
]
self.model_fn = model_fn_aaa
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = None,
vram_limit: float = None,
):
# Initialize pipeline
pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
pipe.dit = model_pool.fetch_model("aaa_dit")
pipe.vae = model_pool.fetch_model("flux2_vae")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 1.0,
# Image
input_image: Image.Image = None,
denoising_strength: float = 1.0,
# Shape
height: int = 1024,
width: int = 1024,
# Randomness
seed: int = None,
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 30,
# Progress bar
progress_bar_cmd = tqdm,
):
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
# Parameters
inputs_posi = {"prompt": prompt}
inputs_nega = {"negative_prompt": negative_prompt}
inputs_shared = {
"cfg_scale": cfg_scale,
"input_image": input_image, "denoising_strength": denoising_strength,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"num_inference_steps": num_inference_steps,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
# Decode
self.load_models_to_device(['vae'])
image = self.vae.decode(inputs_shared["latents"])
image = self.vae_output_to_image(image)
self.load_models_to_device([])
return image
class AAAUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_embeds",),
onload_model_names=("text_encoder",)
)
self.hidden_states_layers = (-1,)
def process(self, pipe: AAAImagePipeline, prompt):
pipe.load_models_to_device(self.onload_model_names)
text = pipe.tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
return {"prompt_embeds": prompt_embeds}
class AAAUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
class AAAUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",)
)
def process(self, pipe: AAAImagePipeline, input_image, noise):
if input_image is None:
return {"latents": noise, "input_latents": None}
pipe.load_models_to_device(['vae'])
image = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(image)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents, "input_latents": input_latents}
def model_fn_aaa(
dit: AAADiT,
latents=None,
prompt_embeds=None,
timestep=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
model_output = dit(
latents,
prompt_embeds,
timestep,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
return model_output
```
</details>
## 3. Preparing Dataset
To quickly verify training effectiveness, we use the dataset [Pokemon-First Generation](https://modelscope.cn/datasets/DiffSynth-Studio/pokemon-gen1), which is reproduced from the open-source project [pokemon-dataset-zh](https://github.com/42arch/pokemon-dataset-zh), containing 151 first-generation Pokemon from Bulbasaur to Mew. If you want to use other datasets, please refer to the document [Preparing Datasets](../Pipeline_Usage/Model_Training.md#preparing-datasets) and [`diffsynth.core.data`](../API_Reference/core/data.md).
```shell
modelscope download --dataset DiffSynth-Studio/pokemon-gen1 --local_dir ./data
```
### 4. Start Training
The training process can be quickly implemented using Pipeline. We have placed the complete code at [../Research_Tutorial/train_from_scratch.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/docs/en/Research_Tutorial/train_from_scratch.py), which can be directly started with `python docs/en/Research_Tutorial/train_from_scratch.py` for single GPU training.
To enable multi-GPU parallel training, please run `accelerate config` to set relevant parameters, then use the command `accelerate launch docs/en/Research_Tutorial/train_from_scratch.py` to start training.
This training script has no stopping condition, please manually close it when needed. The model converges after training approximately 60,000 steps, requiring 10-20 hours for single GPU training.
<details>
<summary>Training Code</summary>
```python
class AAATrainingModule(DiffusionTrainingModule):
def __init__(self, device):
super().__init__()
self.pipe = AAAImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device=device,
model_configs=[
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
)
self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
self.pipe.freeze_except(["dit"])
self.pipe.scheduler.set_timesteps(1000, training=True)
def forward(self, data):
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {"negative_prompt": ""}
inputs_shared = {
"input_image": data["image"],
"height": data["image"].size[1],
"width": data["image"].size[0],
"cfg_scale": 1,
"use_gradient_checkpointing": False,
"use_gradient_checkpointing_offload": False,
}
for unit in self.pipe.units:
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
return loss
if __name__ == "__main__":
accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
dataset = UnifiedDataset(
base_path="data/images",
metadata_path="data/metadata_merged.csv",
max_data_items=10000000,
data_file_keys=("image",),
main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
)
model = AAATrainingModule(device=accelerator.device)
model_logger = ModelLogger(
"models/AAA/v1",
remove_prefix_in_ckpt="pipe.dit.",
)
launch_training_task(
accelerator, dataset, model, model_logger,
learning_rate=2e-4,
num_workers=4,
save_steps=50000,
num_epochs=999999,
)
```
</details>
## 5. Verifying Training Results
If you don't want to wait for the model training to complete, you can directly download [our pre-trained model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel).
```shell
modelscope download --model DiffSynth-Studio/AAAMyModel step-600000.safetensors --local_dir models/DiffSynth-Studio/AAAMyModel
```
Loading the model
```python
from diffsynth import load_model
pipe = AAAImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
)
pipe.dit = load_model(AAADiT, "models/DiffSynth-Studio/AAAMyModel/step-600000.safetensors", torch_dtype=torch.bfloat16, device="cuda")
```
Model inference, generating the first-generation Pokemon "starter trio". At this point, the images generated by the model basically match the training data.
```python
for seed, prompt in enumerate([
"green, lizard, plant, Grass, Poison, seed on back, red eyes, smiling expression, short stout limbs, sharp claws",
"orange, cream, lizard, Fire, flame on tail tip, large eyes, smiling expression, cream-colored belly patch, sharp claws",
"blue, beige, brown, turtle, water type, shell, big eyes, short limbs, curled tail",
]):
image = pipe(
prompt=prompt,
negative_prompt=" ",
num_inference_steps=30,
cfg_scale=10,
seed=seed,
height=256, width=256,
)
image.save(f"image_{seed}.jpg")
```
|![Image](https://github.com/user-attachments/assets/3c620fbf-5d28-4a1a-b887-519d85ac7d1c)|![Image](https://github.com/user-attachments/assets/909efd4c-9e61-4b33-9321-39da0e499b00)|![Image](https://github.com/user-attachments/assets/f3474bcd-b474-4a90-a1ea-579f67e161e3)|
|-|-|-|
Model inference, generating Pokemon with "sharp claws". At this point, different random seeds can produce different image results.
```python
for seed, prompt in enumerate([
"sharp claws",
"sharp claws",
"sharp claws",
]):
image = pipe(
prompt=prompt,
negative_prompt=" ",
num_inference_steps=30,
cfg_scale=10,
seed=seed+4,
height=256, width=256,
)
image.save(f"image_sharp_claws_{seed}.jpg")
```
|![Image](https://github.com/user-attachments/assets/94862edd-96ae-4276-a38f-795249f11a13)|![Image](https://github.com/user-attachments/assets/b2291f23-20ba-42de-8bfd-76cb4afc6eea)|![Image](https://github.com/user-attachments/assets/f2aab9a4-85ec-498e-8039-648b1289796e)|
|-|-|-|
Now, we have obtained a 0.1B small text-to-image model. This model can already generate 151 Pokemon, but cannot generate other image content. If you increase the amount of data, model parameters, and number of GPUs based on this, you can train a more powerful text-to-image model!

View File

@@ -0,0 +1,341 @@
import torch, accelerate
from PIL import Image
from typing import Union
from tqdm import tqdm
from einops import rearrange, repeat
from transformers import AutoProcessor, AutoTokenizer
from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
from diffsynth.models.general_modules import TimestepEmbeddings
from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
from diffsynth.models.flux2_vae import Flux2VAE
class AAAPositionalEmbedding(torch.nn.Module):
def __init__(self, height=16, width=16, dim=1024):
super().__init__()
self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
def forward(self, image, text):
height, width = image.shape[-2:]
image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
emb = torch.concat([image_emb, text_emb], dim=1)
return emb
class AAABlock(torch.nn.Module):
def __init__(self, dim=1024, num_heads=32):
super().__init__()
self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
self.to_q = torch.nn.Linear(dim, dim)
self.to_k = torch.nn.Linear(dim, dim)
self.to_v = torch.nn.Linear(dim, dim)
self.to_out = torch.nn.Linear(dim, dim)
self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
self.ff = torch.nn.Sequential(
torch.nn.Linear(dim, dim*3),
torch.nn.SiLU(),
torch.nn.Linear(dim*3, dim),
)
self.to_gate = torch.nn.Linear(dim, dim * 2)
self.num_heads = num_heads
def attention(self, emb, pos_emb):
emb = self.norm_attn(emb + pos_emb)
q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
emb = attention_forward(
q, k, v,
q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
dims={"n": self.num_heads},
)
emb = self.to_out(emb)
return emb
def feed_forward(self, emb, pos_emb):
emb = self.norm_mlp(emb + pos_emb)
emb = self.ff(emb)
return emb
def forward(self, emb, pos_emb, t_emb):
gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
return emb
class AAADiT(torch.nn.Module):
def __init__(self, dim=1024):
super().__init__()
self.pos_embedder = AAAPositionalEmbedding(dim=dim)
self.timestep_embedder = TimestepEmbeddings(256, dim)
self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
self.proj_out = torch.nn.Linear(dim, 128)
def forward(
self,
latents,
prompt_embeds,
timestep,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
):
pos_emb = self.pos_embedder(latents, prompt_embeds)
t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
text = self.text_embedder(prompt_embeds)
emb = torch.concat([image, text], dim=1)
for block_id, block in enumerate(self.blocks):
emb = gradient_checkpoint_forward(
block,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
emb=emb,
pos_emb=pos_emb,
t_emb=t_emb,
)
emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
emb = self.proj_out(emb)
emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
return emb
class AAAImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
)
self.scheduler = FlowMatchScheduler("FLUX.2")
self.text_encoder: ZImageTextEncoder = None
self.dit: AAADiT = None
self.vae: Flux2VAE = None
self.tokenizer: AutoProcessor = None
self.in_iteration_models = ("dit",)
self.units = [
AAAUnit_PromptEmbedder(),
AAAUnit_NoiseInitializer(),
AAAUnit_InputImageEmbedder(),
]
self.model_fn = model_fn_aaa
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = None,
vram_limit: float = None,
):
# Initialize pipeline
pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
pipe.dit = model_pool.fetch_model("aaa_dit")
pipe.vae = model_pool.fetch_model("flux2_vae")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 1.0,
# Image
input_image: Image.Image = None,
denoising_strength: float = 1.0,
# Shape
height: int = 1024,
width: int = 1024,
# Randomness
seed: int = None,
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 30,
# Progress bar
progress_bar_cmd = tqdm,
):
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
# Parameters
inputs_posi = {"prompt": prompt}
inputs_nega = {"negative_prompt": negative_prompt}
inputs_shared = {
"cfg_scale": cfg_scale,
"input_image": input_image, "denoising_strength": denoising_strength,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"num_inference_steps": num_inference_steps,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
# Decode
self.load_models_to_device(['vae'])
image = self.vae.decode(inputs_shared["latents"])
image = self.vae_output_to_image(image)
self.load_models_to_device([])
return image
class AAAUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_embeds",),
onload_model_names=("text_encoder",)
)
self.hidden_states_layers = (-1,)
def process(self, pipe: AAAImagePipeline, prompt):
pipe.load_models_to_device(self.onload_model_names)
text = pipe.tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
return {"prompt_embeds": prompt_embeds}
class AAAUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
class AAAUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",)
)
def process(self, pipe: AAAImagePipeline, input_image, noise):
if input_image is None:
return {"latents": noise, "input_latents": None}
pipe.load_models_to_device(['vae'])
image = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(image)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents, "input_latents": input_latents}
def model_fn_aaa(
dit: AAADiT,
latents=None,
prompt_embeds=None,
timestep=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
model_output = dit(
latents,
prompt_embeds,
timestep,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
return model_output
class AAATrainingModule(DiffusionTrainingModule):
def __init__(self, device):
super().__init__()
self.pipe = AAAImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device=device,
model_configs=[
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
)
self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
self.pipe.freeze_except(["dit"])
self.pipe.scheduler.set_timesteps(1000, training=True)
def forward(self, data):
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {"negative_prompt": ""}
inputs_shared = {
"input_image": data["image"],
"height": data["image"].size[1],
"width": data["image"].size[0],
"cfg_scale": 1,
"use_gradient_checkpointing": False,
"use_gradient_checkpointing_offload": False,
}
for unit in self.pipe.units:
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
return loss
if __name__ == "__main__":
accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
dataset = UnifiedDataset(
base_path="data/images",
metadata_path="data/metadata_merged.csv",
max_data_items=10000000,
data_file_keys=("image",),
main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
)
model = AAATrainingModule(device=accelerator.device)
model_logger = ModelLogger(
"models/AAA/v1",
remove_prefix_in_ckpt="pipe.dit.",
)
launch_training_task(
accelerator, dataset, model, model_logger,
learning_rate=2e-4,
num_workers=4,
save_steps=50000,
num_epochs=999999,
)

View File

@@ -8,8 +8,8 @@ We were unable to identify the original proposer of differential LoRA training,
Assume we have two similar-content images: Image 1 and Image 2. For example, both images contain a car, but Image 1 has fewer details while Image 2 has more details. In differential LoRA training, we perform two-step training:
* Train LoRA 1 using Image 1 as training data with [standard supervised training](/docs/en/Training/Supervised_Fine_Tuning.md)
* Train LoRA 2 using Image 2 as training data, after integrating LoRA 1 into the base model, with [standard supervised training](/docs/en/Training/Supervised_Fine_Tuning.md)
* Train LoRA 1 using Image 1 as training data with [standard supervised training](../Training/Supervised_Fine_Tuning.md)
* Train LoRA 2 using Image 2 as training data, after integrating LoRA 1 into the base model, with [standard supervised training](../Training/Supervised_Fine_Tuning.md)
In the first training step, since there is only one training image, the LoRA model easily overfits. Therefore, after training, LoRA 1 will cause the model to generate Image 1 without hesitation, regardless of the random seed. In the second training step, the LoRA model overfits again. Thus, after training, with the combined effect of LoRA 1 and LoRA 2, the model will generate Image 2 without hesitation. In short:

View File

@@ -44,7 +44,7 @@ Click on the model links to go to the model pages and view the model effects.
## Using Distillation Accelerated Training in the Training Framework
First, you need to generate training data. Please refer to the [Model Inference](/docs/en/Pipeline_Usage/Model_Inference.md) section to write inference code and generate training data with a sufficient number of inference steps.
First, you need to generate training data. Please refer to the [Model Inference](../Pipeline_Usage/Model_Inference.md) section to write inference code and generate training data with a sufficient number of inference steps.
Taking Qwen-Image as an example, the following code can generate an image:
@@ -67,7 +67,7 @@ image = pipe(prompt, seed=0, num_inference_steps=40)
image.save("image.jpg")
```
Then, we compile the necessary information into [metadata files](/docs/en/API_Reference/core/data.md#metadata):
Then, we compile the necessary information into [metadata files](../API_Reference/core/data.md#metadata):
```csv
image,prompt,seed,rand_device,num_inference_steps,cfg_scale
@@ -86,11 +86,11 @@ Then start LoRA distillation accelerated training:
bash examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh
```
Please note that in the [training script parameters](/docs/en/Pipeline_Usage/Model_Training.md#script-parameters), the image resolution setting for the dataset should avoid triggering scaling processing. When setting `--height` and `--width` to enable fixed resolution, all training data must be generated with exactly the same width and height. When setting `--max_pixels` to enable dynamic resolution, the value of `--max_pixels` must be greater than or equal to the pixel area of any training image.
Please note that in the [training script parameters](../Pipeline_Usage/Model_Training.md#script-parameters), the image resolution setting for the dataset should avoid triggering scaling processing. When setting `--height` and `--width` to enable fixed resolution, all training data must be generated with exactly the same width and height. When setting `--max_pixels` to enable dynamic resolution, the value of `--max_pixels` must be greater than or equal to the pixel area of any training image.
## Framework Design Concept
Compared to [Standard Supervised Training](/docs/en/Training/Supervised_Fine_Tuning.md), Direct Distillation only differs in the training loss function. The loss function for Direct Distillation is `DirectDistillLoss` in `diffsynth.diffusion.loss`.
Compared to [Standard Supervised Training](../Training/Supervised_Fine_Tuning.md), Direct Distillation only differs in the training loss function. The loss function for Direct Distillation is `DirectDistillLoss` in `diffsynth.diffusion.loss`.
## Future Work

View File

@@ -1,12 +1,12 @@
# Enabling FP8 Precision in Training
Although `DiffSynth-Studio` supports [VRAM management](/docs/en/Pipeline_Usage/VRAM_management.md) in model inference, most of the techniques for reducing VRAM usage are not suitable for training. Offloading would cause extremely slow training processes.
Although `DiffSynth-Studio` supports [VRAM management](../Pipeline_Usage/VRAM_management.md) in model inference, most of the techniques for reducing VRAM usage are not suitable for training. Offloading would cause extremely slow training processes.
FP8 precision is the only VRAM management strategy that can be enabled during training. However, this framework currently does not support native FP8 precision training. For reasons, see [Q&A: Why doesn't the training framework support native FP8 precision training?](/docs/en/QA.md#why-doesnt-the-training-framework-support-native-fp8-precision-training). It only supports storing models whose parameters are not updated by gradients (models that do not require gradient backpropagation, or whose gradients only update their LoRA) in FP8 precision.
FP8 precision is the only VRAM management strategy that can be enabled during training. However, this framework currently does not support native FP8 precision training. For reasons, see [Q&A: Why doesn't the training framework support native FP8 precision training?](../QA.md#why-doesnt-the-training-framework-support-native-fp8-precision-training). It only supports storing models whose parameters are not updated by gradients (models that do not require gradient backpropagation, or whose gradients only update their LoRA) in FP8 precision.
## Enabling FP8
In our provided training scripts, you can quickly set models to be stored in FP8 precision through the `--fp8_models` parameter. Taking Qwen-Image LoRA training as an example, we provide a script for enabling FP8 training located at [`/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh`](/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh). After training is completed, you can verify the training results with the script [`/examples/qwen_image/model_training/special/fp8_training/validate.py`](/examples/qwen_image/model_training/special/fp8_training/validate.py).
In our provided training scripts, you can quickly set models to be stored in FP8 precision through the `--fp8_models` parameter. Taking Qwen-Image LoRA training as an example, we provide a script for enabling FP8 training located at [`/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh). After training is completed, you can verify the training results with the script [`/examples/qwen_image/model_training/special/fp8_training/validate.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/qwen_image/model_training/special/fp8_training/validate.py).
Please note that this FP8 VRAM management strategy does not support gradient updates. When a model is set to be trainable, FP8 precision cannot be enabled for that model. Models that support FP8 include two types:

View File

@@ -8,7 +8,7 @@ This document introduces split training, which can automatically divide the trai
In the training process of most models, a large amount of computation occurs in "preprocessing," i.e., "computations unrelated to the denoising model," including VAE encoding, text encoding, etc. When the corresponding model parameters are fixed, the results of these computations are repetitive. For each data sample, the computational results are identical across multiple epochs. Therefore, we provide a "split training" feature that can automatically analyze and split the training process.
For standard supervised training of ordinary text-to-image models, the splitting process is straightforward. It only requires splitting the computation of all [`Pipeline Units`](/docs/en/Developer_Guide/Building_a_Pipeline.md#units) into the first stage, storing the computational results to disk, and then reading these results from disk in the second stage for subsequent computations. However, if gradient backpropagation is required during preprocessing, the situation becomes extremely complex. To address this, we introduced a computational graph splitting algorithm to analyze how to split the computation.
For standard supervised training of ordinary text-to-image models, the splitting process is straightforward. It only requires splitting the computation of all [`Pipeline Units`](../Developer_Guide/Building_a_Pipeline.md#units) into the first stage, storing the computational results to disk, and then reading these results from disk in the second stage for subsequent computations. However, if gradient backpropagation is required during preprocessing, the situation becomes extremely complex. To address this, we introduced a computational graph splitting algorithm to analyze how to split the computation.
## Computational Graph Splitting Algorithm
@@ -16,7 +16,7 @@ For standard supervised training of ordinary text-to-image models, the splitting
## Using Split Training
Split training already supports [Standard Supervised Training](/docs/en/Training/Supervised_Fine_Tuning.md) and [Direct Distillation Training](/docs/en/Training/Direct_Distill.md). The `--task` parameter in the training command controls this. Taking LoRA training of the Qwen-Image model as an example, the pre-split training command is:
Split training already supports [Standard Supervised Training](../Training/Supervised_Fine_Tuning.md) and [Direct Distillation Training](../Training/Direct_Distill.md). The `--task` parameter in the training command controls this. Taking LoRA training of the Qwen-Image model as an example, the pre-split training command is:
```shell
accelerate launch examples/qwen_image/model_training/train.py \

View File

@@ -1,10 +1,10 @@
# Standard Supervised Training
After understanding the [Basic Principles of Diffusion Models](/docs/en/Training/Understanding_Diffusion_models.md), this document introduces how the framework implements Diffusion model training. This document explains the framework's principles to help developers write new training code. If you want to use our provided default training functions, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md).
After understanding the [Basic Principles of Diffusion Models](../Training/Understanding_Diffusion_models.md), this document introduces how the framework implements Diffusion model training. This document explains the framework's principles to help developers write new training code. If you want to use our provided default training functions, please refer to [Model Training](../Pipeline_Usage/Model_Training.md).
Recalling the model training pseudocode from earlier, when we actually write code, the situation becomes extremely complex. Some models require additional guidance conditions and preprocessing, such as ControlNet; some models require cross-computation with the denoising model, such as VACE; some models require Gradient Checkpointing due to excessive VRAM demands, such as Qwen-Image's DiT.
To achieve strict consistency between inference and training, we abstractly encapsulate components like `Pipeline`, reusing inference code extensively during training. Please refer to [Integrating Pipeline](/docs/en/Developer_Guide/Building_a_Pipeline.md) to understand the design of `Pipeline` components. Next, we'll introduce how the training framework utilizes `Pipeline` components to build training algorithms.
To achieve strict consistency between inference and training, we abstractly encapsulate components like `Pipeline`, reusing inference code extensively during training. Please refer to [Integrating Pipeline](../Developer_Guide/Building_a_Pipeline.md) to understand the design of `Pipeline` components. Next, we'll introduce how the training framework utilizes `Pipeline` components to build training algorithms.
## Framework Design Concept
@@ -48,13 +48,13 @@ In `__init__`, model initialization is required. First load the model, then swit
)
```
The logic for loading models is basically consistent with inference, supporting loading models from remote and local paths. See [Model Inference](/docs/en/Pipeline_Usage/Model_Inference.md) for details, but please note not to enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md).
The logic for loading models is basically consistent with inference, supporting loading models from remote and local paths. See [Model Inference](../Pipeline_Usage/Model_Inference.md) for details, but please note not to enable [VRAM Management](../Pipeline_Usage/VRAM_management.md).
`switch_pipe_to_training_mode` can switch the model to training mode. See `switch_pipe_to_training_mode` for details.
### `forward`
In `forward`, the loss function value needs to be calculated. First perform preprocessing, then compute the loss function through the `Pipeline`'s [`model_fn`](/docs/en/Developer_Guide/Building_a_Pipeline.md#model_fn).
In `forward`, the loss function value needs to be calculated. First perform preprocessing, then compute the loss function through the `Pipeline`'s [`model_fn`](../Developer_Guide/Building_a_Pipeline.md#model_fn).
```python
def forward(self, data):
@@ -90,7 +90,7 @@ The loss function calculation reuses `FlowMatchSFTLoss` from `diffsynth.diffusio
The training framework requires other modules, including:
* accelerator: Training launcher provided by `accelerate`, see [`accelerate`](https://huggingface.co/docs/accelerate/index) for details
* dataset: Generic dataset, see [`diffsynth.core.data`](/docs/en/API_Reference/core/data.md) for details
* dataset: Generic dataset, see [`diffsynth.core.data`](../API_Reference/core/data.md) for details
* model_logger: Model logger, see `diffsynth.diffusion.logger` for details
```python

View File

@@ -6,7 +6,7 @@ This document introduces the basic principles of Diffusion models to help you un
Diffusion models generate clear images or video content through iterative denoising. We start by explaining the generation process of a data sample $x_0$. Intuitively, in a complete round of denoising, we start from random Gaussian noise $x_T$ and iteratively obtain $x_{T-1}$, $x_{T-2}$, $x_{T-3}$, $\cdots$, gradually reducing the noise content at each step until we finally obtain the noise-free data sample $x_0$.
(Figure)
![Image](https://github.com/user-attachments/assets/6471ae4c-a635-4924-8b36-b0bd4d42043d)
This process is intuitive, but to understand the details, we need to answer several questions:
@@ -28,7 +28,7 @@ As for the intermediate values $\sigma_{T-1}$, $\sigma_{T-2}$, $\cdots$, $\sigma
At an intermediate step, we can directly synthesize noisy data samples $x_t=(1-\sigma_t)x_0+\sigma_t x_T$.
(Figure)
![Image](https://github.com/user-attachments/assets/e25a2f71-123c-4e18-8b34-3a066af15667)
## How is the iterative denoising computation performed?
@@ -40,8 +40,6 @@ Before understanding the iterative denoising computation, we need to clarify wha
Among these, the guidance condition $c$ is a newly introduced parameter that is input by the user. It can be text describing the image content or a sketch outlining the image structure.
(Figure)
The model's output $\hat \epsilon(x_t,c,t)$ approximately equals $x_T-x_0$, which is the direction of the entire diffusion process (the reverse process of denoising).
Next, we analyze the computation occurring in one iteration. At time step $t$, after the model computes an approximation of $x_T-x_0$, we calculate the next $x_{t-1}$:
@@ -91,8 +89,6 @@ After understanding the iterative denoising process, we next consider how to tra
The training process differs from the generation process. If we retain multi-step iterations during training, the gradient would need to backpropagate through multiple steps, bringing catastrophic time and space complexity. To improve computational efficiency, we randomly select a time step $t$ for training.
(Figure)
The following is pseudocode for the training process:
> Obtain data sample $x_0$ and guidance condition $c$ from the dataset
@@ -113,7 +109,7 @@ The following is pseudocode for the training process:
From theory to practice, more details need to be filled in. Modern Diffusion model architectures have matured, with mainstream architectures following the "three-stage" architecture proposed by Latent Diffusion, including data encoder-decoder, guidance condition encoder, and denoising model.
(Figure)
![Image](https://github.com/user-attachments/assets/43855430-6427-4aca-83a0-f684e01438b1)
### Data Encoder-Decoder
@@ -142,4 +138,4 @@ The denoising model is the true essence of Diffusion models, with diverse model
## How does this project encapsulate and implement model training?
Please read the next document: [Standard Supervised Training](/docs/en/Training/Supervised_Fine_Tuning.md)
Please read the next document: [Standard Supervised Training](../Training/Supervised_Fine_Tuning.md)

124
docs/en/conf.py Normal file
View File

@@ -0,0 +1,124 @@
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
# import sphinx_book_theme
sys.path.insert(0, os.path.abspath('../../'))
# -- Project information -----------------------------------------------------
project = 'diffsynth'
copyright = '2022-2025, Alibaba ModelScope'
author = 'ModelScope Authors'
version_file = '../../diffsynth/version.py'
html_theme = 'sphinx_rtd_theme'
language = 'en'
def get_version():
with open(version_file, 'r', encoding='utf-8') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__']
# The full version, including alpha/beta/rc tags
version = get_version()
release = version
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.napoleon',
'sphinx.ext.autosummary',
'sphinx.ext.autodoc',
'sphinx.ext.viewcode',
'sphinx_markdown_tables',
'sphinx_copybutton',
"sphinx_rtd_theme",
'sphinx.ext.mathjax',
'myst_parser',
]
# build the templated autosummary files
autosummary_generate = True
numpydoc_show_class_members = False
# Enable overriding of function signatures in the first line of the docstring.
autodoc_docstring_signature = True
# Disable docstring inheritance
autodoc_inherit_docstrings = False
# Show type hints in the description
autodoc_typehints = 'description'
# Add parameter types if the parameter is documented in the docstring
autodoc_typehints_description_target = 'documented_params'
autodoc_default_options = {
'member-order': 'bysource',
}
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
source_suffix = ['.rst', '.md']
# The master toctree document.
root_doc = 'index'
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['build']
# A list of glob-style patterns [1] that are used to find source files.
# They are matched against the source file names relative to the source directory,
# using slashes as directory separators on all platforms.
# The default is **, meaning that all files are recursively included from the source directory.
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
# html_theme = 'sphinx_book_theme'
# html_theme_path = [sphinx_book_theme.get_html_theme_path()]
# html_theme_options = {}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
# html_css_files = ['css/readthedocs.css']
# -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder.
# -- Extension configuration -------------------------------------------------
# Ignore >>> when copying code
copybutton_prompt_text = r'>>> |\.\.\. '
copybutton_prompt_is_regexp = True
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {'https://docs.python.org/': None}
myst_enable_extensions = [
'amsmath',
'dollarmath',
'colon_fence',
]

79
docs/en/index.rst Normal file
View File

@@ -0,0 +1,79 @@
Welcome to DiffSynth-Studio's Documentation
==========================================
.. toctree::
:maxdepth: 2
:caption: Documentation Introduction
README
.. toctree::
:maxdepth: 2
:caption: Getting Started
Pipeline_Usage/Setup
Pipeline_Usage/Model_Inference
Pipeline_Usage/VRAM_management
Pipeline_Usage/Model_Training
Pipeline_Usage/Environment_Variables
Pipeline_Usage/GPU_support
.. toctree::
:maxdepth: 2
:caption: Model Details
Model_Details/FLUX
Model_Details/Wan
Model_Details/Qwen-Image
Model_Details/FLUX2
Model_Details/Z-Image
Model_Details/Anima
Model_Details/LTX-2
.. toctree::
:maxdepth: 2
:caption: Training Framework
Training/Understanding_Diffusion_models
Training/Supervised_Fine_Tuning
Training/FP8_Precision
Training/Direct_Distill
Training/Split_Training
Training/Differential_LoRA
.. toctree::
:maxdepth: 2
:caption: Model Integration
Developer_Guide/Integrating_Your_Model
Developer_Guide/Building_a_Pipeline
Developer_Guide/Enabling_VRAM_management
Developer_Guide/Training_Diffusion_Models
.. toctree::
:maxdepth: 2
:caption: API Reference
API_Reference/core/attention
API_Reference/core/data
API_Reference/core/gradient
API_Reference/core/loader
API_Reference/core/vram
.. toctree::
:maxdepth: 2
:caption: Research Guide
Research_Tutorial/train_from_scratch
.. toctree::
:maxdepth: 2
:caption: FAQ
QA
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`

11
docs/requirements.txt Normal file
View File

@@ -0,0 +1,11 @@
docutils>=0.16.0
myst_parser
recommonmark
sphinx>=5.3.0
sphinx-book-theme
sphinx-copybutton
sphinx-autobuild
sphinx-rtd-theme
sphinx_markdown_tables
sphinxcontrib-mermaid
pymdown-extensions

28
docs/zh/.readthedocs.yaml Normal file
View File

@@ -0,0 +1,28 @@
# .readthedocs.yaml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
# Required
version: 2
# Set the OS, Python version and other tools you might need
build:
os: ubuntu-22.04
tools:
python: "3.10"
# Build documentation in the "docs/" directory with Sphinx
sphinx:
configuration: docs/zh/conf.py
# Optionally build your docs in additional formats such as PDF and ePub
# formats:
# - pdf
# - epub
# Optional but recommended, declare the Python requirements required
# to build your documentation
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
python:
install:
- requirements: docs/requirements.txt

View File

@@ -1,6 +1,6 @@
# `diffsynth.core.attention`: 注意力机制实现
`diffsynth.core.attention` 提供了注意力机制实现的路由机制,根据 `Python` 环境中的可用包和[环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation)自动选择高效的注意力机制实现。
`diffsynth.core.attention` 提供了注意力机制实现的路由机制,根据 `Python` 环境中的可用包和[环境变量](../../Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation)自动选择高效的注意力机制实现。
## 注意力机制
@@ -46,7 +46,7 @@ output_1 = attention(query, key, value)
* xFormers[GitHub](https://github.com/facebookresearch/xformers)、[文档](https://facebookresearch.github.io/xformers/components/ops.html#module-xformers.ops)
* PyTorch[GitHub](https://github.com/pytorch/pytorch)、[文档](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
如需调用除 `PyTorch` 外的其他注意力实现,请按照其 GitHub 页面的指引安装对应的包。`DiffSynth-Studio` 会自动根据 Python 环境中的可用包路由到对应的实现上,也可通过[环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation)控制。
如需调用除 `PyTorch` 外的其他注意力实现,请按照其 GitHub 页面的指引安装对应的包。`DiffSynth-Studio` 会自动根据 Python 环境中的可用包路由到对应的实现上,也可通过[环境变量](../../Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation)控制。
```python
from diffsynth.core.attention import attention_forward

Some files were not shown because too many files have changed in this diff Show More