Compare commits

...

239 Commits

Author SHA1 Message Date
mi804
2cefc20ed6 wanx tiled encode 2025-02-21 12:58:45 +08:00
mi804
02a4c8df9f wanx vae tile decode 2025-02-21 11:27:30 +08:00
mi804
582e33ad51 save_video 2025-02-20 17:57:38 +08:00
mi804
491bbf5369 support wanxvae 2025-02-20 17:44:20 +08:00
mi804
0c92f3b2cc support wanx prompter 2025-02-20 16:08:22 +08:00
Zhongjie Duan
427232cbc0 Merge pull request #328 from modelscope/stepvideo
Stepvideo low VRAM support!
2025-02-18 18:01:40 +08:00
Zhongjie Duan
2899283c01 Update stepvideo examples 2025-02-18 18:00:08 +08:00
Artiprocher
9cff769fbd optimize stepvideo vae 2025-02-18 17:28:05 +08:00
Zhongjie Duan
23e33273f1 Merge pull request #327 from modelscope/stepvideo
support stepvideo quantized
2025-02-17 19:44:41 +08:00
Artiprocher
f191353cf4 support stepvideo quantized 2025-02-17 19:43:47 +08:00
Zhongjie Duan
66a094fc84 Merge pull request #326 from modelscope/stepvideo
support stepvideo
2025-02-17 17:35:26 +08:00
Artiprocher
3681adc5ac support stepvideo 2025-02-17 17:32:25 +08:00
Zhongjie Duan
7434ec8fcd Merge pull request #324 from modelscope/vram_management
support vram management in flux
2025-02-14 10:54:55 +08:00
Artiprocher
0699212665 support vram management in flux 2025-02-13 15:11:39 +08:00
Zhongjie Duan
f47de78b59 Merge pull request #323 from mi804/eligen
update eligen dataset
2025-02-12 19:14:02 +08:00
mi804
5fdc8039ec update eligen dataset 2025-02-11 13:53:51 +08:00
Zhongjie Duan
46d4616e23 Update setup.py 2025-02-06 20:12:01 +08:00
Zhongjie Duan
2e597335be Merge pull request #320 from mi804/eligen
update eligen ui and readme
2025-01-24 16:40:45 +08:00
mi804
d346300162 update eligen ui and readme 2025-01-24 11:26:48 +08:00
Zhongjie Duan
1df7387f1b Merge pull request #318 from modelscope/hunyuanvideo-seed
fix rand device
2025-01-15 20:07:51 +08:00
Artiprocher
75d62a02d1 fix rand device 2025-01-15 19:30:38 +08:00
Zhongjie Duan
9db26879df Merge pull request #317 from mi804/eligen
update eligen logo_transfer
2025-01-14 17:49:03 +08:00
mi804
7beac7972e update eligen logo_transfer 2025-01-14 17:47:39 +08:00
Zhongjie Duan
72cac18d3e Merge pull request #316 from modelscope/teacache-hunyuanvideo
support teacache-hunyuanvideo
2025-01-14 14:48:04 +08:00
Artiprocher
9f8112ec34 support teacache-hunyuanvideo 2025-01-14 14:46:35 +08:00
Zhongjie Duan
d9fad821b2 Merge pull request #314 from modelscope/teacache
support teacache
2025-01-13 15:59:01 +08:00
Artiprocher
c0889c2564 support teacache 2025-01-13 15:56:33 +08:00
Zhongjie Duan
913591c13e Merge pull request #313 from modelscope/Artiprocher-patch-2
Update model_config.py
2025-01-12 11:15:18 +08:00
Zhongjie Duan
aaf13d6e4a Update model_config.py 2025-01-12 11:14:57 +08:00
Zhongjie Duan
90c07fec61 Merge pull request #312 from modelscope/HunyuanVideo-fp8
Update model_config.py
2025-01-11 20:41:20 +08:00
Zhongjie Duan
cc6c3c0807 Update model_config.py 2025-01-11 20:40:53 +08:00
Zhongjie Duan
ce2476ab9b Merge pull request #311 from mi804/eligen
update eligen readme's visualization
2025-01-09 16:26:54 +08:00
mi804
9e70c49317 update eligen readme 2025-01-09 16:22:39 +08:00
Zhongjie Duan
bf1c99645b Merge pull request #308 from mi804/eligen
fix bug for enable_eligen_on_negative
2025-01-09 15:52:03 +08:00
mi804
c2478ff284 update eligen examples and readme 2025-01-09 15:47:23 +08:00
mi804
a60bf3cd5f fix bug for enable_eligen_on_negative 2025-01-08 19:04:33 +08:00
Hong Zhang
34231907d0 Merge pull request #304 from modelscope/eligen-entity-transfer
add entity transfer example
2025-01-03 15:10:59 +08:00
Artiprocher
840dab58cd add entity transfer example 2025-01-03 14:40:37 +08:00
Zhongjie Duan
d5ceca0663 Merge pull request #303 from modelscope/eligen
Eligen
2025-01-03 10:47:26 +08:00
mi804
8cf3422688 update eligen ui 2025-01-03 10:37:34 +08:00
Artiprocher
6f743fc4b6 refine code 2025-01-02 19:54:09 +08:00
Zhongjie Duan
991b133bff Merge pull request #302 from modelscope/cache_latents
Update text_to_image.py
2025-01-02 14:23:33 +08:00
Zhongjie Duan
3b010043de Update text_to_image.py 2025-01-02 14:23:02 +08:00
Zhongjie Duan
088ea29e6e Merge pull request #301 from modelscope/Artiprocher-patch-1
Update model_config.py
2025-01-02 10:54:46 +08:00
Zhongjie Duan
b8b135ff73 Update model_config.py 2025-01-02 10:54:22 +08:00
mi804
2872fdaf48 update video of entity control 2024-12-31 18:09:29 +08:00
mi804
9853f83454 update readme video 2024-12-31 18:02:49 +08:00
mi804
fd6e661203 update readme 2024-12-31 17:50:20 +08:00
mi804
c087f68d74 update readme 2024-12-31 17:08:44 +08:00
mi804
b6620f3dde update_example entity control 2024-12-31 14:04:28 +08:00
Zhongjie Duan
3228c3e085 Support MERJIC's new model (#298)
* Update flux_dit.py
* Update model_config.py
2024-12-28 21:21:25 +08:00
Zhongjie Duan
6cc5fd6d1e Merge pull request #297 from modelscope/dev
Dev
2024-12-26 10:21:50 +08:00
Artiprocher
4f6d5e7074 hunyuanvideo step_processor 2024-12-26 10:20:59 +08:00
Artiprocher
6a999e1127 hunyuanvideo step_processor 2024-12-26 10:13:46 +08:00
mi804
e3d89cec0c temp commit for entity control 2024-12-25 17:19:31 +08:00
Zhongjie Duan
1b6e96a820 Merge pull request #296 from modelscope/dev
update hunyuanvideo examples
2024-12-24 10:48:11 +08:00
Artiprocher
e38ccf4c2f update hunyuanvideo examples 2024-12-24 10:47:26 +08:00
Zhongjie Duan
010c801081 Update hunyuanvideo_v2v_6G.py 2024-12-23 20:57:58 +08:00
Zhongjie Duan
edc9272e55 Merge pull request #295 from modelscope/dev
support hunyuanvideo v2v
2024-12-23 20:56:04 +08:00
Artiprocher
405ca6be33 support hunyuanvideo v2v 2024-12-23 20:43:47 +08:00
Zhongjie Duan
c06ea2271a Merge pull request #293 from modelscope/dev
hunyuanvideo quantization
2024-12-19 16:20:35 +08:00
Artiprocher
0692e8b1e1 hunyuanvideo quantization 2024-12-19 16:20:11 +08:00
Zhongjie Duan
aa23356420 Merge pull request #292 from modelscope/dev
hunyuanvideo examples
2024-12-19 13:29:51 +08:00
Zhongjie Duan
00a610e5ad Merge branch 'main' into dev 2024-12-19 13:29:40 +08:00
Artiprocher
2e39dcc0d3 hunyuanvideo examples 2024-12-19 13:28:44 +08:00
Zhongjie Duan
03d3a26f6f Merge pull request #291 from modelscope/dev
hunyuanvideo examples
2024-12-19 13:20:18 +08:00
Artiprocher
309fa9cf51 hunyuanvideo examples 2024-12-19 13:19:39 +08:00
Zhongjie Duan
65aab8adea Merge pull request #290 from modelscope/dev
Dev
2024-12-19 13:16:55 +08:00
Artiprocher
3d48b287a3 hunyuanvideo examples 2024-12-19 13:15:06 +08:00
Zhongjie Duan
29cebf0bec Update artaug_flux.py 2024-12-18 20:43:53 +08:00
Zhongjie Duan
95a0f0bedc Update README.md 2024-12-18 20:42:50 +08:00
Zhongjie Duan
77e0617861 Merge pull request #289 from modelscope/artaug
Artaug
2024-12-18 20:40:13 +08:00
Artiprocher
469a0405a1 ArtAug 2024-12-18 20:32:23 +08:00
Zhongjie Duan
46f191ffe7 Merge pull request #288 from mi804/hunyuanvideo
Hunyuanvideo
2024-12-18 19:40:23 +08:00
Artiprocher
ec7ac20def hunyuanvideo text encoder offload 2024-12-18 19:35:04 +08:00
mi804
3f410b0b77 hunyuanvideo_vae_encoder 2024-12-18 19:03:04 +08:00
mi804
8e06cac0df vae_encoder_weightsloading 2024-12-18 17:37:46 +08:00
Artiprocher
e5099f4e74 hunyuanvideo 2024-12-18 16:43:06 +08:00
Zhongjie Duan
447adef472 Merge pull request #287 from modelscope/dev-dzj
hunyuanvideo pipeline
2024-12-18 11:47:44 +08:00
Zhongjie Duan
a849b05e5a Merge branch 'dev' into dev-dzj 2024-12-18 11:47:34 +08:00
Artiprocher
b048f1b1de hunyuanvideo pipeline 2024-12-18 11:42:43 +08:00
Zhongjie Duan
f7848f9560 Merge pull request #286 from mi804/hunyuanvideo
hunyuanvideo_vae_decoder
2024-12-18 11:35:06 +08:00
mi804
236b56d285 hunyuanvideo_vae_decoder_model 2024-12-18 11:31:33 +08:00
Zhongjie Duan
42a717054a Merge branch 'dev' into hunyuanvideo 2024-12-18 11:21:33 +08:00
mi804
263166768e hunyuanvideo_vae_decoder 2024-12-18 11:14:57 +08:00
Zhongjie Duan
7a45b7efa7 Merge pull request #284 from modelscope/dev-dzj
hunyuanvideo dit
2024-12-17 14:50:21 +08:00
Zhongjie Duan
54ed532e3e Merge branch 'dev' into dev-dzj 2024-12-17 14:49:46 +08:00
Artiprocher
05e2028c5d hunyuanvideo dit 2024-12-17 14:45:23 +08:00
Zhongjie Duan
79249063b8 Merge pull request #283 from mi804/hunyuanvideo
hunyuanvideo text encoder
2024-12-17 14:42:46 +08:00
Zhongjie Duan
31ebec7a72 Merge pull request #282 from modelscope/lora-patch-2
support resume from opensource format
2024-12-16 12:26:37 +08:00
Artiprocher
919d399fdb support resume from opensource format 2024-12-16 12:25:05 +08:00
Zhongjie Duan
32a7a1487d Merge pull request #281 from modelscope/lora-patch
support resume training
2024-12-16 11:10:32 +08:00
Artiprocher
8c2671ce40 support resume training 2024-12-16 11:08:14 +08:00
root
5d1005a7c8 hunyuanvideo text encoder 2024-12-11 18:52:42 +08:00
Artiprocher
b84f906964 support artaug 2024-12-03 15:30:01 +08:00
Zhongjie Duan
7c0520d029 Merge pull request #277 from modelscope/sd35-lora
support sd35-lora
2024-11-29 12:35:32 +08:00
Artiprocher
9d09121fbc support sd35-lora 2024-11-29 11:45:40 +08:00
Zhongjie Duan
7f2a5424d4 Merge pull request #276 from modelscope/Artiprocher-patch-2
Update flux_ipadapter example
2024-11-28 10:44:29 +08:00
Zhongjie Duan
00830f0ecd Update flux_ipadapter.py 2024-11-28 10:44:07 +08:00
Zhongjie Duan
fd7737af7d Merge pull request #275 from mi804/flux_ipadapter
Flux ipadapter
2024-11-28 10:43:06 +08:00
root
f2130c4c25 minor 2024-11-26 19:08:41 +08:00
root
4f40683fd8 support flux ipadapter 2024-11-26 18:08:50 +08:00
Zhongjie Duan
5fc9e53eec Merge pull request #272 from modelscope/fix_kolors_pad
fix_kolors_pad
2024-11-21 14:50:21 +08:00
tc2000731
27e3cea285 fix_kolors_pad 2024-11-21 11:39:28 +08:00
Zhongjie Duan
ee770fa68f Merge pull request #271 from modelscope/sd35-series
Sd35 series
2024-11-20 09:54:41 +08:00
Artiprocher
9cb4aa16eb fix cogvideo height width checker 2024-11-20 09:51:31 +08:00
Zhongjie Duan
92d990629f Merge pull request #269 from modelscope/fix_image_resize
fix_image_resize
2024-11-18 19:24:57 +08:00
tc2000731
ba58f1bc0b fix_image_resize 2024-11-18 18:34:21 +08:00
Artiprocher
02fcfd530f support sd3.5 medium and large-turbo 2024-11-15 14:20:39 +08:00
Zhongjie Duan
095e8a3de8 Merge pull request #265 from modelscope/dev
support height width checker
2024-11-13 12:39:56 +08:00
Artiprocher
e17ad83fb5 support height width checker 2024-11-13 12:39:09 +08:00
Zhongjie Duan
e7c41151ec Merge pull request #264 from modelscope/dev
Dev
2024-11-13 09:53:49 +08:00
Artiprocher
7f4ba62d4f support size checker 2024-11-12 19:41:09 +08:00
Artiprocher
71b17a3a53 update mask blur 2024-11-12 19:20:17 +08:00
Artiprocher
d46b8b8fd7 bux fix 2024-11-12 10:17:01 +08:00
Artiprocher
a671070a28 bug fix 2024-11-11 21:01:38 +08:00
Zhongjie Duan
4600d5351b Update model_config.py 2024-11-11 19:26:30 +08:00
Zhongjie Duan
75bba5b8e5 Merge pull request #263 from modelscope/super-alignment
support mask blur
2024-11-11 19:24:30 +08:00
Artiprocher
8d1d1536d3 support mask blur 2024-11-11 18:59:55 +08:00
Zhongjie Duan
a7050a185b Merge pull request #262 from modelscope/sd3.5
Sd3.5
2024-11-11 18:47:49 +08:00
Zhongjie Duan
d345541c2d Merge pull request #261 from modelscope/omnigen
support omnigen
2024-11-11 18:47:09 +08:00
Artiprocher
bd028e4c66 support omnigen 2024-11-11 18:39:40 +08:00
Zhongjie Duan
d6f4fb67cc Merge pull request #260 from mi804/sd3.5
update default t5_sequence_length to 77
2024-11-11 16:39:31 +08:00
mi804
4378b540cf update t5_sequence_length 2024-11-11 16:28:17 +08:00
Artiprocher
39ddb7c3e3 support sd3.5 2024-11-06 19:57:01 +08:00
Zhongjie Duan
344cbd3286 Merge pull request #258 from modelscope/Artiprocher-patch-2
Update README.md
2024-11-05 19:09:04 +08:00
Zhongjie Duan
d4ba173b53 Update README.md 2024-11-05 19:08:52 +08:00
Zhongjie Duan
c56ce656b2 Merge pull request #252 from modelscope/Flux_ControlNet_Quantization
add Flux_ControlNet_Quantization
2024-11-01 14:51:10 +08:00
tc2000731
9377214518 update controlnet_frames, downloads 2024-10-31 17:38:57 +08:00
tc2000731
900a1c095f add Flux_ControlNet_Quantization 2024-10-29 17:29:24 +08:00
Zhongjie Duan
7e97a96840 Merge pull request #249 from modelscope/newpush
update noise generate
2024-10-25 16:43:37 +08:00
Zhongjie Duan
69f272d7ba Merge pull request #251 from modelscope/flux-examples
Flux examples
2024-10-25 16:35:47 +08:00
Artiprocher
a653554bd9 update examples 2024-10-25 16:30:35 +08:00
Artiprocher
6a25006544 update examples 2024-10-25 16:27:19 +08:00
Qianyi Zhao
8cfe4820f6 Update sd_video.py 2024-10-25 03:23:01 -05:00
Qianyi Zhao
c8021d4224 Update svd_video.py 2024-10-25 01:44:09 -05:00
Zhongjie Duan
3a64cc27b5 Merge pull request #250 from modelscope/flux-controlnet
Flux controlnet
2024-10-25 10:58:37 +08:00
Zhongjie Duan
2edc485ec1 Update requirements.txt 2024-10-25 00:16:11 +08:00
Artiprocher
a6d6553cee bug fix 2024-10-24 17:36:22 +08:00
Artiprocher
45feef9413 update model config 2024-10-24 16:10:15 +08:00
Artiprocher
105fe3961c update examples 2024-10-24 15:42:46 +08:00
Qianyi Zhao
d381c7b186 Update svd_video.py 2024-10-23 03:27:59 -05:00
Zhongjie Duan
5e8334c0bf Merge pull request #248 from modelscope/Artiprocher-patch-1
Update requirements.txt
2024-10-23 16:03:35 +08:00
Zhongjie Duan
2ea8a16afb Update requirements.txt 2024-10-23 16:03:21 +08:00
Artiprocher
aa054db1c7 bug fix 2024-10-23 14:24:41 +08:00
Artiprocher
07d70a6a56 support flux-controlnet 2024-10-22 18:52:24 +08:00
Qing112
747572e62c update noise generate 2024-10-21 15:09:21 +08:00
Zhongjie Duan
72ed76e89e Merge pull request #243 from modelscope/flux-lora
support preset lora
2024-10-21 14:04:44 +08:00
Artiprocher
a403cb04f3 support preset lora 2024-10-21 14:03:58 +08:00
Zhongjie Duan
ed71184854 Merge pull request #242 from modelscope/accelerate_load_model
accelerate load model
2024-10-21 10:00:09 +08:00
tc2000731
dfbf43e463 accelerate load model 2024-10-18 15:29:50 +08:00
Zhongjie Duan
7d7d72dcfe Merge pull request #239 from modelscope/flux-lora-update
Flux lora update
2024-10-14 19:12:33 +08:00
Artiprocher
540c036988 add alpha to lora converter 2024-10-14 18:57:54 +08:00
Artiprocher
58f89ceec9 update examples 2024-10-14 17:51:12 +08:00
Artiprocher
4e3a184199 update flux training 2024-10-14 10:00:32 +08:00
Zhongjie Duan
22e4ae99e8 Flux lora update (#237)
* update flux lora

---------

Co-authored-by: tc2000731 <tc2000731@163.com>
2024-10-11 18:41:24 +08:00
Zhongjie Duan
75ab786afc Merge pull request #234 from modelscope/doc-patch
Patch
2024-10-10 19:17:00 +08:00
Artiprocher
e5c72ba1f2 update examples 2024-10-10 18:26:37 +08:00
Artiprocher
66873d7d64 update examples 2024-10-10 18:23:43 +08:00
Artiprocher
a0d1d5bcea update examples 2024-10-10 17:25:55 +08:00
Artiprocher
fa0fa95bb6 update flux pipeline 2024-10-10 17:05:04 +08:00
Artiprocher
41ea2f811a update ESRGAN 2024-10-08 18:23:39 +08:00
Artiprocher
ec352cfce2 update model loader 2024-10-08 16:46:44 +08:00
Zhongjie Duan
aade874241 Merge pull request #232 from modelscope/Artiprocher-patch-1
Update README.md
2024-10-08 13:37:12 +08:00
Zhongjie Duan
c01eb653d7 Update README.md 2024-10-08 13:36:56 +08:00
Zhongjie Duan
892f80c265 Merge pull request #230 from modelscope/Artiprocher-dev
support ExVideo-CogVideoX-LoRA-129f-v1
2024-09-30 17:42:49 +08:00
Artiprocher
2e487a2c55 support ExVideo-CogVideoX-LoRA-129f-v1 2024-09-30 17:33:15 +08:00
Zhongjie Duan
a34e3ba338 Merge pull request #229 from modelscope/flux-enhance
support t5 sequence length
2024-09-30 15:33:51 +08:00
Artiprocher
c414f4cb12 support t5 sequence length 2024-09-30 14:45:30 +08:00
Zhongjie Duan
d91c603875 Flux fp8 lora training (#221)
* flux fp8 lora training

---------

Co-authored-by: tc2000731 <tc2000731@163.com>
2024-09-24 11:12:32 +08:00
Zhongjie Duan
7f899dcfca Merge pull request #216 from modelscope/Artiprocher-bugfix
bug fix
2024-09-19 12:27:22 +08:00
Artiprocher
5f12fd4346 bug fix 2024-09-19 12:26:46 +08:00
Zhongjie Duan
a7197f846b Merge pull request #215 from modelscope/flux-fp8
Support FLUX fp8
2024-09-19 10:36:16 +08:00
Artiprocher
ac81fa7a9f update examples 2024-09-19 10:33:30 +08:00
Artiprocher
091df1f1e7 support flux-fp8 2024-09-19 10:32:16 +08:00
tc2000731
a9fbfa108f float8_flux 2024-09-18 16:10:59 +08:00
Zhongjie Duan
44a8bf4143 Merge pull request #210 from modelscope/opensource-alignment
staticmethod
2024-09-14 17:18:19 +08:00
Artiprocher
3da8aa257b staticmethod 2024-09-14 17:16:59 +08:00
Zhongjie Duan
884dd749a0 Merge pull request #209 from modelscope/Artiprocher-patch-1
Update model_config.py
2024-09-14 11:42:30 +08:00
Zhongjie Duan
c697591d6e Update model_config.py 2024-09-14 11:41:47 +08:00
Zhongjie Duan
0b706e03e7 Merge pull request #208 from Qing112/main
update model_config and downloader
2024-09-14 11:40:42 +08:00
Qing112
447e75cd06 update model_config and downloader 2024-09-14 11:35:01 +08:00
Zhongjie Duan
7f76c8809c Merge pull request #207 from modelscope/flux-schnell
support flux-schnell
2024-09-14 11:17:59 +08:00
Artiprocher
cde1f81df6 support flux-schnell 2024-09-14 11:16:03 +08:00
Zhongjie Duan
c21ed1e478 Flux lora (#205) 2024-09-12 16:49:30 +08:00
Zhongjie Duan
a8cb4a21d1 align flux lora format (#204) 2024-09-12 16:01:27 +08:00
Zhongjie Duan
0b9e673fa2 Merge pull request #199 from modelscope/examples
update examples
2024-09-10 17:45:44 +08:00
Artiprocher
d242af8e22 update examples 2024-09-10 17:36:35 +08:00
Hong Zhang
76bd931d79 refine system_prompt for QwenPrompt (#198) 2024-09-10 15:15:23 +08:00
ZhouTianchen
995f3374f1 update omost (#190)
* update omost
2024-09-09 17:39:46 +08:00
Zhongjie Duan
1887885274 Merge pull request #197 from mi804/cpuoffload
add cpuoffload support for image pipelines
2024-09-09 14:48:26 +08:00
mi804
ce43cf412d add cpuoffload support for image pipelines 2024-09-09 13:50:52 +08:00
Zhongjie Duan
d1712f0594 Merge pull request #194 from modelscope/flux-lora
support flux training
2024-09-06 19:15:42 +08:00
Artiprocher
416b73b8c0 support flux training 2024-09-06 10:37:28 +08:00
Zhongjie Duan
4654aa0cab Merge pull request #188 from modelscope/qwen
support Qwen prompt refine
2024-09-04 17:22:56 +08:00
Zhongjie Duan
6f9d8f465a Merge branch 'main' into qwen 2024-09-04 17:22:38 +08:00
Artiprocher
e5e55345dc support qwen prompt refiner 2024-09-04 17:12:01 +08:00
Zhongjie Duan
8d6eb6d41a Merge pull request #187 from modelscope/omost
support Omost LLM
2024-09-04 12:52:23 +08:00
Zhongjie Duan
1118e67cec Merge branch 'main' into omost 2024-09-04 12:52:03 +08:00
Artiprocher
d70cd04b15 fix bugs 2024-09-04 12:48:32 +08:00
Zhongjie Duan
3d1db23224 Merge pull request #186 from modelscope/flux-lora
support flux lora inference
2024-09-04 09:47:08 +08:00
Artiprocher
a488810693 support flux lora inference 2024-09-04 09:39:39 +08:00
tc2000731
0b066d3cb4 add omost.py + omost_flux_example 2024-09-03 19:40:40 +08:00
Zhongjie Duan
d154bee18a support CogVideoX-5B (#184)
* support cogvideo

* update examples
2024-09-03 11:37:54 +08:00
Yudi
3a8694b642 add qwen prompt refiner 2024-08-27 17:28:32 +08:00
Zhongjie Duan
fe485b3fa1 Merge pull request #176 from modelscope/Artiprocher-dev
remove packages from requirements.txt
2024-08-26 15:02:59 +08:00
Artiprocher
e70eaa6a31 remove packages from requirements.txt 2024-08-26 15:01:35 +08:00
Zhongjie Duan
27ef67306d Merge pull request #175 from modelscope/Artiprocher-dev
model cache
2024-08-26 13:57:48 +08:00
Artiprocher
547aca3db2 model cache 2024-08-26 13:57:03 +08:00
Zhongjie Duan
5f7360e2ce Merge pull request #171 from modelscope/Artiprocher-dev
update README
2024-08-23 16:47:13 +08:00
Artiprocher
23f9675218 update README 2024-08-23 16:46:26 +08:00
Zhongjie Duan
ef1e82076c Merge pull request #170 from modelscope/Artiprocher-dev
update model config
2024-08-23 14:18:15 +08:00
Artiprocher
65d4588cc7 update model config 2024-08-23 14:17:10 +08:00
Zhongjie Duan
0488f90c8f Merge pull request #169 from modelscope/Artiprocher-dev
fix bug
2024-08-23 09:28:46 +08:00
Artiprocher
03d91f6618 fix bug 2024-08-23 09:28:10 +08:00
Zhongjie Duan
ae5e4b67dc Merge pull request #166 from modelscope/Artiprocher-dev
update examples
2024-08-22 11:48:50 +08:00
Artiprocher
a6c6e33d88 update examples 2024-08-22 11:41:48 +08:00
Zhongjie Duan
79d9bf7109 Merge pull request #165 from modelscope/Artiprocher-dev
update UI
2024-08-22 10:45:23 +08:00
Artiprocher
66e1b382cd update examples 2024-08-22 10:37:30 +08:00
Artiprocher
66f1ff43e9 update examples 2024-08-22 10:35:58 +08:00
Artiprocher
d6d14859e3 update UI 2024-08-21 16:57:56 +08:00
Zhongjie Duan
4478bb9bbe Merge pull request #164 from modelscope/Artiprocher-dev
FLUX highres-fix
2024-08-20 13:40:23 +08:00
Artiprocher
a6aaf9da2a support flux UI 2024-08-19 14:24:23 +08:00
Artiprocher
aa908ae0c2 support flux highresfix 2024-08-19 13:35:40 +08:00
Artiprocher
778a2d8f84 support flux highresfix 2024-08-19 13:35:27 +08:00
Zhongjie Duan
508baabf9a Merge pull request #160 from modelscope/Artiprocher-dev
support FLUX
2024-08-17 17:52:59 +08:00
Artiprocher
80aa4d8e19 update examples 2024-08-17 17:51:31 +08:00
Artiprocher
99e11112a7 support FLUX 2024-08-16 20:04:10 +08:00
Zhongjie Duan
1116e6dbc7 Merge pull request #155 from Qing112/main
add Flux text encoder
2024-08-14 11:28:14 +08:00
Qianyi Zhao
d1ac96c1ab add flux_text_encoder.py 2024-08-13 22:26:10 -05:00
Qianyi Zhao
abe88c899e add Flux text encoder 2024-08-14 10:46:52 +08:00
Zhongjie Duan
b1709fcbdb Merge pull request #145 from modelscope/Artiprocher-dev
chatglm quantize
2024-08-02 15:09:41 +08:00
Artiprocher
ec877bf490 chatglm quantize 2024-08-02 14:46:29 +08:00
Zhongjie Duan
a8f1812acf Merge pull request #144 from modelscope/Artiprocher-dev
UI update
2024-08-02 13:49:48 +08:00
Artiprocher
6877b460c4 fix bugs 2024-08-02 13:47:07 +08:00
Artiprocher
f189f9f1be update UI 2024-08-02 10:31:25 +08:00
Artiprocher
6f79fd6d77 support sdxl controlnet union 2024-08-01 10:01:39 +08:00
Zhongjie Duan
60d7bb52d6 Update README.md 2024-07-30 10:42:43 +08:00
Yingda Chen
65a2a0643a add badges 2024-07-30 10:32:03 +08:00
150 changed files with 1599088 additions and 951 deletions

150
README.md
View File

@@ -1,15 +1,26 @@
# DiffSynth Studio # DiffSynth Studio
[![PyPI](https://img.shields.io/pypi/v/DiffSynth)](https://pypi.org/project/DiffSynth/)
[![license](https://img.shields.io/github/license/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)
[![open issues](https://isitmaintained.com/badge/open/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/issues)
[![GitHub pull-requests](https://img.shields.io/github/issues-pr/modelscope/DiffSynth-Studio.svg)](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
[![GitHub latest commit](https://badgen.net/github/last-commit/modelscope/DiffSynth-Studio)](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
<p align="center"> <p align="center">
<a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a> <a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
</p> </p>
Document: https://diffsynth-studio.readthedocs.io/zh-cn/latest/index.html
## Introduction ## Introduction
DiffSynth Studio is a Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. We provide many interesting features. Enjoy the magic of Diffusion models! DiffSynth Studio is a Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. We provide many interesting features. Enjoy the magic of Diffusion models!
Until now, DiffSynth Studio has supported the following models: Until now, DiffSynth Studio has supported the following models:
* [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V)
* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo)
* [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b)
* [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev)
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) * [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
* [Kolors](https://huggingface.co/Kwai-Kolors/Kolors) * [Kolors](https://huggingface.co/Kwai-Kolors/Kolors)
* [Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium) * [Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
@@ -25,6 +36,38 @@ Until now, DiffSynth Studio has supported the following models:
## News ## News
- **February 17, 2024** We support [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)! State-of-the-art video synthesis model! See [./examples/stepvideo](./examples/stepvideo/).
- **December 31, 2024** We propose EliGen, a novel framework for precise entity-level controlled text-to-image generation, complemented by an inpainting fusion pipeline to extend its capabilities to image inpainting tasks. EliGen seamlessly integrates with existing community models, such as IP-Adapter and In-Context LoRA, enhancing its versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/).
- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)
- Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
- Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
- **December 19, 2024** We implement advanced VRAM management for HunyuanVideo, making it possible to generate videos at a resolution of 129x720x1280 using 24GB of VRAM, or at 129x512x384 resolution with just 6GB of VRAM. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
- **December 18, 2024** We propose ArtAug, an approach designed to improve text-to-image synthesis models through synthesis-understanding interactions. We have trained an ArtAug enhancement module for FLUX.1-dev in the format of LoRA. This model integrates the aesthetic understanding of Qwen2-VL-72B into FLUX.1-dev, leading to an improvement in the quality of generated images.
- Paper: https://arxiv.org/abs/2412.12888
- Examples: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
- Demo: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (Coming soon)
- **October 25, 2024** We provide extensive FLUX ControlNet support. This project supports many different ControlNet models that can be freely combined, even if their structures differ. Additionally, ControlNet models are compatible with high-resolution refinement and partition control techniques, enabling very powerful controllable image generation. See [`./examples/ControlNet/`](./examples/ControlNet/).
- **October 8, 2024.** We release the extended LoRA based on CogVideoX-5B and ExVideo. You can download this model from [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) or [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1).
- **August 22, 2024.** CogVideoX-5B is supported in this project. See [here](/examples/video_synthesis/). We provide several interesting features for this text-to-video model, including
- Text to video
- Video editing
- Self-upscaling
- Video interpolation
- **August 22, 2024.** We have implemented an interesting painter that supports all text-to-image models. Now you can create stunning images using the painter, with assistance from AI!
- Use it in our [WebUI](#usage-in-webui).
- **August 21, 2024.** FLUX is supported in DiffSynth-Studio.
- Enable CFG and highres-fix to improve visual quality. See [here](/examples/image_synthesis/README.md)
- LoRA, ControlNet, and additional models will be available soon.
- **June 21, 2024.** 🔥🔥🔥 We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames. - **June 21, 2024.** 🔥🔥🔥 We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
- [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/) - [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
@@ -67,37 +110,64 @@ Until now, DiffSynth Studio has supported the following models:
## Installation ## Installation
Install from source code (recommended):
``` ```
git clone https://github.com/modelscope/DiffSynth-Studio.git git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio cd DiffSynth-Studio
pip install -e . pip install -e .
``` ```
Or install from pypi:
```
pip install diffsynth
```
## Usage (in Python code) ## Usage (in Python code)
The Python examples are in [`examples`](./examples/). We provide an overview here. The Python examples are in [`examples`](./examples/). We provide an overview here.
### Long Video Synthesis ### Download Models
We trained an extended video synthesis model, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/) Download the pre-set models. Model IDs can be found in [config file](/diffsynth/configs/model_config.py).
```python
from diffsynth import download_models
download_models(["FLUX.1-dev", "Kolors"])
```
Download your own models.
```python
from diffsynth.models.downloader import download_from_huggingface, download_from_modelscope
# From Modelscope (recommended)
download_from_modelscope("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.fp16.bin", "models/kolors/Kolors/vae")
# From Huggingface
download_from_huggingface("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.fp16.safetensors", "models/kolors/Kolors/vae")
```
### Video Synthesis
#### Text-to-video using CogVideoX-5B
CogVideoX-5B is released by ZhiPu. We provide an improved pipeline, supporting text-to-video, video editing, self-upscaling and video interpolation. [`examples/video_synthesis`](./examples/video_synthesis/)
The video on the left is generated using the original text-to-video pipeline, while the video on the right is the result after editing and frame interpolation.
https://github.com/user-attachments/assets/26b044c1-4a60-44a4-842f-627ff289d006
#### Long Video Synthesis
We trained extended video synthesis models, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/)
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
### Image Synthesis https://github.com/user-attachments/assets/321ee04b-8c17-479e-8a95-8cbcf21f8d7e
Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/). #### Toon Shading
LoRA fine-tuning is supported in [`examples/train`](./examples/train/).
|Model|Example|
|-|-|
|Stable Diffusion|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|
|Stable Diffusion XL|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|
|Stable Diffusion 3|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)|
|Kolors|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/53ef6f41-da11-4701-8665-9f64392607bf)|
|Hunyuan-DiT|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/60b022c8-df3f-4541-95ab-bf39f2fa8bb5)|
### Toon Shading
Render realistic videos in a flatten style and enable video editing features. [`examples/Diffutoon`](./examples/Diffutoon/) Render realistic videos in a flatten style and enable video editing features. [`examples/Diffutoon`](./examples/Diffutoon/)
@@ -105,16 +175,60 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-47
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/20528af5-5100-474a-8cdc-440b9efdd86c https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/20528af5-5100-474a-8cdc-440b9efdd86c
### Video Stylization #### Video Stylization
Video stylization without video models. [`examples/diffsynth`](./examples/diffsynth/) Video stylization without video models. [`examples/diffsynth`](./examples/diffsynth/)
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
### Image Synthesis
Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/).
LoRA fine-tuning is supported in [`examples/train`](./examples/train/).
|FLUX|Stable Diffusion 3|
|-|-|
|![image_1024_cfg](https://github.com/user-attachments/assets/984561e9-553d-4952-9443-79ce144f379f)|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)|
|Kolors|Hunyuan-DiT|
|-|-|
|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/53ef6f41-da11-4701-8665-9f64392607bf)|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/60b022c8-df3f-4541-95ab-bf39f2fa8bb5)|
|Stable Diffusion|Stable Diffusion XL|
|-|-|
|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|
## Usage (in WebUI) ## Usage (in WebUI)
Create stunning images using the painter, with assistance from AI!
https://github.com/user-attachments/assets/95265d21-cdd6-4125-a7cb-9fbcf6ceb7b0
**This video is not rendered in real-time.**
Before launching the WebUI, please download models to the folder `./models`. See [here](#download-models).
* `Gradio` version
``` ```
python -m streamlit run DiffSynth_Studio.py pip install gradio
```
```
python apps/gradio/DiffSynth_Studio.py
```
![20240822102002](https://github.com/user-attachments/assets/59613157-de51-4109-99b3-97cbffd88076)
* `Streamlit` version
```
pip install streamlit streamlit-drawable-canvas
```
```
python -m streamlit run apps/streamlit/DiffSynth_Studio.py
``` ```
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/93085557-73f3-4eee-a205-9829591ef954 https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/93085557-73f3-4eee-a205-9829591ef954

View File

@@ -0,0 +1,252 @@
import gradio as gr
from diffsynth import ModelManager, SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
import os, torch
from PIL import Image
import numpy as np
config = {
"model_config": {
"Stable Diffusion": {
"model_folder": "models/stable_diffusion",
"pipeline_class": SDImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
"height": 512,
"width": 512,
}
},
"Stable Diffusion XL": {
"model_folder": "models/stable_diffusion_xl",
"pipeline_class": SDXLImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
}
},
"Stable Diffusion 3": {
"model_folder": "models/stable_diffusion_3",
"pipeline_class": SD3ImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
}
},
"Stable Diffusion XL Turbo": {
"model_folder": "models/stable_diffusion_xl_turbo",
"pipeline_class": SDXLImagePipeline,
"default_parameters": {
"negative_prompt": "",
"cfg_scale": 1.0,
"num_inference_steps": 1,
"height": 512,
"width": 512,
}
},
"Kolors": {
"model_folder": "models/kolors",
"pipeline_class": SDXLImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
}
},
"HunyuanDiT": {
"model_folder": "models/HunyuanDiT",
"pipeline_class": HunyuanDiTImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
}
},
"FLUX": {
"model_folder": "models/FLUX",
"pipeline_class": FluxImagePipeline,
"default_parameters": {
"cfg_scale": 1.0,
}
}
},
"max_num_painter_layers": 8,
"max_num_model_cache": 1,
}
def load_model_list(model_type):
if model_type is None:
return []
folder = config["model_config"][model_type]["model_folder"]
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
file_list = sorted(file_list)
return file_list
def load_model(model_type, model_path):
global model_dict
model_key = f"{model_type}:{model_path}"
if model_key in model_dict:
return model_dict[model_key]
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
model_manager = ModelManager()
if model_type == "HunyuanDiT":
model_manager.load_models([
os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
os.path.join(model_path, "mt5/pytorch_model.bin"),
os.path.join(model_path, "model/pytorch_model_ema.pt"),
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
])
elif model_type == "Kolors":
model_manager.load_models([
os.path.join(model_path, "text_encoder"),
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
])
elif model_type == "FLUX":
model_manager.torch_dtype = torch.bfloat16
file_list = [
os.path.join(model_path, "text_encoder/model.safetensors"),
os.path.join(model_path, "text_encoder_2"),
]
for file_name in os.listdir(model_path):
if file_name.endswith(".safetensors"):
file_list.append(os.path.join(model_path, file_name))
model_manager.load_models(file_list)
else:
model_manager.load_model(model_path)
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
while len(model_dict) + 1 > config["max_num_model_cache"]:
key = next(iter(model_dict.keys()))
model_manager_to_release, _ = model_dict[key]
model_manager_to_release.to("cpu")
del model_dict[key]
torch.cuda.empty_cache()
model_dict[model_key] = model_manager, pipe
return model_manager, pipe
model_dict = {}
with gr.Blocks() as app:
gr.Markdown("# DiffSynth-Studio Painter")
with gr.Row():
with gr.Column(scale=382, min_width=100):
with gr.Accordion(label="Model"):
model_type = gr.Dropdown(choices=[i for i in config["model_config"]], label="Model type")
model_path = gr.Dropdown(choices=[], interactive=True, label="Model path")
@gr.on(inputs=model_type, outputs=model_path, triggers=model_type.change)
def model_type_to_model_path(model_type):
return gr.Dropdown(choices=load_model_list(model_type))
with gr.Accordion(label="Prompt"):
prompt = gr.Textbox(label="Prompt", lines=3)
negative_prompt = gr.Textbox(label="Negative prompt", lines=1)
cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=7.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.1, interactive=True, label="Embedded guidance scale (only for FLUX)")
with gr.Accordion(label="Image"):
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, interactive=True, label="Inference steps")
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
with gr.Column():
use_fixed_seed = gr.Checkbox(value=True, interactive=False, label="Use fixed seed")
seed = gr.Number(minimum=0, maximum=10**9, value=0, interactive=True, label="Random seed", show_label=False)
@gr.on(
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
outputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
triggers=model_path.change
)
def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width):
load_model(model_type, model_path)
cfg_scale = config["model_config"][model_type]["default_parameters"].get("cfg_scale", cfg_scale)
embedded_guidance = config["model_config"][model_type]["default_parameters"].get("embedded_guidance", embedded_guidance)
num_inference_steps = config["model_config"][model_type]["default_parameters"].get("num_inference_steps", num_inference_steps)
height = config["model_config"][model_type]["default_parameters"].get("height", height)
width = config["model_config"][model_type]["default_parameters"].get("width", width)
return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width
with gr.Column(scale=618, min_width=100):
with gr.Accordion(label="Painter"):
enable_local_prompt_list = []
local_prompt_list = []
mask_scale_list = []
canvas_list = []
for painter_layer_id in range(config["max_num_painter_layers"]):
with gr.Tab(label=f"Layer {painter_layer_id}"):
enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}")
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
mask_scale = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Mask scale", key=f"mask_scale_{painter_layer_id}")
canvas = gr.ImageEditor(canvas_size=(512, 1), sources=None, layers=False, interactive=True, image_mode="RGBA",
brush=gr.Brush(default_size=100, default_color="#000000", colors=["#000000"]),
label="Painter", key=f"canvas_{painter_layer_id}")
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear, enable_local_prompt.change], show_progress="hidden")
def resize_canvas(height, width, canvas):
h, w = canvas["background"].shape[:2]
if h != height or width != w:
return np.ones((height, width, 3), dtype=np.uint8) * 255
else:
return canvas
enable_local_prompt_list.append(enable_local_prompt)
local_prompt_list.append(local_prompt)
mask_scale_list.append(mask_scale)
canvas_list.append(canvas)
with gr.Accordion(label="Results"):
run_button = gr.Button(value="Generate", variant="primary")
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
with gr.Row():
with gr.Column():
output_to_painter_button = gr.Button(value="Set as painter's background")
with gr.Column():
output_to_input_button = gr.Button(value="Set as input image")
painter_background = gr.State(None)
input_background = gr.State(None)
@gr.on(
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list,
outputs=[output_image],
triggers=run_button.click
)
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed, *args, progress=gr.Progress()):
_, pipe = load_model(model_type, model_path)
input_params = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"cfg_scale": cfg_scale,
"num_inference_steps": num_inference_steps,
"height": height,
"width": width,
"progress_bar_cmd": progress.tqdm,
}
if isinstance(pipe, FluxImagePipeline):
input_params["embedded_guidance"] = embedded_guidance
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = (
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
args[2 * config["max_num_painter_layers"]: 3 * config["max_num_painter_layers"]],
args[3 * config["max_num_painter_layers"]: 4 * config["max_num_painter_layers"]]
)
local_prompts, masks, mask_scales = [], [], []
for enable_local_prompt, local_prompt, mask_scale, canvas in zip(
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list
):
if enable_local_prompt:
local_prompts.append(local_prompt)
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
mask_scales.append(mask_scale)
input_params.update({
"local_prompts": local_prompts,
"masks": masks,
"mask_scales": mask_scales,
})
torch.manual_seed(seed)
image = pipe(**input_params)
return image
@gr.on(inputs=[output_image] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
def send_output_to_painter_background(output_image, *canvas_list):
for canvas in canvas_list:
h, w = canvas["background"].shape[:2]
canvas["background"] = output_image.resize((w, h))
return tuple(canvas_list)
app.launch()

View File

@@ -0,0 +1,390 @@
import os
import torch
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import random
import json
import gradio as gr
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
from modelscope import dataset_snapshot_download
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/*")
example_json = 'data/examples/eligen/entity_control/ui_examples.json'
with open(example_json, 'r') as f:
examples = json.load(f)['examples']
for idx in range(len(examples)):
example_id = examples[idx]['example_id']
entity_prompts = examples[idx]['local_prompt_list']
examples[idx]['mask_lists'] = [Image.open(f"data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
def create_canvas_data(background, masks):
if background.shape[-1] == 3:
background = np.dstack([background, np.full(background.shape[:2], 255, dtype=np.uint8)])
layers = []
for mask in masks:
if mask is not None:
mask_single_channel = mask if mask.ndim == 2 else mask[..., 0]
layer = np.zeros((mask_single_channel.shape[0], mask_single_channel.shape[1], 4), dtype=np.uint8)
layer[..., -1] = mask_single_channel
layers.append(layer)
else:
layers.append(np.zeros_like(background))
composite = background.copy()
for layer in layers:
if layer.size > 0:
composite = np.where(layer[..., -1:] > 0, layer, composite)
return {
"background": background,
"layers": layers,
"composite": composite,
}
def load_example(load_example_button):
example_idx = int(load_example_button.split()[-1]) - 1
example = examples[example_idx]
result = [
50,
example["global_prompt"],
example["negative_prompt"],
example["seed"],
*example["local_prompt_list"],
]
num_entities = len(example["local_prompt_list"])
result += [""] * (config["max_num_painter_layers"] - num_entities)
masks = []
for mask in example["mask_lists"]:
mask_single_channel = np.array(mask.convert("L"))
masks.append(mask_single_channel)
for _ in range(config["max_num_painter_layers"] - len(masks)):
blank_mask = np.zeros_like(masks[0]) if masks else np.zeros((512, 512), dtype=np.uint8)
masks.append(blank_mask)
background = np.ones((masks[0].shape[0], masks[0].shape[1], 4), dtype=np.uint8) * 255
canvas_data_list = []
for mask in masks:
canvas_data = create_canvas_data(background, [mask])
canvas_data_list.append(canvas_data)
result.extend(canvas_data_list)
return result
def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'):
save_dir = os.path.join('workdirs/tmp_mask', random_dir)
print(f'save to {save_dir}')
os.makedirs(save_dir, exist_ok=True)
for i, mask in enumerate(masks):
save_path = os.path.join(save_dir, f'{i}.png')
mask.save(save_path)
sample = {
"global_prompt": global_prompt,
"mask_prompts": mask_prompts,
"seed": seed,
}
with open(os.path.join(save_dir, f"prompts.json"), 'w') as f:
json.dump(sample, f, indent=4)
def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False):
# Create a blank image for overlays
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
colors = [
(165, 238, 173, 80),
(76, 102, 221, 80),
(221, 160, 77, 80),
(204, 93, 71, 80),
(145, 187, 149, 80),
(134, 141, 172, 80),
(157, 137, 109, 80),
(153, 104, 95, 80),
(165, 238, 173, 80),
(76, 102, 221, 80),
(221, 160, 77, 80),
(204, 93, 71, 80),
(145, 187, 149, 80),
(134, 141, 172, 80),
(157, 137, 109, 80),
(153, 104, 95, 80),
]
# Generate random colors for each mask
if use_random_colors:
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
# Font settings
try:
font = ImageFont.truetype("arial", font_size) # Adjust as needed
except IOError:
font = ImageFont.load_default(font_size)
# Overlay each mask onto the overlay image
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
if mask is None:
continue
# Convert mask to RGBA mode
mask_rgba = mask.convert('RGBA')
mask_data = mask_rgba.getdata()
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
mask_rgba.putdata(new_data)
# Draw the mask prompt text on the mask
draw = ImageDraw.Draw(mask_rgba)
mask_bbox = mask.getbbox() # Get the bounding box of the mask
if mask_bbox is None:
continue
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
# Alpha composite the overlay with this mask
overlay = Image.alpha_composite(overlay, mask_rgba)
# Composite the overlay onto the original image
result = Image.alpha_composite(image.convert('RGBA'), overlay)
return result
config = {
"model_config": {
"FLUX": {
"model_folder": "models/FLUX",
"pipeline_class": FluxImagePipeline,
"default_parameters": {
"cfg_scale": 3.0,
"embedded_guidance": 3.5,
"num_inference_steps": 30,
}
},
},
"max_num_painter_layers": 8,
"max_num_model_cache": 1,
}
model_dict = {}
def load_model(model_type='FLUX', model_path='FLUX.1-dev'):
global model_dict
model_key = f"{model_type}:{model_path}"
if model_key in model_dict:
return model_dict[model_key]
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
model_manager.load_lora(
download_customized_models(
model_id="DiffSynth-Studio/Eligen",
origin_file_path="model_bf16.safetensors",
local_dir="models/lora/entity_control",
),
lora_alpha=1,
)
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
model_dict[model_key] = model_manager, pipe
return model_manager, pipe
with gr.Blocks() as app:
gr.Markdown(
"""## EliGen: Entity-Level Controllable Text-to-Image Model
1. On the left, input the **global prompt** for the overall image, such as "a person stands by the river."
2. On the right, input the **local prompt** for each entity, such as "person," and draw the corresponding mask in the **Entity Mask Painter**. Generally, solid rectangular masks yield better results.
3. Click the **Generate** button to create the image. By selecting different **random seeds**, you can generate diverse images.
4. **You can directly click the "Load Example" button on any sample at the bottom to load example inputs.**
"""
)
loading_status = gr.Textbox(label="Loading Model...", value="Loading model... Please wait...", visible=True)
main_interface = gr.Column(visible=False)
def initialize_model():
try:
load_model()
return {
loading_status: gr.update(value="Model loaded successfully!", visible=False),
main_interface: gr.update(visible=True),
}
except Exception as e:
print(f'Failed to load model with error: {e}')
return {
loading_status: gr.update(value=f"Failed to load model: {str(e)}", visible=True),
main_interface: gr.update(visible=True),
}
app.load(initialize_model, inputs=None, outputs=[loading_status, main_interface])
with main_interface:
with gr.Row():
local_prompt_list = []
canvas_list = []
random_mask_dir = gr.State(f'{random.randint(0, 1000000):08d}')
with gr.Column(scale=382, min_width=100):
model_type = gr.State('FLUX')
model_path = gr.State('FLUX.1-dev')
with gr.Accordion(label="Global prompt"):
prompt = gr.Textbox(label="Global Prompt", lines=3)
negative_prompt = gr.Textbox(label="Negative prompt", value="worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw, blur,", lines=3)
with gr.Accordion(label="Inference Options", open=True):
seed = gr.Number(minimum=0, maximum=10**9, value=42, interactive=True, label="Random seed", show_label=True)
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, interactive=True, label="Inference steps")
cfg_scale = gr.Slider(minimum=2.0, maximum=10.0, value=3.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=3.5, step=0.1, interactive=True, label="Embedded guidance scale")
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
with gr.Accordion(label="Inpaint Input Image", open=False):
input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil")
background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight", visible=False)
with gr.Column():
reset_input_button = gr.Button(value="Reset Inpaint Input")
send_input_to_painter = gr.Button(value="Set as painter's background")
@gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click)
def reset_input_image(input_image):
return None
with gr.Column(scale=618, min_width=100):
with gr.Accordion(label="Entity Painter"):
for painter_layer_id in range(config["max_num_painter_layers"]):
with gr.Tab(label=f"Entity {painter_layer_id}"):
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
canvas = gr.ImageEditor(
canvas_size=(512, 512),
sources=None,
layers=False,
interactive=True,
image_mode="RGBA",
brush=gr.Brush(
default_size=50,
default_color="#000000",
colors=["#000000"],
),
label="Entity Mask Painter",
key=f"canvas_{painter_layer_id}",
width=width,
height=height,
)
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear], show_progress="hidden")
def resize_canvas(height, width, canvas):
h, w = canvas["background"].shape[:2]
if h != height or width != w:
return np.ones((height, width, 3), dtype=np.uint8) * 255
else:
return canvas
local_prompt_list.append(local_prompt)
canvas_list.append(canvas)
with gr.Accordion(label="Results"):
run_button = gr.Button(value="Generate", variant="primary")
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
with gr.Row():
with gr.Column():
output_to_painter_button = gr.Button(value="Set as painter's background")
with gr.Column():
return_with_mask = gr.Checkbox(value=False, interactive=True, label="show result with mask painting")
output_to_input_button = gr.Button(value="Set as input image", visible=False, interactive=False)
real_output = gr.State(None)
mask_out = gr.State(None)
@gr.on(
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + local_prompt_list + canvas_list,
outputs=[output_image, real_output, mask_out],
triggers=run_button.click
)
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()):
_, pipe = load_model(model_type, model_path)
input_params = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"cfg_scale": cfg_scale,
"num_inference_steps": num_inference_steps,
"height": height,
"width": width,
"progress_bar_cmd": progress.tqdm,
}
if isinstance(pipe, FluxImagePipeline):
input_params["embedded_guidance"] = embedded_guidance
if input_image is not None:
input_params["input_image"] = input_image.resize((width, height)).convert("RGB")
input_params["enable_eligen_inpaint"] = True
local_prompt_list, canvas_list = (
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
)
local_prompts, masks = [], []
for local_prompt, canvas in zip(local_prompt_list, canvas_list):
if isinstance(local_prompt, str) and len(local_prompt) > 0:
local_prompts.append(local_prompt)
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
entity_masks = None if len(masks) == 0 else masks
entity_prompts = None if len(local_prompts) == 0 else local_prompts
input_params.update({
"eligen_entity_prompts": entity_prompts,
"eligen_entity_masks": entity_masks,
})
torch.manual_seed(seed)
# save_mask_prompts(masks, local_prompts, prompt, seed, random_mask_dir)
image = pipe(**input_params)
masks = [mask.resize(image.size) for mask in masks]
image_with_mask = visualize_masks(image, masks, local_prompts)
real_output = gr.State(image)
mask_out = gr.State(image_with_mask)
if return_with_mask:
return image_with_mask, real_output, mask_out
return image, real_output, mask_out
@gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click)
def send_input_to_painter_background(input_image, *canvas_list):
if input_image is None:
return tuple(canvas_list)
for canvas in canvas_list:
h, w = canvas["background"].shape[:2]
canvas["background"] = input_image.resize((w, h))
return tuple(canvas_list)
@gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
def send_output_to_painter_background(real_output, *canvas_list):
if real_output is None:
return tuple(canvas_list)
for canvas in canvas_list:
h, w = canvas["background"].shape[:2]
canvas["background"] = real_output.value.resize((w, h))
return tuple(canvas_list)
@gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden")
def show_output(return_with_mask, real_output, mask_out):
if return_with_mask:
return mask_out.value
else:
return real_output.value
@gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click)
def send_output_to_pipe_input(real_output):
return real_output.value
with gr.Column():
gr.Markdown("## Examples")
for i in range(0, len(examples), 2):
with gr.Row():
if i < len(examples):
example = examples[i]
with gr.Column():
example_image = gr.Image(
value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png",
label=example["description"],
interactive=False,
width=1024,
height=512
)
load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
load_example_button.click(
load_example,
inputs=[load_example_button],
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
)
if i + 1 < len(examples):
example = examples[i + 1]
with gr.Column():
example_image = gr.Image(
value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png",
label=example["description"],
interactive=False,
width=1024,
height=512
)
load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
load_example_button.click(
load_example,
inputs=[load_example_button],
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
)
app.config["show_progress"] = "hidden"
app.launch()

View File

@@ -1,11 +1,11 @@
import torch, os, io import torch, os, io, json, time
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import streamlit as st import streamlit as st
st.set_page_config(layout="wide") st.set_page_config(layout="wide")
from streamlit_drawable_canvas import st_canvas from streamlit_drawable_canvas import st_canvas
from diffsynth.models import ModelManager from diffsynth.models import ModelManager
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
from diffsynth.data.video import crop_and_resize from diffsynth.data.video import crop_and_resize
@@ -49,13 +49,20 @@ config = {
"width": 1024, "width": 1024,
} }
}, },
"FLUX": {
"model_folder": "models/FLUX",
"pipeline_class": FluxImagePipeline,
"fixed_parameters": {
"cfg_scale": 1.0,
}
}
} }
def load_model_list(model_type): def load_model_list(model_type):
folder = config[model_type]["model_folder"] folder = config[model_type]["model_folder"]
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")] file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
if model_type in ["HunyuanDiT", "Kolors"]: if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))] file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
file_list = sorted(file_list) file_list = sorted(file_list)
return file_list return file_list
@@ -85,6 +92,16 @@ def load_model(model_type, model_path):
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"), os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"), os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
]) ])
elif model_type == "FLUX":
model_manager.torch_dtype = torch.bfloat16
file_list = [
os.path.join(model_path, "text_encoder/model.safetensors"),
os.path.join(model_path, "text_encoder_2"),
]
for file_name in os.listdir(model_path):
if file_name.endswith(".safetensors"):
file_list.append(os.path.join(model_path, file_name))
model_manager.load_models(file_list)
else: else:
model_manager.load_model(model_path) model_manager.load_model(model_path)
pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager) pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)
@@ -255,6 +272,48 @@ with column_input:
key="canvas" key="canvas"
) )
num_painter_layer = st.number_input("Number of painter layers", min_value=0, max_value=10, step=1, value=0)
local_prompts, masks, mask_scales = [], [], []
white_board = Image.fromarray(np.ones((512, 512, 3), dtype=np.uint8) * 255)
painter_layers_json_data = []
for painter_tab_id in range(num_painter_layer):
with st.expander(f"Painter layer {painter_tab_id}", expanded=True):
enable_local_prompt = st.checkbox(f"Enable prompt {painter_tab_id}", value=True)
local_prompt = st.text_area(f"Prompt {painter_tab_id}")
mask_scale = st.slider(f"Mask scale {painter_tab_id}", min_value=0.0, max_value=3.0, value=1.0)
stroke_width = st.slider(f"Stroke width {painter_tab_id}", min_value=1, max_value=300, value=100)
canvas_result_local = st_canvas(
fill_color="#000000",
stroke_width=stroke_width,
stroke_color="#000000",
background_color="rgba(255, 255, 255, 0)",
background_image=white_board,
update_streamlit=True,
height=512,
width=512,
drawing_mode="freedraw",
key=f"canvas_{painter_tab_id}"
)
if canvas_result_local.json_data is not None:
painter_layers_json_data.append(canvas_result_local.json_data.copy())
painter_layers_json_data[-1]["prompt"] = local_prompt
if enable_local_prompt:
local_prompts.append(local_prompt)
if canvas_result_local.image_data is not None:
mask = apply_stroke_to_image(canvas_result_local.image_data, white_board)
else:
mask = white_board
mask = Image.fromarray(255 - np.array(mask))
masks.append(mask)
mask_scales.append(mask_scale)
save_painter_layers = st.button("Save painter layers")
if save_painter_layers:
os.makedirs("data/painter_layers", exist_ok=True)
json_file_path = f"data/painter_layers/{time.time_ns()}.json"
with open(json_file_path, "w") as f:
json.dump(painter_layers_json_data, f, indent=4)
st.markdown(f"Painter layers are saved in {json_file_path}.")
with column_output: with column_output:
run_button = st.button("Generate image", type="primary") run_button = st.button("Generate image", type="primary")
@@ -282,6 +341,7 @@ with column_output:
progress_bar_st = st.progress(0.0) progress_bar_st = st.progress(0.0)
image = pipeline( image = pipeline(
prompt, negative_prompt=negative_prompt, prompt, negative_prompt=negative_prompt,
local_prompts=local_prompts, masks=masks, mask_scales=mask_scales,
cfg_scale=cfg_scale, num_inference_steps=num_inference_steps, cfg_scale=cfg_scale, num_inference_steps=num_inference_steps,
height=height, width=width, height=height, width=width,
input_image=input_image, denoising_strength=denoising_strength, input_image=input_image, denoising_strength=denoising_strength,

View File

@@ -16,6 +16,7 @@ from ..models.sd3_vae_decoder import SD3VAEDecoder
from ..models.sd3_vae_encoder import SD3VAEEncoder from ..models.sd3_vae_encoder import SD3VAEEncoder
from ..models.sd_controlnet import SDControlNet from ..models.sd_controlnet import SDControlNet
from ..models.sdxl_controlnet import SDXLControlNetUnion
from ..models.sd_motion import SDMotionModel from ..models.sd_motion import SDMotionModel
from ..models.sdxl_motion import SDXLMotionModel from ..models.sdxl_motion import SDXLMotionModel
@@ -31,7 +32,29 @@ from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from ..models.hunyuan_dit import HunyuanDiT from ..models.hunyuan_dit import HunyuanDiT
from ..models.flux_dit import FluxDiT
from ..models.flux_text_encoder import FluxTextEncoder2
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
from ..models.flux_controlnet import FluxControlNet
from ..models.flux_ipadapter import FluxIpAdapter
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
from ..models.cog_dit import CogDiT
from ..models.omnigen import OmniGenTransformer
from ..models.hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
from ..models.hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
from ..extensions.RIFE import IFNet
from ..extensions.ESRGAN import RRDBNet
from ..models.hunyuan_video_dit import HunyuanVideoDiT
from ..models.stepvideo_vae import StepVideoVAE
from ..models.stepvideo_dit import StepVideoModel
from ..models.wanx_vae import WanXVideoVAE
model_loader_configs = [ model_loader_configs = [
# These configs are provided for detecting model type automatically. # These configs are provided for detecting model type automatically.
@@ -60,13 +83,47 @@ model_loader_configs = [
(None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"), (None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
(None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"), (None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
(None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"), (None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
(None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"),
(None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
(None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
(None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
(None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"),
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
(None, "61cbcbc7ac11f169c5949223efa960d1", ["omnigen_transformer"], [OmniGenTransformer], "diffusers"),
(None, "78d18b9101345ff695f312e7e62538c0", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
(None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
(None, "5da81baee73198a7c19e6d2fe8b5148e", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
(None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"),
(None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
(None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
(None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"),
(None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"),
(None, "1378ea763357eea97acdef78e65d6d96", ["wanxvideo_vae"], [WanXVideoVAE], "civitai")
] ]
huggingface_model_loader_configs = [ huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically. # These configs are provided for detecting model type automatically.
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name) # The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder"), ("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator"), ("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt"), ("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
# ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
] ]
patch_model_loader_configs = [ patch_model_loader_configs = [
# These configs are provided for detecting model type automatically. # These configs are provided for detecting model type automatically.
@@ -87,6 +144,175 @@ preset_models_on_huggingface = {
"ExVideo-SVD-128f-v1": [ "ExVideo-SVD-128f-v1": [
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"), ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
], ],
# Stable Diffusion
"StableDiffusion_v15": [
("benjamin-paine/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
],
"DreamShaper_8": [
("Yntec/Dreamshaper8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
],
# Textual Inversion
"TextualInversion_VeryBadImageNegative_v1.3": [
("gemasai/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
],
# Stable Diffusion XL
"StableDiffusionXL_v1": [
("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
],
"BluePencilXL_v200": [
("frankjoshua/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
],
"StableDiffusionXL_Turbo": [
("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
],
# Stable Diffusion 3
"StableDiffusion3": [
("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
],
"StableDiffusion3_without_T5": [
("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
],
# ControlNet
"ControlNet_v11f1p_sd15_depth": [
("lllyasviel/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
],
"ControlNet_v11p_sd15_softedge": [
("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators")
],
"ControlNet_v11f1e_sd15_tile": [
("lllyasviel/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
],
"ControlNet_v11p_sd15_lineart": [
("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"),
("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators")
],
"ControlNet_union_sdxl_promax": [
("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
],
# AnimateDiff
"AnimateDiff_v2": [
("guoyww/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
],
"AnimateDiff_xl_beta": [
("guoyww/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
],
# Qwen Prompt
"QwenPrompt": [
("Qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("Qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("Qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
("Qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("Qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("Qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("Qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
("Qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
],
# Beautiful Prompt
"BeautifulPrompt": [
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
],
# Omost prompt
"OmostPrompt":[
("lllyasviel/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("lllyasviel/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("lllyasviel/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("lllyasviel/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("lllyasviel/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("lllyasviel/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
],
# Translator
"opus-mt-zh-en": [
("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
("Helsinki-NLP/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
("Helsinki-NLP/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
("Helsinki-NLP/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
("Helsinki-NLP/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
("Helsinki-NLP/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
("Helsinki-NLP/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
("Helsinki-NLP/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
],
# IP-Adapter
"IP-Adapter-SD": [
("h94/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
("h94/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
],
"IP-Adapter-SDXL": [
("h94/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
("h94/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
],
"SDXL-vae-fp16-fix": [
("madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
],
# Kolors
"Kolors": [
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
],
# FLUX
"FLUX.1-dev": [
("black-forest-labs/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
("black-forest-labs/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
],
"InstantX/FLUX.1-dev-IP-Adapter": {
"file_list": [
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
],
"load_path": [
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
],
},
# RIFE
"RIFE": [
("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
],
# CogVideo
"CogVideoX-5B": [
("THUDM/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
("THUDM/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
("THUDM/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
("THUDM/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
("THUDM/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
("THUDM/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
],
# Stable Diffusion 3.5
"StableDiffusion3.5-large": [
("stabilityai/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
("stabilityai/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
],
} }
preset_models_on_modelscope = { preset_models_on_modelscope = {
# Hunyuan DiT # Hunyuan DiT
@@ -104,6 +330,9 @@ preset_models_on_modelscope = {
"ExVideo-SVD-128f-v1": [ "ExVideo-SVD-128f-v1": [
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"), ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
], ],
"ExVideo-CogVideoX-LoRA-129f-v1": [
("ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1", "ExVideo-CogVideoX-LoRA-129f-v1.safetensors", "models/lora"),
],
# Stable Diffusion # Stable Diffusion
"StableDiffusion_v15": [ "StableDiffusion_v15": [
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"), ("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
@@ -131,6 +360,9 @@ preset_models_on_modelscope = {
"StableDiffusionXL_Turbo": [ "StableDiffusionXL_Turbo": [
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"), ("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
], ],
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
],
# Stable Diffusion 3 # Stable Diffusion 3
"StableDiffusion3": [ "StableDiffusion3": [
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"), ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
@@ -155,6 +387,28 @@ preset_models_on_modelscope = {
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"), ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators") ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
], ],
"ControlNet_union_sdxl_promax": [
("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
],
"Annotators:Depth": [
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
],
"Annotators:Softedge": [
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
],
"Annotators:Lineart": [
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
],
"Annotators:Normal": [
("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
],
"Annotators:Openpose": [
("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
],
# AnimateDiff # AnimateDiff
"AnimateDiff_v2": [ "AnimateDiff_v2": [
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"), ("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
@@ -166,8 +420,25 @@ preset_models_on_modelscope = {
"RIFE": [ "RIFE": [
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"), ("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
], ],
# Qwen Prompt
"QwenPrompt": {
"file_list": [
("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
],
"load_path": [
"models/QwenPrompt/qwen2-1.5b-instruct",
],
},
# Beautiful Prompt # Beautiful Prompt
"BeautifulPrompt": [ "BeautifulPrompt": {
"file_list": [
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
@@ -175,8 +446,29 @@ preset_models_on_modelscope = {
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
], ],
"load_path": [
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
],
},
# Omost prompt
"OmostPrompt": {
"file_list": [
("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
],
"load_path": [
"models/OmostPrompt/omost-llama-3-8b-4bits",
],
},
# Translator # Translator
"opus-mt-zh-en": [ "opus-mt-zh-en": {
"file_list": [
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"), ("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"), ("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"), ("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
@@ -186,6 +478,10 @@ preset_models_on_modelscope = {
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"), ("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"), ("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
], ],
"load_path": [
"models/translator/opus-mt-zh-en",
],
},
# IP-Adapter # IP-Adapter
"IP-Adapter-SD": [ "IP-Adapter-SD": [
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"), ("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
@@ -196,7 +492,8 @@ preset_models_on_modelscope = {
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"), ("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
], ],
# Kolors # Kolors
"Kolors": [ "Kolors": {
"file_list": [
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"), ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"), ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"), ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
@@ -209,14 +506,190 @@ preset_models_on_modelscope = {
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"), ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"), ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
], ],
"load_path": [
"models/kolors/Kolors/text_encoder",
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
],
},
"SDXL-vae-fp16-fix": [ "SDXL-vae-fp16-fix": [
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix") ("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
], ],
# FLUX
"FLUX.1-dev": {
"file_list": [
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
],
"load_path": [
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
],
},
"FLUX.1-schnell": {
"file_list": [
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"),
],
"load_path": [
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
"models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
],
},
"InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
("InstantX/FLUX.1-dev-Controlnet-Union-alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha"),
],
"jasperai/Flux.1-dev-Controlnet-Depth": [
("jasperai/Flux.1-dev-Controlnet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth"),
],
"jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
("jasperai/Flux.1-dev-Controlnet-Surface-Normals", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals"),
],
"jasperai/Flux.1-dev-Controlnet-Upscaler": [
("jasperai/Flux.1-dev-Controlnet-Upscaler", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler"),
],
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"),
],
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"),
],
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
("Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth"),
],
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
],
"InstantX/FLUX.1-dev-IP-Adapter": {
"file_list": [
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
],
"load_path": [
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
],
},
# ESRGAN
"ESRGAN_x4": [
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
],
# RIFE
"RIFE": [
("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
],
# Omnigen
"OmniGen-v1": {
"file_list": [
("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1/vae"),
("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
],
"load_path": [
"models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
"models/OmniGen/OmniGen-v1/model.safetensors",
]
},
# CogVideo
"CogVideoX-5B": {
"file_list": [
("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
],
"load_path": [
"models/CogVideo/CogVideoX-5b/text_encoder",
"models/CogVideo/CogVideoX-5b/transformer",
"models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
],
},
# Stable Diffusion 3.5
"StableDiffusion3.5-large": [
("AI-ModelScope/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
],
"StableDiffusion3.5-medium": [
("AI-ModelScope/stable-diffusion-3.5-medium", "sd3.5_medium.safetensors", "models/stable_diffusion_3"),
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
],
"StableDiffusion3.5-large-turbo": [
("AI-ModelScope/stable-diffusion-3.5-large-turbo", "sd3.5_large_turbo.safetensors", "models/stable_diffusion_3"),
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
],
"HunyuanVideo":{
"file_list": [
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
],
"load_path": [
"models/HunyuanVideo/text_encoder/model.safetensors",
"models/HunyuanVideo/text_encoder_2",
"models/HunyuanVideo/vae/pytorch_model.pt",
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
],
},
"HunyuanVideo-fp8":{
"file_list": [
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
("DiffSynth-Studio/HunyuanVideo-safetensors", "model.fp8.safetensors", "models/HunyuanVideo/transformers")
],
"load_path": [
"models/HunyuanVideo/text_encoder/model.safetensors",
"models/HunyuanVideo/text_encoder_2",
"models/HunyuanVideo/vae/pytorch_model.pt",
"models/HunyuanVideo/transformers/model.fp8.safetensors"
],
},
} }
Preset_model_id: TypeAlias = Literal[ Preset_model_id: TypeAlias = Literal[
"HunyuanDiT", "HunyuanDiT",
"stable-video-diffusion-img2vid-xt", "stable-video-diffusion-img2vid-xt",
"ExVideo-SVD-128f-v1", "ExVideo-SVD-128f-v1",
"ExVideo-CogVideoX-LoRA-129f-v1",
"StableDiffusion_v15", "StableDiffusion_v15",
"DreamShaper_8", "DreamShaper_8",
"AingDiffusion_v12", "AingDiffusion_v12",
@@ -240,4 +713,32 @@ Preset_model_id: TypeAlias = Literal[
"StableDiffusion3_without_T5", "StableDiffusion3_without_T5",
"Kolors", "Kolors",
"SDXL-vae-fp16-fix", "SDXL-vae-fp16-fix",
"ControlNet_union_sdxl_promax",
"FLUX.1-dev",
"FLUX.1-schnell",
"InstantX/FLUX.1-dev-Controlnet-Union-alpha",
"jasperai/Flux.1-dev-Controlnet-Depth",
"jasperai/Flux.1-dev-Controlnet-Surface-Normals",
"jasperai/Flux.1-dev-Controlnet-Upscaler",
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
"InstantX/FLUX.1-dev-IP-Adapter",
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
"QwenPrompt",
"OmostPrompt",
"ESRGAN_x4",
"RIFE",
"OmniGen-v1",
"CogVideoX-5B",
"Annotators:Depth",
"Annotators:Softedge",
"Annotators:Lineart",
"Annotators:Normal",
"Annotators:Openpose",
"StableDiffusion3.5-large",
"StableDiffusion3.5-medium",
"HunyuanVideo",
"HunyuanVideo-fp8",
] ]

View File

@@ -1,2 +1,2 @@
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager, FluxMultiControlNetManager
from .processors import Annotator from .processors import Annotator

View File

@@ -4,10 +4,11 @@ from .processors import Processor_id
class ControlNetConfigUnit: class ControlNetConfigUnit:
def __init__(self, processor_id: Processor_id, model_path, scale=1.0): def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False):
self.processor_id = processor_id self.processor_id = processor_id
self.model_path = model_path self.model_path = model_path
self.scale = scale self.scale = scale
self.skip_processor = skip_processor
class ControlNetUnit: class ControlNetUnit:
@@ -23,6 +24,16 @@ class MultiControlNetManager:
self.models = [unit.model for unit in controlnet_units] self.models = [unit.model for unit in controlnet_units]
self.scales = [unit.scale for unit in controlnet_units] self.scales = [unit.scale for unit in controlnet_units]
def cpu(self):
for model in self.models:
model.cpu()
def to(self, device):
for model in self.models:
model.to(device)
for processor in self.processors:
processor.to(device)
def process_image(self, image, processor_id=None): def process_image(self, image, processor_id=None):
if processor_id is None: if processor_id is None:
processed_image = [processor(image) for processor in self.processors] processed_image = [processor(image) for processor in self.processors]
@@ -37,13 +48,14 @@ class MultiControlNetManager:
def __call__( def __call__(
self, self,
sample, timestep, encoder_hidden_states, conditionings, sample, timestep, encoder_hidden_states, conditionings,
tiled=False, tile_size=64, tile_stride=32 tiled=False, tile_size=64, tile_stride=32, **kwargs
): ):
res_stack = None res_stack = None
for conditioning, model, scale in zip(conditionings, self.models, self.scales): for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
res_stack_ = model( res_stack_ = model(
sample, timestep, encoder_hidden_states, conditioning, sample, timestep, encoder_hidden_states, conditioning, **kwargs,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
processor_id=processor.processor_id
) )
res_stack_ = [res * scale for res in res_stack_] res_stack_ = [res * scale for res in res_stack_]
if res_stack is None: if res_stack is None:
@@ -51,3 +63,29 @@ class MultiControlNetManager:
else: else:
res_stack = [i + j for i, j in zip(res_stack, res_stack_)] res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
return res_stack return res_stack
class FluxMultiControlNetManager(MultiControlNetManager):
def __init__(self, controlnet_units=[]):
super().__init__(controlnet_units=controlnet_units)
def process_image(self, image, processor_id=None):
if processor_id is None:
processed_image = [processor(image) for processor in self.processors]
else:
processed_image = [self.processors[processor_id](image)]
return processed_image
def __call__(self, conditionings, **kwargs):
res_stack, single_res_stack = None, None
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs)
res_stack_ = [res * scale for res in res_stack_]
single_res_stack_ = [res * scale for res in single_res_stack_]
if res_stack is None:
res_stack = res_stack_
single_res_stack = single_res_stack_
else:
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
return res_stack, single_res_stack

View File

@@ -3,16 +3,17 @@ import warnings
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
from controlnet_aux.processor import ( from controlnet_aux.processor import (
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector, NormalBaeDetector
) )
Processor_id: TypeAlias = Literal[ Processor_id: TypeAlias = Literal[
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile" "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
] ]
class Annotator: class Annotator:
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'): def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
if not skip_processor:
if processor_id == "canny": if processor_id == "canny":
self.processor = CannyDetector() self.processor = CannyDetector()
elif processor_id == "depth": elif processor_id == "depth":
@@ -25,15 +26,24 @@ class Annotator:
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device) self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
elif processor_id == "openpose": elif processor_id == "openpose":
self.processor = OpenposeDetector.from_pretrained(model_path).to(device) self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
elif processor_id == "tile": elif processor_id == "normal":
self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
self.processor = None self.processor = None
else: else:
raise ValueError(f"Unsupported processor_id: {processor_id}") raise ValueError(f"Unsupported processor_id: {processor_id}")
else:
self.processor = None
self.processor_id = processor_id self.processor_id = processor_id
self.detect_resolution = detect_resolution self.detect_resolution = detect_resolution
def __call__(self, image): def to(self,device):
if hasattr(self.processor,"model") and hasattr(self.processor.model,"to"):
self.processor.model.to(device)
def __call__(self, image, mask=None):
width, height = image.size width, height = image.size
if self.processor_id == "openpose": if self.processor_id == "openpose":
kwargs = { kwargs = {

View File

@@ -1,4 +1,4 @@
import torch, os import torch, os, torchvision
from torchvision import transforms from torchvision import transforms
import pandas as pd import pandas as pd
from PIL import Image from PIL import Image
@@ -11,9 +11,10 @@ class TextImageDataset(torch.utils.data.Dataset):
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv")) metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]] self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
self.text = metadata["text"].to_list() self.text = metadata["text"].to_list()
self.height = height
self.width = width
self.image_processor = transforms.Compose( self.image_processor = transforms.Compose(
[ [
transforms.Resize(max(height, width), interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)), transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x), transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(), transforms.ToTensor(),
@@ -27,6 +28,11 @@ class TextImageDataset(torch.utils.data.Dataset):
data_id = (data_id + index) % len(self.path) # For fixed seed. data_id = (data_id + index) % len(self.path) # For fixed seed.
text = self.text[data_id] text = self.text[data_id]
image = Image.open(self.path[data_id]).convert("RGB") image = Image.open(self.path[data_id]).convert("RGB")
target_height, target_width = self.height, self.width
width, height = image.size
scale = max(target_width / width, target_height / height)
shape = [round(height*scale),round(width*scale)]
image = torchvision.transforms.functional.resize(image,shape,interpolation=transforms.InterpolationMode.BILINEAR)
image = self.image_processor(image) image = self.image_processor(image)
return {"text": text, "image": image} return {"text": text, "image": image}

View File

@@ -135,8 +135,8 @@ class VideoData:
frame.save(os.path.join(folder, f"{i}.png")) frame.save(os.path.join(folder, f"{i}.png"))
def save_video(frames, save_path, fps, quality=9): def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
writer = imageio.get_writer(save_path, fps=fps, quality=quality) writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)
for frame in tqdm(frames, desc="Saving video"): for frame in tqdm(frames, desc="Saving video"):
frame = np.array(frame) frame = np.array(frame)
writer.append_data(frame) writer.append_data(frame)

View File

@@ -41,7 +41,7 @@ class RRDB(torch.nn.Module):
class RRDBNet(torch.nn.Module): class RRDBNet(torch.nn.Module):
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32): def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, **kwargs):
super(RRDBNet, self).__init__() super(RRDBNet, self).__init__()
self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)]) self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
@@ -66,6 +66,21 @@ class RRDBNet(torch.nn.Module):
out = self.conv_last(self.lrelu(self.conv_hr(feat))) out = self.conv_last(self.lrelu(self.conv_hr(feat)))
return out return out
@staticmethod
def state_dict_converter():
return RRDBNetStateDictConverter()
class RRDBNetStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict, {"upcast_to_float32": True}
def from_civitai(self, state_dict):
return state_dict, {"upcast_to_float32": True}
class ESRGAN(torch.nn.Module): class ESRGAN(torch.nn.Module):
def __init__(self, model): def __init__(self, model):
@@ -73,12 +88,8 @@ class ESRGAN(torch.nn.Module):
self.model = model self.model = model
@staticmethod @staticmethod
def from_pretrained(model_path): def from_model_manager(model_manager):
model = RRDBNet() return ESRGAN(model_manager.fetch_model("esrgan"))
state_dict = torch.load(model_path, map_location="cpu")["params_ema"]
model.load_state_dict(state_dict)
model.eval()
return ESRGAN(model)
def process_image(self, image): def process_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1) image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
@@ -96,6 +107,12 @@ class ESRGAN(torch.nn.Module):
@torch.no_grad() @torch.no_grad()
def upscale(self, images, batch_size=4, progress_bar=lambda x:x): def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
if not isinstance(images, list):
images = [images]
is_single_image = True
else:
is_single_image = False
# Preprocess # Preprocess
input_tensor = self.process_images(images) input_tensor = self.process_images(images)
@@ -115,4 +132,6 @@ class ESRGAN(torch.nn.Module):
# To images # To images
output_images = self.decode_images(output_tensor) output_images = self.decode_images(output_tensor)
if is_single_image:
output_images = output_images[0]
return output_images return output_images

View File

@@ -58,7 +58,7 @@ class IFBlock(nn.Module):
class IFNet(nn.Module): class IFNet(nn.Module):
def __init__(self): def __init__(self, **kwargs):
super(IFNet, self).__init__() super(IFNet, self).__init__()
self.block0 = IFBlock(7+4, c=90) self.block0 = IFBlock(7+4, c=90)
self.block1 = IFBlock(7+4, c=90) self.block1 = IFBlock(7+4, c=90)
@@ -113,7 +113,7 @@ class IFNetStateDictConverter:
return state_dict_ return state_dict_
def from_civitai(self, state_dict): def from_civitai(self, state_dict):
return self.from_diffusers(state_dict) return self.from_diffusers(state_dict), {"upcast_to_float32": True}
class RIFEInterpolater: class RIFEInterpolater:
@@ -125,7 +125,7 @@ class RIFEInterpolater:
@staticmethod @staticmethod
def from_model_manager(model_manager): def from_model_manager(model_manager):
return RIFEInterpolater(model_manager.RIFE, device=model_manager.device) return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
def process_image(self, image): def process_image(self, image):
width, height = image.size width, height = image.size
@@ -203,7 +203,7 @@ class RIFESmoother(RIFEInterpolater):
@staticmethod @staticmethod
def from_model_manager(model_manager): def from_model_manager(model_manager):
return RIFESmoother(model_manager.RIFE, device=model_manager.device) return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
def process_tensors(self, input_tensor, scale=1.0, batch_size=4): def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
output_tensor = [] output_tensor = []

408
diffsynth/models/cog_dit.py Normal file
View File

@@ -0,0 +1,408 @@
import torch
from einops import rearrange, repeat
from .sd3_dit import TimestepEmbeddings
from .attention import Attention
from .utils import load_state_dict_from_folder
from .tiler import TileWorker2Dto3D
import numpy as np
class CogPatchify(torch.nn.Module):
def __init__(self, dim_in, dim_out, patch_size) -> None:
super().__init__()
self.proj = torch.nn.Conv3d(dim_in, dim_out, kernel_size=(1, patch_size, patch_size), stride=(1, patch_size, patch_size))
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = rearrange(hidden_states, "B C T H W -> B (T H W) C")
return hidden_states
class CogAdaLayerNorm(torch.nn.Module):
def __init__(self, dim, dim_cond, single=False):
super().__init__()
self.single = single
self.linear = torch.nn.Linear(dim_cond, dim * (2 if single else 6))
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=True, eps=1e-5)
def forward(self, hidden_states, prompt_emb, emb):
emb = self.linear(torch.nn.functional.silu(emb))
if self.single:
shift, scale = emb.unsqueeze(1).chunk(2, dim=2)
hidden_states = self.norm(hidden_states) * (1 + scale) + shift
return hidden_states
else:
shift_a, scale_a, gate_a, shift_b, scale_b, gate_b = emb.unsqueeze(1).chunk(6, dim=2)
hidden_states = self.norm(hidden_states) * (1 + scale_a) + shift_a
prompt_emb = self.norm(prompt_emb) * (1 + scale_b) + shift_b
return hidden_states, prompt_emb, gate_a, gate_b
class CogDiTBlock(torch.nn.Module):
def __init__(self, dim, dim_cond, num_heads):
super().__init__()
self.norm1 = CogAdaLayerNorm(dim, dim_cond)
self.attn1 = Attention(q_dim=dim, num_heads=48, head_dim=dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
self.norm_q = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
self.norm_k = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
self.norm2 = CogAdaLayerNorm(dim, dim_cond)
self.ff = torch.nn.Sequential(
torch.nn.Linear(dim, dim*4),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(dim*4, dim)
)
def apply_rotary_emb(self, x, freqs_cis):
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
def process_qkv(self, q, k, v, image_rotary_emb, text_seq_length):
q = self.norm_q(q)
k = self.norm_k(k)
q[:, :, text_seq_length:] = self.apply_rotary_emb(q[:, :, text_seq_length:], image_rotary_emb)
k[:, :, text_seq_length:] = self.apply_rotary_emb(k[:, :, text_seq_length:], image_rotary_emb)
return q, k, v
def forward(self, hidden_states, prompt_emb, time_emb, image_rotary_emb):
# Attention
norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm1(
hidden_states, prompt_emb, time_emb
)
attention_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
attention_io = self.attn1(
attention_io,
qkv_preprocessor=lambda q, k, v: self.process_qkv(q, k, v, image_rotary_emb, prompt_emb.shape[1])
)
hidden_states = hidden_states + gate_a * attention_io[:, prompt_emb.shape[1]:]
prompt_emb = prompt_emb + gate_b * attention_io[:, :prompt_emb.shape[1]]
# Feed forward
norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm2(
hidden_states, prompt_emb, time_emb
)
ff_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
ff_io = self.ff(ff_io)
hidden_states = hidden_states + gate_a * ff_io[:, prompt_emb.shape[1]:]
prompt_emb = prompt_emb + gate_b * ff_io[:, :prompt_emb.shape[1]]
return hidden_states, prompt_emb
class CogDiT(torch.nn.Module):
def __init__(self):
super().__init__()
self.patchify = CogPatchify(16, 3072, 2)
self.time_embedder = TimestepEmbeddings(3072, 512)
self.context_embedder = torch.nn.Linear(4096, 3072)
self.blocks = torch.nn.ModuleList([CogDiTBlock(3072, 512, 48) for _ in range(42)])
self.norm_final = torch.nn.LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
self.norm_out = CogAdaLayerNorm(3072, 512, single=True)
self.proj_out = torch.nn.Linear(3072, 64, bias=True)
def get_resize_crop_region_for_grid(self, src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
def get_3d_rotary_pos_embed(
self, embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
):
start, stop = crops_coords
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
# Compute dimensions for each axis
dim_t = embed_dim // 4
dim_h = embed_dim // 8 * 3
dim_w = embed_dim // 8 * 3
# Temporal frequencies
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
grid_t = torch.from_numpy(grid_t).float()
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
# Spatial frequencies for height and width
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
grid_h = torch.from_numpy(grid_h).float()
grid_w = torch.from_numpy(grid_w).float()
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
# Broadcast and concatenate tensors along specified dimension
def broadcast(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = {len(t.shape) for t in tensors}
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*(list(t.shape) for t in tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
), "invalid dimensions for broadcastable concatenation"
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
return torch.cat(tensors, dim=dim)
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
t, h, w, d = freqs.shape
freqs = freqs.view(t * h * w, d)
# Generate sine and cosine components
sin = freqs.sin()
cos = freqs.cos()
if use_real:
return cos, sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def prepare_rotary_positional_embeddings(
self,
height: int,
width: int,
num_frames: int,
device: torch.device,
):
grid_height = height // 2
grid_width = width // 2
base_size_width = 720 // (8 * 2)
base_size_height = 480 // (8 * 2)
grid_crops_coords = self.get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = self.get_3d_rotary_pos_embed(
embed_dim=64,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
use_real=True,
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
def unpatchify(self, hidden_states, height, width):
hidden_states = rearrange(hidden_states, "B (T H W) (C P Q) -> B C T (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
return hidden_states
def build_mask(self, T, H, W, dtype, device, is_bound):
t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
border_width = (H + W) // 4
pad = torch.ones_like(h) * border_width
mask = torch.stack([
pad if is_bound[0] else t + 1,
pad if is_bound[1] else T - t,
pad if is_bound[2] else h + 1,
pad if is_bound[3] else H - h,
pad if is_bound[4] else w + 1,
pad if is_bound[5] else W - w
]).min(dim=0).values
mask = mask.clip(1, border_width)
mask = (mask / border_width).to(dtype=dtype, device=device)
mask = rearrange(mask, "T H W -> 1 1 T H W")
return mask
def tiled_forward(self, hidden_states, timestep, prompt_emb, tile_size=(60, 90), tile_stride=(30, 45)):
B, C, T, H, W = hidden_states.shape
value = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
weight = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
# Split tasks
tasks = []
for h in range(0, H, tile_stride):
for w in range(0, W, tile_stride):
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
continue
h_, w_ = h + tile_size, w + tile_size
if h_ > H: h, h_ = max(H - tile_size, 0), H
if w_ > W: w, w_ = max(W - tile_size, 0), W
tasks.append((h, h_, w, w_))
# Run
for hl, hr, wl, wr in tasks:
mask = self.build_mask(
value.shape[2], (hr-hl), (wr-wl),
hidden_states.dtype, hidden_states.device,
is_bound=(True, True, hl==0, hr>=H, wl==0, wr>=W)
)
model_output = self.forward(hidden_states[:, :, :, hl:hr, wl:wr], timestep, prompt_emb)
value[:, :, :, hl:hr, wl:wr] += model_output * mask
weight[:, :, :, hl:hr, wl:wr] += mask
value = value / weight
return value
def forward(self, hidden_states, timestep, prompt_emb, image_rotary_emb=None, tiled=False, tile_size=90, tile_stride=30, use_gradient_checkpointing=False):
if tiled:
return TileWorker2Dto3D().tiled_forward(
forward_fn=lambda x: self.forward(x, timestep, prompt_emb),
model_input=hidden_states,
tile_size=tile_size, tile_stride=tile_stride,
tile_device=hidden_states.device, tile_dtype=hidden_states.dtype,
computation_device=self.context_embedder.weight.device, computation_dtype=self.context_embedder.weight.dtype
)
num_frames, height, width = hidden_states.shape[-3:]
if image_rotary_emb is None:
image_rotary_emb = self.prepare_rotary_positional_embeddings(height, width, num_frames, device=self.context_embedder.weight.device)
hidden_states = self.patchify(hidden_states)
time_emb = self.time_embedder(timestep, dtype=hidden_states.dtype)
prompt_emb = self.context_embedder(prompt_emb)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.blocks:
if self.training and use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, time_emb, image_rotary_emb,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, time_emb, image_rotary_emb)
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
hidden_states = self.norm_out(hidden_states, prompt_emb, time_emb)
hidden_states = self.proj_out(hidden_states)
hidden_states = self.unpatchify(hidden_states, height, width)
return hidden_states
@staticmethod
def state_dict_converter():
return CogDiTStateDictConverter()
@staticmethod
def from_pretrained(file_path, torch_dtype=torch.bfloat16):
model = CogDiT().to(torch_dtype)
state_dict = load_state_dict_from_folder(file_path, torch_dtype=torch_dtype)
state_dict = CogDiT.state_dict_converter().from_diffusers(state_dict)
model.load_state_dict(state_dict)
return model
class CogDiTStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"patch_embed.proj.weight": "patchify.proj.weight",
"patch_embed.proj.bias": "patchify.proj.bias",
"patch_embed.text_proj.weight": "context_embedder.weight",
"patch_embed.text_proj.bias": "context_embedder.bias",
"time_embedding.linear_1.weight": "time_embedder.timestep_embedder.0.weight",
"time_embedding.linear_1.bias": "time_embedder.timestep_embedder.0.bias",
"time_embedding.linear_2.weight": "time_embedder.timestep_embedder.2.weight",
"time_embedding.linear_2.bias": "time_embedder.timestep_embedder.2.bias",
"norm_final.weight": "norm_final.weight",
"norm_final.bias": "norm_final.bias",
"norm_out.linear.weight": "norm_out.linear.weight",
"norm_out.linear.bias": "norm_out.linear.bias",
"norm_out.norm.weight": "norm_out.norm.weight",
"norm_out.norm.bias": "norm_out.norm.bias",
"proj_out.weight": "proj_out.weight",
"proj_out.bias": "proj_out.bias",
}
suffix_dict = {
"norm1.linear.weight": "norm1.linear.weight",
"norm1.linear.bias": "norm1.linear.bias",
"norm1.norm.weight": "norm1.norm.weight",
"norm1.norm.bias": "norm1.norm.bias",
"attn1.norm_q.weight": "norm_q.weight",
"attn1.norm_q.bias": "norm_q.bias",
"attn1.norm_k.weight": "norm_k.weight",
"attn1.norm_k.bias": "norm_k.bias",
"attn1.to_q.weight": "attn1.to_q.weight",
"attn1.to_q.bias": "attn1.to_q.bias",
"attn1.to_k.weight": "attn1.to_k.weight",
"attn1.to_k.bias": "attn1.to_k.bias",
"attn1.to_v.weight": "attn1.to_v.weight",
"attn1.to_v.bias": "attn1.to_v.bias",
"attn1.to_out.0.weight": "attn1.to_out.weight",
"attn1.to_out.0.bias": "attn1.to_out.bias",
"norm2.linear.weight": "norm2.linear.weight",
"norm2.linear.bias": "norm2.linear.bias",
"norm2.norm.weight": "norm2.norm.weight",
"norm2.norm.bias": "norm2.norm.bias",
"ff.net.0.proj.weight": "ff.0.weight",
"ff.net.0.proj.bias": "ff.0.bias",
"ff.net.2.weight": "ff.2.weight",
"ff.net.2.bias": "ff.2.bias",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
if name == "patch_embed.proj.weight":
param = param.unsqueeze(2)
state_dict_[rename_dict[name]] = param
else:
names = name.split(".")
if names[0] == "transformer_blocks":
suffix = ".".join(names[2:])
state_dict_[f"blocks.{names[1]}." + suffix_dict[suffix]] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

518
diffsynth/models/cog_vae.py Normal file
View File

@@ -0,0 +1,518 @@
import torch
from einops import rearrange, repeat
from .tiler import TileWorker2Dto3D
class Downsample3D(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 2,
padding: int = 0,
compress_time: bool = False,
):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
if self.compress_time:
batch_size, channels, frames, height, width = x.shape
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
if x.shape[-1] % 2 == 1:
x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0:
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
x = torch.cat([x_first[..., None], x_rest], dim=-1)
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
else:
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
# Pad the tensor
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
batch_size, channels, frames, height, width = x.shape
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
x = self.conv(x)
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
return x
class Upsample3D(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: int = 1,
compress_time: bool = False,
) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, inputs: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
if self.compress_time:
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
# split first frame
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0)
x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0)
x_first = x_first[:, :, None, :, :]
inputs = torch.cat([x_first, x_rest], dim=2)
elif inputs.shape[2] > 1:
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
else:
inputs = inputs.squeeze(2)
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
inputs = inputs[:, :, None, :, :]
else:
# only interpolate 2D
b, c, t, h, w = inputs.shape
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
b, c, t, h, w = inputs.shape
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
inputs = self.conv(inputs)
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
return inputs
class CogVideoXSpatialNorm3D(torch.nn.Module):
def __init__(self, f_channels, zq_channels, groups):
super().__init__()
self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
self.conv_y = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1)
self.conv_b = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1)
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
z_first = torch.nn.functional.interpolate(z_first, size=f_first_size)
z_rest = torch.nn.functional.interpolate(z_rest, size=f_rest_size)
zq = torch.cat([z_first, z_rest], dim=2)
else:
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:])
norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
class Resnet3DBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels, spatial_norm_dim, groups, eps=1e-6, use_conv_shortcut=False):
super().__init__()
self.nonlinearity = torch.nn.SiLU()
if spatial_norm_dim is None:
self.norm1 = torch.nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
self.norm2 = torch.nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
else:
self.norm1 = CogVideoXSpatialNorm3D(in_channels, spatial_norm_dim, groups)
self.norm2 = CogVideoXSpatialNorm3D(out_channels, spatial_norm_dim, groups)
self.conv1 = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
self.conv2 = CachedConv3d(out_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
if in_channels != out_channels:
if use_conv_shortcut:
self.conv_shortcut = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
else:
self.conv_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1)
else:
self.conv_shortcut = lambda x: x
def forward(self, hidden_states, zq):
residual = hidden_states
hidden_states = self.norm1(hidden_states, zq) if isinstance(self.norm1, CogVideoXSpatialNorm3D) else self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states, zq) if isinstance(self.norm2, CogVideoXSpatialNorm3D) else self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = hidden_states + self.conv_shortcut(residual)
return hidden_states
class CachedConv3d(torch.nn.Conv3d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super().__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.cached_tensor = None
def clear_cache(self):
self.cached_tensor = None
def forward(self, input: torch.Tensor, use_cache = True) -> torch.Tensor:
if use_cache:
if self.cached_tensor is None:
self.cached_tensor = torch.concat([input[:, :, :1]] * 2, dim=2)
input = torch.concat([self.cached_tensor, input], dim=2)
self.cached_tensor = input[:, :, -2:]
return super().forward(input)
class CogVAEDecoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.scaling_factor = 0.7
self.conv_in = CachedConv3d(16, 512, kernel_size=3, stride=1, padding=(0, 1, 1))
self.blocks = torch.nn.ModuleList([
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Upsample3D(512, 512, compress_time=True),
Resnet3DBlock(512, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Upsample3D(256, 256, compress_time=True),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Upsample3D(256, 256, compress_time=False),
Resnet3DBlock(256, 128, 16, 32),
Resnet3DBlock(128, 128, 16, 32),
Resnet3DBlock(128, 128, 16, 32),
Resnet3DBlock(128, 128, 16, 32),
])
self.norm_out = CogVideoXSpatialNorm3D(128, 16, 32)
self.conv_act = torch.nn.SiLU()
self.conv_out = CachedConv3d(128, 3, kernel_size=3, stride=1, padding=(0, 1, 1))
def forward(self, sample):
sample = sample / self.scaling_factor
hidden_states = self.conv_in(sample)
for block in self.blocks:
hidden_states = block(hidden_states, sample)
hidden_states = self.norm_out(hidden_states, sample)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
def decode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x):
if tiled:
B, C, T, H, W = sample.shape
return TileWorker2Dto3D().tiled_forward(
forward_fn=lambda x: self.decode_small_video(x),
model_input=sample,
tile_size=tile_size, tile_stride=tile_stride,
tile_device=sample.device, tile_dtype=sample.dtype,
computation_device=sample.device, computation_dtype=sample.dtype,
scales=(3/16, (T//2*8+T%2)/T, 8, 8),
progress_bar=progress_bar
)
else:
return self.decode_small_video(sample)
def decode_small_video(self, sample):
B, C, T, H, W = sample.shape
computation_device = self.conv_in.weight.device
computation_dtype = self.conv_in.weight.dtype
value = []
for i in range(T//2):
tl = i*2 + T%2 - (T%2 and i==0)
tr = i*2 + 2 + T%2
model_input = sample[:, :, tl: tr, :, :].to(dtype=computation_dtype, device=computation_device)
model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device)
value.append(model_output)
value = torch.concat(value, dim=2)
for name, module in self.named_modules():
if isinstance(module, CachedConv3d):
module.clear_cache()
return value
@staticmethod
def state_dict_converter():
return CogVAEDecoderStateDictConverter()
class CogVAEEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.scaling_factor = 0.7
self.conv_in = CachedConv3d(3, 128, kernel_size=3, stride=1, padding=(0, 1, 1))
self.blocks = torch.nn.ModuleList([
Resnet3DBlock(128, 128, None, 32),
Resnet3DBlock(128, 128, None, 32),
Resnet3DBlock(128, 128, None, 32),
Downsample3D(128, 128, compress_time=True),
Resnet3DBlock(128, 256, None, 32),
Resnet3DBlock(256, 256, None, 32),
Resnet3DBlock(256, 256, None, 32),
Downsample3D(256, 256, compress_time=True),
Resnet3DBlock(256, 256, None, 32),
Resnet3DBlock(256, 256, None, 32),
Resnet3DBlock(256, 256, None, 32),
Downsample3D(256, 256, compress_time=False),
Resnet3DBlock(256, 512, None, 32),
Resnet3DBlock(512, 512, None, 32),
Resnet3DBlock(512, 512, None, 32),
Resnet3DBlock(512, 512, None, 32),
Resnet3DBlock(512, 512, None, 32),
])
self.norm_out = torch.nn.GroupNorm(32, 512, eps=1e-06, affine=True)
self.conv_act = torch.nn.SiLU()
self.conv_out = CachedConv3d(512, 32, kernel_size=3, stride=1, padding=(0, 1, 1))
def forward(self, sample):
hidden_states = self.conv_in(sample)
for block in self.blocks:
hidden_states = block(hidden_states, sample)
hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)[:, :16]
hidden_states = hidden_states * self.scaling_factor
return hidden_states
def encode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x):
if tiled:
B, C, T, H, W = sample.shape
return TileWorker2Dto3D().tiled_forward(
forward_fn=lambda x: self.encode_small_video(x),
model_input=sample,
tile_size=(i * 8 for i in tile_size), tile_stride=(i * 8 for i in tile_stride),
tile_device=sample.device, tile_dtype=sample.dtype,
computation_device=sample.device, computation_dtype=sample.dtype,
scales=(16/3, (T//4+T%2)/T, 1/8, 1/8),
progress_bar=progress_bar
)
else:
return self.encode_small_video(sample)
def encode_small_video(self, sample):
B, C, T, H, W = sample.shape
computation_device = self.conv_in.weight.device
computation_dtype = self.conv_in.weight.dtype
value = []
for i in range(T//8):
t = i*8 + T%2 - (T%2 and i==0)
t_ = i*8 + 8 + T%2
model_input = sample[:, :, t: t_, :, :].to(dtype=computation_dtype, device=computation_device)
model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device)
value.append(model_output)
value = torch.concat(value, dim=2)
for name, module in self.named_modules():
if isinstance(module, CachedConv3d):
module.clear_cache()
return value
@staticmethod
def state_dict_converter():
return CogVAEEncoderStateDictConverter()
class CogVAEEncoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"encoder.conv_in.conv.weight": "conv_in.weight",
"encoder.conv_in.conv.bias": "conv_in.bias",
"encoder.down_blocks.0.downsamplers.0.conv.weight": "blocks.3.conv.weight",
"encoder.down_blocks.0.downsamplers.0.conv.bias": "blocks.3.conv.bias",
"encoder.down_blocks.1.downsamplers.0.conv.weight": "blocks.7.conv.weight",
"encoder.down_blocks.1.downsamplers.0.conv.bias": "blocks.7.conv.bias",
"encoder.down_blocks.2.downsamplers.0.conv.weight": "blocks.11.conv.weight",
"encoder.down_blocks.2.downsamplers.0.conv.bias": "blocks.11.conv.bias",
"encoder.norm_out.weight": "norm_out.weight",
"encoder.norm_out.bias": "norm_out.bias",
"encoder.conv_out.conv.weight": "conv_out.weight",
"encoder.conv_out.conv.bias": "conv_out.bias",
}
prefix_dict = {
"encoder.down_blocks.0.resnets.0.": "blocks.0.",
"encoder.down_blocks.0.resnets.1.": "blocks.1.",
"encoder.down_blocks.0.resnets.2.": "blocks.2.",
"encoder.down_blocks.1.resnets.0.": "blocks.4.",
"encoder.down_blocks.1.resnets.1.": "blocks.5.",
"encoder.down_blocks.1.resnets.2.": "blocks.6.",
"encoder.down_blocks.2.resnets.0.": "blocks.8.",
"encoder.down_blocks.2.resnets.1.": "blocks.9.",
"encoder.down_blocks.2.resnets.2.": "blocks.10.",
"encoder.down_blocks.3.resnets.0.": "blocks.12.",
"encoder.down_blocks.3.resnets.1.": "blocks.13.",
"encoder.down_blocks.3.resnets.2.": "blocks.14.",
"encoder.mid_block.resnets.0.": "blocks.15.",
"encoder.mid_block.resnets.1.": "blocks.16.",
}
suffix_dict = {
"norm1.norm_layer.weight": "norm1.norm_layer.weight",
"norm1.norm_layer.bias": "norm1.norm_layer.bias",
"norm1.conv_y.conv.weight": "norm1.conv_y.weight",
"norm1.conv_y.conv.bias": "norm1.conv_y.bias",
"norm1.conv_b.conv.weight": "norm1.conv_b.weight",
"norm1.conv_b.conv.bias": "norm1.conv_b.bias",
"norm2.norm_layer.weight": "norm2.norm_layer.weight",
"norm2.norm_layer.bias": "norm2.norm_layer.bias",
"norm2.conv_y.conv.weight": "norm2.conv_y.weight",
"norm2.conv_y.conv.bias": "norm2.conv_y.bias",
"norm2.conv_b.conv.weight": "norm2.conv_b.weight",
"norm2.conv_b.conv.bias": "norm2.conv_b.bias",
"conv1.conv.weight": "conv1.weight",
"conv1.conv.bias": "conv1.bias",
"conv2.conv.weight": "conv2.weight",
"conv2.conv.bias": "conv2.bias",
"conv_shortcut.weight": "conv_shortcut.weight",
"conv_shortcut.bias": "conv_shortcut.bias",
"norm1.weight": "norm1.weight",
"norm1.bias": "norm1.bias",
"norm2.weight": "norm2.weight",
"norm2.bias": "norm2.bias",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
state_dict_[rename_dict[name]] = param
else:
for prefix in prefix_dict:
if name.startswith(prefix):
suffix = name[len(prefix):]
state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)
class CogVAEDecoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"decoder.conv_in.conv.weight": "conv_in.weight",
"decoder.conv_in.conv.bias": "conv_in.bias",
"decoder.up_blocks.0.upsamplers.0.conv.weight": "blocks.6.conv.weight",
"decoder.up_blocks.0.upsamplers.0.conv.bias": "blocks.6.conv.bias",
"decoder.up_blocks.1.upsamplers.0.conv.weight": "blocks.11.conv.weight",
"decoder.up_blocks.1.upsamplers.0.conv.bias": "blocks.11.conv.bias",
"decoder.up_blocks.2.upsamplers.0.conv.weight": "blocks.16.conv.weight",
"decoder.up_blocks.2.upsamplers.0.conv.bias": "blocks.16.conv.bias",
"decoder.norm_out.norm_layer.weight": "norm_out.norm_layer.weight",
"decoder.norm_out.norm_layer.bias": "norm_out.norm_layer.bias",
"decoder.norm_out.conv_y.conv.weight": "norm_out.conv_y.weight",
"decoder.norm_out.conv_y.conv.bias": "norm_out.conv_y.bias",
"decoder.norm_out.conv_b.conv.weight": "norm_out.conv_b.weight",
"decoder.norm_out.conv_b.conv.bias": "norm_out.conv_b.bias",
"decoder.conv_out.conv.weight": "conv_out.weight",
"decoder.conv_out.conv.bias": "conv_out.bias"
}
prefix_dict = {
"decoder.mid_block.resnets.0.": "blocks.0.",
"decoder.mid_block.resnets.1.": "blocks.1.",
"decoder.up_blocks.0.resnets.0.": "blocks.2.",
"decoder.up_blocks.0.resnets.1.": "blocks.3.",
"decoder.up_blocks.0.resnets.2.": "blocks.4.",
"decoder.up_blocks.0.resnets.3.": "blocks.5.",
"decoder.up_blocks.1.resnets.0.": "blocks.7.",
"decoder.up_blocks.1.resnets.1.": "blocks.8.",
"decoder.up_blocks.1.resnets.2.": "blocks.9.",
"decoder.up_blocks.1.resnets.3.": "blocks.10.",
"decoder.up_blocks.2.resnets.0.": "blocks.12.",
"decoder.up_blocks.2.resnets.1.": "blocks.13.",
"decoder.up_blocks.2.resnets.2.": "blocks.14.",
"decoder.up_blocks.2.resnets.3.": "blocks.15.",
"decoder.up_blocks.3.resnets.0.": "blocks.17.",
"decoder.up_blocks.3.resnets.1.": "blocks.18.",
"decoder.up_blocks.3.resnets.2.": "blocks.19.",
"decoder.up_blocks.3.resnets.3.": "blocks.20.",
}
suffix_dict = {
"norm1.norm_layer.weight": "norm1.norm_layer.weight",
"norm1.norm_layer.bias": "norm1.norm_layer.bias",
"norm1.conv_y.conv.weight": "norm1.conv_y.weight",
"norm1.conv_y.conv.bias": "norm1.conv_y.bias",
"norm1.conv_b.conv.weight": "norm1.conv_b.weight",
"norm1.conv_b.conv.bias": "norm1.conv_b.bias",
"norm2.norm_layer.weight": "norm2.norm_layer.weight",
"norm2.norm_layer.bias": "norm2.norm_layer.bias",
"norm2.conv_y.conv.weight": "norm2.conv_y.weight",
"norm2.conv_y.conv.bias": "norm2.conv_y.bias",
"norm2.conv_b.conv.weight": "norm2.conv_b.weight",
"norm2.conv_b.conv.bias": "norm2.conv_b.bias",
"conv1.conv.weight": "conv1.weight",
"conv1.conv.bias": "conv1.bias",
"conv2.conv.weight": "conv2.weight",
"conv2.conv.bias": "conv2.bias",
"conv_shortcut.weight": "conv_shortcut.weight",
"conv_shortcut.bias": "conv_shortcut.bias",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
state_dict_[rename_dict[name]] = param
else:
for prefix in prefix_dict:
if name.startswith(prefix):
suffix = name[len(prefix):]
state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -8,11 +8,11 @@ from ..configs.model_config import preset_models_on_huggingface, preset_models_o
def download_from_modelscope(model_id, origin_file_path, local_dir): def download_from_modelscope(model_id, origin_file_path, local_dir):
os.makedirs(local_dir, exist_ok=True) os.makedirs(local_dir, exist_ok=True)
if os.path.basename(origin_file_path) in os.listdir(local_dir): file_name = os.path.basename(origin_file_path)
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.") if file_name in os.listdir(local_dir):
return print(f" {file_name} has been already in {local_dir}.")
else: else:
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}") print(f" Start downloading {os.path.join(local_dir, file_name)}")
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir) snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
downloaded_file_path = os.path.join(local_dir, origin_file_path) downloaded_file_path = os.path.join(local_dir, origin_file_path)
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1]) target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
@@ -23,12 +23,17 @@ def download_from_modelscope(model_id, origin_file_path, local_dir):
def download_from_huggingface(model_id, origin_file_path, local_dir): def download_from_huggingface(model_id, origin_file_path, local_dir):
os.makedirs(local_dir, exist_ok=True) os.makedirs(local_dir, exist_ok=True)
if os.path.basename(origin_file_path) in os.listdir(local_dir): file_name = os.path.basename(origin_file_path)
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.") if file_name in os.listdir(local_dir):
return print(f" {file_name} has been already in {local_dir}.")
else: else:
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}") print(f" Start downloading {os.path.join(local_dir, file_name)}")
hf_hub_download(model_id, origin_file_path, local_dir=local_dir) hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
downloaded_file_path = os.path.join(local_dir, origin_file_path)
target_file_path = os.path.join(local_dir, file_name)
if downloaded_file_path != target_file_path:
shutil.move(downloaded_file_path, target_file_path)
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
Preset_model_website: TypeAlias = Literal[ Preset_model_website: TypeAlias = Literal[
@@ -45,16 +50,14 @@ website_to_download_fn = {
} }
def download_models( def download_customized_models(
model_id_list: List[Preset_model_id] = [], model_id,
origin_file_path,
local_dir,
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"], downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
): ):
print(f"Downloading models: {model_id_list}")
downloaded_files = [] downloaded_files = []
for model_id in model_id_list:
for website in downloading_priority: for website in downloading_priority:
if model_id in website_to_preset_models[website]:
for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]:
# Check if the file is downloaded. # Check if the file is downloaded.
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path)) file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
if file_to_download in downloaded_files: if file_to_download in downloaded_files:
@@ -64,3 +67,45 @@ def download_models(
if os.path.basename(origin_file_path) in os.listdir(local_dir): if os.path.basename(origin_file_path) in os.listdir(local_dir):
downloaded_files.append(file_to_download) downloaded_files.append(file_to_download)
return downloaded_files return downloaded_files
def download_models(
model_id_list: List[Preset_model_id] = [],
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
):
print(f"Downloading models: {model_id_list}")
downloaded_files = []
load_files = []
for model_id in model_id_list:
for website in downloading_priority:
if model_id in website_to_preset_models[website]:
# Parse model metadata
model_metadata = website_to_preset_models[website][model_id]
if isinstance(model_metadata, list):
file_data = model_metadata
else:
file_data = model_metadata.get("file_list", [])
# Try downloading the model from this website.
model_files = []
for model_id, origin_file_path, local_dir in file_data:
# Check if the file is downloaded.
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
if file_to_download in downloaded_files:
continue
# Download
website_to_download_fn[website](model_id, origin_file_path, local_dir)
if os.path.basename(origin_file_path) in os.listdir(local_dir):
downloaded_files.append(file_to_download)
model_files.append(file_to_download)
# If the model is successfully downloaded, break.
if len(model_files) > 0:
if isinstance(model_metadata, dict) and "load_path" in model_metadata:
model_files = model_metadata["load_path"]
load_files.extend(model_files)
break
return load_files

View File

@@ -0,0 +1,327 @@
import torch
from einops import rearrange, repeat
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm
from .utils import hash_state_dict_keys, init_weights_on_device
class FluxControlNet(torch.nn.Module):
def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
self.time_embedder = TimestepEmbeddings(256, 3072)
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
self.context_embedder = torch.nn.Linear(4096, 3072)
self.x_embedder = torch.nn.Linear(64, 3072)
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)])
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)])
self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)])
self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)])
self.mode_dict = mode_dict
self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None
self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072)
def prepare_image_ids(self, latents):
batch_size, _, height, width = latents.shape
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
return latent_image_ids
def patchify(self, hidden_states):
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
return hidden_states
def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states):
if len(res_stack) == 0:
return [torch.zeros_like(hidden_states)] * num_blocks
interval = (num_blocks + len(res_stack) - 1) // len(res_stack)
aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)]
return aligned_res_stack
def forward(
self,
hidden_states,
controlnet_conditioning,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
processor_id=None,
tiled=False, tile_size=128, tile_stride=64,
**kwargs
):
if image_ids is None:
image_ids = self.prepare_image_ids(hidden_states)
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
if self.guidance_embedder is not None:
guidance = guidance * 1000
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
prompt_emb = self.context_embedder(prompt_emb)
if self.controlnet_mode_embedder is not None: # Different from FluxDiT
processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int)
processor_id = repeat(processor_id, "D -> B D", B=1).to(text_ids.device)
prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1)
text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
hidden_states = self.patchify(hidden_states)
hidden_states = self.x_embedder(hidden_states)
controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT
controlnet_res_stack = []
for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
controlnet_res_stack.append(controlnet_block(hidden_states))
controlnet_single_res_stack = []
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:]))
controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:])
controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:])
return controlnet_res_stack, controlnet_single_res_stack
@staticmethod
def state_dict_converter():
return FluxControlNetStateDictConverter()
def quantize(self):
def cast_to(weight, dtype=None, device=None, copy=False):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
return weight.to(dtype=dtype, copy=copy)
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight)
return r
def cast_weight(s, input=None, dtype=None, device=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if device is None:
device = input.device
weight = cast_to(s.weight, dtype, device)
return weight
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
bias = None
weight = cast_to(s.weight, dtype, device)
bias = cast_to(s.bias, bias_dtype, device)
return weight, bias
class quantized_layer:
class QLinear(torch.nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self,input,**kwargs):
weight,bias= cast_bias_weight(self,input)
return torch.nn.functional.linear(input,weight,bias)
class QRMSNorm(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self,hidden_states,**kwargs):
weight= cast_weight(self.module,hidden_states)
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
hidden_states = hidden_states.to(input_dtype) * weight
return hidden_states
class QEmbedding(torch.nn.Embedding):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self,input,**kwargs):
weight= cast_weight(self,input)
return torch.nn.functional.embedding(
input, weight, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse)
def replace_layer(model):
for name, module in model.named_children():
if isinstance(module,quantized_layer.QRMSNorm):
continue
if isinstance(module, torch.nn.Linear):
with init_weights_on_device():
new_layer = quantized_layer.QLinear(module.in_features,module.out_features)
new_layer.weight = module.weight
if module.bias is not None:
new_layer.bias = module.bias
setattr(model, name, new_layer)
elif isinstance(module, RMSNorm):
if hasattr(module,"quantized"):
continue
module.quantized= True
new_layer = quantized_layer.QRMSNorm(module)
setattr(model, name, new_layer)
elif isinstance(module,torch.nn.Embedding):
rows, cols = module.weight.shape
new_layer = quantized_layer.QEmbedding(
num_embeddings=rows,
embedding_dim=cols,
_weight=module.weight,
# _freeze=module.freeze,
padding_idx=module.padding_idx,
max_norm=module.max_norm,
norm_type=module.norm_type,
scale_grad_by_freq=module.scale_grad_by_freq,
sparse=module.sparse)
setattr(model, name, new_layer)
else:
replace_layer(module)
replace_layer(self)
class FluxControlNetStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
hash_value = hash_state_dict_keys(state_dict)
global_rename_dict = {
"context_embedder": "context_embedder",
"x_embedder": "x_embedder",
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
"norm_out.linear": "final_norm_out.linear",
"proj_out": "final_proj_out",
}
rename_dict = {
"proj_out": "proj_out",
"norm1.linear": "norm1_a.linear",
"norm1_context.linear": "norm1_b.linear",
"attn.to_q": "attn.a_to_q",
"attn.to_k": "attn.a_to_k",
"attn.to_v": "attn.a_to_v",
"attn.to_out.0": "attn.a_to_out",
"attn.add_q_proj": "attn.b_to_q",
"attn.add_k_proj": "attn.b_to_k",
"attn.add_v_proj": "attn.b_to_v",
"attn.to_add_out": "attn.b_to_out",
"ff.net.0.proj": "ff_a.0",
"ff.net.2": "ff_a.2",
"ff_context.net.0.proj": "ff_b.0",
"ff_context.net.2": "ff_b.2",
"attn.norm_q": "attn.norm_q_a",
"attn.norm_k": "attn.norm_k_a",
"attn.norm_added_q": "attn.norm_q_b",
"attn.norm_added_k": "attn.norm_k_b",
}
rename_dict_single = {
"attn.to_q": "a_to_q",
"attn.to_k": "a_to_k",
"attn.to_v": "a_to_v",
"attn.norm_q": "norm_q_a",
"attn.norm_k": "norm_k_a",
"norm.linear": "norm.linear",
"proj_mlp": "proj_in_besides_attn",
"proj_out": "proj_out",
}
state_dict_ = {}
for name, param in state_dict.items():
if name.endswith(".weight") or name.endswith(".bias"):
suffix = ".weight" if name.endswith(".weight") else ".bias"
prefix = name[:-len(suffix)]
if prefix in global_rename_dict:
state_dict_[global_rename_dict[prefix] + suffix] = param
elif prefix.startswith("transformer_blocks."):
names = prefix.split(".")
names[0] = "blocks"
middle = ".".join(names[2:])
if middle in rename_dict:
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
state_dict_[name_] = param
elif prefix.startswith("single_transformer_blocks."):
names = prefix.split(".")
names[0] = "single_blocks"
middle = ".".join(names[2:])
if middle in rename_dict_single:
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
state_dict_[name_] = param
else:
state_dict_[name] = param
else:
state_dict_[name] = param
for name in list(state_dict_.keys()):
if ".proj_in_besides_attn." in name:
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
param = torch.concat([
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
state_dict_[name],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
state_dict_.pop(name)
for name in list(state_dict_.keys()):
for component in ["a", "b"]:
if f".{component}_to_q." in name:
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
if hash_value == "78d18b9101345ff695f312e7e62538c0":
extra_kwargs = {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}}
elif hash_value == "b001c89139b5f053c715fe772362dd2a":
extra_kwargs = {"num_single_blocks": 0}
elif hash_value == "52357cb26250681367488a8954c271e8":
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
else:
extra_kwargs = {}
return state_dict_, extra_kwargs
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -0,0 +1,739 @@
import torch
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm, RMSNorm
from einops import rearrange
from .tiler import TileWorker
from .utils import init_weights_on_device
def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
batch_size, num_tokens = hidden_states.shape[0:2]
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, num_tokens, -1)
hidden_states = hidden_states + scale * ip_hidden_states
return hidden_states
class RoPEEmbedding(torch.nn.Module):
def __init__(self, dim, theta, axes_dim):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
batch_size, seq_length = pos.shape
out = torch.einsum("...n,d->...nd", pos, omega)
cos_out = torch.cos(out)
sin_out = torch.sin(out)
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
return out.float()
def forward(self, ids):
n_axes = ids.shape[-1]
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
return emb.unsqueeze(1)
class FluxJointAttention(torch.nn.Module):
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.only_out_a = only_out_a
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
self.norm_q_b = RMSNorm(head_dim, eps=1e-6)
self.norm_k_b = RMSNorm(head_dim, eps=1e-6)
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
if not only_out_a:
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
def apply_rope(self, xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
batch_size = hidden_states_a.shape[0]
# Part A
qkv_a = self.a_to_qkv(hidden_states_a)
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q_a, k_a, v_a = qkv_a.chunk(3, dim=1)
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
# Part B
qkv_b = self.b_to_qkv(hidden_states_b)
qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q_b, k_b, v_b = qkv_b.chunk(3, dim=1)
q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b)
q = torch.concat([q_b, q_a], dim=2)
k = torch.concat([k_b, k_a], dim=2)
v = torch.concat([v_b, v_a], dim=2)
q, k = self.apply_rope(q, k, image_rotary_emb)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
if ipadapter_kwargs_list is not None:
hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list)
hidden_states_a = self.a_to_out(hidden_states_a)
if self.only_out_a:
return hidden_states_a
else:
hidden_states_b = self.b_to_out(hidden_states_b)
return hidden_states_a, hidden_states_b
class FluxJointTransformerBlock(torch.nn.Module):
def __init__(self, dim, num_attention_heads):
super().__init__()
self.norm1_a = AdaLayerNorm(dim)
self.norm1_b = AdaLayerNorm(dim)
self.attn = FluxJointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_a = torch.nn.Sequential(
torch.nn.Linear(dim, dim*4),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(dim*4, dim)
)
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_b = torch.nn.Sequential(
torch.nn.Linear(dim, dim*4),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(dim*4, dim)
)
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
# Attention
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
# Part A
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
# Part B
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
return hidden_states_a, hidden_states_b
class FluxSingleAttention(torch.nn.Module):
def __init__(self, dim_a, dim_b, num_heads, head_dim):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
def apply_rope(self, xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def forward(self, hidden_states, image_rotary_emb):
batch_size = hidden_states.shape[0]
qkv_a = self.a_to_qkv(hidden_states)
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q_a, k_a, v = qkv_a.chunk(3, dim=1)
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
q, k = self.apply_rope(q_a, k_a, image_rotary_emb)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
return hidden_states
class AdaLayerNormSingle(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = torch.nn.SiLU()
self.linear = torch.nn.Linear(dim, 3 * dim, bias=True)
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa
class FluxSingleTransformerBlock(torch.nn.Module):
def __init__(self, dim, num_attention_heads):
super().__init__()
self.num_heads = num_attention_heads
self.head_dim = dim // num_attention_heads
self.dim = dim
self.norm = AdaLayerNormSingle(dim)
self.to_qkv_mlp = torch.nn.Linear(dim, dim * (3 + 4))
self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6)
self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6)
self.proj_out = torch.nn.Linear(dim * 5, dim)
def apply_rope(self, xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
batch_size = hidden_states.shape[0]
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q, k, v = qkv.chunk(3, dim=1)
q, k = self.norm_q_a(q), self.norm_k_a(k)
q, k = self.apply_rope(q, k, image_rotary_emb)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
if ipadapter_kwargs_list is not None:
hidden_states = interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs_list)
return hidden_states
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
residual = hidden_states_a
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
hidden_states_a = residual + hidden_states_a
return hidden_states_a, hidden_states_b
class AdaLayerNormContinuous(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = torch.nn.SiLU()
self.linear = torch.nn.Linear(dim, dim * 2, bias=True)
self.norm = torch.nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
def forward(self, x, conditioning):
emb = self.linear(self.silu(conditioning))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]
return x
class FluxDiT(torch.nn.Module):
def __init__(self, disable_guidance_embedder=False):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
self.time_embedder = TimestepEmbeddings(256, 3072)
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
self.context_embedder = torch.nn.Linear(4096, 3072)
self.x_embedder = torch.nn.Linear(64, 3072)
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
self.final_norm_out = AdaLayerNormContinuous(3072)
self.final_proj_out = torch.nn.Linear(3072, 64)
def patchify(self, hidden_states):
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
return hidden_states
def unpatchify(self, hidden_states, height, width):
hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
return hidden_states
def prepare_image_ids(self, latents):
batch_size, _, height, width = latents.shape
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
return latent_image_ids
def tiled_forward(
self,
hidden_states,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
tile_size=128, tile_stride=64,
**kwargs
):
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
hidden_states = TileWorker().tiled_forward(
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None),
hidden_states,
tile_size,
tile_stride,
tile_device=hidden_states.device,
tile_dtype=hidden_states.dtype
)
return hidden_states
def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
N = len(entity_masks)
batch_size = entity_masks[0].shape[0]
total_seq_len = N * prompt_seq_len + image_seq_len
patched_masks = [self.patchify(entity_masks[i]) for i in range(N)]
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
image_start = N * prompt_seq_len
image_end = N * prompt_seq_len + image_seq_len
# prompt-image mask
for i in range(N):
prompt_start = i * prompt_seq_len
prompt_end = (i + 1) * prompt_seq_len
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1)
# prompt update with image
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
# image update with prompt
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
# prompt-prompt mask
for i in range(N):
for j in range(N):
if i != j:
prompt_start_i = i * prompt_seq_len
prompt_end_i = (i + 1) * prompt_seq_len
prompt_start_j = j * prompt_seq_len
prompt_end_j = (j + 1) * prompt_seq_len
attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False
attention_mask = attention_mask.float()
attention_mask[attention_mask == 0] = float('-inf')
attention_mask[attention_mask == 1] = 0
return attention_mask
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids):
repeat_dim = hidden_states.shape[1]
max_masks = 0
attention_mask = None
prompt_embs = [prompt_emb]
if entity_masks is not None:
# entity_masks
batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
# global mask
global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
entity_masks = entity_masks + [global_mask] # append global to last
# attention mask
attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
attention_mask = attention_mask.unsqueeze(1)
# embds: n_masks * b * seq * d
local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)]
prompt_embs = local_embs + prompt_embs # append global to last
prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
prompt_emb = torch.cat(prompt_embs, dim=1)
# positional embedding
text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
return prompt_emb, image_rotary_emb, attention_mask
def forward(
self,
hidden_states,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None,
use_gradient_checkpointing=False,
**kwargs
):
if tiled:
return self.tiled_forward(
hidden_states,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
tile_size=tile_size, tile_stride=tile_stride,
**kwargs
)
if image_ids is None:
image_ids = self.prepare_image_ids(hidden_states)
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
if self.guidance_embedder is not None:
guidance = guidance * 1000
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
height, width = hidden_states.shape[-2:]
hidden_states = self.patchify(hidden_states)
hidden_states = self.x_embedder(hidden_states)
if entity_prompt_emb is not None and entity_masks is not None:
prompt_emb, image_rotary_emb, attention_mask = self.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
else:
prompt_emb = self.context_embedder(prompt_emb)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
attention_mask = None
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.blocks:
if self.training and use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
for block in self.single_blocks:
if self.training and use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
hidden_states = self.final_norm_out(hidden_states, conditioning)
hidden_states = self.final_proj_out(hidden_states)
hidden_states = self.unpatchify(hidden_states, height, width)
return hidden_states
def quantize(self):
def cast_to(weight, dtype=None, device=None, copy=False):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
return weight.to(dtype=dtype, copy=copy)
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight)
return r
def cast_weight(s, input=None, dtype=None, device=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if device is None:
device = input.device
weight = cast_to(s.weight, dtype, device)
return weight
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
bias = None
weight = cast_to(s.weight, dtype, device)
bias = cast_to(s.bias, bias_dtype, device)
return weight, bias
class quantized_layer:
class Linear(torch.nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self,input,**kwargs):
weight,bias= cast_bias_weight(self,input)
return torch.nn.functional.linear(input,weight,bias)
class RMSNorm(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self,hidden_states,**kwargs):
weight= cast_weight(self.module,hidden_states)
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
hidden_states = hidden_states.to(input_dtype) * weight
return hidden_states
def replace_layer(model):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
with init_weights_on_device():
new_layer = quantized_layer.Linear(module.in_features,module.out_features)
new_layer.weight = module.weight
if module.bias is not None:
new_layer.bias = module.bias
# del module
setattr(model, name, new_layer)
elif isinstance(module, RMSNorm):
if hasattr(module,"quantized"):
continue
module.quantized= True
new_layer = quantized_layer.RMSNorm(module)
setattr(model, name, new_layer)
else:
replace_layer(module)
replace_layer(self)
@staticmethod
def state_dict_converter():
return FluxDiTStateDictConverter()
class FluxDiTStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
global_rename_dict = {
"context_embedder": "context_embedder",
"x_embedder": "x_embedder",
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
"norm_out.linear": "final_norm_out.linear",
"proj_out": "final_proj_out",
}
rename_dict = {
"proj_out": "proj_out",
"norm1.linear": "norm1_a.linear",
"norm1_context.linear": "norm1_b.linear",
"attn.to_q": "attn.a_to_q",
"attn.to_k": "attn.a_to_k",
"attn.to_v": "attn.a_to_v",
"attn.to_out.0": "attn.a_to_out",
"attn.add_q_proj": "attn.b_to_q",
"attn.add_k_proj": "attn.b_to_k",
"attn.add_v_proj": "attn.b_to_v",
"attn.to_add_out": "attn.b_to_out",
"ff.net.0.proj": "ff_a.0",
"ff.net.2": "ff_a.2",
"ff_context.net.0.proj": "ff_b.0",
"ff_context.net.2": "ff_b.2",
"attn.norm_q": "attn.norm_q_a",
"attn.norm_k": "attn.norm_k_a",
"attn.norm_added_q": "attn.norm_q_b",
"attn.norm_added_k": "attn.norm_k_b",
}
rename_dict_single = {
"attn.to_q": "a_to_q",
"attn.to_k": "a_to_k",
"attn.to_v": "a_to_v",
"attn.norm_q": "norm_q_a",
"attn.norm_k": "norm_k_a",
"norm.linear": "norm.linear",
"proj_mlp": "proj_in_besides_attn",
"proj_out": "proj_out",
}
state_dict_ = {}
for name, param in state_dict.items():
if name.endswith(".weight") or name.endswith(".bias"):
suffix = ".weight" if name.endswith(".weight") else ".bias"
prefix = name[:-len(suffix)]
if prefix in global_rename_dict:
state_dict_[global_rename_dict[prefix] + suffix] = param
elif prefix.startswith("transformer_blocks."):
names = prefix.split(".")
names[0] = "blocks"
middle = ".".join(names[2:])
if middle in rename_dict:
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
state_dict_[name_] = param
elif prefix.startswith("single_transformer_blocks."):
names = prefix.split(".")
names[0] = "single_blocks"
middle = ".".join(names[2:])
if middle in rename_dict_single:
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
state_dict_[name_] = param
else:
pass
else:
pass
for name in list(state_dict_.keys()):
if ".proj_in_besides_attn." in name:
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
param = torch.concat([
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
state_dict_[name],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
state_dict_.pop(name)
for name in list(state_dict_.keys()):
for component in ["a", "b"]:
if f".{component}_to_q." in name:
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
return state_dict_
def from_civitai(self, state_dict):
rename_dict = {
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
"time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight",
"time_in.out_layer.bias": "time_embedder.timestep_embedder.2.bias",
"time_in.out_layer.weight": "time_embedder.timestep_embedder.2.weight",
"txt_in.bias": "context_embedder.bias",
"txt_in.weight": "context_embedder.weight",
"vector_in.in_layer.bias": "pooled_text_embedder.0.bias",
"vector_in.in_layer.weight": "pooled_text_embedder.0.weight",
"vector_in.out_layer.bias": "pooled_text_embedder.2.bias",
"vector_in.out_layer.weight": "pooled_text_embedder.2.weight",
"final_layer.linear.bias": "final_proj_out.bias",
"final_layer.linear.weight": "final_proj_out.weight",
"guidance_in.in_layer.bias": "guidance_embedder.timestep_embedder.0.bias",
"guidance_in.in_layer.weight": "guidance_embedder.timestep_embedder.0.weight",
"guidance_in.out_layer.bias": "guidance_embedder.timestep_embedder.2.bias",
"guidance_in.out_layer.weight": "guidance_embedder.timestep_embedder.2.weight",
"img_in.bias": "x_embedder.bias",
"img_in.weight": "x_embedder.weight",
"final_layer.adaLN_modulation.1.weight": "final_norm_out.linear.weight",
"final_layer.adaLN_modulation.1.bias": "final_norm_out.linear.bias",
}
suffix_rename_dict = {
"img_attn.norm.key_norm.scale": "attn.norm_k_a.weight",
"img_attn.norm.query_norm.scale": "attn.norm_q_a.weight",
"img_attn.proj.bias": "attn.a_to_out.bias",
"img_attn.proj.weight": "attn.a_to_out.weight",
"img_attn.qkv.bias": "attn.a_to_qkv.bias",
"img_attn.qkv.weight": "attn.a_to_qkv.weight",
"img_mlp.0.bias": "ff_a.0.bias",
"img_mlp.0.weight": "ff_a.0.weight",
"img_mlp.2.bias": "ff_a.2.bias",
"img_mlp.2.weight": "ff_a.2.weight",
"img_mod.lin.bias": "norm1_a.linear.bias",
"img_mod.lin.weight": "norm1_a.linear.weight",
"txt_attn.norm.key_norm.scale": "attn.norm_k_b.weight",
"txt_attn.norm.query_norm.scale": "attn.norm_q_b.weight",
"txt_attn.proj.bias": "attn.b_to_out.bias",
"txt_attn.proj.weight": "attn.b_to_out.weight",
"txt_attn.qkv.bias": "attn.b_to_qkv.bias",
"txt_attn.qkv.weight": "attn.b_to_qkv.weight",
"txt_mlp.0.bias": "ff_b.0.bias",
"txt_mlp.0.weight": "ff_b.0.weight",
"txt_mlp.2.bias": "ff_b.2.bias",
"txt_mlp.2.weight": "ff_b.2.weight",
"txt_mod.lin.bias": "norm1_b.linear.bias",
"txt_mod.lin.weight": "norm1_b.linear.weight",
"linear1.bias": "to_qkv_mlp.bias",
"linear1.weight": "to_qkv_mlp.weight",
"linear2.bias": "proj_out.bias",
"linear2.weight": "proj_out.weight",
"modulation.lin.bias": "norm.linear.bias",
"modulation.lin.weight": "norm.linear.weight",
"norm.key_norm.scale": "norm_k_a.weight",
"norm.query_norm.scale": "norm_q_a.weight",
}
state_dict_ = {}
for name, param in state_dict.items():
if name.startswith("model.diffusion_model."):
name = name[len("model.diffusion_model."):]
names = name.split(".")
if name in rename_dict:
rename = rename_dict[name]
if name.startswith("final_layer.adaLN_modulation.1."):
param = torch.concat([param[3072:], param[:3072]], dim=0)
state_dict_[rename] = param
elif names[0] == "double_blocks":
rename = f"blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
state_dict_[rename] = param
elif names[0] == "single_blocks":
if ".".join(names[2:]) in suffix_rename_dict:
rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
state_dict_[rename] = param
else:
pass
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
return state_dict_, {"disable_guidance_embedder": True}
else:
return state_dict_

View File

@@ -0,0 +1,94 @@
from .svd_image_encoder import SVDImageEncoder
from .sd3_dit import RMSNorm
from transformers import CLIPImageProcessor
import torch
class MLPProjModel(torch.nn.Module):
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.num_tokens = num_tokens
self.proj = torch.nn.Sequential(
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
torch.nn.GELU(),
torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, id_embeds):
x = self.proj(id_embeds)
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
x = self.norm(x)
return x
class IpAdapterModule(torch.nn.Module):
def __init__(self, num_attention_heads, attention_head_dim, input_dim):
super().__init__()
self.num_heads = num_attention_heads
self.head_dim = attention_head_dim
output_dim = num_attention_heads * attention_head_dim
self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
self.norm_added_k = RMSNorm(attention_head_dim, eps=1e-5, elementwise_affine=False)
def forward(self, hidden_states):
batch_size = hidden_states.shape[0]
# ip_k
ip_k = self.to_k_ip(hidden_states)
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
ip_k = self.norm_added_k(ip_k)
# ip_v
ip_v = self.to_v_ip(hidden_states)
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
return ip_k, ip_v
class FluxIpAdapter(torch.nn.Module):
def __init__(self, num_attention_heads=24, attention_head_dim=128, cross_attention_dim=4096, num_tokens=128, num_blocks=57):
super().__init__()
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(num_attention_heads, attention_head_dim, cross_attention_dim) for _ in range(num_blocks)])
self.image_proj = MLPProjModel(cross_attention_dim=cross_attention_dim, id_embeddings_dim=1152, num_tokens=num_tokens)
self.set_adapter()
def set_adapter(self):
self.call_block_id = {i:i for i in range(len(self.ipadapter_modules))}
def forward(self, hidden_states, scale=1.0):
hidden_states = self.image_proj(hidden_states)
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
ip_kv_dict = {}
for block_id in self.call_block_id:
ipadapter_id = self.call_block_id[block_id]
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
ip_kv_dict[block_id] = {
"ip_k": ip_k,
"ip_v": ip_v,
"scale": scale
}
return ip_kv_dict
@staticmethod
def state_dict_converter():
return FluxIpAdapterStateDictConverter()
class FluxIpAdapterStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {}
for name in state_dict["ip_adapter"]:
name_ = 'ipadapter_modules.' + name
state_dict_[name_] = state_dict["ip_adapter"][name]
for name in state_dict["image_proj"]:
name_ = "image_proj." + name
state_dict_[name_] = state_dict["image_proj"][name]
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -0,0 +1,32 @@
import torch
from transformers import T5EncoderModel, T5Config
from .sd_text_encoder import SDTextEncoder
class FluxTextEncoder2(T5EncoderModel):
def __init__(self, config):
super().__init__(config)
self.eval()
def forward(self, input_ids):
outputs = super().forward(input_ids=input_ids)
prompt_emb = outputs.last_hidden_state
return prompt_emb
@staticmethod
def state_dict_converter():
return FluxTextEncoder2StateDictConverter()
class FluxTextEncoder2StateDictConverter():
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = state_dict
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -0,0 +1,303 @@
from .sd3_vae_encoder import SD3VAEEncoder, SDVAEEncoderStateDictConverter
from .sd3_vae_decoder import SD3VAEDecoder, SDVAEDecoderStateDictConverter
class FluxVAEEncoder(SD3VAEEncoder):
def __init__(self):
super().__init__()
self.scaling_factor = 0.3611
self.shift_factor = 0.1159
@staticmethod
def state_dict_converter():
return FluxVAEEncoderStateDictConverter()
class FluxVAEDecoder(SD3VAEDecoder):
def __init__(self):
super().__init__()
self.scaling_factor = 0.3611
self.shift_factor = 0.1159
@staticmethod
def state_dict_converter():
return FluxVAEDecoderStateDictConverter()
class FluxVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
def __init__(self):
pass
def from_civitai(self, state_dict):
rename_dict = {
"encoder.conv_in.bias": "conv_in.bias",
"encoder.conv_in.weight": "conv_in.weight",
"encoder.conv_out.bias": "conv_out.bias",
"encoder.conv_out.weight": "conv_out.weight",
"encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
"encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
"encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
"encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
"encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
"encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
"encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
"encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
"encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
"encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
"encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
"encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
"encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
"encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
"encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
"encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
"encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
"encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
"encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
"encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
"encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
"encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
"encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
"encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
"encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
"encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
"encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
"encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
"encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
"encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
"encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
"encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
"encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
"encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
"encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
"encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
"encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
"encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
"encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
"encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
"encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
"encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
"encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
"encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
"encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
"encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
"encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
"encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
"encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
"encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
"encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
"encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
"encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
"encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
"encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
"encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
"encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
"encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
"encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
"encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
"encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
"encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
"encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
"encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
"encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
"encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
"encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
"encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
"encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
"encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
"encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
"encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
"encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
"encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
"encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
"encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
"encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
"encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
"encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
"encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
"encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
"encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
"encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
"encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
"encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
"encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
"encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
"encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
"encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
"encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
"encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
"encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
"encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
"encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
"encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
"encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
"encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
"encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
"encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
"encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
"encoder.norm_out.bias": "conv_norm_out.bias",
"encoder.norm_out.weight": "conv_norm_out.weight",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if "transformer_blocks" in rename_dict[name]:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
return state_dict_
class FluxVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
def __init__(self):
pass
def from_civitai(self, state_dict):
rename_dict = {
"decoder.conv_in.bias": "conv_in.bias",
"decoder.conv_in.weight": "conv_in.weight",
"decoder.conv_out.bias": "conv_out.bias",
"decoder.conv_out.weight": "conv_out.weight",
"decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
"decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
"decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
"decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
"decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
"decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
"decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
"decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
"decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
"decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
"decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
"decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
"decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
"decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
"decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
"decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
"decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
"decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
"decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
"decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
"decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
"decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
"decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
"decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
"decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
"decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
"decoder.norm_out.bias": "conv_norm_out.bias",
"decoder.norm_out.weight": "conv_norm_out.weight",
"decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
"decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
"decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
"decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
"decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
"decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
"decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
"decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
"decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
"decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
"decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
"decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
"decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
"decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
"decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
"decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
"decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
"decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
"decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
"decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
"decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
"decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
"decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
"decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
"decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
"decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
"decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
"decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
"decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
"decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
"decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
"decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
"decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
"decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
"decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
"decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
"decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
"decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
"decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
"decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
"decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
"decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
"decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
"decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
"decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
"decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
"decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
"decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
"decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
"decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
"decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
"decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
"decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
"decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
"decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
"decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
"decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
"decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
"decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
"decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
"decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
"decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
"decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
"decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
"decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
"decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
"decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
"decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
"decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
"decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
"decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
"decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
"decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
"decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
"decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
"decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
"decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
"decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
"decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
"decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
"decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
"decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
"decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
"decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
"decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
"decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
"decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
"decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
"decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
"decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
"decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
"decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
"decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
"decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
"decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
"decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
"decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
"decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
"decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
"decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
"decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
"decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
"decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
"decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
"decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
"decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if "transformer_blocks" in rename_dict[name]:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
return state_dict_

View File

@@ -0,0 +1,885 @@
import torch
from .sd3_dit import TimestepEmbeddings, RMSNorm
from .utils import init_weights_on_device
from einops import rearrange, repeat
from tqdm import tqdm
from typing import Union, Tuple, List
def HunyuanVideoRope(latents):
def _to_tuple(x, dim=2):
if isinstance(x, int):
return (x,) * dim
elif len(x) == dim:
return x
else:
raise ValueError(f"Expected length {dim} or int, but got {x}")
def get_meshgrid_nd(start, *args, dim=2):
"""
Get n-D meshgrid with start, stop and num.
Args:
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
n-tuples.
*args: See above.
dim (int): Dimension of the meshgrid. Defaults to 2.
Returns:
grid (np.ndarray): [dim, ...]
"""
if len(args) == 0:
# start is grid_size
num = _to_tuple(start, dim=dim)
start = (0,) * dim
stop = num
elif len(args) == 1:
# start is start, args[0] is stop, step is 1
start = _to_tuple(start, dim=dim)
stop = _to_tuple(args[0], dim=dim)
num = [stop[i] - start[i] for i in range(dim)]
elif len(args) == 2:
# start is start, args[0] is stop, args[1] is num
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
axis_grid = []
for i in range(dim):
a, b, n = start[i], stop[i], num[i]
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
axis_grid.append(g)
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
return grid
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[torch.FloatTensor, int],
theta: float = 10000.0,
use_real: bool = False,
theta_rescale_factor: float = 1.0,
interpolation_factor: float = 1.0,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool, optional): If True, return real part and imaginary part separately.
Otherwise, return complex numbers.
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
Returns:
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
"""
if isinstance(pos, int):
pos = torch.arange(pos).float()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
if theta_rescale_factor != 1.0:
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
) # [D/2]
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(
torch.ones_like(freqs), freqs
) # complex64 # [S, D/2]
return freqs_cis
def get_nd_rotary_pos_embed(
rope_dim_list,
start,
*args,
theta=10000.0,
use_real=False,
theta_rescale_factor: Union[float, List[float]] = 1.0,
interpolation_factor: Union[float, List[float]] = 1.0,
):
"""
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
Args:
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
sum(rope_dim_list) should equal to head_dim of attention layer.
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
*args: See above.
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
part and an imaginary part separately.
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
Returns:
pos_embed (torch.Tensor): [HW, D/2]
"""
grid = get_meshgrid_nd(
start, *args, dim=len(rope_dim_list)
) # [3, W, H, D] / [2, W, H]
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
assert len(theta_rescale_factor) == len(
rope_dim_list
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
assert len(interpolation_factor) == len(
rope_dim_list
), "len(interpolation_factor) should equal to len(rope_dim_list)"
# use 1/ndim of dimensions to encode grid_axis
embs = []
for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(
rope_dim_list[i],
grid[i].reshape(-1),
theta,
use_real=use_real,
theta_rescale_factor=theta_rescale_factor[i],
interpolation_factor=interpolation_factor[i],
) # 2 x [WHD, rope_dim_list[i]]
embs.append(emb)
if use_real:
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
return cos, sin
else:
emb = torch.cat(embs, dim=1) # (WHD, D/2)
return emb
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
[16, 56, 56],
[latents.shape[2], latents.shape[3] // 2, latents.shape[4] // 2],
theta=256,
use_real=True,
theta_rescale_factor=1,
)
return freqs_cos, freqs_sin
class PatchEmbed(torch.nn.Module):
def __init__(self, patch_size=(1, 2, 2), in_channels=16, embed_dim=3072):
super().__init__()
self.proj = torch.nn.Conv3d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
return x
class IndividualTokenRefinerBlock(torch.nn.Module):
def __init__(self, hidden_size=3072, num_heads=24):
super().__init__()
self.num_heads = num_heads
self.norm1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
self.self_attn_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
self.self_attn_proj = torch.nn.Linear(hidden_size, hidden_size)
self.norm2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
self.mlp = torch.nn.Sequential(
torch.nn.Linear(hidden_size, hidden_size * 4),
torch.nn.SiLU(),
torch.nn.Linear(hidden_size * 4, hidden_size)
)
self.adaLN_modulation = torch.nn.Sequential(
torch.nn.SiLU(),
torch.nn.Linear(hidden_size, hidden_size * 2, device="cuda", dtype=torch.bfloat16),
)
def forward(self, x, c, attn_mask=None):
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
qkv = self.self_attn_qkv(norm_x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
attn = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
attn = rearrange(attn, "B H L D -> B L (H D)")
x = x + self.self_attn_proj(attn) * gate_msa.unsqueeze(1)
x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1)
return x
class SingleTokenRefiner(torch.nn.Module):
def __init__(self, in_channels=4096, hidden_size=3072, depth=2):
super().__init__()
self.input_embedder = torch.nn.Linear(in_channels, hidden_size, bias=True)
self.t_embedder = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
self.c_embedder = torch.nn.Sequential(
torch.nn.Linear(in_channels, hidden_size),
torch.nn.SiLU(),
torch.nn.Linear(hidden_size, hidden_size)
)
self.blocks = torch.nn.ModuleList([IndividualTokenRefinerBlock(hidden_size=hidden_size) for _ in range(depth)])
def forward(self, x, t, mask=None):
timestep_aware_representations = self.t_embedder(t, dtype=torch.float32)
mask_float = mask.float().unsqueeze(-1)
context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
context_aware_representations = self.c_embedder(context_aware_representations)
c = timestep_aware_representations + context_aware_representations
x = self.input_embedder(x)
mask = mask.to(device=x.device, dtype=torch.bool)
mask = repeat(mask, "B L -> B 1 D L", D=mask.shape[-1])
mask = mask & mask.transpose(2, 3)
mask[:, :, :, 0] = True
for block in self.blocks:
x = block(x, c, mask)
return x
class ModulateDiT(torch.nn.Module):
def __init__(self, hidden_size, factor=6):
super().__init__()
self.act = torch.nn.SiLU()
self.linear = torch.nn.Linear(hidden_size, factor * hidden_size)
def forward(self, x):
return self.linear(self.act(x))
def modulate(x, shift=None, scale=None):
if scale is None and shift is None:
return x
elif shift is None:
return x * (1 + scale.unsqueeze(1))
elif scale is None:
return x + shift.unsqueeze(1)
else:
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def reshape_for_broadcast(
freqs_cis,
x: torch.Tensor,
head_first=False,
):
ndim = x.ndim
assert 0 <= 1 < ndim
if isinstance(freqs_cis, tuple):
# freqs_cis: (cos, sin) in real space
if head_first:
assert freqs_cis[0].shape == (
x.shape[-2],
x.shape[-1],
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
shape = [
d if i == ndim - 2 or i == ndim - 1 else 1
for i, d in enumerate(x.shape)
]
else:
assert freqs_cis[0].shape == (
x.shape[1],
x.shape[-1],
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
else:
# freqs_cis: values in complex space
if head_first:
assert freqs_cis.shape == (
x.shape[-2],
x.shape[-1],
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
shape = [
d if i == ndim - 2 or i == ndim - 1 else 1
for i, d in enumerate(x.shape)
]
else:
assert freqs_cis.shape == (
x.shape[1],
x.shape[-1],
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def rotate_half(x):
x_real, x_imag = (
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
) # [B, S, H, D//2]
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis,
head_first: bool = False,
):
xk_out = None
if isinstance(freqs_cis, tuple):
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
cos, sin = cos.to(xq.device), sin.to(xq.device)
# real * cos - imag * sin
# imag * cos + real * sin
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
else:
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
xq_ = torch.view_as_complex(
xq.float().reshape(*xq.shape[:-1], -1, 2)
) # [B, S, H, D//2]
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
xq.device
) # [S, D//2] --> [1, S, 1, D//2]
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
xk_ = torch.view_as_complex(
xk.float().reshape(*xk.shape[:-1], -1, 2)
) # [B, S, H, D//2]
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
return xq_out, xk_out
def attention(q, k, v):
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = x.transpose(1, 2).flatten(2, 3)
return x
class MMDoubleStreamBlockComponent(torch.nn.Module):
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
super().__init__()
self.heads_num = heads_num
self.mod = ModulateDiT(hidden_size)
self.norm1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.to_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
self.norm_q = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
self.norm_k = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
self.to_out = torch.nn.Linear(hidden_size, hidden_size)
self.norm2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.ff = torch.nn.Sequential(
torch.nn.Linear(hidden_size, hidden_size * mlp_width_ratio),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size)
)
def forward(self, hidden_states, conditioning, freqs_cis=None):
mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale)
qkv = self.to_qkv(norm_hidden_states)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
q = self.norm_q(q)
k = self.norm_k(k)
if freqs_cis is not None:
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False)
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate)
def process_ff(self, hidden_states, attn_output, mod):
mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod
hidden_states = hidden_states + self.to_out(attn_output) * mod1_gate.unsqueeze(1)
hidden_states = hidden_states + self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale)) * mod2_gate.unsqueeze(1)
return hidden_states
class MMDoubleStreamBlock(torch.nn.Module):
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
super().__init__()
self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis):
(q_a, k_a, v_a), mod_a = self.component_a(hidden_states_a, conditioning, freqs_cis)
(q_b, k_b, v_b), mod_b = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
v_a, v_b = torch.concat([v_a, v_b[:, :71]], dim=1), v_b[:, 71:].contiguous()
attn_output_a = attention(q_a, k_a, v_a)
attn_output_b = attention(q_b, k_b, v_b)
attn_output_a, attn_output_b = attn_output_a[:, :-71].contiguous(), torch.concat([attn_output_a[:, -71:], attn_output_b], dim=1)
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a)
hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b)
return hidden_states_a, hidden_states_b
class MMSingleStreamBlockOriginal(torch.nn.Module):
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
super().__init__()
self.hidden_size = hidden_size
self.heads_num = heads_num
self.mlp_hidden_dim = hidden_size * mlp_width_ratio
self.linear1 = torch.nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
self.linear2 = torch.nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.q_norm = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
self.k_norm = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
self.pre_norm = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = torch.nn.GELU(approximate="tanh")
self.modulation = ModulateDiT(hidden_size, factor=3)
def forward(self, x, vec, freqs_cis=None, txt_len=256):
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
q = self.q_norm(q)
k = self.k_norm(k)
q_a, q_b = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
q = torch.cat((q_a, q_b), dim=1)
k = torch.cat((k_a, k_b), dim=1)
attn_output_a = attention(q[:, :-185].contiguous(), k[:, :-185].contiguous(), v[:, :-185].contiguous())
attn_output_b = attention(q[:, -185:].contiguous(), k[:, -185:].contiguous(), v[:, -185:].contiguous())
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp)), 2))
return x + output * mod_gate.unsqueeze(1)
class MMSingleStreamBlock(torch.nn.Module):
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
super().__init__()
self.heads_num = heads_num
self.mod = ModulateDiT(hidden_size, factor=3)
self.norm = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.to_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
self.norm_q = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
self.norm_k = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
self.to_out = torch.nn.Linear(hidden_size, hidden_size)
self.ff = torch.nn.Sequential(
torch.nn.Linear(hidden_size, hidden_size * mlp_width_ratio),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False)
)
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256):
mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1)
norm_hidden_states = self.norm(hidden_states)
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale)
qkv = self.to_qkv(norm_hidden_states)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
q = self.norm_q(q)
k = self.norm_k(k)
q_a, q_b = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
v_a, v_b = v[:, :-185].contiguous(), v[:, -185:].contiguous()
attn_output_a = attention(q_a, k_a, v_a)
attn_output_b = attention(q_b, k_b, v_b)
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
hidden_states = hidden_states + self.to_out(attn_output) * mod_gate.unsqueeze(1)
hidden_states = hidden_states + self.ff(norm_hidden_states) * mod_gate.unsqueeze(1)
return hidden_states
class FinalLayer(torch.nn.Module):
def __init__(self, hidden_size=3072, patch_size=(1, 2, 2), out_channels=16):
super().__init__()
self.norm_final = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = torch.nn.Linear(hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels)
self.adaLN_modulation = torch.nn.Sequential(torch.nn.SiLU(), torch.nn.Linear(hidden_size, 2 * hidden_size))
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift=shift, scale=scale)
x = self.linear(x)
return x
class HunyuanVideoDiT(torch.nn.Module):
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40):
super().__init__()
self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size)
self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size)
self.time_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
self.vector_in = torch.nn.Sequential(
torch.nn.Linear(768, hidden_size),
torch.nn.SiLU(),
torch.nn.Linear(hidden_size, hidden_size)
)
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)])
self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)])
self.final_layer = FinalLayer(hidden_size)
# TODO: remove these parameters
self.dtype = torch.bfloat16
self.patch_size = [1, 2, 2]
self.hidden_size = 3072
self.heads_num = 24
self.rope_dim_list = [16, 56, 56]
def unpatchify(self, x, T, H, W):
x = rearrange(x, "B (T H W) (C pT pH pW) -> B C (T pT) (H pH) (W pW)", H=H, W=W, pT=1, pH=2, pW=2)
return x
def enable_block_wise_offload(self, warm_device="cuda", cold_device="cpu"):
self.warm_device = warm_device
self.cold_device = cold_device
self.to(self.cold_device)
def load_models_to_device(self, loadmodel_names=[], device="cpu"):
for model_name in loadmodel_names:
model = getattr(self, model_name)
if model is not None:
model.to(device)
torch.cuda.empty_cache()
def prepare_freqs(self, latents):
return HunyuanVideoRope(latents)
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
prompt_emb: torch.Tensor = None,
text_mask: torch.Tensor = None,
pooled_prompt_emb: torch.Tensor = None,
freqs_cos: torch.Tensor = None,
freqs_sin: torch.Tensor = None,
guidance: torch.Tensor = None,
**kwargs
):
B, C, T, H, W = x.shape
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb) + self.guidance_in(guidance * 1000, dtype=torch.float32)
img = self.img_in(x)
txt = self.txt_in(prompt_emb, t, text_mask)
for block in tqdm(self.double_blocks, desc="Double stream blocks"):
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
x = torch.concat([img, txt], dim=1)
for block in tqdm(self.single_blocks, desc="Single stream blocks"):
x = block(x, vec, (freqs_cos, freqs_sin))
img = x[:, :-256]
img = self.final_layer(img, vec)
img = self.unpatchify(img, T=T//1, H=H//2, W=W//2)
return img
def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"):
def cast_to(weight, dtype=None, device=None, copy=False):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
return weight.to(dtype=dtype, copy=copy)
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight)
return r
def cast_weight(s, input=None, dtype=None, device=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if device is None:
device = input.device
weight = cast_to(s.weight, dtype, device)
return weight
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
weight = cast_to(s.weight, dtype, device)
bias = cast_to(s.bias, bias_dtype, device) if s.bias is not None else None
return weight, bias
class quantized_layer:
class Linear(torch.nn.Linear):
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
super().__init__(*args, **kwargs)
self.dtype = dtype
self.device = device
def block_forward_(self, x, i, j, dtype, device):
weight_ = cast_to(
self.weight[j * self.block_size: (j + 1) * self.block_size, i * self.block_size: (i + 1) * self.block_size],
dtype=dtype, device=device
)
if self.bias is None or i > 0:
bias_ = None
else:
bias_ = cast_to(self.bias[j * self.block_size: (j + 1) * self.block_size], dtype=dtype, device=device)
x_ = x[..., i * self.block_size: (i + 1) * self.block_size]
y_ = torch.nn.functional.linear(x_, weight_, bias_)
del x_, weight_, bias_
torch.cuda.empty_cache()
return y_
def block_forward(self, x, **kwargs):
# This feature can only reduce 2GB VRAM, so we disable it.
y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device)
for i in range((self.in_features + self.block_size - 1) // self.block_size):
for j in range((self.out_features + self.block_size - 1) // self.block_size):
y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device)
return y
def forward(self, x, **kwargs):
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
return torch.nn.functional.linear(x, weight, bias)
class RMSNorm(torch.nn.Module):
def __init__(self, module, dtype=torch.bfloat16, device="cuda"):
super().__init__()
self.module = module
self.dtype = dtype
self.device = device
def forward(self, hidden_states, **kwargs):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
hidden_states = hidden_states.to(input_dtype)
if self.module.weight is not None:
weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda")
hidden_states = hidden_states * weight
return hidden_states
class Conv3d(torch.nn.Conv3d):
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
super().__init__(*args, **kwargs)
self.dtype = dtype
self.device = device
def forward(self, x):
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
class LayerNorm(torch.nn.LayerNorm):
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
super().__init__(*args, **kwargs)
self.dtype = dtype
self.device = device
def forward(self, x):
if self.weight is not None and self.bias is not None:
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
else:
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
def replace_layer(model, dtype=torch.bfloat16, device="cuda"):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
with init_weights_on_device():
new_layer = quantized_layer.Linear(
module.in_features, module.out_features, bias=module.bias is not None,
dtype=dtype, device=device
)
new_layer.load_state_dict(module.state_dict(), assign=True)
setattr(model, name, new_layer)
elif isinstance(module, torch.nn.Conv3d):
with init_weights_on_device():
new_layer = quantized_layer.Conv3d(
module.in_channels, module.out_channels, kernel_size=module.kernel_size, stride=module.stride,
dtype=dtype, device=device
)
new_layer.load_state_dict(module.state_dict(), assign=True)
setattr(model, name, new_layer)
elif isinstance(module, RMSNorm):
new_layer = quantized_layer.RMSNorm(
module,
dtype=dtype, device=device
)
setattr(model, name, new_layer)
elif isinstance(module, torch.nn.LayerNorm):
with init_weights_on_device():
new_layer = quantized_layer.LayerNorm(
module.normalized_shape, elementwise_affine=module.elementwise_affine, eps=module.eps,
dtype=dtype, device=device
)
new_layer.load_state_dict(module.state_dict(), assign=True)
setattr(model, name, new_layer)
else:
replace_layer(module, dtype=dtype, device=device)
replace_layer(self, dtype=dtype, device=device)
@staticmethod
def state_dict_converter():
return HunyuanVideoDiTStateDictConverter()
class HunyuanVideoDiTStateDictConverter:
def __init__(self):
pass
def from_civitai(self, state_dict):
if "module" in state_dict:
state_dict = state_dict["module"]
direct_dict = {
"img_in.proj": "img_in.proj",
"time_in.mlp.0": "time_in.timestep_embedder.0",
"time_in.mlp.2": "time_in.timestep_embedder.2",
"vector_in.in_layer": "vector_in.0",
"vector_in.out_layer": "vector_in.2",
"guidance_in.mlp.0": "guidance_in.timestep_embedder.0",
"guidance_in.mlp.2": "guidance_in.timestep_embedder.2",
"txt_in.input_embedder": "txt_in.input_embedder",
"txt_in.t_embedder.mlp.0": "txt_in.t_embedder.timestep_embedder.0",
"txt_in.t_embedder.mlp.2": "txt_in.t_embedder.timestep_embedder.2",
"txt_in.c_embedder.linear_1": "txt_in.c_embedder.0",
"txt_in.c_embedder.linear_2": "txt_in.c_embedder.2",
"final_layer.linear": "final_layer.linear",
"final_layer.adaLN_modulation.1": "final_layer.adaLN_modulation.1",
}
txt_suffix_dict = {
"norm1": "norm1",
"self_attn_qkv": "self_attn_qkv",
"self_attn_proj": "self_attn_proj",
"norm2": "norm2",
"mlp.fc1": "mlp.0",
"mlp.fc2": "mlp.2",
"adaLN_modulation.1": "adaLN_modulation.1",
}
double_suffix_dict = {
"img_mod.linear": "component_a.mod.linear",
"img_attn_qkv": "component_a.to_qkv",
"img_attn_q_norm": "component_a.norm_q",
"img_attn_k_norm": "component_a.norm_k",
"img_attn_proj": "component_a.to_out",
"img_mlp.fc1": "component_a.ff.0",
"img_mlp.fc2": "component_a.ff.2",
"txt_mod.linear": "component_b.mod.linear",
"txt_attn_qkv": "component_b.to_qkv",
"txt_attn_q_norm": "component_b.norm_q",
"txt_attn_k_norm": "component_b.norm_k",
"txt_attn_proj": "component_b.to_out",
"txt_mlp.fc1": "component_b.ff.0",
"txt_mlp.fc2": "component_b.ff.2",
}
single_suffix_dict = {
"linear1": ["to_qkv", "ff.0"],
"linear2": ["to_out", "ff.2"],
"q_norm": "norm_q",
"k_norm": "norm_k",
"modulation.linear": "mod.linear",
}
# single_suffix_dict = {
# "linear1": "linear1",
# "linear2": "linear2",
# "q_norm": "q_norm",
# "k_norm": "k_norm",
# "modulation.linear": "modulation.linear",
# }
state_dict_ = {}
for name, param in state_dict.items():
names = name.split(".")
direct_name = ".".join(names[:-1])
if direct_name in direct_dict:
name_ = direct_dict[direct_name] + "." + names[-1]
state_dict_[name_] = param
elif names[0] == "double_blocks":
prefix = ".".join(names[:2])
suffix = ".".join(names[2:-1])
name_ = prefix + "." + double_suffix_dict[suffix] + "." + names[-1]
state_dict_[name_] = param
elif names[0] == "single_blocks":
prefix = ".".join(names[:2])
suffix = ".".join(names[2:-1])
if isinstance(single_suffix_dict[suffix], list):
if suffix == "linear1":
name_a, name_b = single_suffix_dict[suffix]
param_a, param_b = torch.split(param, (3072*3, 3072*4), dim=0)
state_dict_[prefix + "." + name_a + "." + names[-1]] = param_a
state_dict_[prefix + "." + name_b + "." + names[-1]] = param_b
elif suffix == "linear2":
if names[-1] == "weight":
name_a, name_b = single_suffix_dict[suffix]
param_a, param_b = torch.split(param, (3072*1, 3072*4), dim=-1)
state_dict_[prefix + "." + name_a + "." + names[-1]] = param_a
state_dict_[prefix + "." + name_b + "." + names[-1]] = param_b
else:
name_a, name_b = single_suffix_dict[suffix]
state_dict_[prefix + "." + name_a + "." + names[-1]] = param
else:
pass
else:
name_ = prefix + "." + single_suffix_dict[suffix] + "." + names[-1]
state_dict_[name_] = param
elif names[0] == "txt_in":
prefix = ".".join(names[:4]).replace(".individual_token_refiner.", ".")
suffix = ".".join(names[4:-1])
name_ = prefix + "." + txt_suffix_dict[suffix] + "." + names[-1]
state_dict_[name_] = param
else:
pass
return state_dict_

View File

@@ -0,0 +1,55 @@
from transformers import LlamaModel, LlamaConfig, DynamicCache
from copy import deepcopy
import torch
class HunyuanVideoLLMEncoder(LlamaModel):
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.auto_offload = False
def enable_auto_offload(self, **kwargs):
self.auto_offload = True
def forward(
self,
input_ids,
attention_mask,
hidden_state_skip_layer=2
):
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
inputs_embeds = embed_tokens(input_ids)
past_key_values = DynamicCache()
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, None, False)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
rotary_emb = deepcopy(self.rotary_emb).to(input_ids.device) if self.auto_offload else self.rotary_emb
position_embeddings = rotary_emb(hidden_states, position_ids)
# decoder layers
for layer_id, decoder_layer in enumerate(self.layers):
if self.auto_offload:
decoder_layer = deepcopy(decoder_layer).to(hidden_states.device)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=False,
use_cache=True,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if layer_id + hidden_state_skip_layer + 1 >= len(self.layers):
break
return hidden_states

View File

@@ -0,0 +1,507 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import numpy as np
from tqdm import tqdm
from einops import repeat
class CausalConv3d(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size, stride=1, dilation=1, pad_mode='replicate', **kwargs):
super().__init__()
self.pad_mode = pad_mode
self.time_causal_padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0
) # W, H, T
self.conv = nn.Conv3d(in_channel, out_channel, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
return self.conv(x)
class UpsampleCausal3D(nn.Module):
def __init__(self, channels, use_conv=False, out_channels=None, kernel_size=None, bias=True, upsample_factor=(2, 2, 2)):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.upsample_factor = upsample_factor
self.conv = None
if use_conv:
kernel_size = 3 if kernel_size is None else kernel_size
self.conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
def forward(self, hidden_states):
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()
# interpolate
B, C, T, H, W = hidden_states.shape
first_h, other_h = hidden_states.split((1, T - 1), dim=2)
if T > 1:
other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
first_h = F.interpolate(first_h.squeeze(2), scale_factor=self.upsample_factor[1:], mode="nearest").unsqueeze(2)
hidden_states = torch.cat((first_h, other_h), dim=2) if T > 1 else first_h
# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)
if self.conv:
hidden_states = self.conv(hidden_states)
return hidden_states
class ResnetBlockCausal3D(nn.Module):
def __init__(self, in_channels, out_channels=None, dropout=0.0, groups=32, eps=1e-6, conv_shortcut_bias=True):
super().__init__()
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, stride=1)
self.dropout = nn.Dropout(dropout)
self.nonlinearity = nn.SiLU()
self.conv_shortcut = None
if in_channels != out_channels:
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, bias=conv_shortcut_bias)
def forward(self, input_tensor):
hidden_states = input_tensor
# conv1
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
# conv2
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
# shortcut
if self.conv_shortcut is not None:
input_tensor = (self.conv_shortcut(input_tensor))
# shortcut and scale
output_tensor = input_tensor + hidden_states
return output_tensor
def prepare_causal_attention_mask(n_frame, n_hw, dtype, device, batch_size=None):
seq_len = n_frame * n_hw
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
for i in range(seq_len):
i_frame = i // n_hw
mask[i, :(i_frame + 1) * n_hw] = 0
if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask
class Attention(nn.Module):
def __init__(self,
in_channels,
num_heads,
head_dim,
num_groups=32,
dropout=0.0,
eps=1e-6,
bias=True,
residual_connection=True):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.residual_connection = residual_connection
dim_inner = head_dim * num_heads
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=eps, affine=True)
self.to_q = nn.Linear(in_channels, dim_inner, bias=bias)
self.to_k = nn.Linear(in_channels, dim_inner, bias=bias)
self.to_v = nn.Linear(in_channels, dim_inner, bias=bias)
self.to_out = nn.Sequential(nn.Linear(dim_inner, in_channels, bias=bias), nn.Dropout(dropout))
def forward(self, input_tensor, attn_mask=None):
hidden_states = self.group_norm(input_tensor.transpose(1, 2)).transpose(1, 2)
batch_size = hidden_states.shape[0]
q = self.to_q(hidden_states)
k = self.to_k(hidden_states)
v = self.to_v(hidden_states)
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if attn_mask is not None:
attn_mask = attn_mask.view(batch_size, self.num_heads, -1, attn_mask.shape[-1])
hidden_states = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = self.to_out(hidden_states)
if self.residual_connection:
output_tensor = input_tensor + hidden_states
return output_tensor
class UNetMidBlockCausal3D(nn.Module):
def __init__(self, in_channels, dropout=0.0, num_layers=1, eps=1e-6, num_groups=32, attention_head_dim=None):
super().__init__()
resnets = [
ResnetBlockCausal3D(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
groups=num_groups,
eps=eps,
)
]
attentions = []
attention_head_dim = attention_head_dim or in_channels
for _ in range(num_layers):
attentions.append(
Attention(
in_channels,
num_heads=in_channels // attention_head_dim,
head_dim=attention_head_dim,
num_groups=num_groups,
dropout=dropout,
eps=eps,
bias=True,
residual_connection=True,
))
resnets.append(
ResnetBlockCausal3D(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
groups=num_groups,
eps=eps,
))
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states):
hidden_states = self.resnets[0](hidden_states)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
B, C, T, H, W = hidden_states.shape
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
attn_mask = prepare_causal_attention_mask(T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B)
hidden_states = attn(hidden_states, attn_mask=attn_mask)
hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
hidden_states = resnet(hidden_states)
return hidden_states
class UpDecoderBlockCausal3D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
dropout=0.0,
num_layers=1,
eps=1e-6,
num_groups=32,
add_upsample=True,
upsample_scale_factor=(2, 2, 2),
):
super().__init__()
resnets = []
for i in range(num_layers):
cur_in_channel = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlockCausal3D(
in_channels=cur_in_channel,
out_channels=out_channels,
groups=num_groups,
dropout=dropout,
eps=eps,
))
self.resnets = nn.ModuleList(resnets)
self.upsamplers = None
if add_upsample:
self.upsamplers = nn.ModuleList([
UpsampleCausal3D(
out_channels,
use_conv=True,
out_channels=out_channels,
upsample_factor=upsample_scale_factor,
)
])
def forward(self, hidden_states):
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class DecoderCausal3D(nn.Module):
def __init__(
self,
in_channels=16,
out_channels=3,
eps=1e-6,
dropout=0.0,
block_out_channels=[128, 256, 512, 512],
layers_per_block=2,
num_groups=32,
time_compression_ratio=4,
spatial_compression_ratio=8,
gradient_checkpointing=False,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
self.up_blocks = nn.ModuleList([])
# mid
self.mid_block = UNetMidBlockCausal3D(
in_channels=block_out_channels[-1],
dropout=dropout,
eps=eps,
num_groups=num_groups,
attention_head_dim=block_out_channels[-1],
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i in range(len(block_out_channels)):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
num_time_upsample_layers = int(np.log2(time_compression_ratio))
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
up_block = UpDecoderBlockCausal3D(
in_channels=prev_output_channel,
out_channels=output_channel,
dropout=dropout,
num_layers=layers_per_block + 1,
eps=eps,
num_groups=num_groups,
add_upsample=bool(add_spatial_upsample or add_time_upsample),
upsample_scale_factor=upsample_scale_factor,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups, eps=eps)
self.conv_act = nn.SiLU()
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
self.gradient_checkpointing = gradient_checkpointing
def forward(self, hidden_states):
hidden_states = self.conv_in(hidden_states)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# middle
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
hidden_states,
use_reentrant=False,
)
# up
for up_block in self.up_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
hidden_states,
use_reentrant=False,
)
else:
# middle
hidden_states = self.mid_block(hidden_states)
# up
for up_block in self.up_blocks:
hidden_states = up_block(hidden_states)
# post-process
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class HunyuanVideoVAEDecoder(nn.Module):
def __init__(
self,
in_channels=16,
out_channels=3,
eps=1e-6,
dropout=0.0,
block_out_channels=[128, 256, 512, 512],
layers_per_block=2,
num_groups=32,
time_compression_ratio=4,
spatial_compression_ratio=8,
gradient_checkpointing=False,
):
super().__init__()
self.decoder = DecoderCausal3D(
in_channels=in_channels,
out_channels=out_channels,
eps=eps,
dropout=dropout,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
num_groups=num_groups,
time_compression_ratio=time_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
gradient_checkpointing=gradient_checkpointing,
)
self.post_quant_conv = nn.Conv3d(in_channels, in_channels, kernel_size=1)
self.scaling_factor = 0.476986
def forward(self, latents):
latents = latents / self.scaling_factor
latents = self.post_quant_conv(latents)
dec = self.decoder(latents)
return dec
def build_1d_mask(self, length, left_bound, right_bound, border_width):
x = torch.ones((length,))
if not left_bound:
x[:border_width] = (torch.arange(border_width) + 1) / border_width
if not right_bound:
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
return x
def build_mask(self, data, is_bound, border_width):
_, _, T, H, W = data.shape
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
t = repeat(t, "T -> T H W", T=T, H=H, W=W)
h = repeat(h, "H -> T H W", T=T, H=H, W=W)
w = repeat(w, "W -> T H W", T=T, H=H, W=W)
mask = torch.stack([t, h, w]).min(dim=0).values
mask = rearrange(mask, "T H W -> 1 1 T H W")
return mask
def tile_forward(self, hidden_states, tile_size, tile_stride):
B, C, T, H, W = hidden_states.shape
size_t, size_h, size_w = tile_size
stride_t, stride_h, stride_w = tile_stride
# Split tasks
tasks = []
for t in range(0, T, stride_t):
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
for h in range(0, H, stride_h):
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
for w in range(0, W, stride_w):
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
t_, h_, w_ = t + size_t, h + size_h, w + size_w
tasks.append((t, t_, h, h_, w, w_))
# Run
torch_dtype = self.post_quant_conv.weight.dtype
data_device = hidden_states.device
computation_device = self.post_quant_conv.weight.device
weight = torch.zeros((1, 1, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
values = torch.zeros((B, 3, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
if t > 0:
hidden_states_batch = hidden_states_batch[:, :, 1:]
mask = self.build_mask(
hidden_states_batch,
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
border_width=((size_t - stride_t) * 4, (size_h - stride_h) * 8, (size_w - stride_w) * 8)
).to(dtype=torch_dtype, device=data_device)
target_t = 0 if t==0 else t * 4 + 1
target_h = h * 8
target_w = w * 8
values[
:,
:,
target_t: target_t + hidden_states_batch.shape[2],
target_h: target_h + hidden_states_batch.shape[3],
target_w: target_w + hidden_states_batch.shape[4],
] += hidden_states_batch * mask
weight[
:,
:,
target_t: target_t + hidden_states_batch.shape[2],
target_h: target_h + hidden_states_batch.shape[3],
target_w: target_w + hidden_states_batch.shape[4],
] += mask
return values / weight
def decode_video(self, latents, tile_size=(17, 32, 32), tile_stride=(12, 24, 24)):
latents = latents.to(self.post_quant_conv.weight.dtype)
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
@staticmethod
def state_dict_converter():
return HunyuanVideoVAEDecoderStateDictConverter()
class HunyuanVideoVAEDecoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith('decoder.') or name.startswith('post_quant_conv.'):
state_dict_[name] = state_dict[name]
return state_dict_

View File

@@ -0,0 +1,307 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import numpy as np
from tqdm import tqdm
from .hunyuan_video_vae_decoder import CausalConv3d, ResnetBlockCausal3D, UNetMidBlockCausal3D
class DownsampleCausal3D(nn.Module):
def __init__(self, channels, out_channels, kernel_size=3, bias=True, stride=2):
super().__init__()
self.conv = CausalConv3d(channels, out_channels, kernel_size, stride=stride, bias=bias)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
return hidden_states
class DownEncoderBlockCausal3D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
dropout=0.0,
num_layers=1,
eps=1e-6,
num_groups=32,
add_downsample=True,
downsample_stride=2,
):
super().__init__()
resnets = []
for i in range(num_layers):
cur_in_channel = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlockCausal3D(
in_channels=cur_in_channel,
out_channels=out_channels,
groups=num_groups,
dropout=dropout,
eps=eps,
))
self.resnets = nn.ModuleList(resnets)
self.downsamplers = None
if add_downsample:
self.downsamplers = nn.ModuleList([DownsampleCausal3D(
out_channels,
out_channels,
stride=downsample_stride,
)])
def forward(self, hidden_states):
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
class EncoderCausal3D(nn.Module):
def __init__(
self,
in_channels: int = 3,
out_channels: int = 16,
eps=1e-6,
dropout=0.0,
block_out_channels=[128, 256, 512, 512],
layers_per_block=2,
num_groups=32,
time_compression_ratio: int = 4,
spatial_compression_ratio: int = 8,
gradient_checkpointing=False,
):
super().__init__()
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
self.down_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i in range(len(block_out_channels)):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
num_time_downsample_layers = int(np.log2(time_compression_ratio))
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
downsample_stride_T = (2,) if add_time_downsample else (1,)
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
down_block = DownEncoderBlockCausal3D(
in_channels=input_channel,
out_channels=output_channel,
dropout=dropout,
num_layers=layers_per_block,
eps=eps,
num_groups=num_groups,
add_downsample=bool(add_spatial_downsample or add_time_downsample),
downsample_stride=downsample_stride,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlockCausal3D(
in_channels=block_out_channels[-1],
dropout=dropout,
eps=eps,
num_groups=num_groups,
attention_head_dim=block_out_channels[-1],
)
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups, eps=eps)
self.conv_act = nn.SiLU()
self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3)
self.gradient_checkpointing = gradient_checkpointing
def forward(self, hidden_states):
hidden_states = self.conv_in(hidden_states)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# down
for down_block in self.down_blocks:
torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block),
hidden_states,
use_reentrant=False,
)
# middle
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
hidden_states,
use_reentrant=False,
)
else:
# down
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states)
# middle
hidden_states = self.mid_block(hidden_states)
# post-process
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class HunyuanVideoVAEEncoder(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=16,
eps=1e-6,
dropout=0.0,
block_out_channels=[128, 256, 512, 512],
layers_per_block=2,
num_groups=32,
time_compression_ratio=4,
spatial_compression_ratio=8,
gradient_checkpointing=False,
):
super().__init__()
self.encoder = EncoderCausal3D(
in_channels=in_channels,
out_channels=out_channels,
eps=eps,
dropout=dropout,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
num_groups=num_groups,
time_compression_ratio=time_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
gradient_checkpointing=gradient_checkpointing,
)
self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1)
self.scaling_factor = 0.476986
def forward(self, images):
latents = self.encoder(images)
latents = self.quant_conv(latents)
latents = latents[:, :16]
latents = latents * self.scaling_factor
return latents
def build_1d_mask(self, length, left_bound, right_bound, border_width):
x = torch.ones((length,))
if not left_bound:
x[:border_width] = (torch.arange(border_width) + 1) / border_width
if not right_bound:
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
return x
def build_mask(self, data, is_bound, border_width):
_, _, T, H, W = data.shape
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
t = repeat(t, "T -> T H W", T=T, H=H, W=W)
h = repeat(h, "H -> T H W", T=T, H=H, W=W)
w = repeat(w, "W -> T H W", T=T, H=H, W=W)
mask = torch.stack([t, h, w]).min(dim=0).values
mask = rearrange(mask, "T H W -> 1 1 T H W")
return mask
def tile_forward(self, hidden_states, tile_size, tile_stride):
B, C, T, H, W = hidden_states.shape
size_t, size_h, size_w = tile_size
stride_t, stride_h, stride_w = tile_stride
# Split tasks
tasks = []
for t in range(0, T, stride_t):
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
for h in range(0, H, stride_h):
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
for w in range(0, W, stride_w):
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
t_, h_, w_ = t + size_t, h + size_h, w + size_w
tasks.append((t, t_, h, h_, w, w_))
# Run
torch_dtype = self.quant_conv.weight.dtype
data_device = hidden_states.device
computation_device = self.quant_conv.weight.device
weight = torch.zeros((1, 1, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
values = torch.zeros((B, 16, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
if t > 0:
hidden_states_batch = hidden_states_batch[:, :, 1:]
mask = self.build_mask(
hidden_states_batch,
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
border_width=((size_t - stride_t) // 4, (size_h - stride_h) // 8, (size_w - stride_w) // 8)
).to(dtype=torch_dtype, device=data_device)
target_t = 0 if t==0 else t // 4 + 1
target_h = h // 8
target_w = w // 8
values[
:,
:,
target_t: target_t + hidden_states_batch.shape[2],
target_h: target_h + hidden_states_batch.shape[3],
target_w: target_w + hidden_states_batch.shape[4],
] += hidden_states_batch * mask
weight[
:,
:,
target_t: target_t + hidden_states_batch.shape[2],
target_h: target_h + hidden_states_batch.shape[3],
target_w: target_w + hidden_states_batch.shape[4],
] += mask
return values / weight
def encode_video(self, latents, tile_size=(65, 256, 256), tile_stride=(48, 192, 192)):
latents = latents.to(self.quant_conv.weight.dtype)
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
@staticmethod
def state_dict_converter():
return HunyuanVideoVAEEncoderStateDictConverter()
class HunyuanVideoVAEEncoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith('encoder.') or name.startswith('quant_conv.'):
state_dict_[name] = state_dict[name]
return state_dict_

File diff suppressed because one or more lines are too long

View File

@@ -4,7 +4,10 @@ from .sdxl_unet import SDXLUNet
from .sd_text_encoder import SDTextEncoder from .sd_text_encoder import SDTextEncoder
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2 from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from .sd3_dit import SD3DiT from .sd3_dit import SD3DiT
from .flux_dit import FluxDiT
from .hunyuan_dit import HunyuanDiT from .hunyuan_dit import HunyuanDiT
from .cog_dit import CogDiT
from .hunyuan_video_dit import HunyuanVideoDiT
@@ -17,6 +20,13 @@ class LoRAFromCivitai:
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0): def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
for key in state_dict:
if ".lora_up" in key:
return self.convert_state_dict_up_down(state_dict, lora_prefix, alpha)
return self.convert_state_dict_AB(state_dict, lora_prefix, alpha)
def convert_state_dict_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "") renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
state_dict_ = {} state_dict_ = {}
for key in state_dict: for key in state_dict:
@@ -39,6 +49,29 @@ class LoRAFromCivitai:
return state_dict_ return state_dict_
def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16):
state_dict_ = {}
for key in state_dict:
if ".lora_B." not in key:
continue
if not key.startswith(lora_prefix):
continue
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2)
weight_down = weight_down.squeeze(3).squeeze(2)
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
lora_weight = alpha * torch.mm(weight_up, weight_down)
keys = key.split(".")
keys.pop(keys.index("lora_B"))
target_name = ".".join(keys)
target_name = target_name[len(lora_prefix):]
state_dict_[target_name] = lora_weight.cpu()
return state_dict_
def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None): def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
state_dict_model = model.state_dict() state_dict_model = model.state_dict()
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha) state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
@@ -46,11 +79,19 @@ class LoRAFromCivitai:
state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora) state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora)
elif model_resource == "civitai": elif model_resource == "civitai":
state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora) state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora)
if isinstance(state_dict_lora, tuple):
state_dict_lora = state_dict_lora[0]
if len(state_dict_lora) > 0: if len(state_dict_lora) > 0:
print(f" {len(state_dict_lora)} tensors are updated.") print(f" {len(state_dict_lora)} tensors are updated.")
for name in state_dict_lora: for name in state_dict_lora:
fp8=False
if state_dict_model[name].dtype == torch.float8_e4m3fn:
state_dict_model[name]= state_dict_model[name].to(state_dict_lora[name].dtype)
fp8=True
state_dict_model[name] += state_dict_lora[name].to( state_dict_model[name] += state_dict_lora[name].to(
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device) dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
if fp8:
state_dict_model[name] = state_dict_model[name].to(torch.float8_e4m3fn)
model.load_state_dict(state_dict_model) model.load_state_dict(state_dict_model)
@@ -65,6 +106,8 @@ class LoRAFromCivitai:
converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \ converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
else model.__class__.state_dict_converter().from_civitai else model.__class__.state_dict_converter().from_civitai
state_dict_lora_ = converter_fn(state_dict_lora_) state_dict_lora_ = converter_fn(state_dict_lora_)
if isinstance(state_dict_lora_, tuple):
state_dict_lora_ = state_dict_lora_[0]
if len(state_dict_lora_) == 0: if len(state_dict_lora_) == 0:
continue continue
for name in state_dict_lora_: for name in state_dict_lora_:
@@ -134,13 +177,39 @@ class SDXLLoRAFromCivitai(LoRAFromCivitai):
} }
class FluxLoRAFromCivitai(LoRAFromCivitai):
def __init__(self):
super().__init__()
self.supported_model_classes = [FluxDiT, FluxDiT]
self.lora_prefix = ["lora_unet_", "transformer."]
self.renamed_lora_prefix = {}
self.special_keys = {
"single.blocks": "single_blocks",
"double.blocks": "double_blocks",
"img.attn": "img_attn",
"img.mlp": "img_mlp",
"img.mod": "img_mod",
"txt.attn": "txt_attn",
"txt.mlp": "txt_mlp",
"txt.mod": "txt_mod",
}
class GeneralLoRAFromPeft: class GeneralLoRAFromPeft:
def __init__(self): def __init__(self):
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT] self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT]
def convert_state_dict(self, state_dict, alpha=1.0, device="cuda", torch_dtype=torch.float16): def fetch_device_dtype_from_state_dict(self, state_dict):
device, torch_dtype = None, None
for name, param in state_dict.items():
device, torch_dtype = param.device, param.dtype
break
return device, torch_dtype
def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}):
device, torch_dtype = self.fetch_device_dtype_from_state_dict(target_state_dict)
state_dict_ = {} state_dict_ = {}
for key in state_dict: for key in state_dict:
if ".lora_B." not in key: if ".lora_B." not in key:
@@ -154,25 +223,26 @@ class GeneralLoRAFromPeft:
else: else:
lora_weight = alpha * torch.mm(weight_up, weight_down) lora_weight = alpha * torch.mm(weight_up, weight_down)
keys = key.split(".") keys = key.split(".")
if len(keys) > keys.index("lora_B") + 2:
keys.pop(keys.index("lora_B") + 1) keys.pop(keys.index("lora_B") + 1)
keys.pop(keys.index("lora_B")) keys.pop(keys.index("lora_B"))
target_name = ".".join(keys) target_name = ".".join(keys)
if target_name not in target_state_dict:
return {}
state_dict_[target_name] = lora_weight.cpu() state_dict_[target_name] = lora_weight.cpu()
return state_dict_ return state_dict_
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""): def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
state_dict_model = model.state_dict() state_dict_model = model.state_dict()
for name, param in state_dict_model.items(): state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, target_state_dict=state_dict_model)
torch_dtype = param.dtype
device = param.device
break
state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, device=device, torch_dtype=torch_dtype)
if len(state_dict_lora) > 0: if len(state_dict_lora) > 0:
print(f" {len(state_dict_lora)} tensors are updated.") print(f" {len(state_dict_lora)} tensors are updated.")
for name in state_dict_lora: for name in state_dict_lora:
state_dict_model[name] += state_dict_lora[name].to( state_dict_model[name] += state_dict_lora[name].to(
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device) dtype=state_dict_model[name].dtype,
device=state_dict_model[name].device
)
model.load_state_dict(state_dict_model) model.load_state_dict(state_dict_model)
@@ -182,14 +252,116 @@ class GeneralLoRAFromPeft:
continue continue
state_dict_model = model.state_dict() state_dict_model = model.state_dict()
try: try:
state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0) state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0, target_state_dict=state_dict_model)
if len(state_dict_lora_) == 0: if len(state_dict_lora_) > 0:
continue
for name in state_dict_lora_:
if name not in state_dict_model:
break
else:
return "", "" return "", ""
except: except:
pass pass
return None return None
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
def __init__(self):
super().__init__()
self.supported_model_classes = [HunyuanVideoDiT, HunyuanVideoDiT]
self.lora_prefix = ["diffusion_model.", "transformer."]
self.special_keys = {}
class FluxLoRAConverter:
def __init__(self):
pass
@staticmethod
def align_to_opensource_format(state_dict, alpha=1.0):
prefix_rename_dict = {
"single_blocks": "lora_unet_single_blocks",
"blocks": "lora_unet_double_blocks",
}
middle_rename_dict = {
"norm.linear": "modulation_lin",
"to_qkv_mlp": "linear1",
"proj_out": "linear2",
"norm1_a.linear": "img_mod_lin",
"norm1_b.linear": "txt_mod_lin",
"attn.a_to_qkv": "img_attn_qkv",
"attn.b_to_qkv": "txt_attn_qkv",
"attn.a_to_out": "img_attn_proj",
"attn.b_to_out": "txt_attn_proj",
"ff_a.0": "img_mlp_0",
"ff_a.2": "img_mlp_2",
"ff_b.0": "txt_mlp_0",
"ff_b.2": "txt_mlp_2",
}
suffix_rename_dict = {
"lora_B.weight": "lora_up.weight",
"lora_A.weight": "lora_down.weight",
}
state_dict_ = {}
for name, param in state_dict.items():
names = name.split(".")
if names[-2] != "lora_A" and names[-2] != "lora_B":
names.pop(-2)
prefix = names[0]
middle = ".".join(names[2:-2])
suffix = ".".join(names[-2:])
block_id = names[1]
if middle not in middle_rename_dict:
continue
rename = prefix_rename_dict[prefix] + "_" + block_id + "_" + middle_rename_dict[middle] + "." + suffix_rename_dict[suffix]
state_dict_[rename] = param
if rename.endswith("lora_up.weight"):
state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((alpha,))[0]
return state_dict_
@staticmethod
def align_to_diffsynth_format(state_dict):
rename_dict = {
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
"lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
"lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
"lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
"lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
}
def guess_block_id(name):
names = name.split("_")
for i in names:
if i.isdigit():
return i, name.replace(f"_{i}_", "_blockid_")
return None, None
state_dict_ = {}
for name, param in state_dict.items():
block_id, source_name = guess_block_id(name)
if source_name in rename_dict:
target_name = rename_dict[source_name]
target_name = target_name.replace(".blockid.", f".{block_id}.")
state_dict_[target_name] = param
else:
state_dict_[name] = param
return state_dict_
def get_lora_loaders():
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]

View File

@@ -1,16 +1,13 @@
import os, torch, hashlib, json, importlib import os, torch, json, importlib
from safetensors import safe_open
from torch import Tensor
from typing_extensions import Literal, TypeAlias
from typing import List from typing import List
from .downloader import download_models, Preset_model_id, Preset_model_website from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website
from .sd_text_encoder import SDTextEncoder from .sd_text_encoder import SDTextEncoder
from .sd_unet import SDUNet from .sd_unet import SDUNet
from .sd_vae_encoder import SDVAEEncoder from .sd_vae_encoder import SDVAEEncoder
from .sd_vae_decoder import SDVAEDecoder from .sd_vae_decoder import SDVAEDecoder
from .lora import SDLoRAFromCivitai, SDXLLoRAFromCivitai, GeneralLoRAFromPeft from .lora import get_lora_loaders
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2 from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from .sdxl_unet import SDXLUNet from .sdxl_unet import SDXLUNet
@@ -23,6 +20,7 @@ from .sd3_vae_decoder import SD3VAEDecoder
from .sd3_vae_encoder import SD3VAEEncoder from .sd3_vae_encoder import SD3VAEEncoder
from .sd_controlnet import SDControlNet from .sd_controlnet import SDControlNet
from .sdxl_controlnet import SDXLControlNetUnion
from .sd_motion import SDMotionModel from .sd_motion import SDMotionModel
from .sdxl_motion import SDXLMotionModel from .sdxl_motion import SDXLMotionModel
@@ -37,129 +35,22 @@ from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from .hunyuan_dit import HunyuanDiT from .hunyuan_dit import HunyuanDiT
from .hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
from .hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
from .flux_dit import FluxDiT
from .flux_text_encoder import FluxTextEncoder2
from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
from .flux_ipadapter import FluxIpAdapter
from .cog_vae import CogVAEEncoder, CogVAEDecoder
from .cog_dit import CogDiT
from ..extensions.RIFE import IFNet
from ..extensions.ESRGAN import RRDBNet
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
def load_state_dict(file_path, torch_dtype=None):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu")
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
state_dict[i] = state_dict[i].to(torch_dtype)
return state_dict
def search_for_embeddings(state_dict):
embeddings = []
for k in state_dict:
if isinstance(state_dict[k], torch.Tensor):
embeddings.append(state_dict[k])
elif isinstance(state_dict[k], dict):
embeddings += search_for_embeddings(state_dict[k])
return embeddings
def search_parameter(param, state_dict):
for name, param_ in state_dict.items():
if param.numel() == param_.numel():
if param.shape == param_.shape:
if torch.dist(param, param_) < 1e-6:
return name
else:
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
return name
return None
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
matched_keys = set()
with torch.no_grad():
for name in source_state_dict:
rename = search_parameter(source_state_dict[name], target_state_dict)
if rename is not None:
print(f'"{name}": "{rename}",')
matched_keys.add(rename)
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
length = source_state_dict[name].shape[0] // 3
rename = []
for i in range(3):
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
if None not in rename:
print(f'"{name}": {rename},')
for rename_ in rename:
matched_keys.add(rename_)
for name in target_state_dict:
if name not in matched_keys:
print("Cannot find", name, target_state_dict[name].shape)
def search_for_files(folder, extensions):
files = []
if os.path.isdir(folder):
for file in sorted(os.listdir(folder)):
files += search_for_files(os.path.join(folder, file), extensions)
elif os.path.isfile(folder):
for extension in extensions:
if folder.endswith(extension):
files.append(folder)
break
return files
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, Tensor):
if with_shape:
shape = "_".join(map(str, list(value.shape)))
keys.append(key + ":" + shape)
keys.append(key)
elif isinstance(value, dict):
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
keys.sort()
keys_str = ",".join(keys)
return keys_str
def split_state_dict_with_prefix(state_dict):
keys = sorted([key for key in state_dict if isinstance(key, str)])
prefix_dict = {}
for key in keys:
prefix = key if "." not in key else key.split(".")[0]
if prefix not in prefix_dict:
prefix_dict[prefix] = []
prefix_dict[prefix].append(key)
state_dicts = []
for prefix, keys in prefix_dict.items():
sub_state_dict = {key: state_dict[key] for key in keys}
state_dicts.append(sub_state_dict)
return state_dicts
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device): def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
@@ -177,8 +68,10 @@ def load_model_from_single_file(state_dict, model_names, model_classes, model_re
else: else:
model_state_dict, extra_kwargs = state_dict_results, {} model_state_dict, extra_kwargs = state_dict_results, {}
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
model = model_class(**extra_kwargs).to(dtype=torch_dtype, device=device) with init_weights_on_device():
model.load_state_dict(model_state_dict) model= model_class(**extra_kwargs)
model.load_state_dict(model_state_dict, assign=True)
model = model.to(dtype=torch_dtype, device=device)
loaded_model_names.append(model_name) loaded_model_names.append(model_name)
loaded_models.append(model) loaded_models.append(model)
return loaded_model_names, loaded_models return loaded_model_names, loaded_models
@@ -187,10 +80,16 @@ def load_model_from_single_file(state_dict, model_names, model_classes, model_re
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device): def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
loaded_model_names, loaded_models = [], [] loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes): for model_name, model_class in zip(model_names, model_classes):
if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval() model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
else:
model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
if torch_dtype == torch.float16 and hasattr(model, "half"): if torch_dtype == torch.float16 and hasattr(model, "half"):
model = model.half() model = model.half()
try:
model = model.to(device=device) model = model.to(device=device)
except:
pass
loaded_model_names.append(model_name) loaded_model_names.append(model_name)
loaded_models.append(model) loaded_models.append(model)
return loaded_model_names, loaded_models return loaded_model_names, loaded_models
@@ -259,7 +158,7 @@ class ModelDetectorFromSingleFile:
def match(self, file_path="", state_dict={}): def match(self, file_path="", state_dict={}):
if os.path.isdir(file_path): if isinstance(file_path, str) and os.path.isdir(file_path):
return False return False
if len(state_dict) == 0: if len(state_dict) == 0:
state_dict = load_state_dict(file_path) state_dict = load_state_dict(file_path)
@@ -301,7 +200,7 @@ class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
def match(self, file_path="", state_dict={}): def match(self, file_path="", state_dict={}):
if os.path.isdir(file_path): if isinstance(file_path, str) and os.path.isdir(file_path):
return False return False
if len(state_dict) == 0: if len(state_dict) == 0:
state_dict = load_state_dict(file_path) state_dict = load_state_dict(file_path)
@@ -339,19 +238,19 @@ class ModelDetectorFromHuggingfaceFolder:
self.add_model_metadata(*metadata) self.add_model_metadata(*metadata)
def add_model_metadata(self, architecture, huggingface_lib, model_name): def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
self.architecture_dict[architecture] = (huggingface_lib, model_name) self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
def match(self, file_path="", state_dict={}): def match(self, file_path="", state_dict={}):
if os.path.isfile(file_path): if not isinstance(file_path, str) or os.path.isfile(file_path):
return False return False
file_list = os.listdir(file_path) file_list = os.listdir(file_path)
if "config.json" not in file_list: if "config.json" not in file_list:
return False return False
with open(os.path.join(file_path, "config.json"), "r") as f: with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f) config = json.load(f)
if "architectures" not in config: if "architectures" not in config and "_class_name" not in config:
return False return False
return True return True
@@ -360,8 +259,11 @@ class ModelDetectorFromHuggingfaceFolder:
with open(os.path.join(file_path, "config.json"), "r") as f: with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f) config = json.load(f)
loaded_model_names, loaded_models = [], [] loaded_model_names, loaded_models = [], []
for architecture in config["architectures"]: architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
huggingface_lib, model_name = self.architecture_dict[architecture] for architecture in architectures:
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
if redirected_architecture is not None:
architecture = redirected_architecture
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture) model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device) loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
loaded_model_names += loaded_model_names_ loaded_model_names += loaded_model_names_
@@ -382,7 +284,7 @@ class ModelDetectorFromPatchedSingleFile:
def match(self, file_path="", state_dict={}): def match(self, file_path="", state_dict={}):
if os.path.isdir(file_path): if not isinstance(file_path, str) or os.path.isdir(file_path):
return False return False
if len(state_dict) == 0: if len(state_dict) == 0:
state_dict = load_state_dict(file_path) state_dict = load_state_dict(file_path)
@@ -467,11 +369,15 @@ class ModelManager:
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0): def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
if isinstance(file_path, list):
for file_path_ in file_path:
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
else:
print(f"Loading LoRA models from file: {file_path}") print(f"Loading LoRA models from file: {file_path}")
if len(state_dict) == 0: if len(state_dict) == 0:
state_dict = load_state_dict(file_path) state_dict = load_state_dict(file_path)
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path): for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
for lora in [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft()]: for lora in get_lora_loaders():
match_results = lora.match(model, state_dict) match_results = lora.match(model, state_dict)
if match_results is not None: if match_results is not None:
print(f" Adding LoRA to {model_name} ({model_path}).") print(f" Adding LoRA to {model_name} ({model_path}).")
@@ -480,9 +386,15 @@ class ModelManager:
break break
def load_model(self, file_path, model_names=None): def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
print(f"Loading models from: {file_path}") print(f"Loading models from: {file_path}")
if os.path.isfile(file_path): if device is None: device = self.device
if torch_dtype is None: torch_dtype = self.torch_dtype
if isinstance(file_path, list):
state_dict = {}
for path in file_path:
state_dict.update(load_state_dict(path))
elif os.path.isfile(file_path):
state_dict = load_state_dict(file_path) state_dict = load_state_dict(file_path)
else: else:
state_dict = None state_dict = None
@@ -490,7 +402,7 @@ class ModelManager:
if model_detector.match(file_path, state_dict): if model_detector.match(file_path, state_dict):
model_names, models = model_detector.load( model_names, models = model_detector.load(
file_path, state_dict, file_path, state_dict,
device=self.device, torch_dtype=self.torch_dtype, device=device, torch_dtype=torch_dtype,
allowed_model_names=model_names, model_manager=self allowed_model_names=model_names, model_manager=self
) )
for model_name, model in zip(model_names, models): for model_name, model in zip(model_names, models):
@@ -503,9 +415,9 @@ class ModelManager:
print(f" We cannot detect the model type. No models are loaded.") print(f" We cannot detect the model type. No models are loaded.")
def load_models(self, file_path_list, model_names=None): def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
for file_path in file_path_list: for file_path in file_path_list:
self.load_model(file_path, model_names) self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
def fetch_model(self, model_name, file_path=None, require_model_path=False): def fetch_model(self, model_name, file_path=None, require_model_path=False):

803
diffsynth/models/omnigen.py Normal file
View File

@@ -0,0 +1,803 @@
# The code is revised from DiT
import os
import torch
import torch.nn as nn
import numpy as np
import math
from safetensors.torch import load_file
from typing import List, Optional, Tuple, Union
import torch.utils.checkpoint
from huggingface_hub import snapshot_download
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers import Phi3Config, Phi3Model
from transformers.cache_utils import Cache, DynamicCache
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Phi3Transformer(Phi3Model):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
We only modified the attention mask
Args:
config: Phi3Config
"""
def prefetch_layer(self, layer_idx: int, device: torch.device):
"Starts prefetching the next layer cache"
with torch.cuda.stream(self.prefetch_stream):
# Prefetch next layer tensors to GPU
for name, param in self.layers[layer_idx].named_parameters():
param.data = param.data.to(device, non_blocking=True)
def evict_previous_layer(self, layer_idx: int):
"Moves the previous layer cache to the CPU"
prev_layer_idx = layer_idx - 1
for name, param in self.layers[prev_layer_idx].named_parameters():
param.data = param.data.to("cpu", non_blocking=True)
def get_offlaod_layer(self, layer_idx: int, device: torch.device):
# init stream
if not hasattr(self, "prefetch_stream"):
self.prefetch_stream = torch.cuda.Stream()
# delete previous layer
torch.cuda.current_stream().synchronize()
self.evict_previous_layer(layer_idx)
# make sure the current layer is ready
torch.cuda.synchronize(self.prefetch_stream)
# load next layer
self.prefetch_layer((layer_idx + 1) % len(self.layers), device)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
offload_model: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)
# if inputs_embeds is None:
# inputs_embeds = self.embed_tokens(input_ids)
# if cache_position is None:
# past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
# cache_position = torch.arange(
# past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
# )
# if position_ids is None:
# position_ids = cache_position.unsqueeze(0)
if attention_mask is not None and attention_mask.dim() == 3:
dtype = inputs_embeds.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = (1 - attention_mask) * min_dtype
attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
else:
raise Exception("attention_mask parameter was unavailable or invalid")
# causal_mask = self._update_causal_mask(
# attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
# )
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
layer_idx = -1
for decoder_layer in self.layers:
layer_idx += 1
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
if offload_model and not self.training:
self.get_offlaod_layer(layer_idx, device=inputs_embeds.device)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
print('************')
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t, dtype=torch.float32):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
t_emb = self.mlp(t_freq)
return t_emb
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=1):
"""
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, int):
grid_size = (grid_size, grid_size)
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class PatchEmbedMR(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(
self,
patch_size: int = 2,
in_chans: int = 4,
embed_dim: int = 768,
bias: bool = True,
):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
def forward(self, x):
x = self.proj(x)
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
return x
class OmniGenOriginalModel(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
transformer_config: Phi3Config,
patch_size=2,
in_channels=4,
pe_interpolation: float = 1.0,
pos_embed_max_size: int = 192,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.pos_embed_max_size = pos_embed_max_size
hidden_size = transformer_config.hidden_size
self.x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
self.input_x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
self.time_token = TimestepEmbedder(hidden_size)
self.t_embedder = TimestepEmbedder(hidden_size)
self.pe_interpolation = pe_interpolation
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, interpolation_scale=self.pe_interpolation, base_size=64)
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
self.initialize_weights()
self.llm = Phi3Transformer(config=transformer_config)
self.llm.config.use_cache = False
@classmethod
def from_pretrained(cls, model_name):
if not os.path.exists(model_name):
cache_folder = os.getenv('HF_HUB_CACHE')
model_name = snapshot_download(repo_id=model_name,
cache_dir=cache_folder,
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
config = Phi3Config.from_pretrained(model_name)
model = cls(config)
if os.path.exists(os.path.join(model_name, 'model.safetensors')):
print("Loading safetensors")
ckpt = load_file(os.path.join(model_name, 'model.safetensors'))
else:
ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')
model.load_state_dict(ckpt)
return model
def initialize_weights(self):
assert not hasattr(self, "llama")
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
w = self.input_x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.time_token.mlp[0].weight, std=0.02)
nn.init.normal_(self.time_token.mlp[2].weight, std=0.02)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def unpatchify(self, x, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
x = x.reshape(shape=(x.shape[0], h//self.patch_size, w//self.patch_size, self.patch_size, self.patch_size, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h, w))
return imgs
def cropped_pos_embed(self, height, width):
"""Crops positional embeddings for SD3 compatibility."""
if self.pos_embed_max_size is None:
raise ValueError("`pos_embed_max_size` must be set for cropping.")
height = height // self.patch_size
width = width // self.patch_size
if height > self.pos_embed_max_size:
raise ValueError(
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
)
if width > self.pos_embed_max_size:
raise ValueError(
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
)
top = (self.pos_embed_max_size - height) // 2
left = (self.pos_embed_max_size - width) // 2
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
# print(top, top + height, left, left + width, spatial_pos_embed.size())
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
return spatial_pos_embed
def patch_multiple_resolutions(self, latents, padding_latent=None, is_input_images:bool=False):
if isinstance(latents, list):
return_list = False
if padding_latent is None:
padding_latent = [None] * len(latents)
return_list = True
patched_latents, num_tokens, shapes = [], [], []
for latent, padding in zip(latents, padding_latent):
height, width = latent.shape[-2:]
if is_input_images:
latent = self.input_x_embedder(latent)
else:
latent = self.x_embedder(latent)
pos_embed = self.cropped_pos_embed(height, width)
latent = latent + pos_embed
if padding is not None:
latent = torch.cat([latent, padding], dim=-2)
patched_latents.append(latent)
num_tokens.append(pos_embed.size(1))
shapes.append([height, width])
if not return_list:
latents = torch.cat(patched_latents, dim=0)
else:
latents = patched_latents
else:
height, width = latents.shape[-2:]
if is_input_images:
latents = self.input_x_embedder(latents)
else:
latents = self.x_embedder(latents)
pos_embed = self.cropped_pos_embed(height, width)
latents = latents + pos_embed
num_tokens = latents.size(1)
shapes = [height, width]
return latents, num_tokens, shapes
def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False):
"""
"""
input_is_list = isinstance(x, list)
x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
if input_img_latents is not None:
input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
if input_ids is not None:
condition_embeds = self.llm.embed_tokens(input_ids).clone()
input_img_inx = 0
for b_inx in input_image_sizes.keys():
for start_inx, end_inx in input_image_sizes[b_inx]:
condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
input_img_inx += 1
if input_img_latents is not None:
assert input_img_inx == len(input_latents)
input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
else:
input_emb = torch.cat([time_token, x], dim=1)
output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model)
output, past_key_values = output.last_hidden_state, output.past_key_values
if input_is_list:
image_embedding = output[:, -max(num_tokens):]
time_emb = self.t_embedder(timestep, dtype=x.dtype)
x = self.final_layer(image_embedding, time_emb)
latents = []
for i in range(x.size(0)):
latent = x[i:i+1, :num_tokens[i]]
latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
latents.append(latent)
else:
image_embedding = output[:, -num_tokens:]
time_emb = self.t_embedder(timestep, dtype=x.dtype)
x = self.final_layer(image_embedding, time_emb)
latents = self.unpatchify(x, shapes[0], shapes[1])
if return_past_key_values:
return latents, past_key_values
return latents
@torch.no_grad()
def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
self.llm.config.use_cache = use_kv_cache
model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values, return_past_key_values=True, offload_model=offload_model)
if use_img_cfg:
cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
model_out = [cond, cond, cond]
else:
cond, uncond = torch.split(model_out, len(model_out) // 2, dim=0)
cond = uncond + cfg_scale * (cond - uncond)
model_out = [cond, cond]
return torch.cat(model_out, dim=0), past_key_values
@torch.no_grad()
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
self.llm.config.use_cache = use_kv_cache
if past_key_values is None:
past_key_values = [None] * len(attention_mask)
x = torch.split(x, len(x) // len(attention_mask), dim=0)
timestep = timestep.to(x[0].dtype)
timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
model_out, pask_key_values = [], []
for i in range(len(input_ids)):
temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
model_out.append(temp_out)
pask_key_values.append(temp_pask_key_values)
if len(model_out) == 3:
cond, uncond, img_cond = model_out
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
model_out = [cond, cond, cond]
elif len(model_out) == 2:
cond, uncond = model_out
cond = uncond + cfg_scale * (cond - uncond)
model_out = [cond, cond]
else:
return model_out[0]
return torch.cat(model_out, dim=0), pask_key_values
class OmniGenTransformer(OmniGenOriginalModel):
def __init__(self):
config = {
"_name_or_path": "Phi-3-vision-128k-instruct",
"architectures": [
"Phi3ForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 3072,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 131072,
"model_type": "phi3",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 32,
"original_max_position_embeddings": 4096,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"long_factor": [
1.0299999713897705,
1.0499999523162842,
1.0499999523162842,
1.0799999237060547,
1.2299998998641968,
1.2299998998641968,
1.2999999523162842,
1.4499999284744263,
1.5999999046325684,
1.6499998569488525,
1.8999998569488525,
2.859999895095825,
3.68999981880188,
5.419999599456787,
5.489999771118164,
5.489999771118164,
9.09000015258789,
11.579999923706055,
15.65999984741211,
15.769999504089355,
15.789999961853027,
18.360000610351562,
21.989999771118164,
23.079999923706055,
30.009998321533203,
32.35000228881836,
32.590003967285156,
35.56000518798828,
39.95000457763672,
53.840003967285156,
56.20000457763672,
57.95000457763672,
59.29000473022461,
59.77000427246094,
59.920005798339844,
61.190006256103516,
61.96000671386719,
62.50000762939453,
63.3700065612793,
63.48000717163086,
63.48000717163086,
63.66000747680664,
63.850006103515625,
64.08000946044922,
64.760009765625,
64.80001068115234,
64.81001281738281,
64.81001281738281
],
"short_factor": [
1.05,
1.05,
1.05,
1.1,
1.1,
1.1,
1.2500000000000002,
1.2500000000000002,
1.4000000000000004,
1.4500000000000004,
1.5500000000000005,
1.8500000000000008,
1.9000000000000008,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.1000000000000005,
2.1000000000000005,
2.2,
2.3499999999999996,
2.3499999999999996,
2.3499999999999996,
2.3499999999999996,
2.3999999999999995,
2.3999999999999995,
2.6499999999999986,
2.6999999999999984,
2.8999999999999977,
2.9499999999999975,
3.049999999999997,
3.049999999999997,
3.049999999999997
],
"type": "su"
},
"rope_theta": 10000.0,
"sliding_window": 131072,
"tie_word_embeddings": False,
"torch_dtype": "bfloat16",
"transformers_version": "4.38.1",
"use_cache": True,
"vocab_size": 32064,
"_attn_implementation": "sdpa"
}
config = Phi3Config(**config)
super().__init__(config)
def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False):
input_is_list = isinstance(x, list)
x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
if input_img_latents is not None:
input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
if input_ids is not None:
condition_embeds = self.llm.embed_tokens(input_ids).clone()
input_img_inx = 0
for b_inx in input_image_sizes.keys():
for start_inx, end_inx in input_image_sizes[b_inx]:
condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
input_img_inx += 1
if input_img_latents is not None:
assert input_img_inx == len(input_latents)
input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
else:
input_emb = torch.cat([time_token, x], dim=1)
output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model)
output, past_key_values = output.last_hidden_state, output.past_key_values
if input_is_list:
image_embedding = output[:, -max(num_tokens):]
time_emb = self.t_embedder(timestep, dtype=x.dtype)
x = self.final_layer(image_embedding, time_emb)
latents = []
for i in range(x.size(0)):
latent = x[i:i+1, :num_tokens[i]]
latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
latents.append(latent)
else:
image_embedding = output[:, -num_tokens:]
time_emb = self.t_embedder(timestep, dtype=x.dtype)
x = self.final_layer(image_embedding, time_emb)
latents = self.unpatchify(x, shapes[0], shapes[1])
if return_past_key_values:
return latents, past_key_values
return latents
@torch.no_grad()
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
self.llm.config.use_cache = use_kv_cache
if past_key_values is None:
past_key_values = [None] * len(attention_mask)
x = torch.split(x, len(x) // len(attention_mask), dim=0)
timestep = timestep.to(x[0].dtype)
timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
model_out, pask_key_values = [], []
for i in range(len(input_ids)):
temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
model_out.append(temp_out)
pask_key_values.append(temp_pask_key_values)
if len(model_out) == 3:
cond, uncond, img_cond = model_out
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
model_out = [cond, cond, cond]
elif len(model_out) == 2:
cond, uncond = model_out
cond = uncond + cfg_scale * (cond - uncond)
model_out = [cond, cond]
else:
return model_out[0]
return torch.cat(model_out, dim=0), pask_key_values
@staticmethod
def state_dict_converter():
return OmniGenTransformerStateDictConverter()
class OmniGenTransformerStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
return state_dict

View File

@@ -5,6 +5,26 @@ from .tiler import TileWorker
class RMSNorm(torch.nn.Module):
def __init__(self, dim, eps, elementwise_affine=True):
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = torch.nn.Parameter(torch.ones((dim,)))
else:
self.weight = None
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
hidden_states = hidden_states.to(input_dtype)
if self.weight is not None:
hidden_states = hidden_states * self.weight
return hidden_states
class PatchEmbed(torch.nn.Module): class PatchEmbed(torch.nn.Module):
def __init__(self, patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192): def __init__(self, patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192):
super().__init__() super().__init__()
@@ -12,7 +32,7 @@ class PatchEmbed(torch.nn.Module):
self.patch_size = patch_size self.patch_size = patch_size
self.proj = torch.nn.Conv2d(in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size) self.proj = torch.nn.Conv2d(in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size)
self.pos_embed = torch.nn.Parameter(torch.zeros(1, self.pos_embed_max_size, self.pos_embed_max_size, 1536)) self.pos_embed = torch.nn.Parameter(torch.zeros(1, self.pos_embed_max_size, self.pos_embed_max_size, embed_dim))
def cropped_pos_embed(self, height, width): def cropped_pos_embed(self, height, width):
height = height // self.patch_size height = height // self.patch_size
@@ -32,9 +52,9 @@ class PatchEmbed(torch.nn.Module):
class TimestepEmbeddings(torch.nn.Module): class TimestepEmbeddings(torch.nn.Module):
def __init__(self, dim_in, dim_out): def __init__(self, dim_in, dim_out, computation_device=None):
super().__init__() super().__init__()
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0) self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
self.timestep_embedder = torch.nn.Sequential( self.timestep_embedder = torch.nn.Sequential(
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out) torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
) )
@@ -47,10 +67,11 @@ class TimestepEmbeddings(torch.nn.Module):
class AdaLayerNorm(torch.nn.Module): class AdaLayerNorm(torch.nn.Module):
def __init__(self, dim, single=False): def __init__(self, dim, single=False, dual=False):
super().__init__() super().__init__()
self.single = single self.single = single
self.linear = torch.nn.Linear(dim, dim * (2 if single else 6)) self.dual = dual
self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual])
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb): def forward(self, x, emb):
@@ -59,6 +80,12 @@ class AdaLayerNorm(torch.nn.Module):
scale, shift = emb.unsqueeze(1).chunk(2, dim=2) scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
x = self.norm(x) * (1 + scale) + shift x = self.norm(x) * (1 + scale) + shift
return x return x
elif self.dual:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2)
norm_x = self.norm(x)
x = norm_x * (1 + scale_msa) + shift_msa
norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2
else: else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
x = self.norm(x) * (1 + scale_msa) + shift_msa x = self.norm(x) * (1 + scale_msa) + shift_msa
@@ -67,7 +94,7 @@ class AdaLayerNorm(torch.nn.Module):
class JointAttention(torch.nn.Module): class JointAttention(torch.nn.Module):
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False): def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False, use_rms_norm=False):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
self.head_dim = head_dim self.head_dim = head_dim
@@ -80,12 +107,38 @@ class JointAttention(torch.nn.Module):
if not only_out_a: if not only_out_a:
self.b_to_out = torch.nn.Linear(dim_b, dim_b) self.b_to_out = torch.nn.Linear(dim_b, dim_b)
if use_rms_norm:
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
self.norm_q_b = RMSNorm(head_dim, eps=1e-6)
self.norm_k_b = RMSNorm(head_dim, eps=1e-6)
else:
self.norm_q_a = None
self.norm_k_a = None
self.norm_q_b = None
self.norm_k_b = None
def process_qkv(self, hidden_states, to_qkv, norm_q, norm_k):
batch_size = hidden_states.shape[0]
qkv = to_qkv(hidden_states)
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q, k, v = qkv.chunk(3, dim=1)
if norm_q is not None:
q = norm_q(q)
if norm_k is not None:
k = norm_k(k)
return q, k, v
def forward(self, hidden_states_a, hidden_states_b): def forward(self, hidden_states_a, hidden_states_b):
batch_size = hidden_states_a.shape[0] batch_size = hidden_states_a.shape[0]
qkv = torch.concat([self.a_to_qkv(hidden_states_a), self.b_to_qkv(hidden_states_b)], dim=1) qa, ka, va = self.process_qkv(hidden_states_a, self.a_to_qkv, self.norm_q_a, self.norm_k_a)
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) qb, kb, vb = self.process_qkv(hidden_states_b, self.b_to_qkv, self.norm_q_b, self.norm_k_b)
q, k, v = qkv.chunk(3, dim=1) q = torch.concat([qa, qb], dim=2)
k = torch.concat([ka, kb], dim=2)
v = torch.concat([va, vb], dim=2)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v) hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
@@ -100,13 +153,55 @@ class JointAttention(torch.nn.Module):
class JointTransformerBlock(torch.nn.Module): class SingleAttention(torch.nn.Module):
def __init__(self, dim, num_attention_heads): def __init__(self, dim_a, num_heads, head_dim, use_rms_norm=False):
super().__init__() super().__init__()
self.norm1_a = AdaLayerNorm(dim) self.num_heads = num_heads
self.head_dim = head_dim
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
if use_rms_norm:
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
else:
self.norm_q_a = None
self.norm_k_a = None
def process_qkv(self, hidden_states, to_qkv, norm_q, norm_k):
batch_size = hidden_states.shape[0]
qkv = to_qkv(hidden_states)
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q, k, v = qkv.chunk(3, dim=1)
if norm_q is not None:
q = norm_q(q)
if norm_k is not None:
k = norm_k(k)
return q, k, v
def forward(self, hidden_states_a):
batch_size = hidden_states_a.shape[0]
q, k, v = self.process_qkv(hidden_states_a, self.a_to_qkv, self.norm_q_a, self.norm_k_a)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.a_to_out(hidden_states)
return hidden_states
class DualTransformerBlock(torch.nn.Module):
def __init__(self, dim, num_attention_heads, use_rms_norm=False):
super().__init__()
self.norm1_a = AdaLayerNorm(dim, dual=True)
self.norm1_b = AdaLayerNorm(dim) self.norm1_b = AdaLayerNorm(dim)
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads) self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
self.attn2 = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_a = torch.nn.Sequential( self.ff_a = torch.nn.Sequential(
@@ -124,6 +219,56 @@ class JointTransformerBlock(torch.nn.Module):
def forward(self, hidden_states_a, hidden_states_b, temb): def forward(self, hidden_states_a, hidden_states_b, temb):
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a, norm_hidden_states_a_2, gate_msa_a_2 = self.norm1_a(hidden_states_a, emb=temb)
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
# Attention
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b)
# Part A
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
hidden_states_a = hidden_states_a + gate_msa_a_2 * self.attn2(norm_hidden_states_a_2)
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
# Part B
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
return hidden_states_a, hidden_states_b
class JointTransformerBlock(torch.nn.Module):
def __init__(self, dim, num_attention_heads, use_rms_norm=False, dual=False):
super().__init__()
self.norm1_a = AdaLayerNorm(dim, dual=dual)
self.norm1_b = AdaLayerNorm(dim)
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
if dual:
self.attn2 = SingleAttention(dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_a = torch.nn.Sequential(
torch.nn.Linear(dim, dim*4),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(dim*4, dim)
)
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_b = torch.nn.Sequential(
torch.nn.Linear(dim, dim*4),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(dim*4, dim)
)
def forward(self, hidden_states_a, hidden_states_b, temb):
if self.norm1_a.dual:
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a, norm_hidden_states_a_2, gate_msa_a_2 = self.norm1_a(hidden_states_a, emb=temb)
else:
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb) norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb) norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
@@ -132,6 +277,8 @@ class JointTransformerBlock(torch.nn.Module):
# Part A # Part A
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
if self.norm1_a.dual:
hidden_states_a = hidden_states_a + gate_msa_a_2 * self.attn2(norm_hidden_states_a_2)
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a) hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
@@ -145,12 +292,12 @@ class JointTransformerBlock(torch.nn.Module):
class JointTransformerFinalBlock(torch.nn.Module): class JointTransformerFinalBlock(torch.nn.Module):
def __init__(self, dim, num_attention_heads): def __init__(self, dim, num_attention_heads, use_rms_norm=False):
super().__init__() super().__init__()
self.norm1_a = AdaLayerNorm(dim) self.norm1_a = AdaLayerNorm(dim)
self.norm1_b = AdaLayerNorm(dim, single=True) self.norm1_b = AdaLayerNorm(dim, single=True)
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, only_out_a=True) self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, only_out_a=True, use_rms_norm=use_rms_norm)
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_a = torch.nn.Sequential( self.ff_a = torch.nn.Sequential(
@@ -177,15 +324,17 @@ class JointTransformerFinalBlock(torch.nn.Module):
class SD3DiT(torch.nn.Module): class SD3DiT(torch.nn.Module):
def __init__(self): def __init__(self, embed_dim=1536, num_layers=24, use_rms_norm=False, num_dual_blocks=0, pos_embed_max_size=192):
super().__init__() super().__init__()
self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192) self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=embed_dim, pos_embed_max_size=pos_embed_max_size)
self.time_embedder = TimestepEmbeddings(256, 1536) self.time_embedder = TimestepEmbeddings(256, embed_dim)
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(2048, 1536), torch.nn.SiLU(), torch.nn.Linear(1536, 1536)) self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(2048, embed_dim), torch.nn.SiLU(), torch.nn.Linear(embed_dim, embed_dim))
self.context_embedder = torch.nn.Linear(4096, 1536) self.context_embedder = torch.nn.Linear(4096, embed_dim)
self.blocks = torch.nn.ModuleList([JointTransformerBlock(1536, 24) for _ in range(23)] + [JointTransformerFinalBlock(1536, 24)]) self.blocks = torch.nn.ModuleList([JointTransformerBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm, dual=True) for _ in range(num_dual_blocks)]
self.norm_out = AdaLayerNorm(1536, single=True) + [JointTransformerBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm) for _ in range(num_layers-1-num_dual_blocks)]
self.proj_out = torch.nn.Linear(1536, 64) + [JointTransformerFinalBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm)])
self.norm_out = AdaLayerNorm(embed_dim, single=True)
self.proj_out = torch.nn.Linear(embed_dim, 64)
def tiled_forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size=128, tile_stride=64): def tiled_forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size=128, tile_stride=64):
# Due to the global positional embedding, we cannot implement layer-wise tiled forward. # Due to the global positional embedding, we cannot implement layer-wise tiled forward.
@@ -238,6 +387,24 @@ class SD3DiTStateDictConverter:
def __init__(self): def __init__(self):
pass pass
def infer_architecture(self, state_dict):
embed_dim = state_dict["blocks.0.ff_a.0.weight"].shape[1]
num_layers = 100
while num_layers > 0 and f"blocks.{num_layers-1}.ff_a.0.bias" not in state_dict:
num_layers -= 1
use_rms_norm = "blocks.0.attn.norm_q_a.weight" in state_dict
num_dual_blocks = 0
while f"blocks.{num_dual_blocks}.attn2.a_to_out.bias" in state_dict:
num_dual_blocks += 1
pos_embed_max_size = state_dict["pos_embedder.pos_embed"].shape[1]
return {
"embed_dim": embed_dim,
"num_layers": num_layers,
"use_rms_norm": use_rms_norm,
"num_dual_blocks": num_dual_blocks,
"pos_embed_max_size": pos_embed_max_size
}
def from_diffusers(self, state_dict): def from_diffusers(self, state_dict):
rename_dict = { rename_dict = {
"context_embedder": "context_embedder", "context_embedder": "context_embedder",
@@ -264,12 +431,17 @@ class SD3DiTStateDictConverter:
"ff.net.2": "ff_a.2", "ff.net.2": "ff_a.2",
"ff_context.net.0.proj": "ff_b.0", "ff_context.net.0.proj": "ff_b.0",
"ff_context.net.2": "ff_b.2", "ff_context.net.2": "ff_b.2",
"attn.norm_q": "attn.norm_q_a",
"attn.norm_k": "attn.norm_k_a",
"attn.norm_added_q": "attn.norm_q_b",
"attn.norm_added_k": "attn.norm_k_b",
} }
state_dict_ = {} state_dict_ = {}
for name, param in state_dict.items(): for name, param in state_dict.items():
if name in rename_dict: if name in rename_dict:
if name == "pos_embed.pos_embed": if name == "pos_embed.pos_embed":
param = param.reshape((1, 192, 192, 1536)) param = param.reshape((1, 192, 192, param.shape[-1]))
state_dict_[rename_dict[name]] = param state_dict_[rename_dict[name]] = param
elif name.endswith(".weight") or name.endswith(".bias"): elif name.endswith(".weight") or name.endswith(".bias"):
suffix = ".weight" if name.endswith(".weight") else ".bias" suffix = ".weight" if name.endswith(".weight") else ".bias"
@@ -283,7 +455,19 @@ class SD3DiTStateDictConverter:
if middle in rename_dict: if middle in rename_dict:
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]]) name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
state_dict_[name_] = param state_dict_[name_] = param
return state_dict_ merged_keys = [name for name in state_dict_ if ".a_to_q." in name or ".b_to_q." in name]
for key in merged_keys:
param = torch.concat([
state_dict_[key.replace("to_q", "to_q")],
state_dict_[key.replace("to_q", "to_k")],
state_dict_[key.replace("to_q", "to_v")],
], dim=0)
name = key.replace("to_q", "to_qkv")
state_dict_.pop(key.replace("to_q", "to_q"))
state_dict_.pop(key.replace("to_q", "to_k"))
state_dict_.pop(key.replace("to_q", "to_v"))
state_dict_[name] = param
return state_dict_, self.infer_architecture(state_dict_)
def from_civitai(self, state_dict): def from_civitai(self, state_dict):
rename_dict = { rename_dict = {
@@ -291,478 +475,7 @@ class SD3DiTStateDictConverter:
"model.diffusion_model.context_embedder.weight": "context_embedder.weight", "model.diffusion_model.context_embedder.weight": "context_embedder.weight",
"model.diffusion_model.final_layer.linear.bias": "proj_out.bias", "model.diffusion_model.final_layer.linear.bias": "proj_out.bias",
"model.diffusion_model.final_layer.linear.weight": "proj_out.weight", "model.diffusion_model.final_layer.linear.weight": "proj_out.weight",
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias": "blocks.0.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.weight": "blocks.0.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.0.context_block.attn.proj.bias": "blocks.0.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.0.context_block.attn.proj.weight": "blocks.0.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.0.context_block.attn.qkv.bias": ['blocks.0.attn.b_to_q.bias', 'blocks.0.attn.b_to_k.bias', 'blocks.0.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.0.context_block.attn.qkv.weight": ['blocks.0.attn.b_to_q.weight', 'blocks.0.attn.b_to_k.weight', 'blocks.0.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc1.bias": "blocks.0.ff_b.0.bias",
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc1.weight": "blocks.0.ff_b.0.weight",
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc2.bias": "blocks.0.ff_b.2.bias",
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc2.weight": "blocks.0.ff_b.2.weight",
"model.diffusion_model.joint_blocks.0.x_block.adaLN_modulation.1.bias": "blocks.0.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.0.x_block.adaLN_modulation.1.weight": "blocks.0.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.0.x_block.attn.proj.bias": "blocks.0.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.0.x_block.attn.proj.weight": "blocks.0.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.0.x_block.attn.qkv.bias": ['blocks.0.attn.a_to_q.bias', 'blocks.0.attn.a_to_k.bias', 'blocks.0.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.0.x_block.attn.qkv.weight": ['blocks.0.attn.a_to_q.weight', 'blocks.0.attn.a_to_k.weight', 'blocks.0.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc1.bias": "blocks.0.ff_a.0.bias",
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc1.weight": "blocks.0.ff_a.0.weight",
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc2.bias": "blocks.0.ff_a.2.bias",
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc2.weight": "blocks.0.ff_a.2.weight",
"model.diffusion_model.joint_blocks.1.context_block.adaLN_modulation.1.bias": "blocks.1.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.1.context_block.adaLN_modulation.1.weight": "blocks.1.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.1.context_block.attn.proj.bias": "blocks.1.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.1.context_block.attn.proj.weight": "blocks.1.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.1.context_block.attn.qkv.bias": ['blocks.1.attn.b_to_q.bias', 'blocks.1.attn.b_to_k.bias', 'blocks.1.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.1.context_block.attn.qkv.weight": ['blocks.1.attn.b_to_q.weight', 'blocks.1.attn.b_to_k.weight', 'blocks.1.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc1.bias": "blocks.1.ff_b.0.bias",
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc1.weight": "blocks.1.ff_b.0.weight",
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc2.bias": "blocks.1.ff_b.2.bias",
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc2.weight": "blocks.1.ff_b.2.weight",
"model.diffusion_model.joint_blocks.1.x_block.adaLN_modulation.1.bias": "blocks.1.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.1.x_block.adaLN_modulation.1.weight": "blocks.1.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.1.x_block.attn.proj.bias": "blocks.1.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.1.x_block.attn.proj.weight": "blocks.1.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.1.x_block.attn.qkv.bias": ['blocks.1.attn.a_to_q.bias', 'blocks.1.attn.a_to_k.bias', 'blocks.1.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.1.x_block.attn.qkv.weight": ['blocks.1.attn.a_to_q.weight', 'blocks.1.attn.a_to_k.weight', 'blocks.1.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc1.bias": "blocks.1.ff_a.0.bias",
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc1.weight": "blocks.1.ff_a.0.weight",
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc2.bias": "blocks.1.ff_a.2.bias",
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc2.weight": "blocks.1.ff_a.2.weight",
"model.diffusion_model.joint_blocks.10.context_block.adaLN_modulation.1.bias": "blocks.10.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.10.context_block.adaLN_modulation.1.weight": "blocks.10.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.10.context_block.attn.proj.bias": "blocks.10.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.10.context_block.attn.proj.weight": "blocks.10.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.10.context_block.attn.qkv.bias": ['blocks.10.attn.b_to_q.bias', 'blocks.10.attn.b_to_k.bias', 'blocks.10.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.10.context_block.attn.qkv.weight": ['blocks.10.attn.b_to_q.weight', 'blocks.10.attn.b_to_k.weight', 'blocks.10.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc1.bias": "blocks.10.ff_b.0.bias",
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc1.weight": "blocks.10.ff_b.0.weight",
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc2.bias": "blocks.10.ff_b.2.bias",
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc2.weight": "blocks.10.ff_b.2.weight",
"model.diffusion_model.joint_blocks.10.x_block.adaLN_modulation.1.bias": "blocks.10.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.10.x_block.adaLN_modulation.1.weight": "blocks.10.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.10.x_block.attn.proj.bias": "blocks.10.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.10.x_block.attn.proj.weight": "blocks.10.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.10.x_block.attn.qkv.bias": ['blocks.10.attn.a_to_q.bias', 'blocks.10.attn.a_to_k.bias', 'blocks.10.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.10.x_block.attn.qkv.weight": ['blocks.10.attn.a_to_q.weight', 'blocks.10.attn.a_to_k.weight', 'blocks.10.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc1.bias": "blocks.10.ff_a.0.bias",
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc1.weight": "blocks.10.ff_a.0.weight",
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc2.bias": "blocks.10.ff_a.2.bias",
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc2.weight": "blocks.10.ff_a.2.weight",
"model.diffusion_model.joint_blocks.11.context_block.adaLN_modulation.1.bias": "blocks.11.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.11.context_block.adaLN_modulation.1.weight": "blocks.11.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.11.context_block.attn.proj.bias": "blocks.11.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.11.context_block.attn.proj.weight": "blocks.11.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.11.context_block.attn.qkv.bias": ['blocks.11.attn.b_to_q.bias', 'blocks.11.attn.b_to_k.bias', 'blocks.11.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.11.context_block.attn.qkv.weight": ['blocks.11.attn.b_to_q.weight', 'blocks.11.attn.b_to_k.weight', 'blocks.11.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc1.bias": "blocks.11.ff_b.0.bias",
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc1.weight": "blocks.11.ff_b.0.weight",
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc2.bias": "blocks.11.ff_b.2.bias",
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc2.weight": "blocks.11.ff_b.2.weight",
"model.diffusion_model.joint_blocks.11.x_block.adaLN_modulation.1.bias": "blocks.11.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.11.x_block.adaLN_modulation.1.weight": "blocks.11.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.11.x_block.attn.proj.bias": "blocks.11.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.11.x_block.attn.proj.weight": "blocks.11.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.11.x_block.attn.qkv.bias": ['blocks.11.attn.a_to_q.bias', 'blocks.11.attn.a_to_k.bias', 'blocks.11.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.11.x_block.attn.qkv.weight": ['blocks.11.attn.a_to_q.weight', 'blocks.11.attn.a_to_k.weight', 'blocks.11.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc1.bias": "blocks.11.ff_a.0.bias",
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc1.weight": "blocks.11.ff_a.0.weight",
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc2.bias": "blocks.11.ff_a.2.bias",
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc2.weight": "blocks.11.ff_a.2.weight",
"model.diffusion_model.joint_blocks.12.context_block.adaLN_modulation.1.bias": "blocks.12.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.12.context_block.adaLN_modulation.1.weight": "blocks.12.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.12.context_block.attn.proj.bias": "blocks.12.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.12.context_block.attn.proj.weight": "blocks.12.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.12.context_block.attn.qkv.bias": ['blocks.12.attn.b_to_q.bias', 'blocks.12.attn.b_to_k.bias', 'blocks.12.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.12.context_block.attn.qkv.weight": ['blocks.12.attn.b_to_q.weight', 'blocks.12.attn.b_to_k.weight', 'blocks.12.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc1.bias": "blocks.12.ff_b.0.bias",
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc1.weight": "blocks.12.ff_b.0.weight",
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc2.bias": "blocks.12.ff_b.2.bias",
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc2.weight": "blocks.12.ff_b.2.weight",
"model.diffusion_model.joint_blocks.12.x_block.adaLN_modulation.1.bias": "blocks.12.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.12.x_block.adaLN_modulation.1.weight": "blocks.12.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.12.x_block.attn.proj.bias": "blocks.12.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.12.x_block.attn.proj.weight": "blocks.12.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.12.x_block.attn.qkv.bias": ['blocks.12.attn.a_to_q.bias', 'blocks.12.attn.a_to_k.bias', 'blocks.12.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.12.x_block.attn.qkv.weight": ['blocks.12.attn.a_to_q.weight', 'blocks.12.attn.a_to_k.weight', 'blocks.12.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc1.bias": "blocks.12.ff_a.0.bias",
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc1.weight": "blocks.12.ff_a.0.weight",
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc2.bias": "blocks.12.ff_a.2.bias",
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc2.weight": "blocks.12.ff_a.2.weight",
"model.diffusion_model.joint_blocks.13.context_block.adaLN_modulation.1.bias": "blocks.13.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.13.context_block.adaLN_modulation.1.weight": "blocks.13.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.13.context_block.attn.proj.bias": "blocks.13.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.13.context_block.attn.proj.weight": "blocks.13.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.13.context_block.attn.qkv.bias": ['blocks.13.attn.b_to_q.bias', 'blocks.13.attn.b_to_k.bias', 'blocks.13.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.13.context_block.attn.qkv.weight": ['blocks.13.attn.b_to_q.weight', 'blocks.13.attn.b_to_k.weight', 'blocks.13.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc1.bias": "blocks.13.ff_b.0.bias",
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc1.weight": "blocks.13.ff_b.0.weight",
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc2.bias": "blocks.13.ff_b.2.bias",
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc2.weight": "blocks.13.ff_b.2.weight",
"model.diffusion_model.joint_blocks.13.x_block.adaLN_modulation.1.bias": "blocks.13.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.13.x_block.adaLN_modulation.1.weight": "blocks.13.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.13.x_block.attn.proj.bias": "blocks.13.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.13.x_block.attn.proj.weight": "blocks.13.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.13.x_block.attn.qkv.bias": ['blocks.13.attn.a_to_q.bias', 'blocks.13.attn.a_to_k.bias', 'blocks.13.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.13.x_block.attn.qkv.weight": ['blocks.13.attn.a_to_q.weight', 'blocks.13.attn.a_to_k.weight', 'blocks.13.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc1.bias": "blocks.13.ff_a.0.bias",
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc1.weight": "blocks.13.ff_a.0.weight",
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc2.bias": "blocks.13.ff_a.2.bias",
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc2.weight": "blocks.13.ff_a.2.weight",
"model.diffusion_model.joint_blocks.14.context_block.adaLN_modulation.1.bias": "blocks.14.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.14.context_block.adaLN_modulation.1.weight": "blocks.14.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.14.context_block.attn.proj.bias": "blocks.14.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.14.context_block.attn.proj.weight": "blocks.14.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.14.context_block.attn.qkv.bias": ['blocks.14.attn.b_to_q.bias', 'blocks.14.attn.b_to_k.bias', 'blocks.14.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.14.context_block.attn.qkv.weight": ['blocks.14.attn.b_to_q.weight', 'blocks.14.attn.b_to_k.weight', 'blocks.14.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc1.bias": "blocks.14.ff_b.0.bias",
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc1.weight": "blocks.14.ff_b.0.weight",
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc2.bias": "blocks.14.ff_b.2.bias",
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc2.weight": "blocks.14.ff_b.2.weight",
"model.diffusion_model.joint_blocks.14.x_block.adaLN_modulation.1.bias": "blocks.14.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.14.x_block.adaLN_modulation.1.weight": "blocks.14.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.14.x_block.attn.proj.bias": "blocks.14.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.14.x_block.attn.proj.weight": "blocks.14.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.14.x_block.attn.qkv.bias": ['blocks.14.attn.a_to_q.bias', 'blocks.14.attn.a_to_k.bias', 'blocks.14.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.14.x_block.attn.qkv.weight": ['blocks.14.attn.a_to_q.weight', 'blocks.14.attn.a_to_k.weight', 'blocks.14.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc1.bias": "blocks.14.ff_a.0.bias",
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc1.weight": "blocks.14.ff_a.0.weight",
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc2.bias": "blocks.14.ff_a.2.bias",
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc2.weight": "blocks.14.ff_a.2.weight",
"model.diffusion_model.joint_blocks.15.context_block.adaLN_modulation.1.bias": "blocks.15.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.15.context_block.adaLN_modulation.1.weight": "blocks.15.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.15.context_block.attn.proj.bias": "blocks.15.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.15.context_block.attn.proj.weight": "blocks.15.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.15.context_block.attn.qkv.bias": ['blocks.15.attn.b_to_q.bias', 'blocks.15.attn.b_to_k.bias', 'blocks.15.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.15.context_block.attn.qkv.weight": ['blocks.15.attn.b_to_q.weight', 'blocks.15.attn.b_to_k.weight', 'blocks.15.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc1.bias": "blocks.15.ff_b.0.bias",
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc1.weight": "blocks.15.ff_b.0.weight",
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc2.bias": "blocks.15.ff_b.2.bias",
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc2.weight": "blocks.15.ff_b.2.weight",
"model.diffusion_model.joint_blocks.15.x_block.adaLN_modulation.1.bias": "blocks.15.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.15.x_block.adaLN_modulation.1.weight": "blocks.15.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.15.x_block.attn.proj.bias": "blocks.15.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.15.x_block.attn.proj.weight": "blocks.15.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.15.x_block.attn.qkv.bias": ['blocks.15.attn.a_to_q.bias', 'blocks.15.attn.a_to_k.bias', 'blocks.15.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.15.x_block.attn.qkv.weight": ['blocks.15.attn.a_to_q.weight', 'blocks.15.attn.a_to_k.weight', 'blocks.15.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc1.bias": "blocks.15.ff_a.0.bias",
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc1.weight": "blocks.15.ff_a.0.weight",
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc2.bias": "blocks.15.ff_a.2.bias",
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc2.weight": "blocks.15.ff_a.2.weight",
"model.diffusion_model.joint_blocks.16.context_block.adaLN_modulation.1.bias": "blocks.16.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.16.context_block.adaLN_modulation.1.weight": "blocks.16.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.16.context_block.attn.proj.bias": "blocks.16.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.16.context_block.attn.proj.weight": "blocks.16.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.16.context_block.attn.qkv.bias": ['blocks.16.attn.b_to_q.bias', 'blocks.16.attn.b_to_k.bias', 'blocks.16.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.16.context_block.attn.qkv.weight": ['blocks.16.attn.b_to_q.weight', 'blocks.16.attn.b_to_k.weight', 'blocks.16.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc1.bias": "blocks.16.ff_b.0.bias",
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc1.weight": "blocks.16.ff_b.0.weight",
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc2.bias": "blocks.16.ff_b.2.bias",
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc2.weight": "blocks.16.ff_b.2.weight",
"model.diffusion_model.joint_blocks.16.x_block.adaLN_modulation.1.bias": "blocks.16.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.16.x_block.adaLN_modulation.1.weight": "blocks.16.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.16.x_block.attn.proj.bias": "blocks.16.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.16.x_block.attn.proj.weight": "blocks.16.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.16.x_block.attn.qkv.bias": ['blocks.16.attn.a_to_q.bias', 'blocks.16.attn.a_to_k.bias', 'blocks.16.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.16.x_block.attn.qkv.weight": ['blocks.16.attn.a_to_q.weight', 'blocks.16.attn.a_to_k.weight', 'blocks.16.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc1.bias": "blocks.16.ff_a.0.bias",
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc1.weight": "blocks.16.ff_a.0.weight",
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc2.bias": "blocks.16.ff_a.2.bias",
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc2.weight": "blocks.16.ff_a.2.weight",
"model.diffusion_model.joint_blocks.17.context_block.adaLN_modulation.1.bias": "blocks.17.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.17.context_block.adaLN_modulation.1.weight": "blocks.17.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.17.context_block.attn.proj.bias": "blocks.17.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.17.context_block.attn.proj.weight": "blocks.17.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.17.context_block.attn.qkv.bias": ['blocks.17.attn.b_to_q.bias', 'blocks.17.attn.b_to_k.bias', 'blocks.17.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.17.context_block.attn.qkv.weight": ['blocks.17.attn.b_to_q.weight', 'blocks.17.attn.b_to_k.weight', 'blocks.17.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc1.bias": "blocks.17.ff_b.0.bias",
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc1.weight": "blocks.17.ff_b.0.weight",
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc2.bias": "blocks.17.ff_b.2.bias",
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc2.weight": "blocks.17.ff_b.2.weight",
"model.diffusion_model.joint_blocks.17.x_block.adaLN_modulation.1.bias": "blocks.17.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.17.x_block.adaLN_modulation.1.weight": "blocks.17.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.17.x_block.attn.proj.bias": "blocks.17.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.17.x_block.attn.proj.weight": "blocks.17.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.17.x_block.attn.qkv.bias": ['blocks.17.attn.a_to_q.bias', 'blocks.17.attn.a_to_k.bias', 'blocks.17.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.17.x_block.attn.qkv.weight": ['blocks.17.attn.a_to_q.weight', 'blocks.17.attn.a_to_k.weight', 'blocks.17.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc1.bias": "blocks.17.ff_a.0.bias",
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc1.weight": "blocks.17.ff_a.0.weight",
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc2.bias": "blocks.17.ff_a.2.bias",
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc2.weight": "blocks.17.ff_a.2.weight",
"model.diffusion_model.joint_blocks.18.context_block.adaLN_modulation.1.bias": "blocks.18.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.18.context_block.adaLN_modulation.1.weight": "blocks.18.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.18.context_block.attn.proj.bias": "blocks.18.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.18.context_block.attn.proj.weight": "blocks.18.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.18.context_block.attn.qkv.bias": ['blocks.18.attn.b_to_q.bias', 'blocks.18.attn.b_to_k.bias', 'blocks.18.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.18.context_block.attn.qkv.weight": ['blocks.18.attn.b_to_q.weight', 'blocks.18.attn.b_to_k.weight', 'blocks.18.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc1.bias": "blocks.18.ff_b.0.bias",
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc1.weight": "blocks.18.ff_b.0.weight",
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc2.bias": "blocks.18.ff_b.2.bias",
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc2.weight": "blocks.18.ff_b.2.weight",
"model.diffusion_model.joint_blocks.18.x_block.adaLN_modulation.1.bias": "blocks.18.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.18.x_block.adaLN_modulation.1.weight": "blocks.18.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.18.x_block.attn.proj.bias": "blocks.18.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.18.x_block.attn.proj.weight": "blocks.18.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.18.x_block.attn.qkv.bias": ['blocks.18.attn.a_to_q.bias', 'blocks.18.attn.a_to_k.bias', 'blocks.18.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.18.x_block.attn.qkv.weight": ['blocks.18.attn.a_to_q.weight', 'blocks.18.attn.a_to_k.weight', 'blocks.18.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc1.bias": "blocks.18.ff_a.0.bias",
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc1.weight": "blocks.18.ff_a.0.weight",
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc2.bias": "blocks.18.ff_a.2.bias",
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc2.weight": "blocks.18.ff_a.2.weight",
"model.diffusion_model.joint_blocks.19.context_block.adaLN_modulation.1.bias": "blocks.19.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.19.context_block.adaLN_modulation.1.weight": "blocks.19.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.19.context_block.attn.proj.bias": "blocks.19.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.19.context_block.attn.proj.weight": "blocks.19.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.19.context_block.attn.qkv.bias": ['blocks.19.attn.b_to_q.bias', 'blocks.19.attn.b_to_k.bias', 'blocks.19.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.19.context_block.attn.qkv.weight": ['blocks.19.attn.b_to_q.weight', 'blocks.19.attn.b_to_k.weight', 'blocks.19.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc1.bias": "blocks.19.ff_b.0.bias",
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc1.weight": "blocks.19.ff_b.0.weight",
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc2.bias": "blocks.19.ff_b.2.bias",
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc2.weight": "blocks.19.ff_b.2.weight",
"model.diffusion_model.joint_blocks.19.x_block.adaLN_modulation.1.bias": "blocks.19.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.19.x_block.adaLN_modulation.1.weight": "blocks.19.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.19.x_block.attn.proj.bias": "blocks.19.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.19.x_block.attn.proj.weight": "blocks.19.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.19.x_block.attn.qkv.bias": ['blocks.19.attn.a_to_q.bias', 'blocks.19.attn.a_to_k.bias', 'blocks.19.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.19.x_block.attn.qkv.weight": ['blocks.19.attn.a_to_q.weight', 'blocks.19.attn.a_to_k.weight', 'blocks.19.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc1.bias": "blocks.19.ff_a.0.bias",
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc1.weight": "blocks.19.ff_a.0.weight",
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc2.bias": "blocks.19.ff_a.2.bias",
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc2.weight": "blocks.19.ff_a.2.weight",
"model.diffusion_model.joint_blocks.2.context_block.adaLN_modulation.1.bias": "blocks.2.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.2.context_block.adaLN_modulation.1.weight": "blocks.2.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.2.context_block.attn.proj.bias": "blocks.2.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.2.context_block.attn.proj.weight": "blocks.2.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.2.context_block.attn.qkv.bias": ['blocks.2.attn.b_to_q.bias', 'blocks.2.attn.b_to_k.bias', 'blocks.2.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.2.context_block.attn.qkv.weight": ['blocks.2.attn.b_to_q.weight', 'blocks.2.attn.b_to_k.weight', 'blocks.2.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc1.bias": "blocks.2.ff_b.0.bias",
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc1.weight": "blocks.2.ff_b.0.weight",
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc2.bias": "blocks.2.ff_b.2.bias",
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc2.weight": "blocks.2.ff_b.2.weight",
"model.diffusion_model.joint_blocks.2.x_block.adaLN_modulation.1.bias": "blocks.2.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.2.x_block.adaLN_modulation.1.weight": "blocks.2.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.2.x_block.attn.proj.bias": "blocks.2.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.2.x_block.attn.proj.weight": "blocks.2.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.2.x_block.attn.qkv.bias": ['blocks.2.attn.a_to_q.bias', 'blocks.2.attn.a_to_k.bias', 'blocks.2.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.2.x_block.attn.qkv.weight": ['blocks.2.attn.a_to_q.weight', 'blocks.2.attn.a_to_k.weight', 'blocks.2.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc1.bias": "blocks.2.ff_a.0.bias",
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc1.weight": "blocks.2.ff_a.0.weight",
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc2.bias": "blocks.2.ff_a.2.bias",
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc2.weight": "blocks.2.ff_a.2.weight",
"model.diffusion_model.joint_blocks.20.context_block.adaLN_modulation.1.bias": "blocks.20.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.20.context_block.adaLN_modulation.1.weight": "blocks.20.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.20.context_block.attn.proj.bias": "blocks.20.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.20.context_block.attn.proj.weight": "blocks.20.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.20.context_block.attn.qkv.bias": ['blocks.20.attn.b_to_q.bias', 'blocks.20.attn.b_to_k.bias', 'blocks.20.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.20.context_block.attn.qkv.weight": ['blocks.20.attn.b_to_q.weight', 'blocks.20.attn.b_to_k.weight', 'blocks.20.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc1.bias": "blocks.20.ff_b.0.bias",
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc1.weight": "blocks.20.ff_b.0.weight",
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc2.bias": "blocks.20.ff_b.2.bias",
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc2.weight": "blocks.20.ff_b.2.weight",
"model.diffusion_model.joint_blocks.20.x_block.adaLN_modulation.1.bias": "blocks.20.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.20.x_block.adaLN_modulation.1.weight": "blocks.20.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.20.x_block.attn.proj.bias": "blocks.20.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.20.x_block.attn.proj.weight": "blocks.20.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.20.x_block.attn.qkv.bias": ['blocks.20.attn.a_to_q.bias', 'blocks.20.attn.a_to_k.bias', 'blocks.20.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.20.x_block.attn.qkv.weight": ['blocks.20.attn.a_to_q.weight', 'blocks.20.attn.a_to_k.weight', 'blocks.20.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc1.bias": "blocks.20.ff_a.0.bias",
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc1.weight": "blocks.20.ff_a.0.weight",
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc2.bias": "blocks.20.ff_a.2.bias",
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc2.weight": "blocks.20.ff_a.2.weight",
"model.diffusion_model.joint_blocks.21.context_block.adaLN_modulation.1.bias": "blocks.21.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.21.context_block.adaLN_modulation.1.weight": "blocks.21.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.21.context_block.attn.proj.bias": "blocks.21.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.21.context_block.attn.proj.weight": "blocks.21.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.21.context_block.attn.qkv.bias": ['blocks.21.attn.b_to_q.bias', 'blocks.21.attn.b_to_k.bias', 'blocks.21.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.21.context_block.attn.qkv.weight": ['blocks.21.attn.b_to_q.weight', 'blocks.21.attn.b_to_k.weight', 'blocks.21.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc1.bias": "blocks.21.ff_b.0.bias",
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc1.weight": "blocks.21.ff_b.0.weight",
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc2.bias": "blocks.21.ff_b.2.bias",
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc2.weight": "blocks.21.ff_b.2.weight",
"model.diffusion_model.joint_blocks.21.x_block.adaLN_modulation.1.bias": "blocks.21.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.21.x_block.adaLN_modulation.1.weight": "blocks.21.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.21.x_block.attn.proj.bias": "blocks.21.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.21.x_block.attn.proj.weight": "blocks.21.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.21.x_block.attn.qkv.bias": ['blocks.21.attn.a_to_q.bias', 'blocks.21.attn.a_to_k.bias', 'blocks.21.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.21.x_block.attn.qkv.weight": ['blocks.21.attn.a_to_q.weight', 'blocks.21.attn.a_to_k.weight', 'blocks.21.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc1.bias": "blocks.21.ff_a.0.bias",
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc1.weight": "blocks.21.ff_a.0.weight",
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc2.bias": "blocks.21.ff_a.2.bias",
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc2.weight": "blocks.21.ff_a.2.weight",
"model.diffusion_model.joint_blocks.22.context_block.adaLN_modulation.1.bias": "blocks.22.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.22.context_block.adaLN_modulation.1.weight": "blocks.22.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.22.context_block.attn.proj.bias": "blocks.22.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.22.context_block.attn.proj.weight": "blocks.22.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.22.context_block.attn.qkv.bias": ['blocks.22.attn.b_to_q.bias', 'blocks.22.attn.b_to_k.bias', 'blocks.22.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.22.context_block.attn.qkv.weight": ['blocks.22.attn.b_to_q.weight', 'blocks.22.attn.b_to_k.weight', 'blocks.22.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc1.bias": "blocks.22.ff_b.0.bias",
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc1.weight": "blocks.22.ff_b.0.weight",
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc2.bias": "blocks.22.ff_b.2.bias",
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc2.weight": "blocks.22.ff_b.2.weight",
"model.diffusion_model.joint_blocks.22.x_block.adaLN_modulation.1.bias": "blocks.22.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.22.x_block.adaLN_modulation.1.weight": "blocks.22.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.22.x_block.attn.proj.bias": "blocks.22.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.22.x_block.attn.proj.weight": "blocks.22.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.22.x_block.attn.qkv.bias": ['blocks.22.attn.a_to_q.bias', 'blocks.22.attn.a_to_k.bias', 'blocks.22.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.22.x_block.attn.qkv.weight": ['blocks.22.attn.a_to_q.weight', 'blocks.22.attn.a_to_k.weight', 'blocks.22.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc1.bias": "blocks.22.ff_a.0.bias",
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc1.weight": "blocks.22.ff_a.0.weight",
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc2.bias": "blocks.22.ff_a.2.bias",
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc2.weight": "blocks.22.ff_a.2.weight",
"model.diffusion_model.joint_blocks.23.context_block.attn.qkv.bias": ['blocks.23.attn.b_to_q.bias', 'blocks.23.attn.b_to_k.bias', 'blocks.23.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.23.context_block.attn.qkv.weight": ['blocks.23.attn.b_to_q.weight', 'blocks.23.attn.b_to_k.weight', 'blocks.23.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.bias": "blocks.23.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.weight": "blocks.23.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.23.x_block.attn.proj.bias": "blocks.23.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.23.x_block.attn.proj.weight": "blocks.23.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.23.x_block.attn.qkv.bias": ['blocks.23.attn.a_to_q.bias', 'blocks.23.attn.a_to_k.bias', 'blocks.23.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.23.x_block.attn.qkv.weight": ['blocks.23.attn.a_to_q.weight', 'blocks.23.attn.a_to_k.weight', 'blocks.23.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc1.bias": "blocks.23.ff_a.0.bias",
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc1.weight": "blocks.23.ff_a.0.weight",
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc2.bias": "blocks.23.ff_a.2.bias",
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc2.weight": "blocks.23.ff_a.2.weight",
"model.diffusion_model.joint_blocks.3.context_block.adaLN_modulation.1.bias": "blocks.3.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.3.context_block.adaLN_modulation.1.weight": "blocks.3.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.3.context_block.attn.proj.bias": "blocks.3.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.3.context_block.attn.proj.weight": "blocks.3.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.3.context_block.attn.qkv.bias": ['blocks.3.attn.b_to_q.bias', 'blocks.3.attn.b_to_k.bias', 'blocks.3.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.3.context_block.attn.qkv.weight": ['blocks.3.attn.b_to_q.weight', 'blocks.3.attn.b_to_k.weight', 'blocks.3.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc1.bias": "blocks.3.ff_b.0.bias",
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc1.weight": "blocks.3.ff_b.0.weight",
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc2.bias": "blocks.3.ff_b.2.bias",
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc2.weight": "blocks.3.ff_b.2.weight",
"model.diffusion_model.joint_blocks.3.x_block.adaLN_modulation.1.bias": "blocks.3.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.3.x_block.adaLN_modulation.1.weight": "blocks.3.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.3.x_block.attn.proj.bias": "blocks.3.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.3.x_block.attn.proj.weight": "blocks.3.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.3.x_block.attn.qkv.bias": ['blocks.3.attn.a_to_q.bias', 'blocks.3.attn.a_to_k.bias', 'blocks.3.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.3.x_block.attn.qkv.weight": ['blocks.3.attn.a_to_q.weight', 'blocks.3.attn.a_to_k.weight', 'blocks.3.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc1.bias": "blocks.3.ff_a.0.bias",
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc1.weight": "blocks.3.ff_a.0.weight",
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc2.bias": "blocks.3.ff_a.2.bias",
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc2.weight": "blocks.3.ff_a.2.weight",
"model.diffusion_model.joint_blocks.4.context_block.adaLN_modulation.1.bias": "blocks.4.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.4.context_block.adaLN_modulation.1.weight": "blocks.4.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.4.context_block.attn.proj.bias": "blocks.4.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.4.context_block.attn.proj.weight": "blocks.4.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.4.context_block.attn.qkv.bias": ['blocks.4.attn.b_to_q.bias', 'blocks.4.attn.b_to_k.bias', 'blocks.4.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.4.context_block.attn.qkv.weight": ['blocks.4.attn.b_to_q.weight', 'blocks.4.attn.b_to_k.weight', 'blocks.4.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc1.bias": "blocks.4.ff_b.0.bias",
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc1.weight": "blocks.4.ff_b.0.weight",
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc2.bias": "blocks.4.ff_b.2.bias",
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc2.weight": "blocks.4.ff_b.2.weight",
"model.diffusion_model.joint_blocks.4.x_block.adaLN_modulation.1.bias": "blocks.4.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.4.x_block.adaLN_modulation.1.weight": "blocks.4.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.4.x_block.attn.proj.bias": "blocks.4.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.4.x_block.attn.proj.weight": "blocks.4.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.4.x_block.attn.qkv.bias": ['blocks.4.attn.a_to_q.bias', 'blocks.4.attn.a_to_k.bias', 'blocks.4.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.4.x_block.attn.qkv.weight": ['blocks.4.attn.a_to_q.weight', 'blocks.4.attn.a_to_k.weight', 'blocks.4.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc1.bias": "blocks.4.ff_a.0.bias",
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc1.weight": "blocks.4.ff_a.0.weight",
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc2.bias": "blocks.4.ff_a.2.bias",
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc2.weight": "blocks.4.ff_a.2.weight",
"model.diffusion_model.joint_blocks.5.context_block.adaLN_modulation.1.bias": "blocks.5.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.5.context_block.adaLN_modulation.1.weight": "blocks.5.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.5.context_block.attn.proj.bias": "blocks.5.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.5.context_block.attn.proj.weight": "blocks.5.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.5.context_block.attn.qkv.bias": ['blocks.5.attn.b_to_q.bias', 'blocks.5.attn.b_to_k.bias', 'blocks.5.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.5.context_block.attn.qkv.weight": ['blocks.5.attn.b_to_q.weight', 'blocks.5.attn.b_to_k.weight', 'blocks.5.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc1.bias": "blocks.5.ff_b.0.bias",
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc1.weight": "blocks.5.ff_b.0.weight",
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc2.bias": "blocks.5.ff_b.2.bias",
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc2.weight": "blocks.5.ff_b.2.weight",
"model.diffusion_model.joint_blocks.5.x_block.adaLN_modulation.1.bias": "blocks.5.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.5.x_block.adaLN_modulation.1.weight": "blocks.5.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.5.x_block.attn.proj.bias": "blocks.5.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.5.x_block.attn.proj.weight": "blocks.5.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.5.x_block.attn.qkv.bias": ['blocks.5.attn.a_to_q.bias', 'blocks.5.attn.a_to_k.bias', 'blocks.5.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.5.x_block.attn.qkv.weight": ['blocks.5.attn.a_to_q.weight', 'blocks.5.attn.a_to_k.weight', 'blocks.5.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc1.bias": "blocks.5.ff_a.0.bias",
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc1.weight": "blocks.5.ff_a.0.weight",
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc2.bias": "blocks.5.ff_a.2.bias",
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc2.weight": "blocks.5.ff_a.2.weight",
"model.diffusion_model.joint_blocks.6.context_block.adaLN_modulation.1.bias": "blocks.6.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.6.context_block.adaLN_modulation.1.weight": "blocks.6.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.6.context_block.attn.proj.bias": "blocks.6.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.6.context_block.attn.proj.weight": "blocks.6.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.6.context_block.attn.qkv.bias": ['blocks.6.attn.b_to_q.bias', 'blocks.6.attn.b_to_k.bias', 'blocks.6.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.6.context_block.attn.qkv.weight": ['blocks.6.attn.b_to_q.weight', 'blocks.6.attn.b_to_k.weight', 'blocks.6.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc1.bias": "blocks.6.ff_b.0.bias",
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc1.weight": "blocks.6.ff_b.0.weight",
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc2.bias": "blocks.6.ff_b.2.bias",
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc2.weight": "blocks.6.ff_b.2.weight",
"model.diffusion_model.joint_blocks.6.x_block.adaLN_modulation.1.bias": "blocks.6.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.6.x_block.adaLN_modulation.1.weight": "blocks.6.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.6.x_block.attn.proj.bias": "blocks.6.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.6.x_block.attn.proj.weight": "blocks.6.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.6.x_block.attn.qkv.bias": ['blocks.6.attn.a_to_q.bias', 'blocks.6.attn.a_to_k.bias', 'blocks.6.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.6.x_block.attn.qkv.weight": ['blocks.6.attn.a_to_q.weight', 'blocks.6.attn.a_to_k.weight', 'blocks.6.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc1.bias": "blocks.6.ff_a.0.bias",
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc1.weight": "blocks.6.ff_a.0.weight",
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc2.bias": "blocks.6.ff_a.2.bias",
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc2.weight": "blocks.6.ff_a.2.weight",
"model.diffusion_model.joint_blocks.7.context_block.adaLN_modulation.1.bias": "blocks.7.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.7.context_block.adaLN_modulation.1.weight": "blocks.7.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.7.context_block.attn.proj.bias": "blocks.7.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.7.context_block.attn.proj.weight": "blocks.7.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.7.context_block.attn.qkv.bias": ['blocks.7.attn.b_to_q.bias', 'blocks.7.attn.b_to_k.bias', 'blocks.7.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.7.context_block.attn.qkv.weight": ['blocks.7.attn.b_to_q.weight', 'blocks.7.attn.b_to_k.weight', 'blocks.7.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc1.bias": "blocks.7.ff_b.0.bias",
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc1.weight": "blocks.7.ff_b.0.weight",
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc2.bias": "blocks.7.ff_b.2.bias",
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc2.weight": "blocks.7.ff_b.2.weight",
"model.diffusion_model.joint_blocks.7.x_block.adaLN_modulation.1.bias": "blocks.7.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.7.x_block.adaLN_modulation.1.weight": "blocks.7.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.7.x_block.attn.proj.bias": "blocks.7.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.7.x_block.attn.proj.weight": "blocks.7.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.7.x_block.attn.qkv.bias": ['blocks.7.attn.a_to_q.bias', 'blocks.7.attn.a_to_k.bias', 'blocks.7.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.7.x_block.attn.qkv.weight": ['blocks.7.attn.a_to_q.weight', 'blocks.7.attn.a_to_k.weight', 'blocks.7.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc1.bias": "blocks.7.ff_a.0.bias",
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc1.weight": "blocks.7.ff_a.0.weight",
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc2.bias": "blocks.7.ff_a.2.bias",
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc2.weight": "blocks.7.ff_a.2.weight",
"model.diffusion_model.joint_blocks.8.context_block.adaLN_modulation.1.bias": "blocks.8.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.8.context_block.adaLN_modulation.1.weight": "blocks.8.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.8.context_block.attn.proj.bias": "blocks.8.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.8.context_block.attn.proj.weight": "blocks.8.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.8.context_block.attn.qkv.bias": ['blocks.8.attn.b_to_q.bias', 'blocks.8.attn.b_to_k.bias', 'blocks.8.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.8.context_block.attn.qkv.weight": ['blocks.8.attn.b_to_q.weight', 'blocks.8.attn.b_to_k.weight', 'blocks.8.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc1.bias": "blocks.8.ff_b.0.bias",
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc1.weight": "blocks.8.ff_b.0.weight",
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc2.bias": "blocks.8.ff_b.2.bias",
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc2.weight": "blocks.8.ff_b.2.weight",
"model.diffusion_model.joint_blocks.8.x_block.adaLN_modulation.1.bias": "blocks.8.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.8.x_block.adaLN_modulation.1.weight": "blocks.8.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.8.x_block.attn.proj.bias": "blocks.8.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.8.x_block.attn.proj.weight": "blocks.8.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.8.x_block.attn.qkv.bias": ['blocks.8.attn.a_to_q.bias', 'blocks.8.attn.a_to_k.bias', 'blocks.8.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.8.x_block.attn.qkv.weight": ['blocks.8.attn.a_to_q.weight', 'blocks.8.attn.a_to_k.weight', 'blocks.8.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc1.bias": "blocks.8.ff_a.0.bias",
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc1.weight": "blocks.8.ff_a.0.weight",
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc2.bias": "blocks.8.ff_a.2.bias",
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc2.weight": "blocks.8.ff_a.2.weight",
"model.diffusion_model.joint_blocks.9.context_block.adaLN_modulation.1.bias": "blocks.9.norm1_b.linear.bias",
"model.diffusion_model.joint_blocks.9.context_block.adaLN_modulation.1.weight": "blocks.9.norm1_b.linear.weight",
"model.diffusion_model.joint_blocks.9.context_block.attn.proj.bias": "blocks.9.attn.b_to_out.bias",
"model.diffusion_model.joint_blocks.9.context_block.attn.proj.weight": "blocks.9.attn.b_to_out.weight",
"model.diffusion_model.joint_blocks.9.context_block.attn.qkv.bias": ['blocks.9.attn.b_to_q.bias', 'blocks.9.attn.b_to_k.bias', 'blocks.9.attn.b_to_v.bias'],
"model.diffusion_model.joint_blocks.9.context_block.attn.qkv.weight": ['blocks.9.attn.b_to_q.weight', 'blocks.9.attn.b_to_k.weight', 'blocks.9.attn.b_to_v.weight'],
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc1.bias": "blocks.9.ff_b.0.bias",
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc1.weight": "blocks.9.ff_b.0.weight",
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc2.bias": "blocks.9.ff_b.2.bias",
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc2.weight": "blocks.9.ff_b.2.weight",
"model.diffusion_model.joint_blocks.9.x_block.adaLN_modulation.1.bias": "blocks.9.norm1_a.linear.bias",
"model.diffusion_model.joint_blocks.9.x_block.adaLN_modulation.1.weight": "blocks.9.norm1_a.linear.weight",
"model.diffusion_model.joint_blocks.9.x_block.attn.proj.bias": "blocks.9.attn.a_to_out.bias",
"model.diffusion_model.joint_blocks.9.x_block.attn.proj.weight": "blocks.9.attn.a_to_out.weight",
"model.diffusion_model.joint_blocks.9.x_block.attn.qkv.bias": ['blocks.9.attn.a_to_q.bias', 'blocks.9.attn.a_to_k.bias', 'blocks.9.attn.a_to_v.bias'],
"model.diffusion_model.joint_blocks.9.x_block.attn.qkv.weight": ['blocks.9.attn.a_to_q.weight', 'blocks.9.attn.a_to_k.weight', 'blocks.9.attn.a_to_v.weight'],
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc1.bias": "blocks.9.ff_a.0.bias",
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc1.weight": "blocks.9.ff_a.0.weight",
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.bias": "blocks.9.ff_a.2.bias",
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.weight": "blocks.9.ff_a.2.weight",
"model.diffusion_model.pos_embed": "pos_embedder.pos_embed", "model.diffusion_model.pos_embed": "pos_embedder.pos_embed",
"model.diffusion_model.t_embedder.mlp.0.bias": "time_embedder.timestep_embedder.0.bias", "model.diffusion_model.t_embedder.mlp.0.bias": "time_embedder.timestep_embedder.0.bias",
"model.diffusion_model.t_embedder.mlp.0.weight": "time_embedder.timestep_embedder.0.weight", "model.diffusion_model.t_embedder.mlp.0.weight": "time_embedder.timestep_embedder.0.weight",
@@ -780,19 +493,59 @@ class SD3DiTStateDictConverter:
"model.diffusion_model.final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight", "model.diffusion_model.final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
"model.diffusion_model.final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias", "model.diffusion_model.final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
} }
for i in range(40):
rename_dict.update({
f"model.diffusion_model.joint_blocks.{i}.context_block.adaLN_modulation.1.bias": f"blocks.{i}.norm1_b.linear.bias",
f"model.diffusion_model.joint_blocks.{i}.context_block.adaLN_modulation.1.weight": f"blocks.{i}.norm1_b.linear.weight",
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.proj.bias": f"blocks.{i}.attn.b_to_out.bias",
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.proj.weight": f"blocks.{i}.attn.b_to_out.weight",
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.qkv.bias": [f'blocks.{i}.attn.b_to_q.bias', f'blocks.{i}.attn.b_to_k.bias', f'blocks.{i}.attn.b_to_v.bias'],
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.qkv.weight": [f'blocks.{i}.attn.b_to_q.weight', f'blocks.{i}.attn.b_to_k.weight', f'blocks.{i}.attn.b_to_v.weight'],
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc1.bias": f"blocks.{i}.ff_b.0.bias",
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc1.weight": f"blocks.{i}.ff_b.0.weight",
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc2.bias": f"blocks.{i}.ff_b.2.bias",
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc2.weight": f"blocks.{i}.ff_b.2.weight",
f"model.diffusion_model.joint_blocks.{i}.x_block.adaLN_modulation.1.bias": f"blocks.{i}.norm1_a.linear.bias",
f"model.diffusion_model.joint_blocks.{i}.x_block.adaLN_modulation.1.weight": f"blocks.{i}.norm1_a.linear.weight",
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.proj.bias": f"blocks.{i}.attn.a_to_out.bias",
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.proj.weight": f"blocks.{i}.attn.a_to_out.weight",
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.qkv.bias": [f'blocks.{i}.attn.a_to_q.bias', f'blocks.{i}.attn.a_to_k.bias', f'blocks.{i}.attn.a_to_v.bias'],
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.qkv.weight": [f'blocks.{i}.attn.a_to_q.weight', f'blocks.{i}.attn.a_to_k.weight', f'blocks.{i}.attn.a_to_v.weight'],
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc1.bias": f"blocks.{i}.ff_a.0.bias",
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc1.weight": f"blocks.{i}.ff_a.0.weight",
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc2.bias": f"blocks.{i}.ff_a.2.bias",
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc2.weight": f"blocks.{i}.ff_a.2.weight",
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.ln_q.weight": f"blocks.{i}.attn.norm_q_a.weight",
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.ln_k.weight": f"blocks.{i}.attn.norm_k_a.weight",
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.ln_q.weight": f"blocks.{i}.attn.norm_q_b.weight",
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.ln_k.weight": f"blocks.{i}.attn.norm_k_b.weight",
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.ln_q.weight": f"blocks.{i}.attn2.norm_q_a.weight",
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.ln_k.weight": f"blocks.{i}.attn2.norm_k_a.weight",
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.qkv.weight": f"blocks.{i}.attn2.a_to_qkv.weight",
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.qkv.bias": f"blocks.{i}.attn2.a_to_qkv.bias",
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.proj.weight": f"blocks.{i}.attn2.a_to_out.weight",
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.proj.bias": f"blocks.{i}.attn2.a_to_out.bias",
})
state_dict_ = {} state_dict_ = {}
for name in state_dict: for name in state_dict:
if name in rename_dict: if name in rename_dict:
param = state_dict[name] param = state_dict[name]
if name.startswith("model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1."): if name == "model.diffusion_model.pos_embed":
param = torch.concat([param[1536:], param[:1536]], axis=0) pos_embed_max_size = int(param.shape[1] ** 0.5 + 0.4)
elif name.startswith("model.diffusion_model.final_layer.adaLN_modulation.1."): param = param.reshape((1, pos_embed_max_size, pos_embed_max_size, param.shape[-1]))
param = torch.concat([param[1536:], param[:1536]], axis=0)
elif name == "model.diffusion_model.pos_embed":
param = param.reshape((1, 192, 192, 1536))
if isinstance(rename_dict[name], str): if isinstance(rename_dict[name], str):
state_dict_[rename_dict[name]] = param state_dict_[rename_dict[name]] = param
else: else:
name_ = rename_dict[name][0].replace(".a_to_q.", ".a_to_qkv.").replace(".b_to_q.", ".b_to_qkv.") name_ = rename_dict[name][0].replace(".a_to_q.", ".a_to_qkv.").replace(".b_to_q.", ".b_to_qkv.")
state_dict_[name_] = param state_dict_[name_] = param
return state_dict_ extra_kwargs = self.infer_architecture(state_dict_)
num_layers = extra_kwargs["num_layers"]
for name in [
f"blocks.{num_layers-1}.norm1_b.linear.weight", f"blocks.{num_layers-1}.norm1_b.linear.bias", "norm_out.linear.weight", "norm_out.linear.bias",
]:
param = state_dict_[name]
dim = param.shape[0] // 2
param = torch.concat([param[dim:], param[:dim]], axis=0)
state_dict_[name] = param
return state_dict_, self.infer_architecture(state_dict_)

View File

@@ -8,9 +8,12 @@ class SD3TextEncoder1(SDTextEncoder):
def __init__(self, vocab_size=49408): def __init__(self, vocab_size=49408):
super().__init__(vocab_size=vocab_size) super().__init__(vocab_size=vocab_size)
def forward(self, input_ids, clip_skip=2): def forward(self, input_ids, clip_skip=2, extra_mask=None):
embeds = self.token_embedding(input_ids) + self.position_embeds embeds = self.token_embedding(input_ids)
embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device)
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
if extra_mask is not None:
attn_mask[:, extra_mask[0]==0] = float("-inf")
for encoder_id, encoder in enumerate(self.encoders): for encoder_id, encoder in enumerate(self.encoders):
embeds = encoder(embeds, attn_mask=attn_mask) embeds = encoder(embeds, attn_mask=attn_mask)
if encoder_id + clip_skip == len(self.encoders): if encoder_id + clip_skip == len(self.encoders):
@@ -322,6 +325,11 @@ class SD3TextEncoder1StateDictConverter:
if name == "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight": if name == "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight":
param = param.reshape((1, param.shape[0], param.shape[1])) param = param.reshape((1, param.shape[0], param.shape[1]))
state_dict_[rename_dict[name]] = param state_dict_[rename_dict[name]] = param
elif ("text_encoders.clip_l.transformer." + name) in rename_dict:
param = state_dict[name]
if name == "text_model.embeddings.position_embedding.weight":
param = param.reshape((1, param.shape[0], param.shape[1]))
state_dict_[rename_dict["text_encoders.clip_l.transformer." + name]] = param
return state_dict_ return state_dict_
@@ -860,6 +868,11 @@ class SD3TextEncoder2StateDictConverter(SDXLTextEncoder2StateDictConverter):
if name == "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight": if name == "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight":
param = param.reshape((1, param.shape[0], param.shape[1])) param = param.reshape((1, param.shape[0], param.shape[1]))
state_dict_[rename_dict[name]] = param state_dict_[rename_dict[name]] = param
elif ("text_encoders.clip_g.transformer." + name) in rename_dict:
param = state_dict[name]
if name == "text_model.embeddings.position_embedding.weight":
param = param.reshape((1, param.shape[0], param.shape[1]))
state_dict_[rename_dict["text_encoders.clip_g.transformer." + name]] = param
return state_dict_ return state_dict_

View File

@@ -97,6 +97,7 @@ class SDControlNet(torch.nn.Module):
self, self,
sample, timestep, encoder_hidden_states, conditioning, sample, timestep, encoder_hidden_states, conditioning,
tiled=False, tile_size=64, tile_stride=32, tiled=False, tile_size=64, tile_stride=32,
**kwargs
): ):
# 1. time # 1. time
time_emb = self.time_proj(timestep).to(sample.dtype) time_emb = self.time_proj(timestep).to(sample.dtype)

View File

@@ -0,0 +1,318 @@
import torch
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
from .sdxl_unet import SDXLUNet
from .tiler import TileWorker
from .sd_controlnet import ControlNetConditioningLayer
from collections import OrderedDict
class QuickGELU(torch.nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(torch.nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = torch.nn.MultiheadAttention(d_model, n_head)
self.ln_1 = torch.nn.LayerNorm(d_model)
self.mlp = torch.nn.Sequential(OrderedDict([
("c_fc", torch.nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", torch.nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = torch.nn.LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class SDXLControlNetUnion(torch.nn.Module):
def __init__(self, global_pool=False):
super().__init__()
self.time_proj = Timesteps(320)
self.time_embedding = torch.nn.Sequential(
torch.nn.Linear(320, 1280),
torch.nn.SiLU(),
torch.nn.Linear(1280, 1280)
)
self.add_time_proj = Timesteps(256)
self.add_time_embedding = torch.nn.Sequential(
torch.nn.Linear(2816, 1280),
torch.nn.SiLU(),
torch.nn.Linear(1280, 1280)
)
self.control_type_proj = Timesteps(256)
self.control_type_embedding = torch.nn.Sequential(
torch.nn.Linear(256 * 8, 1280),
torch.nn.SiLU(),
torch.nn.Linear(1280, 1280)
)
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
self.controlnet_transformer = ResidualAttentionBlock(320, 8)
self.task_embedding = torch.nn.Parameter(torch.randn(8, 320))
self.spatial_ch_projs = torch.nn.Linear(320, 320)
self.blocks = torch.nn.ModuleList([
# DownBlock2D
ResnetBlock(320, 320, 1280),
PushBlock(),
ResnetBlock(320, 320, 1280),
PushBlock(),
DownSampler(320),
PushBlock(),
# CrossAttnDownBlock2D
ResnetBlock(320, 640, 1280),
AttentionBlock(10, 64, 640, 2, 2048),
PushBlock(),
ResnetBlock(640, 640, 1280),
AttentionBlock(10, 64, 640, 2, 2048),
PushBlock(),
DownSampler(640),
PushBlock(),
# CrossAttnDownBlock2D
ResnetBlock(640, 1280, 1280),
AttentionBlock(20, 64, 1280, 10, 2048),
PushBlock(),
ResnetBlock(1280, 1280, 1280),
AttentionBlock(20, 64, 1280, 10, 2048),
PushBlock(),
# UNetMidBlock2DCrossAttn
ResnetBlock(1280, 1280, 1280),
AttentionBlock(20, 64, 1280, 10, 2048),
ResnetBlock(1280, 1280, 1280),
PushBlock()
])
self.controlnet_blocks = torch.nn.ModuleList([
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
])
self.global_pool = global_pool
# 0 -- openpose
# 1 -- depth
# 2 -- hed/pidi/scribble/ted
# 3 -- canny/lineart/anime_lineart/mlsd
# 4 -- normal
# 5 -- segment
# 6 -- tile
# 7 -- repaint
self.task_id = {
"openpose": 0,
"depth": 1,
"softedge": 2,
"canny": 3,
"lineart": 3,
"lineart_anime": 3,
"tile": 6,
"inpaint": 7
}
def fuse_condition_to_input(self, hidden_states, task_id, conditioning):
controlnet_cond = self.controlnet_conv_in(conditioning)
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
feat_seq = feat_seq + self.task_embedding[task_id]
x = torch.stack([feat_seq, torch.mean(hidden_states, dim=(2, 3))], dim=1)
x = self.controlnet_transformer(x)
alpha = self.spatial_ch_projs(x[:,0]).unsqueeze(-1).unsqueeze(-1)
controlnet_cond_fuser = controlnet_cond + alpha
hidden_states = hidden_states + controlnet_cond_fuser
return hidden_states
def forward(
self,
sample, timestep, encoder_hidden_states,
conditioning, processor_id, add_time_id, add_text_embeds,
tiled=False, tile_size=64, tile_stride=32,
unet:SDXLUNet=None,
**kwargs
):
task_id = self.task_id[processor_id]
# 1. time
t_emb = self.time_proj(timestep).to(sample.dtype)
t_emb = self.time_embedding(t_emb)
time_embeds = self.add_time_proj(add_time_id)
time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1))
add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(sample.dtype)
if unet is not None and unet.is_kolors:
add_embeds = unet.add_time_embedding(add_embeds)
else:
add_embeds = self.add_time_embedding(add_embeds)
control_type = torch.zeros((sample.shape[0], 8), dtype=sample.dtype, device=sample.device)
control_type[:, task_id] = 1
control_embeds = self.control_type_proj(control_type.flatten())
control_embeds = control_embeds.reshape((sample.shape[0], -1))
control_embeds = control_embeds.to(sample.dtype)
control_embeds = self.control_type_embedding(control_embeds)
time_emb = t_emb + add_embeds + control_embeds
# 2. pre-process
height, width = sample.shape[2], sample.shape[3]
hidden_states = self.conv_in(sample)
hidden_states = self.fuse_condition_to_input(hidden_states, task_id, conditioning)
text_emb = encoder_hidden_states
if unet is not None and unet.is_kolors:
text_emb = unet.text_intermediate_proj(text_emb)
res_stack = [hidden_states]
# 3. blocks
for i, block in enumerate(self.blocks):
if tiled and not isinstance(block, PushBlock):
_, _, inter_height, _ = hidden_states.shape
resize_scale = inter_height / height
hidden_states = TileWorker().tiled_forward(
lambda x: block(x, time_emb, text_emb, res_stack)[0],
hidden_states,
int(tile_size * resize_scale),
int(tile_stride * resize_scale),
tile_device=hidden_states.device,
tile_dtype=hidden_states.dtype
)
else:
hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
# 4. ControlNet blocks
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
# pool
if self.global_pool:
controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
return controlnet_res_stack
@staticmethod
def state_dict_converter():
return SDXLControlNetUnionStateDictConverter()
class SDXLControlNetUnionStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
# architecture
block_types = [
"ResnetBlock", "PushBlock", "ResnetBlock", "PushBlock", "DownSampler", "PushBlock",
"ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock", "DownSampler", "PushBlock",
"ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock",
"ResnetBlock", "AttentionBlock", "ResnetBlock", "PushBlock"
]
# controlnet_rename_dict
controlnet_rename_dict = {
"controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
"controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
"controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
"controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
"controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
"controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
"controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
"controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
"controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
"controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
"controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
"controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
"controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
"controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
"controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
"controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
"control_add_embedding.linear_1.weight": "control_type_embedding.0.weight",
"control_add_embedding.linear_1.bias": "control_type_embedding.0.bias",
"control_add_embedding.linear_2.weight": "control_type_embedding.2.weight",
"control_add_embedding.linear_2.bias": "control_type_embedding.2.bias",
}
# Rename each parameter
name_list = sorted([name for name in state_dict])
rename_dict = {}
block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
for name in name_list:
names = name.split(".")
if names[0] in ["conv_in", "conv_norm_out", "conv_out", "task_embedding", "spatial_ch_projs"]:
pass
elif name in controlnet_rename_dict:
names = controlnet_rename_dict[name].split(".")
elif names[0] == "controlnet_down_blocks":
names[0] = "controlnet_blocks"
elif names[0] == "controlnet_mid_block":
names = ["controlnet_blocks", "9", names[-1]]
elif names[0] in ["time_embedding", "add_embedding"]:
if names[0] == "add_embedding":
names[0] = "add_time_embedding"
names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
elif names[0] == "control_add_embedding":
names[0] = "control_type_embedding"
elif names[0] == "transformer_layes":
names[0] = "controlnet_transformer"
names.pop(1)
elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
if names[0] == "mid_block":
names.insert(1, "0")
block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
block_type_with_id = ".".join(names[:4])
if block_type_with_id != last_block_type_with_id[block_type]:
block_id[block_type] += 1
last_block_type_with_id[block_type] = block_type_with_id
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
block_id[block_type] += 1
block_type_with_id = ".".join(names[:4])
names = ["blocks", str(block_id[block_type])] + names[4:]
if "ff" in names:
ff_index = names.index("ff")
component = ".".join(names[ff_index:ff_index+3])
component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
names = names[:ff_index] + [component] + names[ff_index+3:]
if "to_out" in names:
names.pop(names.index("to_out") + 1)
else:
print(name, state_dict[name].shape)
# raise ValueError(f"Unknown parameters: {name}")
rename_dict[name] = ".".join(names)
# Convert state_dict
state_dict_ = {}
for name, param in state_dict.items():
if name not in rename_dict:
continue
if ".proj_in." in name or ".proj_out." in name:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -83,6 +83,8 @@ class SDXLUNet(torch.nn.Module):
self.conv_act = torch.nn.SiLU() self.conv_act = torch.nn.SiLU()
self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=3, padding=1) self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=3, padding=1)
self.is_kolors = is_kolors
def forward( def forward(
self, self,
sample, timestep, encoder_hidden_states, add_time_id, add_text_embeds, sample, timestep, encoder_hidden_states, add_time_id, add_text_embeds,

View File

@@ -0,0 +1,940 @@
# Copyright 2025 StepFun Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
from typing import Dict, Optional, Tuple
import torch, math
from torch import nn
from einops import rearrange, repeat
from tqdm import tqdm
class RMSNorm(nn.Module):
def __init__(
self,
dim: int,
elementwise_affine=True,
eps: float = 1e-6,
device=None,
dtype=None,
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
ACTIVATION_FUNCTIONS = {
"swish": nn.SiLU(),
"silu": nn.SiLU(),
"mish": nn.Mish(),
"gelu": nn.GELU(),
"relu": nn.ReLU(),
}
def get_activation(act_fn: str) -> nn.Module:
"""Helper function to get activation function from string.
Args:
act_fn (str): Name of activation function.
Returns:
nn.Module: Activation function.
"""
act_fn = act_fn.lower()
if act_fn in ACTIVATION_FUNCTIONS:
return ACTIVATION_FUNCTIONS[act_fn]
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
return t_emb
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True
):
super().__init__()
linear_cls = nn.Linear
self.linear_1 = linear_cls(
in_channels,
time_embed_dim,
bias=sample_proj_bias,
)
if cond_proj_dim is not None:
self.cond_proj = linear_cls(
cond_proj_dim,
in_channels,
bias=False,
)
else:
self.cond_proj = None
self.act = get_activation(act_fn)
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = linear_cls(
time_embed_dim,
time_embed_dim_out,
bias=sample_proj_bias,
)
if post_act_fn is None:
self.post_act = None
else:
self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
super().__init__()
self.outdim = size_emb_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.use_additional_conditions = use_additional_conditions
if self.use_additional_conditions:
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
self.nframe_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.fps_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timestep, resolution=None, nframe=None, fps=None):
hidden_dtype = timestep.dtype
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
if self.use_additional_conditions:
batch_size = timestep.shape[0]
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
nframe_emb = self.additional_condition_proj(nframe.flatten()).to(hidden_dtype)
nframe_emb = self.nframe_embedder(nframe_emb).reshape(batch_size, -1)
conditioning = timesteps_emb + resolution_emb + nframe_emb
if fps is not None:
fps_emb = self.additional_condition_proj(fps.flatten()).to(hidden_dtype)
fps_emb = self.fps_embedder(fps_emb).reshape(batch_size, -1)
conditioning = conditioning + fps_emb
else:
conditioning = timesteps_emb
return conditioning
class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
"""
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, time_step_rescale=1000):
super().__init__()
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim, size_emb_dim=embedding_dim // 2, use_additional_conditions=use_additional_conditions
)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
self.time_step_rescale = time_step_rescale ## timestep usually in [0, 1], we rescale it to [0,1000] for stability
def forward(
self,
timestep: torch.Tensor,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
embedded_timestep = self.emb(timestep*self.time_step_rescale, **added_cond_kwargs)
out = self.linear(self.silu(embedded_timestep))
return out, embedded_timestep
class PixArtAlphaTextProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_features, hidden_size):
super().__init__()
self.linear_1 = nn.Linear(
in_features,
hidden_size,
bias=True,
)
self.act_1 = nn.GELU(approximate="tanh")
self.linear_2 = nn.Linear(
hidden_size,
hidden_size,
bias=True,
)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class Attention(nn.Module):
def __init__(self):
super().__init__()
def attn_processor(self, attn_type):
if attn_type == 'torch':
return self.torch_attn_func
elif attn_type == 'parallel':
return self.parallel_attn_func
else:
raise Exception('Not supported attention type...')
def torch_attn_func(
self,
q,
k,
v,
attn_mask=None,
causal=False,
drop_rate=0.0,
**kwargs
):
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(q.dtype)
if attn_mask is not None and attn_mask.ndim == 3: ## no head
n_heads = q.shape[2]
attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
q, k, v = map(lambda x: rearrange(x, 'b s h d -> b h s d'), (q, k, v))
if attn_mask is not None:
attn_mask = attn_mask.to(q.device)
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
)
x = rearrange(x, 'b h s d -> b s h d')
return x
class RoPE1D:
def __init__(self, freq=1e4, F0=1.0, scaling_factor=1.0):
self.base = freq
self.F0 = F0
self.scaling_factor = scaling_factor
self.cache = {}
def get_cos_sin(self, D, seq_len, device, dtype):
if (D, seq_len, device, dtype) not in self.cache:
inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
freqs = torch.cat((freqs, freqs), dim=-1)
cos = freqs.cos() # (Seq, Dim)
sin = freqs.sin()
self.cache[D, seq_len, device, dtype] = (cos, sin)
return self.cache[D, seq_len, device, dtype]
@staticmethod
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rope1d(self, tokens, pos1d, cos, sin):
assert pos1d.ndim == 2
cos = torch.nn.functional.embedding(pos1d, cos)[:, :, None, :]
sin = torch.nn.functional.embedding(pos1d, sin)[:, :, None, :]
return (tokens * cos) + (self.rotate_half(tokens) * sin)
def __call__(self, tokens, positions):
"""
input:
* tokens: batch_size x ntokens x nheads x dim
* positions: batch_size x ntokens (t position of each token)
output:
* tokens after appplying RoPE2D (batch_size x ntokens x nheads x dim)
"""
D = tokens.size(3)
assert positions.ndim == 2 # Batch, Seq
cos, sin = self.get_cos_sin(D, int(positions.max()) + 1, tokens.device, tokens.dtype)
tokens = self.apply_rope1d(tokens, positions, cos, sin)
return tokens
class RoPE3D(RoPE1D):
def __init__(self, freq=1e4, F0=1.0, scaling_factor=1.0):
super(RoPE3D, self).__init__(freq, F0, scaling_factor)
self.position_cache = {}
def get_mesh_3d(self, rope_positions, bsz):
f, h, w = rope_positions
if f"{f}-{h}-{w}" not in self.position_cache:
x = torch.arange(f, device='cpu')
y = torch.arange(h, device='cpu')
z = torch.arange(w, device='cpu')
self.position_cache[f"{f}-{h}-{w}"] = torch.cartesian_prod(x, y, z).view(1, f*h*w, 3).expand(bsz, -1, 3)
return self.position_cache[f"{f}-{h}-{w}"]
def __call__(self, tokens, rope_positions, ch_split, parallel=False):
"""
input:
* tokens: batch_size x ntokens x nheads x dim
* rope_positions: list of (f, h, w)
output:
* tokens after appplying RoPE2D (batch_size x ntokens x nheads x dim)
"""
assert sum(ch_split) == tokens.size(-1);
mesh_grid = self.get_mesh_3d(rope_positions, bsz=tokens.shape[0])
out = []
for i, (D, x) in enumerate(zip(ch_split, torch.split(tokens, ch_split, dim=-1))):
cos, sin = self.get_cos_sin(D, int(mesh_grid.max()) + 1, tokens.device, tokens.dtype)
if parallel:
pass
else:
mesh = mesh_grid[:, :, i].clone()
x = self.apply_rope1d(x, mesh.to(tokens.device), cos, sin)
out.append(x)
tokens = torch.cat(out, dim=-1)
return tokens
class SelfAttention(Attention):
def __init__(self, hidden_dim, head_dim, bias=False, with_rope=True, with_qk_norm=True, attn_type='torch'):
super().__init__()
self.head_dim = head_dim
self.n_heads = hidden_dim // head_dim
self.wqkv = nn.Linear(hidden_dim, hidden_dim*3, bias=bias)
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=bias)
self.with_rope = with_rope
self.with_qk_norm = with_qk_norm
if self.with_qk_norm:
self.q_norm = RMSNorm(head_dim, elementwise_affine=True)
self.k_norm = RMSNorm(head_dim, elementwise_affine=True)
if self.with_rope:
self.rope_3d = RoPE3D(freq=1e4, F0=1.0, scaling_factor=1.0)
self.rope_ch_split = [64, 32, 32]
self.core_attention = self.attn_processor(attn_type=attn_type)
self.parallel = attn_type=='parallel'
def apply_rope3d(self, x, fhw_positions, rope_ch_split, parallel=True):
x = self.rope_3d(x, fhw_positions, rope_ch_split, parallel)
return x
def forward(
self,
x,
cu_seqlens=None,
max_seqlen=None,
rope_positions=None,
attn_mask=None
):
xqkv = self.wqkv(x)
xqkv = xqkv.view(*x.shape[:-1], self.n_heads, 3*self.head_dim)
xq, xk, xv = torch.split(xqkv, [self.head_dim]*3, dim=-1) ## seq_len, n, dim
if self.with_qk_norm:
xq = self.q_norm(xq)
xk = self.k_norm(xk)
if self.with_rope:
xq = self.apply_rope3d(xq, rope_positions, self.rope_ch_split, parallel=self.parallel)
xk = self.apply_rope3d(xk, rope_positions, self.rope_ch_split, parallel=self.parallel)
output = self.core_attention(
xq,
xk,
xv,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
attn_mask=attn_mask
)
output = rearrange(output, 'b s h d -> b s (h d)')
output = self.wo(output)
return output
class CrossAttention(Attention):
def __init__(self, hidden_dim, head_dim, bias=False, with_qk_norm=True, attn_type='torch'):
super().__init__()
self.head_dim = head_dim
self.n_heads = hidden_dim // head_dim
self.wq = nn.Linear(hidden_dim, hidden_dim, bias=bias)
self.wkv = nn.Linear(hidden_dim, hidden_dim*2, bias=bias)
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=bias)
self.with_qk_norm = with_qk_norm
if self.with_qk_norm:
self.q_norm = RMSNorm(head_dim, elementwise_affine=True)
self.k_norm = RMSNorm(head_dim, elementwise_affine=True)
self.core_attention = self.attn_processor(attn_type=attn_type)
def forward(
self,
x: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attn_mask=None
):
xq = self.wq(x)
xq = xq.view(*xq.shape[:-1], self.n_heads, self.head_dim)
xkv = self.wkv(encoder_hidden_states)
xkv = xkv.view(*xkv.shape[:-1], self.n_heads, 2*self.head_dim)
xk, xv = torch.split(xkv, [self.head_dim]*2, dim=-1) ## seq_len, n, dim
if self.with_qk_norm:
xq = self.q_norm(xq)
xk = self.k_norm(xk)
output = self.core_attention(
xq,
xk,
xv,
attn_mask=attn_mask
)
output = rearrange(output, 'b s h d -> b s (h d)')
output = self.wo(output)
return output
class GELU(nn.Module):
r"""
GELU activation function with tanh approximation support with `approximate="tanh"`.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
self.approximate = approximate
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.gelu(gate, approximate=self.approximate)
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = self.gelu(hidden_states)
return hidden_states
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
inner_dim: Optional[int] = None,
dim_out: Optional[int] = None,
mult: int = 4,
bias: bool = False,
):
super().__init__()
inner_dim = dim*mult if inner_dim is None else inner_dim
dim_out = dim if dim_out is None else dim_out
self.net = nn.ModuleList([
GELU(dim, inner_dim, approximate="tanh", bias=bias),
nn.Identity(),
nn.Linear(inner_dim, dim_out, bias=bias)
])
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
def modulate(x, scale, shift):
x = x * (1 + scale) + shift
return x
def gate(x, gate):
x = gate * x
return x
class StepVideoTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""
def __init__(
self,
dim: int,
attention_head_dim: int,
norm_eps: float = 1e-5,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = False,
attention_type: str = 'parallel'
):
super().__init__()
self.dim = dim
self.norm1 = nn.LayerNorm(dim, eps=norm_eps)
self.attn1 = SelfAttention(dim, attention_head_dim, bias=False, with_rope=True, with_qk_norm=True, attn_type=attention_type)
self.norm2 = nn.LayerNorm(dim, eps=norm_eps)
self.attn2 = CrossAttention(dim, attention_head_dim, bias=False, with_qk_norm=True, attn_type='torch')
self.ff = FeedForward(dim=dim, inner_dim=ff_inner_dim, dim_out=dim, bias=ff_bias)
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) /dim**0.5)
@torch.no_grad()
def forward(
self,
q: torch.Tensor,
kv: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
attn_mask = None,
rope_positions: list = None,
) -> torch.Tensor:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
torch.clone(chunk) for chunk in (self.scale_shift_table[None].to(dtype=q.dtype, device=q.device) + timestep.reshape(-1, 6, self.dim)).chunk(6, dim=1)
)
scale_shift_q = modulate(self.norm1(q), scale_msa, shift_msa)
attn_q = self.attn1(
scale_shift_q,
rope_positions=rope_positions
)
q = gate(attn_q, gate_msa) + q
attn_q = self.attn2(
q,
kv,
attn_mask
)
q = attn_q + q
scale_shift_q = modulate(self.norm2(q), scale_mlp, shift_mlp)
ff_output = self.ff(scale_shift_q)
q = gate(ff_output, gate_mlp) + q
return q
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding"""
def __init__(
self,
patch_size=64,
in_channels=3,
embed_dim=768,
layer_norm=False,
flatten=True,
bias=True,
):
super().__init__()
self.flatten = flatten
self.layer_norm = layer_norm
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
def forward(self, latent):
latent = self.proj(latent).to(latent.dtype)
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
if self.layer_norm:
latent = self.norm(latent)
return latent
class StepVideoModel(torch.nn.Module):
def __init__(
self,
num_attention_heads: int = 48,
attention_head_dim: int = 128,
in_channels: int = 64,
out_channels: Optional[int] = 64,
num_layers: int = 48,
dropout: float = 0.0,
patch_size: int = 1,
norm_type: str = "ada_norm_single",
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
use_additional_conditions: Optional[bool] = False,
caption_channels: Optional[int]|list|tuple = [6144, 1024],
attention_type: Optional[str] = "torch",
):
super().__init__()
# Set some common variables used across the board.
self.inner_dim = num_attention_heads * attention_head_dim
self.out_channels = in_channels if out_channels is None else out_channels
self.use_additional_conditions = use_additional_conditions
self.pos_embed = PatchEmbed(
patch_size=patch_size,
in_channels=in_channels,
embed_dim=self.inner_dim,
)
self.transformer_blocks = nn.ModuleList(
[
StepVideoTransformerBlock(
dim=self.inner_dim,
attention_head_dim=attention_head_dim,
attention_type=attention_type
)
for _ in range(num_layers)
]
)
# 3. Output blocks.
self.norm_out = nn.LayerNorm(self.inner_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels)
self.patch_size = patch_size
self.adaln_single = AdaLayerNormSingle(
self.inner_dim, use_additional_conditions=self.use_additional_conditions
)
if isinstance(caption_channels, int):
caption_channel = caption_channels
else:
caption_channel, clip_channel = caption_channels
self.clip_projection = nn.Linear(clip_channel, self.inner_dim)
self.caption_norm = nn.LayerNorm(caption_channel, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.caption_projection = PixArtAlphaTextProjection(
in_features=caption_channel, hidden_size=self.inner_dim
)
self.parallel = attention_type=='parallel'
def patchfy(self, hidden_states):
hidden_states = rearrange(hidden_states, 'b f c h w -> (b f) c h w')
hidden_states = self.pos_embed(hidden_states)
return hidden_states
def prepare_attn_mask(self, encoder_attention_mask, encoder_hidden_states, q_seqlen):
kv_seqlens = encoder_attention_mask.sum(dim=1).int()
mask = torch.zeros([len(kv_seqlens), q_seqlen, max(kv_seqlens)], dtype=torch.bool, device=encoder_attention_mask.device)
encoder_hidden_states = encoder_hidden_states[:,: max(kv_seqlens)]
for i, kv_len in enumerate(kv_seqlens):
mask[i, :, :kv_len] = 1
return encoder_hidden_states, mask
def block_forward(
self,
hidden_states,
encoder_hidden_states=None,
timestep=None,
rope_positions=None,
attn_mask=None,
parallel=True
):
for block in tqdm(self.transformer_blocks, desc="Transformer blocks"):
hidden_states = block(
hidden_states,
encoder_hidden_states,
timestep=timestep,
attn_mask=attn_mask,
rope_positions=rope_positions
)
return hidden_states
@torch.inference_mode()
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_hidden_states_2: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
fps: torch.Tensor=None,
return_dict: bool = False,
):
assert hidden_states.ndim==5; "hidden_states's shape should be (bsz, f, ch, h ,w)"
bsz, frame, _, height, width = hidden_states.shape
height, width = height // self.patch_size, width // self.patch_size
hidden_states = self.patchfy(hidden_states)
len_frame = hidden_states.shape[1]
if self.use_additional_conditions:
added_cond_kwargs = {
"resolution": torch.tensor([(height, width)]*bsz, device=hidden_states.device, dtype=hidden_states.dtype),
"nframe": torch.tensor([frame]*bsz, device=hidden_states.device, dtype=hidden_states.dtype),
"fps": fps
}
else:
added_cond_kwargs = {}
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs=added_cond_kwargs
)
encoder_hidden_states = self.caption_projection(self.caption_norm(encoder_hidden_states))
if encoder_hidden_states_2 is not None and hasattr(self, 'clip_projection'):
clip_embedding = self.clip_projection(encoder_hidden_states_2)
encoder_hidden_states = torch.cat([clip_embedding, encoder_hidden_states], dim=1)
hidden_states = rearrange(hidden_states, '(b f) l d-> b (f l) d', b=bsz, f=frame, l=len_frame).contiguous()
encoder_hidden_states, attn_mask = self.prepare_attn_mask(encoder_attention_mask, encoder_hidden_states, q_seqlen=frame*len_frame)
hidden_states = self.block_forward(
hidden_states,
encoder_hidden_states,
timestep=timestep,
rope_positions=[frame, height, width],
attn_mask=attn_mask,
parallel=self.parallel
)
hidden_states = rearrange(hidden_states, 'b (f l) d -> (b f) l d', b=bsz, f=frame, l=len_frame)
embedded_timestep = repeat(embedded_timestep, 'b d -> (b f) d', f=frame).contiguous()
shift, scale = (self.scale_shift_table[None].to(dtype=embedded_timestep.dtype, device=embedded_timestep.device) + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
# unpatchify
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = rearrange(hidden_states, 'n h w p q c -> n c h p w q')
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
output = rearrange(output, '(b f) c h w -> b f c h w', f=frame)
if return_dict:
return {'x': output}
return output
@staticmethod
def state_dict_converter():
return StepVideoDiTStateDictConverter()
class StepVideoDiTStateDictConverter:
def __init__(self):
super().__init__()
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
return state_dict

View File

@@ -0,0 +1,553 @@
# Copyright 2025 StepFun Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
import os
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from .stepvideo_dit import RMSNorm
from safetensors.torch import load_file
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
from einops import rearrange
import json
from typing import List
from functools import wraps
import warnings
class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
def __init__(self, device=None):
self.device = device
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if getattr(func, '__module__', None) == 'torch.nn.init':
if 'tensor' in kwargs:
return kwargs['tensor']
else:
return args[0]
if self.device is not None and func in torch.utils._device._device_constructors() and kwargs.get('device') is None:
kwargs['device'] = self.device
return func(*args, **kwargs)
def with_empty_init(func):
@wraps(func)
def wrapper(*args, **kwargs):
with EmptyInitOnDevice('cpu'):
return func(*args, **kwargs)
return wrapper
class LLaMaEmbedding(nn.Module):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(self,
cfg,
):
super().__init__()
self.hidden_size = cfg.hidden_size
self.params_dtype = cfg.params_dtype
self.fp32_residual_connection = cfg.fp32_residual_connection
self.embedding_weights_in_fp32 = cfg.embedding_weights_in_fp32
self.word_embeddings = torch.nn.Embedding(
cfg.padded_vocab_size, self.hidden_size,
)
self.embedding_dropout = torch.nn.Dropout(cfg.hidden_dropout)
def forward(self, input_ids):
# Embeddings.
if self.embedding_weights_in_fp32:
self.word_embeddings = self.word_embeddings.to(torch.float32)
embeddings = self.word_embeddings(input_ids)
if self.embedding_weights_in_fp32:
embeddings = embeddings.to(self.params_dtype)
self.word_embeddings = self.word_embeddings.to(self.params_dtype)
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
embeddings = embeddings.float()
# Dropout.
embeddings = self.embedding_dropout(embeddings)
return embeddings
class StepChatTokenizer:
"""Step Chat Tokenizer"""
def __init__(
self, model_file, name="StepChatTokenizer",
bot_token="<|BOT|>", # Begin of Turn
eot_token="<|EOT|>", # End of Turn
call_start_token="<|CALL_START|>", # Call Start
call_end_token="<|CALL_END|>", # Call End
think_start_token="<|THINK_START|>", # Think Start
think_end_token="<|THINK_END|>", # Think End
mask_start_token="<|MASK_1e69f|>", # Mask start
mask_end_token="<|UNMASK_1e69f|>", # Mask end
):
import sentencepiece
self._tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)
self._vocab = {}
self._inv_vocab = {}
self._special_tokens = {}
self._inv_special_tokens = {}
self._t5_tokens = []
for idx in range(self._tokenizer.get_piece_size()):
text = self._tokenizer.id_to_piece(idx)
self._inv_vocab[idx] = text
self._vocab[text] = idx
if self._tokenizer.is_control(idx) or self._tokenizer.is_unknown(idx):
self._special_tokens[text] = idx
self._inv_special_tokens[idx] = text
self._unk_id = self._tokenizer.unk_id()
self._bos_id = self._tokenizer.bos_id()
self._eos_id = self._tokenizer.eos_id()
for token in [
bot_token, eot_token, call_start_token, call_end_token,
think_start_token, think_end_token
]:
assert token in self._vocab, f"Token '{token}' not found in tokenizer"
assert token in self._special_tokens, f"Token '{token}' is not a special token"
for token in [mask_start_token, mask_end_token]:
assert token in self._vocab, f"Token '{token}' not found in tokenizer"
self._bot_id = self._tokenizer.piece_to_id(bot_token)
self._eot_id = self._tokenizer.piece_to_id(eot_token)
self._call_start_id = self._tokenizer.piece_to_id(call_start_token)
self._call_end_id = self._tokenizer.piece_to_id(call_end_token)
self._think_start_id = self._tokenizer.piece_to_id(think_start_token)
self._think_end_id = self._tokenizer.piece_to_id(think_end_token)
self._mask_start_id = self._tokenizer.piece_to_id(mask_start_token)
self._mask_end_id = self._tokenizer.piece_to_id(mask_end_token)
self._underline_id = self._tokenizer.piece_to_id("\u2581")
@property
def vocab(self):
return self._vocab
@property
def inv_vocab(self):
return self._inv_vocab
@property
def vocab_size(self):
return self._tokenizer.vocab_size()
def tokenize(self, text: str) -> List[int]:
return self._tokenizer.encode_as_ids(text)
def detokenize(self, token_ids: List[int]) -> str:
return self._tokenizer.decode_ids(token_ids)
class Tokens:
def __init__(self, input_ids, cu_input_ids, attention_mask, cu_seqlens, max_seq_len) -> None:
self.input_ids = input_ids
self.attention_mask = attention_mask
self.cu_input_ids = cu_input_ids
self.cu_seqlens = cu_seqlens
self.max_seq_len = max_seq_len
def to(self, device):
self.input_ids = self.input_ids.to(device)
self.attention_mask = self.attention_mask.to(device)
self.cu_input_ids = self.cu_input_ids.to(device)
self.cu_seqlens = self.cu_seqlens.to(device)
return self
class Wrapped_StepChatTokenizer(StepChatTokenizer):
def __call__(self, text, max_length=320, padding="max_length", truncation=True, return_tensors="pt"):
# [bos, ..., eos, pad, pad, ..., pad]
self.BOS = 1
self.EOS = 2
self.PAD = 2
out_tokens = []
attn_mask = []
if len(text) == 0:
part_tokens = [self.BOS] + [self.EOS]
valid_size = len(part_tokens)
if len(part_tokens) < max_length:
part_tokens += [self.PAD] * (max_length - valid_size)
out_tokens.append(part_tokens)
attn_mask.append([1]*valid_size+[0]*(max_length-valid_size))
else:
for part in text:
part_tokens = self.tokenize(part)
part_tokens = part_tokens[:(max_length - 2)] # leave 2 space for bos and eos
part_tokens = [self.BOS] + part_tokens + [self.EOS]
valid_size = len(part_tokens)
if len(part_tokens) < max_length:
part_tokens += [self.PAD] * (max_length - valid_size)
out_tokens.append(part_tokens)
attn_mask.append([1]*valid_size+[0]*(max_length-valid_size))
out_tokens = torch.tensor(out_tokens, dtype=torch.long)
attn_mask = torch.tensor(attn_mask, dtype=torch.long)
# padding y based on tp size
padded_len = 0
padded_flag = True if padded_len > 0 else False
if padded_flag:
pad_tokens = torch.tensor([[self.PAD] * max_length], device=out_tokens.device)
pad_attn_mask = torch.tensor([[1]*padded_len+[0]*(max_length-padded_len)], device=attn_mask.device)
out_tokens = torch.cat([out_tokens, pad_tokens], dim=0)
attn_mask = torch.cat([attn_mask, pad_attn_mask], dim=0)
# cu_seqlens
cu_out_tokens = out_tokens.masked_select(attn_mask != 0).unsqueeze(0)
seqlen = attn_mask.sum(dim=1).tolist()
cu_seqlens = torch.cumsum(torch.tensor([0]+seqlen), 0).to(device=out_tokens.device,dtype=torch.int32)
max_seq_len = max(seqlen)
return Tokens(out_tokens, cu_out_tokens, attn_mask, cu_seqlens, max_seq_len)
def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=True,
return_attn_probs=False, tp_group_rank=0, tp_group_size=1):
softmax_scale = q.size(-1) ** (-0.5) if softmax_scale is None else softmax_scale
if hasattr(torch.ops.Optimus, "fwd"):
results = torch.ops.Optimus.fwd(q, k, v, None, dropout_p, softmax_scale, causal, return_attn_probs, None, tp_group_rank, tp_group_size)[0]
else:
warnings.warn("Cannot load `torch.ops.Optimus.fwd`. Using `torch.nn.functional.scaled_dot_product_attention` instead.")
results = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True, scale=softmax_scale).transpose(1, 2)
return results
class FlashSelfAttention(torch.nn.Module):
def __init__(
self,
attention_dropout=0.0,
):
super().__init__()
self.dropout_p = attention_dropout
def forward(self, q, k, v, cu_seqlens=None, max_seq_len=None):
if cu_seqlens is None:
output = flash_attn_func(q, k, v, dropout_p=self.dropout_p)
else:
raise ValueError('cu_seqlens is not supported!')
return output
def safediv(n, d):
q, r = divmod(n, d)
assert r == 0
return q
class MultiQueryAttention(nn.Module):
def __init__(self, cfg, layer_id=None):
super().__init__()
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
self.max_seq_len = cfg.seq_length
self.use_flash_attention = cfg.use_flash_attn
assert self.use_flash_attention, 'FlashAttention is required!'
self.n_groups = cfg.num_attention_groups
self.tp_size = 1
self.n_local_heads = cfg.num_attention_heads
self.n_local_groups = self.n_groups
self.wqkv = nn.Linear(
cfg.hidden_size,
cfg.hidden_size + self.head_dim * 2 * self.n_groups,
bias=False,
)
self.wo = nn.Linear(
cfg.hidden_size,
cfg.hidden_size,
bias=False,
)
assert self.use_flash_attention, 'non-Flash attention not supported yet.'
self.core_attention = FlashSelfAttention(attention_dropout=cfg.attention_dropout)
self.layer_id = layer_id
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor],
cu_seqlens: Optional[torch.Tensor],
max_seq_len: Optional[torch.Tensor],
):
seqlen, bsz, dim = x.shape
xqkv = self.wqkv(x)
xq, xkv = torch.split(
xqkv,
(dim // self.tp_size,
self.head_dim*2*self.n_groups // self.tp_size
),
dim=-1,
)
# gather on 1st dimention
xq = xq.view(seqlen, bsz, self.n_local_heads, self.head_dim)
xkv = xkv.view(seqlen, bsz, self.n_local_groups, 2 * self.head_dim)
xk, xv = xkv.chunk(2, -1)
# rotary embedding + flash attn
xq = rearrange(xq, "s b h d -> b s h d")
xk = rearrange(xk, "s b h d -> b s h d")
xv = rearrange(xv, "s b h d -> b s h d")
q_per_kv = self.n_local_heads // self.n_local_groups
if q_per_kv > 1:
b, s, h, d = xk.size()
if h == 1:
xk = xk.expand(b, s, q_per_kv, d)
xv = xv.expand(b, s, q_per_kv, d)
else:
''' To cover the cases where h > 1, we have
the following implementation, which is equivalent to:
xk = xk.repeat_interleave(q_per_kv, dim=-2)
xv = xv.repeat_interleave(q_per_kv, dim=-2)
but can avoid calling aten::item() that involves cpu.
'''
idx = torch.arange(q_per_kv * h, device=xk.device).reshape(q_per_kv, -1).permute(1, 0).flatten()
xk = torch.index_select(xk.repeat(1, 1, q_per_kv, 1), 2, idx).contiguous()
xv = torch.index_select(xv.repeat(1, 1, q_per_kv, 1), 2, idx).contiguous()
if self.use_flash_attention:
output = self.core_attention(xq, xk, xv,
cu_seqlens=cu_seqlens,
max_seq_len=max_seq_len)
# reduce-scatter only support first dimention now
output = rearrange(output, "b s h d -> s b (h d)").contiguous()
else:
xq, xk, xv = [
rearrange(x, "b s ... -> s b ...").contiguous()
for x in (xq, xk, xv)
]
output = self.core_attention(xq, xk, xv, mask)
output = self.wo(output)
return output
class FeedForward(nn.Module):
def __init__(
self,
cfg,
dim: int,
hidden_dim: int,
layer_id: int,
multiple_of: int=256,
):
super().__init__()
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
def swiglu(x):
x = torch.chunk(x, 2, dim=-1)
return F.silu(x[0]) * x[1]
self.swiglu = swiglu
self.w1 = nn.Linear(
dim,
2 * hidden_dim,
bias=False,
)
self.w2 = nn.Linear(
hidden_dim,
dim,
bias=False,
)
def forward(self, x):
x = self.swiglu(self.w1(x))
output = self.w2(x)
return output
class TransformerBlock(nn.Module):
def __init__(
self, cfg, layer_id: int
):
super().__init__()
self.n_heads = cfg.num_attention_heads
self.dim = cfg.hidden_size
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
self.attention = MultiQueryAttention(
cfg,
layer_id=layer_id,
)
self.feed_forward = FeedForward(
cfg,
dim=cfg.hidden_size,
hidden_dim=cfg.ffn_hidden_size,
layer_id=layer_id,
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(
cfg.hidden_size,
eps=cfg.layernorm_epsilon,
)
self.ffn_norm = RMSNorm(
cfg.hidden_size,
eps=cfg.layernorm_epsilon,
)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor],
cu_seqlens: Optional[torch.Tensor],
max_seq_len: Optional[torch.Tensor],
):
residual = self.attention.forward(
self.attention_norm(x), mask,
cu_seqlens, max_seq_len
)
h = x + residual
ffn_res = self.feed_forward.forward(self.ffn_norm(h))
out = h + ffn_res
return out
class Transformer(nn.Module):
def __init__(
self,
config,
max_seq_size=8192,
):
super().__init__()
self.num_layers = config.num_layers
self.layers = self._build_layers(config)
def _build_layers(self, config):
layers = torch.nn.ModuleList()
for layer_id in range(self.num_layers):
layers.append(
TransformerBlock(
config,
layer_id=layer_id + 1 ,
)
)
return layers
def forward(
self,
hidden_states,
attention_mask,
cu_seqlens=None,
max_seq_len=None,
):
if max_seq_len is not None and not isinstance(max_seq_len, torch.Tensor):
max_seq_len = torch.tensor(max_seq_len, dtype=torch.int32, device="cpu")
for lid, layer in enumerate(self.layers):
hidden_states = layer(
hidden_states,
attention_mask,
cu_seqlens,
max_seq_len,
)
return hidden_states
class Step1Model(PreTrainedModel):
config_class=PretrainedConfig
@with_empty_init
def __init__(
self,
config,
):
super().__init__(config)
self.tok_embeddings = LLaMaEmbedding(config)
self.transformer = Transformer(config)
def forward(
self,
input_ids=None,
attention_mask=None,
):
hidden_states = self.tok_embeddings(input_ids)
hidden_states = self.transformer(
hidden_states,
attention_mask,
)
return hidden_states
class STEP1TextEncoder(torch.nn.Module):
def __init__(self, model_dir, max_length=320):
super(STEP1TextEncoder, self).__init__()
self.max_length = max_length
self.text_tokenizer = Wrapped_StepChatTokenizer(os.path.join(model_dir, 'step1_chat_tokenizer.model'))
text_encoder = Step1Model.from_pretrained(model_dir)
self.text_encoder = text_encoder.eval().to(torch.bfloat16)
@staticmethod
def from_pretrained(path, torch_dtype=torch.bfloat16):
model = STEP1TextEncoder(path).to(torch_dtype)
return model
@torch.no_grad
def forward(self, prompts, with_mask=True, max_length=None, device="cuda"):
self.device = device
with torch.no_grad(), torch.amp.autocast(dtype=torch.bfloat16, device_type=device):
if type(prompts) is str:
prompts = [prompts]
txt_tokens = self.text_tokenizer(
prompts, max_length=max_length or self.max_length, padding="max_length", truncation=True, return_tensors="pt"
)
y = self.text_encoder(
txt_tokens.input_ids.to(self.device),
attention_mask=txt_tokens.attention_mask.to(self.device) if with_mask else None
)
y_mask = txt_tokens.attention_mask
return y.transpose(0,1), y_mask

File diff suppressed because it is too large Load Diff

View File

@@ -44,6 +44,7 @@ def get_timestep_embedding(
downscale_freq_shift: float = 1, downscale_freq_shift: float = 1,
scale: float = 1, scale: float = 1,
max_period: int = 10000, max_period: int = 10000,
computation_device = None,
): ):
""" """
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
@@ -57,11 +58,11 @@ def get_timestep_embedding(
half_dim = embedding_dim // 2 half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange( exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device start=0, end=half_dim, dtype=torch.float32, device=timesteps.device if computation_device is None else computation_device
) )
exponent = exponent / (half_dim - downscale_freq_shift) exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent) emb = torch.exp(exponent).to(timesteps.device)
emb = timesteps[:, None].float() * emb[None, :] emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings # scale embeddings
@@ -81,11 +82,12 @@ def get_timestep_embedding(
class TemporalTimesteps(torch.nn.Module): class TemporalTimesteps(torch.nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None):
super().__init__() super().__init__()
self.num_channels = num_channels self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift self.downscale_freq_shift = downscale_freq_shift
self.computation_device = computation_device
def forward(self, timesteps): def forward(self, timesteps):
t_emb = get_timestep_embedding( t_emb = get_timestep_embedding(
@@ -93,6 +95,7 @@ class TemporalTimesteps(torch.nn.Module):
self.num_channels, self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos, flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift, downscale_freq_shift=self.downscale_freq_shift,
computation_device=self.computation_device,
) )
return t_emb return t_emb

View File

@@ -104,3 +104,131 @@ class TileWorker:
# Done! # Done!
model_output = model_output.to(device=inference_device, dtype=inference_dtype) model_output = model_output.to(device=inference_device, dtype=inference_dtype)
return model_output return model_output
class FastTileWorker:
def __init__(self):
pass
def build_mask(self, data, is_bound):
_, _, H, W = data.shape
h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
border_width = (H + W) // 4
pad = torch.ones_like(h) * border_width
mask = torch.stack([
pad if is_bound[0] else h + 1,
pad if is_bound[1] else H - h,
pad if is_bound[2] else w + 1,
pad if is_bound[3] else W - w
]).min(dim=0).values
mask = mask.clip(1, border_width)
mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
mask = rearrange(mask, "H W -> 1 H W")
return mask
def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
# Prepare
B, C, H, W = model_input.shape
border_width = int(tile_stride*0.5) if border_width is None else border_width
weight = torch.zeros((1, 1, H, W), dtype=tile_dtype, device=tile_device)
values = torch.zeros((B, C, H, W), dtype=tile_dtype, device=tile_device)
# Split tasks
tasks = []
for h in range(0, H, tile_stride):
for w in range(0, W, tile_stride):
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
continue
h_, w_ = h + tile_size, w + tile_size
if h_ > H: h, h_ = H - tile_size, H
if w_ > W: w, w_ = W - tile_size, W
tasks.append((h, h_, w, w_))
# Run
for hl, hr, wl, wr in tasks:
# Forward
hidden_states_batch = forward_fn(hl, hr, wl, wr).to(dtype=tile_dtype, device=tile_device)
mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
weight[:, :, hl:hr, wl:wr] += mask
values /= weight
return values
class TileWorker2Dto3D:
"""
Process 3D tensors, but only enable TileWorker on 2D.
"""
def __init__(self):
pass
def build_mask(self, T, H, W, dtype, device, is_bound, border_width):
t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
border_width = (H + W) // 4 if border_width is None else border_width
pad = torch.ones_like(h) * border_width
mask = torch.stack([
pad if is_bound[0] else t + 1,
pad if is_bound[1] else T - t,
pad if is_bound[2] else h + 1,
pad if is_bound[3] else H - h,
pad if is_bound[4] else w + 1,
pad if is_bound[5] else W - w
]).min(dim=0).values
mask = mask.clip(1, border_width)
mask = (mask / border_width).to(dtype=dtype, device=device)
mask = rearrange(mask, "T H W -> 1 1 T H W")
return mask
def tiled_forward(
self,
forward_fn,
model_input,
tile_size, tile_stride,
tile_device="cpu", tile_dtype=torch.float32,
computation_device="cuda", computation_dtype=torch.float32,
border_width=None, scales=[1, 1, 1, 1],
progress_bar=lambda x:x
):
B, C, T, H, W = model_input.shape
scale_C, scale_T, scale_H, scale_W = scales
tile_size_H, tile_size_W = tile_size
tile_stride_H, tile_stride_W = tile_stride
value = torch.zeros((B, int(C*scale_C), int(T*scale_T), int(H*scale_H), int(W*scale_W)), dtype=tile_dtype, device=tile_device)
weight = torch.zeros((1, 1, int(T*scale_T), int(H*scale_H), int(W*scale_W)), dtype=tile_dtype, device=tile_device)
# Split tasks
tasks = []
for h in range(0, H, tile_stride_H):
for w in range(0, W, tile_stride_W):
if (h-tile_stride_H >= 0 and h-tile_stride_H+tile_size_H >= H) or (w-tile_stride_W >= 0 and w-tile_stride_W+tile_size_W >= W):
continue
h_, w_ = h + tile_size_H, w + tile_size_W
if h_ > H: h, h_ = max(H - tile_size_H, 0), H
if w_ > W: w, w_ = max(W - tile_size_W, 0), W
tasks.append((h, h_, w, w_))
# Run
for hl, hr, wl, wr in progress_bar(tasks):
mask = self.build_mask(
int(T*scale_T), int((hr-hl)*scale_H), int((wr-wl)*scale_W),
tile_dtype, tile_device,
is_bound=(True, True, hl==0, hr>=H, wl==0, wr>=W),
border_width=border_width
)
grid_input = model_input[:, :, :, hl:hr, wl:wr].to(dtype=computation_dtype, device=computation_device)
grid_output = forward_fn(grid_input).to(dtype=tile_dtype, device=tile_device)
value[:, :, :, int(hl*scale_H):int(hr*scale_H), int(wl*scale_W):int(wr*scale_W)] += grid_output * mask
weight[:, :, :, int(hl*scale_H):int(hr*scale_H), int(wl*scale_W):int(wr*scale_W)] += mask
value = value / weight
return value

182
diffsynth/models/utils.py Normal file
View File

@@ -0,0 +1,182 @@
import torch, os
from safetensors import safe_open
from contextlib import contextmanager
import hashlib
@contextmanager
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
old_register_parameter = torch.nn.Module.register_parameter
if include_buffers:
old_register_buffer = torch.nn.Module.register_buffer
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
def register_empty_buffer(module, name, buffer, persistent=True):
old_register_buffer(module, name, buffer, persistent=persistent)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(device)
def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs["device"] = device
return fn(*args, **kwargs)
return wrapper
if include_buffers:
tensor_constructors_to_patch = {
torch_function_name: getattr(torch, torch_function_name)
for torch_function_name in ["empty", "zeros", "ones", "full"]
}
else:
tensor_constructors_to_patch = {}
try:
torch.nn.Module.register_parameter = register_empty_parameter
if include_buffers:
torch.nn.Module.register_buffer = register_empty_buffer
for torch_function_name in tensor_constructors_to_patch.keys():
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
yield
finally:
torch.nn.Module.register_parameter = old_register_parameter
if include_buffers:
torch.nn.Module.register_buffer = old_register_buffer
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
setattr(torch, torch_function_name, old_torch_function)
def load_state_dict_from_folder(file_path, torch_dtype=None):
state_dict = {}
for file_name in os.listdir(file_path):
if "." in file_name and file_name.split(".")[-1] in [
"safetensors", "bin", "ckpt", "pth", "pt"
]:
state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
return state_dict
def load_state_dict(file_path, torch_dtype=None):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
state_dict[i] = state_dict[i].to(torch_dtype)
return state_dict
def search_for_embeddings(state_dict):
embeddings = []
for k in state_dict:
if isinstance(state_dict[k], torch.Tensor):
embeddings.append(state_dict[k])
elif isinstance(state_dict[k], dict):
embeddings += search_for_embeddings(state_dict[k])
return embeddings
def search_parameter(param, state_dict):
for name, param_ in state_dict.items():
if param.numel() == param_.numel():
if param.shape == param_.shape:
if torch.dist(param, param_) < 1e-3:
return name
else:
if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
return name
return None
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
matched_keys = set()
with torch.no_grad():
for name in source_state_dict:
rename = search_parameter(source_state_dict[name], target_state_dict)
if rename is not None:
print(f'"{name}": "{rename}",')
matched_keys.add(rename)
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
length = source_state_dict[name].shape[0] // 3
rename = []
for i in range(3):
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
if None not in rename:
print(f'"{name}": {rename},')
for rename_ in rename:
matched_keys.add(rename_)
for name in target_state_dict:
if name not in matched_keys:
print("Cannot find", name, target_state_dict[name].shape)
def search_for_files(folder, extensions):
files = []
if os.path.isdir(folder):
for file in sorted(os.listdir(folder)):
files += search_for_files(os.path.join(folder, file), extensions)
elif os.path.isfile(folder):
for extension in extensions:
if folder.endswith(extension):
files.append(folder)
break
return files
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, torch.Tensor):
if with_shape:
shape = "_".join(map(str, list(value.shape)))
keys.append(key + ":" + shape)
keys.append(key)
elif isinstance(value, dict):
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
keys.sort()
keys_str = ",".join(keys)
return keys_str
def split_state_dict_with_prefix(state_dict):
keys = sorted([key for key in state_dict if isinstance(key, str)])
prefix_dict = {}
for key in keys:
prefix = key if "." not in key else key.split(".")[0]
if prefix not in prefix_dict:
prefix_dict[prefix] = []
prefix_dict[prefix].append(key)
state_dicts = []
for prefix, keys in prefix_dict.items():
sub_state_dict = {key: state_dict[key] for key in keys}
state_dicts.append(sub_state_dict)
return state_dicts
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()

View File

@@ -0,0 +1,254 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def fp16_clamp(x):
if x.dtype == torch.float16 and torch.isinf(x).any():
clamp = torch.finfo(x.dtype).max - 1000
x = torch.clamp(x, min=-clamp, max=clamp)
return x
class GELU(nn.Module):
def forward(self, x):
return 0.5 * x * (1.0 + torch.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
class T5LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super(T5LayerNorm, self).__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
self.eps)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
x = x.type_as(self.weight)
return self.weight * x
class T5Attention(nn.Module):
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
assert dim_attn % num_heads == 0
super(T5Attention, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
self.num_heads = num_heads
self.head_dim = dim_attn // num_heads
# layers
self.q = nn.Linear(dim, dim_attn, bias=False)
self.k = nn.Linear(dim, dim_attn, bias=False)
self.v = nn.Linear(dim, dim_attn, bias=False)
self.o = nn.Linear(dim_attn, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, context=None, mask=None, pos_bias=None):
"""
x: [B, L1, C].
context: [B, L2, C] or None.
mask: [B, L2] or [B, L1, L2] or None.
"""
# check inputs
context = x if context is None else context
b, n, c = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
q = self.q(x).view(b, -1, n, c)
k = self.k(context).view(b, -1, n, c)
v = self.v(context).view(b, -1, n, c)
# attention bias
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
if pos_bias is not None:
attn_bias += pos_bias
if mask is not None:
assert mask.ndim in [2, 3]
mask = mask.view(b, 1, 1,
-1) if mask.ndim == 2 else mask.unsqueeze(1)
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
# compute attention (T5 does not use scaling)
attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
x = torch.einsum('bnij,bjnc->binc', attn, v)
# output
x = x.reshape(b, -1, n * c)
x = self.o(x)
x = self.dropout(x)
return x
class T5FeedForward(nn.Module):
def __init__(self, dim, dim_ffn, dropout=0.1):
super(T5FeedForward, self).__init__()
self.dim = dim
self.dim_ffn = dim_ffn
# layers
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x) * self.gate(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class T5SelfAttention(nn.Module):
def __init__(self,
dim,
dim_attn,
dim_ffn,
num_heads,
num_buckets,
shared_pos=True,
dropout=0.1):
super(T5SelfAttention, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.norm1 = T5LayerNorm(dim)
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=True)
def forward(self, x, mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(
x.size(1), x.size(1))
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.ffn(self.norm2(x)))
return x
class T5RelativeEmbedding(nn.Module):
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
super(T5RelativeEmbedding, self).__init__()
self.num_buckets = num_buckets
self.num_heads = num_heads
self.bidirectional = bidirectional
self.max_dist = max_dist
# layers
self.embedding = nn.Embedding(num_buckets, num_heads)
def forward(self, lq, lk):
device = self.embedding.weight.device
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
# torch.arange(lq).unsqueeze(1).to(device)
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
torch.arange(lq, device=device).unsqueeze(1)
rel_pos = self._relative_position_bucket(rel_pos)
rel_pos_embeds = self.embedding(rel_pos)
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
0) # [1, N, Lq, Lk]
return rel_pos_embeds.contiguous()
def _relative_position_bucket(self, rel_pos):
# preprocess
if self.bidirectional:
num_buckets = self.num_buckets // 2
rel_buckets = (rel_pos > 0).long() * num_buckets
rel_pos = torch.abs(rel_pos)
else:
num_buckets = self.num_buckets
rel_buckets = 0
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
# embeddings for small and large positions
max_exact = num_buckets // 2
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
math.log(self.max_dist / max_exact) *
(num_buckets - max_exact)).long()
rel_pos_large = torch.min(
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
return rel_buckets
def init_weights(m):
if isinstance(m, T5LayerNorm):
nn.init.ones_(m.weight)
elif isinstance(m, T5FeedForward):
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
elif isinstance(m, T5Attention):
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
elif isinstance(m, T5RelativeEmbedding):
nn.init.normal_(
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
class WanXTextEncoder(torch.nn.Module):
def __init__(self,
vocab=256384,
dim=4096,
dim_attn=4096,
dim_ffn=10240,
num_heads=64,
num_layers=24,
num_buckets=32,
shared_pos=False,
dropout=0.1):
super(WanXTextEncoder, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_layers = num_layers
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
else nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=True) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
shared_pos, dropout) for _ in range(num_layers)
])
self.norm = T5LayerNorm(dim)
# initialize weights
self.apply(init_weights)
def forward(self, ids, mask=None):
x = self.token_embedding(ids)
x = self.dropout(x)
e = self.pos_embedding(x.size(1),
x.size(1)) if self.shared_pos else None
for block in self.blocks:
x = block(x, mask, pos_bias=e)
x = self.norm(x)
x = self.dropout(x)
return x

View File

@@ -0,0 +1,794 @@
from einops import rearrange, repeat
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
CACHE_T = 2
def block_causal_mask(x, block_size):
# params
b, n, s, _, device = *x.size(), x.device
assert s % block_size == 0
num_blocks = s // block_size
# build mask
mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
for i in range(num_blocks):
mask[:, :,
i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1
return mask
class CausalConv3d(nn.Conv3d):
"""
Causal 3d convolusion.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
return super().forward(x)
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
def forward(self, x):
return F.normalize(
x, dim=(1 if self.channel_first else
-1)) * self.scale * self.gamma + self.bias
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
'downsample3d')
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == 'upsample2d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
elif mode == 'upsample3d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
self.time_conv = CausalConv3d(dim,
dim * 2, (3, 1, 1),
padding=(1, 0, 0))
elif mode == 'downsample2d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == 'downsample3d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(dim,
dim, (3, 1, 1),
stride=(2, 1, 1),
padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == 'upsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = 'Rep'
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] != 'Rep':
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] == 'Rep':
cache_x = torch.cat([
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2)
if feat_cache[idx] == 'Rep':
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.resample(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
if self.mode == 'downsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
def init_weight(self, conv):
conv_weight = conv.weight
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
one_matrix = torch.eye(c1, c2)
init_matrix = one_matrix
nn.init.zeros_(conv_weight)
conv_weight.data[:, :, 1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def init_weight2(self, conv):
conv_weight = conv.weight.data
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2)
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False), nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1))
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x + h
class AttentionBlock(nn.Module):
"""
Causal self-attention with a single head.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# layers
self.norm = RMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
# zero out the last layer params
nn.init.zeros_(self.proj.weight)
def forward(self, x):
identity = x
b, c, t, h, w = x.size()
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.norm(x)
# compute query, key, value
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(
0, 1, 3, 2).contiguous().chunk(3, dim=-1)
# apply attention
x = F.scaled_dot_product_attention(
q,
k,
v,
#attn_mask=block_causal_mask(q, block_size=h * w)
)
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
# output
x = self.proj(x)
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
return x + identity
class Encoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
downsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = 'downsample3d' if temperal_downsample[
i] else 'downsample2d'
downsamples.append(Resample(out_dim, mode=mode))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout),
AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim, dropout))
# output blocks
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
class Decoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2**(len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout))
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i == 1 or i == 2 or i == 3:
in_dim = in_dim // 2
for _ in range(num_res_blocks + 1):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
upsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# upsample block
if i != len(dim_mult) - 1:
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
def count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, CausalConv3d):
count += 1
return count
class VideoVAE_(nn.Module):
def __init__(self,
dim=96,
z_dim=16,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[False, True, True],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
x_recon = self.decode(z)
return x_recon, mu, log_var
def encode(self, x, scale):
self.clear_cache()
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor):
scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1, self.z_dim, 1, 1, 1)
else:
scale = scale.to(dtype=mu.dtype, device=mu.device)
mu = (mu - scale[0]) * scale[1]
return mu
def decode(self, z, scale):
self.clear_cache()
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
else:
scale = scale.to(dtype=z.dtype, device=z.device)
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else:
out_ = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2) # may add tensor offload
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
class WanXVideoVAE(nn.Module):
def __init__(self, z_dim=16):
super().__init__()
mean = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
]
std = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]
self.mean = torch.tensor(mean)
self.std = torch.tensor(std)
self.scale = [self.mean, 1.0 / self.std]
# init model
self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)
self.upsampling_factor = 8
def build_1d_mask(self, length, left_bound, right_bound, border_width):
x = torch.ones((length,))
if not left_bound:
x[:border_width] = (torch.arange(border_width) + 1) / border_width
if not right_bound:
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
return x
def build_mask(self, data, is_bound, border_width):
_, _, _, H, W = data.shape
h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
h = repeat(h, "H -> H W", H=H, W=W)
w = repeat(w, "W -> H W", H=H, W=W)
mask = torch.stack([h, w]).min(dim=0).values
mask = rearrange(mask, "H W -> 1 1 1 H W")
return mask
def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
_, _, T, H, W = hidden_states.shape
size_h, size_w = tile_size
stride_h, stride_w = tile_stride
# Split tasks
tasks = []
for h in range(0, H, stride_h):
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
for w in range(0, W, stride_w):
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
h_, w_ = h + size_h, w + size_w
tasks.append((h, h_, w, w_))
data_device = "cpu"
computation_device = device
out_T = T * 4 - 3
weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
mask = self.build_mask(
hidden_states_batch,
is_bound=(h==0, h_>=H, w==0, w_>=W),
border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)
).to(dtype=hidden_states.dtype, device=data_device)
target_h = h * self.upsampling_factor
target_w = w * self.upsampling_factor
values[
:,
:,
:,
target_h:target_h + hidden_states_batch.shape[3],
target_w:target_w + hidden_states_batch.shape[4],
] += hidden_states_batch * mask
weight[
:,
:,
:,
target_h: target_h + hidden_states_batch.shape[3],
target_w: target_w + hidden_states_batch.shape[4],
] += mask
values = values / weight
values = values.float().clamp_(-1, 1)
return values
def tiled_encode(self, video, device, tile_size, tile_stride):
_, _, T, H, W = video.shape
size_h, size_w = tile_size
stride_h, stride_w = tile_stride
# Split tasks
tasks = []
for h in range(0, H, stride_h):
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
for w in range(0, W, stride_w):
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
h_, w_ = h + size_h, w + size_w
tasks.append((h, h_, w, w_))
data_device = "cpu"
computation_device = device
out_T = (T + 3) // 4
weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
mask = self.build_mask(
hidden_states_batch,
is_bound=(h==0, h_>=H, w==0, w_>=W),
border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)
).to(dtype=video.dtype, device=data_device)
target_h = h // self.upsampling_factor
target_w = w // self.upsampling_factor
values[
:,
:,
:,
target_h:target_h + hidden_states_batch.shape[3],
target_w:target_w + hidden_states_batch.shape[4],
] += hidden_states_batch * mask
weight[
:,
:,
:,
target_h: target_h + hidden_states_batch.shape[3],
target_w: target_w + hidden_states_batch.shape[4],
] += mask
values = values / weight
values = values.float()
return values
def single_encode(self, video, device):
video = video.to(device)
x = self.model.encode(video, self.scale)
return x.float()
def single_decode(self, hidden_state, device):
hidden_state = hidden_state.to(device)
video = self.model.decode(hidden_state, self.scale)
return video.float().clamp_(-1, 1)
def encode(self, videos, device, tiled=False, tile_size=(272, 272), tile_stride=(144, 128)):
videos = [video.to("cpu") for video in videos]
hidden_states = []
for video in videos:
video = video.unsqueeze(0)
if tiled:
assert tile_size[0] % self.upsampling_factor == 0 and tile_size[1] % self.upsampling_factor == 0, f"tile_size must be devisible by {self.upsampling_factor}"
hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
else:
hidden_state = self.single_encode(video, device)
hidden_state = hidden_state.squeeze(0)
hidden_states.append(hidden_state)
return hidden_states
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
videos = []
for hidden_state in hidden_states:
hidden_state = hidden_state.unsqueeze(0)
if tiled:
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
else:
video = self.single_decode(hidden_state, device)
video = video.squeeze(0)
videos.append(video)
return videos
@staticmethod
def state_dict_converter():
return WanXVideoVAEStateDictConverter()
class WanXVideoVAEStateDictConverter:
def __init__(self):
pass
def from_civitai(self, state_dict):
state_dict_ = {}
for name in state_dict['model_state']:
state_dict_['model.' + name] = state_dict['model_state'][name]
return state_dict_

View File

@@ -5,5 +5,10 @@ from .sdxl_video import SDXLVideoPipeline
from .sd3_image import SD3ImagePipeline from .sd3_image import SD3ImagePipeline
from .hunyuan_image import HunyuanDiTImagePipeline from .hunyuan_image import HunyuanDiTImagePipeline
from .svd_video import SVDVideoPipeline from .svd_video import SVDVideoPipeline
from .flux_image import FluxImagePipeline
from .cog_video import CogVideoPipeline
from .omnigen_image import OmnigenImagePipeline
from .pipeline_runner import SDVideoPipelineRunner from .pipeline_runner import SDVideoPipelineRunner
from .hunyuan_video import HunyuanVideoPipeline
from .step_video import StepVideoPipeline
KolorsImagePipeline = SDXLImagePipeline KolorsImagePipeline = SDXLImagePipeline

View File

@@ -1,15 +1,30 @@
import torch import torch
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from torchvision.transforms import GaussianBlur
class BasePipeline(torch.nn.Module): class BasePipeline(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.float16): def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64):
super().__init__() super().__init__()
self.device = device self.device = device
self.torch_dtype = torch_dtype self.torch_dtype = torch_dtype
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.cpu_offload = False
self.model_names = []
def check_resize_height_width(self, height, width):
if height % self.height_division_factor != 0:
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.")
if width % self.width_division_factor != 0:
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.")
return height, width
def preprocess_image(self, image): def preprocess_image(self, image):
@@ -22,7 +37,7 @@ class BasePipeline(torch.nn.Module):
def vae_output_to_image(self, vae_output): def vae_output_to_image(self, vae_output):
image = vae_output[0].cpu().permute(1, 2, 0).numpy() image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image return image
@@ -32,3 +47,81 @@ class BasePipeline(torch.nn.Module):
video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video] video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
return video return video
def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
if len(latents) > 0:
blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
height, width = value.shape[-2:]
weight = torch.ones_like(value)
for latent, mask, scale in zip(latents, masks, scales):
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
mask = blur(mask)
value += latent * mask * scale
weight += mask * scale
value /= weight
return value
def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None):
if special_kwargs is None:
noise_pred_global = inference_callback(prompt_emb_global)
else:
noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
if special_local_kwargs_list is None:
noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
else:
noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)]
noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
return noise_pred
def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
local_prompts = local_prompts or []
masks = masks or []
mask_scales = mask_scales or []
extended_prompt_dict = self.prompter.extend_prompt(prompt)
prompt = extended_prompt_dict.get("prompt", prompt)
local_prompts += extended_prompt_dict.get("prompts", [])
masks += extended_prompt_dict.get("masks", [])
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
return prompt, local_prompts, masks, mask_scales
def enable_cpu_offload(self):
self.cpu_offload = True
def load_models_to_device(self, loadmodel_names=[]):
# only load models to device if cpu_offload is enabled
if not self.cpu_offload:
return
# offload the unneeded models to cpu
for model_name in self.model_names:
if model_name not in loadmodel_names:
model = getattr(self, model_name)
if model is not None:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
for module in model.modules():
if hasattr(module, "offload"):
module.offload()
else:
model.cpu()
# load the needed models to device
for model_name in loadmodel_names:
model = getattr(self, model_name)
if model is not None:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
for module in model.modules():
if hasattr(module, "onload"):
module.onload()
else:
model.to(self.device)
# fresh the cuda cache
torch.cuda.empty_cache()
def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
generator = None if seed is None else torch.Generator(device).manual_seed(seed)
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
return noise

View File

@@ -0,0 +1,135 @@
from ..models import ModelManager, FluxTextEncoder2, CogDiT, CogVAEEncoder, CogVAEDecoder
from ..prompters import CogPrompter
from ..schedulers import EnhancedDDIMScheduler
from .base import BasePipeline
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
from einops import rearrange
class CogVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
self.scheduler = EnhancedDDIMScheduler(rescale_zero_terminal_snr=True, prediction_type="v_prediction")
self.prompter = CogPrompter()
# models
self.text_encoder: FluxTextEncoder2 = None
self.dit: CogDiT = None
self.vae_encoder: CogVAEEncoder = None
self.vae_decoder: CogVAEDecoder = None
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
self.text_encoder = model_manager.fetch_model("flux_text_encoder_2")
self.dit = model_manager.fetch_model("cog_dit")
self.vae_encoder = model_manager.fetch_model("cog_vae_encoder")
self.vae_decoder = model_manager.fetch_model("cog_vae_decoder")
self.prompter.fetch_models(self.text_encoder)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
@staticmethod
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
pipe = CogVideoPipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype
)
pipe.fetch_models(model_manager, prompt_refiner_classes)
return pipe
def tensor2video(self, frames):
frames = rearrange(frames, "C T H W -> T H W C")
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
frames = [Image.fromarray(frame) for frame in frames]
return frames
def encode_prompt(self, prompt, positive=True):
prompt_emb = self.prompter.encode_prompt(prompt, device=self.device, positive=positive)
return {"prompt_emb": prompt_emb}
def prepare_extra_input(self, latents):
return {"image_rotary_emb": self.dit.prepare_rotary_positional_embeddings(latents.shape[3], latents.shape[4], latents.shape[2], device=self.device)}
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
input_video=None,
cfg_scale=7.0,
denoising_strength=1.0,
num_frames=49,
height=480,
width=720,
num_inference_steps=20,
tiled=False,
tile_size=(60, 90),
tile_stride=(30, 45),
seed=None,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
height, width = self.check_resize_height_width(height, width)
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
# Prepare latent tensors
noise = self.generate_noise((1, 16, num_frames // 4 + 1, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype)
if denoising_strength == 1.0:
latents = noise.clone()
else:
input_video = self.preprocess_images(input_video)
input_video = torch.stack(input_video, dim=2)
latents = self.vae_encoder.encode_video(input_video, **tiler_kwargs, progress_bar=progress_bar_cmd).to(dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, self.scheduler.timesteps[0])
if not tiled: latents = latents.to(self.device)
# Encode prompt
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
# Extra input
extra_input = self.prepare_extra_input(latents)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance
noise_pred_posi = self.dit(
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, **extra_input
)
if cfg_scale != 1.0:
noise_pred_nega = self.dit(
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, **extra_input
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# DDIM
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# Update progress bar
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
video = self.vae_decoder.decode_video(latents.to("cpu"), **tiler_kwargs, progress_bar=progress_bar_cmd)
video = self.tensor2video(video[0])
return video

View File

@@ -136,6 +136,40 @@ def lets_dance_xl(
device = "cuda", device = "cuda",
vram_limit_level = 0, vram_limit_level = 0,
): ):
# 0. Text embedding alignment (only for video processing)
if encoder_hidden_states.shape[0] != sample.shape[0]:
encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1)
if add_text_embeds.shape[0] != sample.shape[0]:
add_text_embeds = add_text_embeds.repeat(sample.shape[0], 1)
# 1. ControlNet
controlnet_insert_block_id = 22
if controlnet is not None and controlnet_frames is not None:
res_stacks = []
# process controlnet frames with batch
for batch_id in range(0, sample.shape[0], controlnet_batch_size):
batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0])
res_stack = controlnet(
sample[batch_id: batch_id_],
timestep,
encoder_hidden_states[batch_id: batch_id_],
controlnet_frames[:, batch_id: batch_id_],
add_time_id=add_time_id,
add_text_embeds=add_text_embeds,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
unet=unet, # for Kolors, some modules in ControlNets will be replaced.
)
if vram_limit_level >= 1:
res_stack = [res.cpu() for res in res_stack]
res_stacks.append(res_stack)
# concat the residual
additional_res_stack = []
for i in range(len(res_stacks[0])):
res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0)
additional_res_stack.append(res)
else:
additional_res_stack = None
# 2. time # 2. time
t_emb = unet.time_proj(timestep).to(sample.dtype) t_emb = unet.time_proj(timestep).to(sample.dtype)
t_emb = unet.time_embedding(t_emb) t_emb = unet.time_embedding(t_emb)
@@ -156,11 +190,31 @@ def lets_dance_xl(
# 4. blocks # 4. blocks
for block_id, block in enumerate(unet.blocks): for block_id, block in enumerate(unet.blocks):
hidden_states, time_emb, text_emb, res_stack = block( # 4.1 UNet
hidden_states, time_emb, text_emb, res_stack, if isinstance(block, PushBlock):
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
if vram_limit_level>=1:
res_stack[-1] = res_stack[-1].cpu()
elif isinstance(block, PopBlock):
if vram_limit_level>=1:
res_stack[-1] = res_stack[-1].to(device)
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
else:
hidden_states_input = hidden_states
hidden_states_output = []
for batch_id in range(0, sample.shape[0], unet_batch_size):
batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
hidden_states, _, _, _ = block(
hidden_states_input[batch_id: batch_id_],
time_emb[batch_id: batch_id_],
text_emb[batch_id: batch_id_],
res_stack,
cross_frame_attention=cross_frame_attention,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}),
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {})
) )
hidden_states_output.append(hidden_states)
hidden_states = torch.concat(hidden_states_output, dim=0)
# 4.2 AnimateDiff # 4.2 AnimateDiff
if motion_modules is not None: if motion_modules is not None:
if block_id in motion_modules.call_block_id: if block_id in motion_modules.call_block_id:
@@ -169,6 +223,10 @@ def lets_dance_xl(
hidden_states, time_emb, text_emb, res_stack, hidden_states, time_emb, text_emb, res_stack,
batch_size=1 batch_size=1
) )
# 4.3 ControlNet
if block_id == controlnet_insert_block_id and additional_res_stack is not None:
hidden_states += additional_res_stack.pop().to(device)
res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
# 5. output # 5. output
hidden_states = unet.conv_norm_out(hidden_states) hidden_states = unet.conv_norm_out(hidden_states)

View File

@@ -0,0 +1,646 @@
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompters import FluxPrompter
from ..schedulers import FlowMatchScheduler
from .base import BasePipeline
from typing import List
import torch
from tqdm import tqdm
import numpy as np
from PIL import Image
from ..models.tiler import FastTileWorker
from transformers import SiglipVisionModel
from copy import deepcopy
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
from ..models.flux_dit import RMSNorm
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
class FluxImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
self.scheduler = FlowMatchScheduler()
self.prompter = FluxPrompter()
# models
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: FluxTextEncoder2 = None
self.dit: FluxDiT = None
self.vae_decoder: FluxVAEDecoder = None
self.vae_encoder: FluxVAEEncoder = None
self.controlnet: FluxMultiControlNetManager = None
self.ipadapter: FluxIpAdapter = None
self.ipadapter_image_encoder: SiglipVisionModel = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
def enable_vram_management(self, num_persistent_param_in_dit=None):
dtype = next(iter(self.text_encoder_1.parameters())).dtype
enable_vram_management(
self.text_encoder_1,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Embedding: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
dtype = next(iter(self.text_encoder_2.parameters())).dtype
enable_vram_management(
self.text_encoder_2,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Embedding: AutoWrappedModule,
T5LayerNorm: AutoWrappedModule,
T5DenseActDense: AutoWrappedModule,
T5DenseGatedActDense: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
dtype = next(iter(self.dit.parameters())).dtype
enable_vram_management(
self.dit,
module_map = {
RMSNorm: AutoWrappedModule,
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cuda",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
dtype = next(iter(self.vae_decoder.parameters())).dtype
enable_vram_management(
self.vae_decoder,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv2d: AutoWrappedModule,
torch.nn.GroupNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
dtype = next(iter(self.vae_encoder.parameters())).dtype
enable_vram_management(
self.vae_encoder,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv2d: AutoWrappedModule,
torch.nn.GroupNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
self.enable_cpu_offload()
def denoising_model(self):
return self.dit
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[]):
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2")
self.dit = model_manager.fetch_model("flux_dit")
self.vae_decoder = model_manager.fetch_model("flux_vae_decoder")
self.vae_encoder = model_manager.fetch_model("flux_vae_encoder")
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
self.prompter.load_prompt_extenders(model_manager, prompt_extender_classes)
# ControlNets
controlnet_units = []
for config in controlnet_config_units:
controlnet_unit = ControlNetUnit(
Annotator(config.processor_id, device=self.device, skip_processor=config.skip_processor),
model_manager.fetch_model("flux_controlnet", config.model_path),
config.scale
)
controlnet_units.append(controlnet_unit)
self.controlnet = FluxMultiControlNetManager(controlnet_units)
# IP-Adapters
self.ipadapter = model_manager.fetch_model("flux_ipadapter")
self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
pipe = FluxImagePipeline(
device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype if torch_dtype is None else torch_dtype,
)
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes, prompt_extender_classes)
return pipe
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.vae_output_to_image(image)
return image
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512):
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
)
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
def prepare_extra_input(self, latents=None, guidance=1.0):
latent_image_ids = self.dit.prepare_image_ids(latents)
guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
return {"image_ids": latent_image_ids, "guidance": guidance}
def apply_controlnet_mask_on_latents(self, latents, mask):
mask = (self.preprocess_image(mask) + 1) / 2
mask = mask.mean(dim=1, keepdim=True)
mask = mask.to(dtype=self.torch_dtype, device=self.device)
mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])
latents = torch.concat([latents, mask], dim=1)
return latents
def apply_controlnet_mask_on_image(self, image, mask):
mask = mask.resize(image.size)
mask = self.preprocess_image(mask).mean(dim=[0, 1])
image = np.array(image)
image[mask > 0] = 0
image = Image.fromarray(image)
return image
def prepare_controlnet_input(self, controlnet_image, controlnet_inpaint_mask, tiler_kwargs):
if isinstance(controlnet_image, Image.Image):
controlnet_image = [controlnet_image] * len(self.controlnet.processors)
controlnet_frames = []
for i in range(len(self.controlnet.processors)):
# image annotator
image = self.controlnet.process_image(controlnet_image[i], processor_id=i)[0]
if controlnet_inpaint_mask is not None and self.controlnet.processors[i].processor_id == "inpaint":
image = self.apply_controlnet_mask_on_image(image, controlnet_inpaint_mask)
# image to tensor
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
# vae encoder
image = self.encode_image(image, **tiler_kwargs)
if controlnet_inpaint_mask is not None and self.controlnet.processors[i].processor_id == "inpaint":
image = self.apply_controlnet_mask_on_latents(image, controlnet_inpaint_mask)
# store it
controlnet_frames.append(image)
return controlnet_frames
def prepare_ipadapter_inputs(self, images, height=384, width=384):
images = [image.convert("RGB").resize((width, height), resample=3) for image in images]
images = [self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) for image in images]
return torch.cat(images, dim=0)
def inpaint_fusion(self, latents, inpaint_latents, pred_noise, fg_mask, bg_mask, progress_id, background_weight=0.):
# inpaint noise
inpaint_noise = (latents - inpaint_latents) / self.scheduler.sigmas[progress_id]
# merge noise
weight = torch.ones_like(inpaint_noise)
inpaint_noise[fg_mask] = pred_noise[fg_mask]
inpaint_noise[bg_mask] += pred_noise[bg_mask] * background_weight
weight[bg_mask] += background_weight
inpaint_noise /= weight
return inpaint_noise
def preprocess_masks(self, masks, height, width, dim):
out_masks = []
for mask in masks:
mask = self.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0
mask = mask.repeat(1, dim, 1, 1).to(device=self.device, dtype=self.torch_dtype)
out_masks.append(mask)
return out_masks
def prepare_entity_inputs(self, entity_prompts, entity_masks, width, height, t5_sequence_length=512, enable_eligen_inpaint=False):
fg_mask, bg_mask = None, None
if enable_eligen_inpaint:
masks_ = deepcopy(entity_masks)
fg_masks = torch.cat([self.preprocess_image(mask.resize((width//8, height//8))).mean(dim=1, keepdim=True) for mask in masks_])
fg_masks = (fg_masks > 0).float()
fg_mask = fg_masks.sum(dim=0, keepdim=True).repeat(1, 16, 1, 1) > 0
bg_mask = ~fg_mask
entity_masks = self.preprocess_masks(entity_masks, height//8, width//8, 1)
entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w
entity_prompts = self.encode_prompt(entity_prompts, t5_sequence_length=t5_sequence_length)['prompt_emb'].unsqueeze(0)
return entity_prompts, entity_masks, fg_mask, bg_mask
def prepare_latents(self, input_image, height, width, seed, tiled, tile_size, tile_stride):
if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
input_latents = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(input_latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
input_latents = None
return latents, input_latents
def prepare_ipadapter(self, ipadapter_images, ipadapter_scale):
if ipadapter_images is not None:
self.load_models_to_device(['ipadapter_image_encoder'])
ipadapter_images = self.prepare_ipadapter_inputs(ipadapter_images)
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images).pooler_output
self.load_models_to_device(['ipadapter'])
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
return ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega
def prepare_controlnet(self, controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative):
if controlnet_image is not None:
self.load_models_to_device(['vae_encoder'])
controlnet_kwargs_posi = {"controlnet_frames": self.prepare_controlnet_input(controlnet_image, controlnet_inpaint_mask, tiler_kwargs)}
if len(masks) > 0 and controlnet_inpaint_mask is not None:
print("The controlnet_inpaint_mask will be overridden by masks.")
local_controlnet_kwargs = [{"controlnet_frames": self.prepare_controlnet_input(controlnet_image, mask, tiler_kwargs)} for mask in masks]
else:
local_controlnet_kwargs = None
else:
controlnet_kwargs_posi, local_controlnet_kwargs = {"controlnet_frames": None}, [{}] * len(masks)
controlnet_kwargs_nega = controlnet_kwargs_posi if enable_controlnet_on_negative else {}
return controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs
def prepare_eligen(self, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale):
if eligen_entity_masks is not None:
entity_prompt_emb_posi, entity_masks_posi, fg_mask, bg_mask = self.prepare_entity_inputs(eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint)
if enable_eligen_on_negative and cfg_scale != 1.0:
entity_prompt_emb_nega = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, entity_masks_posi.shape[1], 1, 1)
entity_masks_nega = entity_masks_posi
else:
entity_prompt_emb_nega, entity_masks_nega = None, None
else:
entity_prompt_emb_posi, entity_masks_posi, entity_prompt_emb_nega, entity_masks_nega = None, None, None, None
fg_mask, bg_mask = None, None
eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi}
eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega}
return eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask
def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale):
# Extend prompt
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
# Encode prompts
prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length)
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
@torch.no_grad()
def __call__(
self,
# Prompt
prompt,
negative_prompt="",
cfg_scale=1.0,
embedded_guidance=3.5,
t5_sequence_length=512,
# Image
input_image=None,
denoising_strength=1.0,
height=1024,
width=1024,
seed=None,
# Steps
num_inference_steps=30,
# local prompts
local_prompts=(),
masks=(),
mask_scales=(),
# ControlNet
controlnet_image=None,
controlnet_inpaint_mask=None,
enable_controlnet_on_negative=False,
# IP-Adapter
ipadapter_images=None,
ipadapter_scale=1.0,
# EliGen
eligen_entity_prompts=None,
eligen_entity_masks=None,
enable_eligen_on_negative=False,
enable_eligen_inpaint=False,
# TeaCache
tea_cache_l1_thresh=None,
# Tile
tiled=False,
tile_size=128,
tile_stride=64,
# Progress bar
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
height, width = self.check_resize_height_width(height, width)
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
latents, input_latents = self.prepare_latents(input_image, height, width, seed, tiled, tile_size, tile_stride)
# Prompt
prompt_emb_posi, prompt_emb_nega, prompt_emb_locals = self.prepare_prompts(prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale)
# Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# Entity control
eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale)
# IP-Adapter
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = self.prepare_ipadapter(ipadapter_images, ipadapter_scale)
# ControlNets
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
# TeaCache
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
# Denoise
self.load_models_to_device(['dit', 'controlnet'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Positive side
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs,
)
noise_pred_posi = self.control_noise_via_local_prompts(
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
special_kwargs=controlnet_kwargs_posi, special_local_kwargs_list=local_controlnet_kwargs
)
# Inpaint
if enable_eligen_inpaint:
noise_pred_posi = self.inpaint_fusion(latents, input_latents, noise_pred_posi, fg_mask, bg_mask, progress_id)
# Classifier-free guidance
if cfg_scale != 1.0:
# Negative side
noise_pred_nega = lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# Iterate
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents, **tiler_kwargs)
# Offload all models
self.load_models_to_device([])
return image
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh):
self.num_inference_steps = num_inference_steps
self.step = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.rel_l1_thresh = rel_l1_thresh
self.previous_residual = None
self.previous_hidden_states = None
def check(self, dit: FluxDiT, hidden_states, conditioning):
inp = hidden_states.clone()
temb_ = conditioning.clone()
modulated_inp, _, _, _, _ = dit.blocks[0].norm1_a(inp, emb=temb_)
if self.step == 0 or self.step == self.num_inference_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.step += 1
if self.step == self.num_inference_steps:
self.step = 0
if should_calc:
self.previous_hidden_states = hidden_states.clone()
return not should_calc
def store(self, hidden_states):
self.previous_residual = hidden_states - self.previous_hidden_states
self.previous_hidden_states = None
def update(self, hidden_states):
hidden_states = hidden_states + self.previous_residual
return hidden_states
def lets_dance_flux(
dit: FluxDiT,
controlnet: FluxMultiControlNetManager = None,
hidden_states=None,
timestep=None,
prompt_emb=None,
pooled_prompt_emb=None,
guidance=None,
text_ids=None,
image_ids=None,
controlnet_frames=None,
tiled=False,
tile_size=128,
tile_stride=64,
entity_prompt_emb=None,
entity_masks=None,
ipadapter_kwargs_list={},
tea_cache: TeaCache = None,
**kwargs
):
if tiled:
def flux_forward_fn(hl, hr, wl, wr):
tiled_controlnet_frames = [f[:, :, hl: hr, wl: wr] for f in controlnet_frames] if controlnet_frames is not None else None
return lets_dance_flux(
dit=dit,
controlnet=controlnet,
hidden_states=hidden_states[:, :, hl: hr, wl: wr],
timestep=timestep,
prompt_emb=prompt_emb,
pooled_prompt_emb=pooled_prompt_emb,
guidance=guidance,
text_ids=text_ids,
image_ids=None,
controlnet_frames=tiled_controlnet_frames,
tiled=False,
**kwargs
)
return FastTileWorker().tiled_forward(
flux_forward_fn,
hidden_states,
tile_size=tile_size,
tile_stride=tile_stride,
tile_device=hidden_states.device,
tile_dtype=hidden_states.dtype
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
controlnet_extra_kwargs = {
"hidden_states": hidden_states,
"timestep": timestep,
"prompt_emb": prompt_emb,
"pooled_prompt_emb": pooled_prompt_emb,
"guidance": guidance,
"text_ids": text_ids,
"image_ids": image_ids,
"tiled": tiled,
"tile_size": tile_size,
"tile_stride": tile_stride,
}
controlnet_res_stack, controlnet_single_res_stack = controlnet(
controlnet_frames, **controlnet_extra_kwargs
)
if image_ids is None:
image_ids = dit.prepare_image_ids(hidden_states)
conditioning = dit.time_embedder(timestep, hidden_states.dtype) + dit.pooled_text_embedder(pooled_prompt_emb)
if dit.guidance_embedder is not None:
guidance = guidance * 1000
conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype)
height, width = hidden_states.shape[-2:]
hidden_states = dit.patchify(hidden_states)
hidden_states = dit.x_embedder(hidden_states)
if entity_prompt_emb is not None and entity_masks is not None:
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
else:
prompt_emb = dit.context_embedder(prompt_emb)
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
attention_mask = None
# TeaCache
if tea_cache is not None:
tea_cache_update = tea_cache.check(dit, hidden_states, conditioning)
else:
tea_cache_update = False
if tea_cache_update:
hidden_states = tea_cache.update(hidden_states)
else:
# Joint Blocks
for block_id, block in enumerate(dit.blocks):
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states = hidden_states + controlnet_res_stack[block_id]
# Single Blocks
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
num_joint_blocks = len(dit.blocks)
for block_id, block in enumerate(dit.single_blocks):
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
if tea_cache is not None:
tea_cache.store(hidden_states)
hidden_states = dit.final_norm_out(hidden_states, conditioning)
hidden_states = dit.final_proj_out(hidden_states)
hidden_states = dit.unpatchify(hidden_states, height, width)
return hidden_states

View File

@@ -125,7 +125,7 @@ class ImageSizeManager:
class HunyuanDiTImagePipeline(BasePipeline): class HunyuanDiTImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16): def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype) super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
self.scheduler = EnhancedDDIMScheduler(prediction_type="v_prediction", beta_start=0.00085, beta_end=0.03) self.scheduler = EnhancedDDIMScheduler(prediction_type="v_prediction", beta_start=0.00085, beta_end=0.03)
self.prompter = HunyuanDiTPrompter() self.prompter = HunyuanDiTPrompter()
self.image_size_manager = ImageSizeManager() self.image_size_manager = ImageSizeManager()
@@ -135,6 +135,7 @@ class HunyuanDiTImagePipeline(BasePipeline):
self.dit: HunyuanDiT = None self.dit: HunyuanDiT = None
self.vae_decoder: SDXLVAEDecoder = None self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None self.vae_encoder: SDXLVAEEncoder = None
self.model_names = ['text_encoder', 'text_encoder_t5', 'dit', 'vae_decoder', 'vae_encoder']
def denoising_model(self): def denoising_model(self):
@@ -153,9 +154,9 @@ class HunyuanDiTImagePipeline(BasePipeline):
@staticmethod @staticmethod
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]): def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None):
pipe = HunyuanDiTImagePipeline( pipe = HunyuanDiTImagePipeline(
device=model_manager.device, device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype, torch_dtype=model_manager.torch_dtype,
) )
pipe.fetch_models(model_manager, prompt_refiner_classes) pipe.fetch_models(model_manager, prompt_refiner_classes)
@@ -209,6 +210,9 @@ class HunyuanDiTImagePipeline(BasePipeline):
def __call__( def __call__(
self, self,
prompt, prompt,
local_prompts=[],
masks=[],
mask_scales=[],
negative_prompt="", negative_prompt="",
cfg_scale=7.5, cfg_scale=7.5,
clip_skip=1, clip_skip=1,
@@ -222,15 +226,19 @@ class HunyuanDiTImagePipeline(BasePipeline):
tiled=False, tiled=False,
tile_size=64, tile_size=64,
tile_stride=32, tile_stride=32,
seed=None,
progress_bar_cmd=tqdm, progress_bar_cmd=tqdm,
progress_bar_st=None, progress_bar_st=None,
): ):
height, width = self.check_resize_height_width(height, width)
# Prepare scheduler # Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength) self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors # Prepare latent tensors
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
if input_image is not None: if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=torch.float32) image = self.preprocess_image(input_image).to(device=self.device, dtype=torch.float32)
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype) latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
@@ -238,21 +246,24 @@ class HunyuanDiTImagePipeline(BasePipeline):
latents = noise.clone() latents = noise.clone()
# Encode prompts # Encode prompts
self.load_models_to_device(['text_encoder', 'text_encoder_t5'])
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
if cfg_scale != 1.0: if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) for prompt_local in local_prompts]
# Prepare positional id # Prepare positional id
extra_input = self.prepare_extra_input(latents, tiled, tile_size) extra_input = self.prepare_extra_input(latents, tiled, tile_size)
# Denoise # Denoise
self.load_models_to_device(['dit'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device) timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device)
# Positive side # Positive side
noise_pred_posi = self.dit( inference_callback = lambda prompt_emb_posi: self.dit(latents, timestep=timestep, **prompt_emb_posi, **extra_input)
latents, timestep=timestep, **prompt_emb_posi, **extra_input, noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
)
if cfg_scale != 1.0: if cfg_scale != 1.0:
# Negative side # Negative side
noise_pred_nega = self.dit( noise_pred_nega = self.dit(
@@ -269,6 +280,9 @@ class HunyuanDiTImagePipeline(BasePipeline):
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image # Decode image
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
# Offload all models
self.load_models_to_device([])
return image return image

View File

@@ -0,0 +1,265 @@
from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder
from ..models.hunyuan_video_dit import HunyuanVideoDiT
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
from ..prompters import HunyuanVideoPrompter
import torch
from einops import rearrange
import numpy as np
from PIL import Image
from tqdm import tqdm
class HunyuanVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler(shift=7.0, sigma_min=0.0, extra_one_step=True)
self.prompter = HunyuanVideoPrompter()
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: HunyuanVideoLLMEncoder = None
self.dit: HunyuanVideoDiT = None
self.vae_decoder: HunyuanVideoVAEDecoder = None
self.vae_encoder: HunyuanVideoVAEEncoder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder']
self.vram_management = False
def enable_vram_management(self):
self.vram_management = True
self.enable_cpu_offload()
self.text_encoder_2.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
self.dit.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
def fetch_models(self, model_manager: ModelManager):
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
self.text_encoder_2 = model_manager.fetch_model("hunyuan_video_text_encoder_2")
self.dit = model_manager.fetch_model("hunyuan_video_dit")
self.vae_decoder = model_manager.fetch_model("hunyuan_video_vae_decoder")
self.vae_encoder = model_manager.fetch_model("hunyuan_video_vae_encoder")
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
@staticmethod
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, enable_vram_management=True):
if device is None: device = model_manager.device
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
pipe = HunyuanVideoPipeline(device=device, torch_dtype=torch_dtype)
pipe.fetch_models(model_manager)
if enable_vram_management:
pipe.enable_vram_management()
return pipe
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256):
prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length
)
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask}
def prepare_extra_input(self, latents=None, guidance=1.0):
freqs_cos, freqs_sin = self.dit.prepare_freqs(latents)
guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
return {"freqs_cos": freqs_cos, "freqs_sin": freqs_sin, "guidance": guidance}
def tensor2video(self, frames):
frames = rearrange(frames, "C T H W -> T H W C")
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
frames = [Image.fromarray(frame) for frame in frames]
return frames
def encode_video(self, frames, tile_size=(17, 30, 30), tile_stride=(12, 20, 20)):
tile_size = ((tile_size[0] - 1) * 4 + 1, tile_size[1] * 8, tile_size[2] * 8)
tile_stride = (tile_stride[0] * 4, tile_stride[1] * 8, tile_stride[2] * 8)
latents = self.vae_encoder.encode_video(frames, tile_size=tile_size, tile_stride=tile_stride)
return latents
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
input_video=None,
denoising_strength=1.0,
seed=None,
rand_device=None,
height=720,
width=1280,
num_frames=129,
embedded_guidance=6.0,
cfg_scale=1.0,
num_inference_steps=30,
tea_cache_l1_thresh=None,
tile_size=(17, 30, 30),
tile_stride=(12, 20, 20),
step_processor=None,
progress_bar_cmd=lambda x: x,
progress_bar_st=None,
):
# Tiler parameters
tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Initialize noise
rand_device = self.device if rand_device is None else rand_device
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device)
if input_video is not None:
self.load_models_to_device(['vae_encoder'])
input_video = self.preprocess_images(input_video)
input_video = torch.stack(input_video, dim=2)
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = noise
# Encode prompts
self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"])
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
# Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# TeaCache
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
# Denoise
self.load_models_to_device([] if self.vram_management else ["dit"])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
# Inference
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
noise_pred_posi = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
if cfg_scale != 1.0:
noise_pred_nega = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# (Experimental feature, may be removed in the future)
if step_processor is not None:
self.load_models_to_device(['vae_decoder'])
rendered_frames = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents, to_final=True)
rendered_frames = self.vae_decoder.decode_video(rendered_frames, **tiler_kwargs)
rendered_frames = self.tensor2video(rendered_frames[0])
rendered_frames = step_processor(rendered_frames, original_frames=input_video)
self.load_models_to_device(['vae_encoder'])
rendered_frames = self.preprocess_images(rendered_frames)
rendered_frames = torch.stack(rendered_frames, dim=2)
target_latents = self.encode_video(rendered_frames).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.scheduler.return_to_timestep(self.scheduler.timesteps[progress_id], latents, target_latents)
self.load_models_to_device([] if self.vram_management else ["dit"])
# Scheduler
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# Decode
self.load_models_to_device(['vae_decoder'])
frames = self.vae_decoder.decode_video(latents, **tiler_kwargs)
self.load_models_to_device([])
frames = self.tensor2video(frames[0])
return frames
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh):
self.num_inference_steps = num_inference_steps
self.step = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.rel_l1_thresh = rel_l1_thresh
self.previous_residual = None
self.previous_hidden_states = None
def check(self, dit: HunyuanVideoDiT, img, vec):
img_ = img.clone()
vec_ = vec.clone()
img_mod1_shift, img_mod1_scale, _, _, _, _ = dit.double_blocks[0].component_a.mod(vec_).chunk(6, dim=-1)
normed_inp = dit.double_blocks[0].component_a.norm1(img_)
modulated_inp = normed_inp * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1)
if self.step == 0 or self.step == self.num_inference_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.step += 1
if self.step == self.num_inference_steps:
self.step = 0
if should_calc:
self.previous_hidden_states = img.clone()
return not should_calc
def store(self, hidden_states):
self.previous_residual = hidden_states - self.previous_hidden_states
self.previous_hidden_states = None
def update(self, hidden_states):
hidden_states = hidden_states + self.previous_residual
return hidden_states
def lets_dance_hunyuan_video(
dit: HunyuanVideoDiT,
x: torch.Tensor,
t: torch.Tensor,
prompt_emb: torch.Tensor = None,
text_mask: torch.Tensor = None,
pooled_prompt_emb: torch.Tensor = None,
freqs_cos: torch.Tensor = None,
freqs_sin: torch.Tensor = None,
guidance: torch.Tensor = None,
tea_cache: TeaCache = None,
**kwargs
):
B, C, T, H, W = x.shape
vec = dit.time_in(t, dtype=torch.float32) + dit.vector_in(pooled_prompt_emb) + dit.guidance_in(guidance * 1000, dtype=torch.float32)
img = dit.img_in(x)
txt = dit.txt_in(prompt_emb, t, text_mask)
# TeaCache
if tea_cache is not None:
tea_cache_update = tea_cache.check(dit, img, vec)
else:
tea_cache_update = False
if tea_cache_update:
print("TeaCache skip forward.")
img = tea_cache.update(img)
else:
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
x = torch.concat([img, txt], dim=1)
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
x = block(x, vec, (freqs_cos, freqs_sin))
img = x[:, :-256]
if tea_cache is not None:
tea_cache.store(img)
img = dit.final_layer(img, vec)
img = dit.unpatchify(img, T=T//1, H=H//2, W=W//2)
return img

View File

@@ -0,0 +1,289 @@
from ..models.omnigen import OmniGenTransformer
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
from ..models.model_manager import ModelManager
from ..prompters.omnigen_prompter import OmniGenPrompter
from ..schedulers import FlowMatchScheduler
from .base import BasePipeline
from typing import Optional, Dict, Any, Tuple, List
from transformers.cache_utils import DynamicCache
import torch, os
from tqdm import tqdm
class OmniGenCache(DynamicCache):
def __init__(self,
num_tokens_for_img: int, offload_kv_cache: bool=False) -> None:
if not torch.cuda.is_available():
print("No avaliable GPU, offload_kv_cache wiil be set to False, which will result in large memory usage and time cost when input multiple images!!!")
offload_kv_cache = False
raise RuntimeError("OffloadedCache can only be used with a GPU")
super().__init__()
self.original_device = []
self.prefetch_stream = torch.cuda.Stream()
self.num_tokens_for_img = num_tokens_for_img
self.offload_kv_cache = offload_kv_cache
def prefetch_layer(self, layer_idx: int):
"Starts prefetching the next layer cache"
if layer_idx < len(self):
with torch.cuda.stream(self.prefetch_stream):
# Prefetch next layer tensors to GPU
device = self.original_device[layer_idx]
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
def evict_previous_layer(self, layer_idx: int):
"Moves the previous layer cache to the CPU"
if len(self) > 2:
# We do it on the default stream so it occurs after all earlier computations on these tensors are done
if layer_idx == 0:
prev_layer_idx = -1
else:
prev_layer_idx = (layer_idx - 1) % len(self)
self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
if layer_idx < len(self):
if self.offload_kv_cache:
# Evict the previous layer if necessary
torch.cuda.current_stream().synchronize()
self.evict_previous_layer(layer_idx)
# Load current layer cache to its original device if not already there
original_device = self.original_device[layer_idx]
# self.prefetch_stream.synchronize(original_device)
torch.cuda.synchronize(self.prefetch_stream)
key_tensor = self.key_cache[layer_idx]
value_tensor = self.value_cache[layer_idx]
# Prefetch the next layer
self.prefetch_layer((layer_idx + 1) % len(self))
else:
key_tensor = self.key_cache[layer_idx]
value_tensor = self.value_cache[layer_idx]
return (key_tensor, value_tensor)
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
Return:
A tuple containing the updated key and value states.
"""
# Update the cache
if len(self.key_cache) < layer_idx:
raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
elif len(self.key_cache) == layer_idx:
# only cache the states for condition tokens
key_states = key_states[..., :-(self.num_tokens_for_img+1), :]
value_states = value_states[..., :-(self.num_tokens_for_img+1), :]
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
self.key_cache.append(key_states)
self.value_cache.append(value_states)
self.original_device.append(key_states.device)
if self.offload_kv_cache:
self.evict_previous_layer(layer_idx)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
else:
# only cache the states for condition tokens
key_tensor, value_tensor = self[layer_idx]
k = torch.cat([key_tensor, key_states], dim=-2)
v = torch.cat([value_tensor, value_states], dim=-2)
return k, v
class OmnigenImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler(num_train_timesteps=1, shift=1, inverse_timesteps=True, sigma_min=0, sigma_max=1)
# models
self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None
self.transformer: OmniGenTransformer = None
self.prompter: OmniGenPrompter = None
self.model_names = ['transformer', 'vae_decoder', 'vae_encoder']
def denoising_model(self):
return self.transformer
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
# Main models
self.transformer, model_path = model_manager.fetch_model("omnigen_transformer", require_model_path=True)
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
self.prompter = OmniGenPrompter.from_pretrained(os.path.dirname(model_path))
@staticmethod
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None):
pipe = OmnigenImagePipeline(
device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, prompt_refiner_classes=[])
return pipe
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def encode_images(self, images, tiled=False, tile_size=64, tile_stride=32):
latents = [self.encode_image(image.to(device=self.device), tiled, tile_size, tile_stride).to(self.torch_dtype) for image in images]
return latents
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.vae_output_to_image(image)
return image
def encode_prompt(self, prompt, clip_skip=1, positive=True):
prompt_emb = self.prompter.encode_prompt(prompt, clip_skip=clip_skip, device=self.device, positive=positive)
return {"encoder_hidden_states": prompt_emb}
def prepare_extra_input(self, latents=None):
return {}
def crop_position_ids_for_cache(self, position_ids, num_tokens_for_img):
if isinstance(position_ids, list):
for i in range(len(position_ids)):
position_ids[i] = position_ids[i][:, -(num_tokens_for_img+1):]
else:
position_ids = position_ids[:, -(num_tokens_for_img+1):]
return position_ids
def crop_attention_mask_for_cache(self, attention_mask, num_tokens_for_img):
if isinstance(attention_mask, list):
return [x[..., -(num_tokens_for_img+1):, :] for x in attention_mask]
return attention_mask[..., -(num_tokens_for_img+1):, :]
@torch.no_grad()
def __call__(
self,
prompt,
reference_images=[],
cfg_scale=2.0,
image_cfg_scale=2.0,
use_kv_cache=True,
offload_kv_cache=True,
input_image=None,
denoising_strength=1.0,
height=1024,
width=1024,
num_inference_steps=20,
tiled=False,
tile_size=64,
tile_stride=32,
seed=None,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
height, width = self.check_resize_height_width(height, width)
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs)
noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
latents = latents.repeat(3, 1, 1, 1)
# Encode prompts
input_data = self.prompter(prompt, reference_images, height=height, width=width, use_img_cfg=True, separate_cfg_input=True, use_input_image_size_as_output=False)
# Encode images
reference_latents = [self.encode_images(images, **tiler_kwargs) for images in input_data['input_pixel_values']]
# Pack all parameters
model_kwargs = dict(input_ids=[input_ids.to(self.device) for input_ids in input_data['input_ids']],
input_img_latents=reference_latents,
input_image_sizes=input_data['input_image_sizes'],
attention_mask=[attention_mask.to(self.device) for attention_mask in input_data["attention_mask"]],
position_ids=[position_ids.to(self.device) for position_ids in input_data["position_ids"]],
cfg_scale=cfg_scale,
img_cfg_scale=image_cfg_scale,
use_img_cfg=True,
use_kv_cache=use_kv_cache,
offload_model=False,
)
# Denoise
self.load_models_to_device(['transformer'])
cache = [OmniGenCache(latents.size(-1)*latents.size(-2) // 4, offload_kv_cache) for _ in range(len(model_kwargs['input_ids']))] if use_kv_cache else None
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).repeat(latents.shape[0]).to(self.device)
# Forward
noise_pred, cache = self.transformer.forward_with_separate_cfg(latents, timestep, past_key_values=cache, **model_kwargs)
# Scheduler
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# Update KV cache
if progress_id == 0 and use_kv_cache:
num_tokens_for_img = latents.size(-1)*latents.size(-2) // 4
if isinstance(cache, list):
model_kwargs['input_ids'] = [None] * len(cache)
else:
model_kwargs['input_ids'] = None
model_kwargs['position_ids'] = self.crop_position_ids_for_cache(model_kwargs['position_ids'], num_tokens_for_img)
model_kwargs['attention_mask'] = self.crop_attention_mask_for_cache(model_kwargs['attention_mask'], num_tokens_for_img)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
del cache
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
# offload all models
self.load_models_to_device([])
return image

View File

@@ -10,7 +10,7 @@ from tqdm import tqdm
class SD3ImagePipeline(BasePipeline): class SD3ImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16): def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype) super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
self.scheduler = FlowMatchScheduler() self.scheduler = FlowMatchScheduler()
self.prompter = SD3Prompter() self.prompter = SD3Prompter()
# models # models
@@ -20,6 +20,7 @@ class SD3ImagePipeline(BasePipeline):
self.dit: SD3DiT = None self.dit: SD3DiT = None
self.vae_decoder: SD3VAEDecoder = None self.vae_decoder: SD3VAEDecoder = None
self.vae_encoder: SD3VAEEncoder = None self.vae_encoder: SD3VAEEncoder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'text_encoder_3', 'dit', 'vae_decoder', 'vae_encoder']
def denoising_model(self): def denoising_model(self):
@@ -29,7 +30,6 @@ class SD3ImagePipeline(BasePipeline):
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]): def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1") self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
self.text_encoder_2 = model_manager.fetch_model("sd3_text_encoder_2") self.text_encoder_2 = model_manager.fetch_model("sd3_text_encoder_2")
if "sd3_text_encoder_3" in model_manager.model:
self.text_encoder_3 = model_manager.fetch_model("sd3_text_encoder_3") self.text_encoder_3 = model_manager.fetch_model("sd3_text_encoder_3")
self.dit = model_manager.fetch_model("sd3_dit") self.dit = model_manager.fetch_model("sd3_dit")
self.vae_decoder = model_manager.fetch_model("sd3_vae_decoder") self.vae_decoder = model_manager.fetch_model("sd3_vae_decoder")
@@ -39,9 +39,9 @@ class SD3ImagePipeline(BasePipeline):
@staticmethod @staticmethod
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]): def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None):
pipe = SD3ImagePipeline( pipe = SD3ImagePipeline(
device=model_manager.device, device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype, torch_dtype=model_manager.torch_dtype,
) )
pipe.fetch_models(model_manager, prompt_refiner_classes) pipe.fetch_models(model_manager, prompt_refiner_classes)
@@ -59,9 +59,9 @@ class SD3ImagePipeline(BasePipeline):
return image return image
def encode_prompt(self, prompt, positive=True): def encode_prompt(self, prompt, positive=True, t5_sequence_length=77):
prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt( prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
) )
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb} return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb}
@@ -74,6 +74,9 @@ class SD3ImagePipeline(BasePipeline):
def __call__( def __call__(
self, self,
prompt, prompt,
local_prompts=[],
masks=[],
mask_scales=[],
negative_prompt="", negative_prompt="",
cfg_scale=7.5, cfg_scale=7.5,
input_image=None, input_image=None,
@@ -81,12 +84,16 @@ class SD3ImagePipeline(BasePipeline):
height=1024, height=1024,
width=1024, width=1024,
num_inference_steps=20, num_inference_steps=20,
t5_sequence_length=77,
tiled=False, tiled=False,
tile_size=128, tile_size=128,
tile_stride=64, tile_stride=64,
seed=None,
progress_bar_cmd=tqdm, progress_bar_cmd=tqdm,
progress_bar_st=None, progress_bar_st=None,
): ):
height, width = self.check_resize_height_width(height, width)
# Tiler parameters # Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
@@ -95,25 +102,30 @@ class SD3ImagePipeline(BasePipeline):
# Prepare latent tensors # Prepare latent tensors
if input_image is not None: if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs) latents = self.encode_image(image, **tiler_kwargs)
noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype) noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else: else:
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype) latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# Encode prompts # Encode prompts
prompt_emb_posi = self.encode_prompt(prompt, positive=True) self.load_models_to_device(['text_encoder_1', 'text_encoder_2', 'text_encoder_3'])
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False) prompt_emb_posi = self.encode_prompt(prompt, positive=True, t5_sequence_length=t5_sequence_length)
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length)
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
# Denoise # Denoise
self.load_models_to_device(['dit'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device) timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance # Classifier-free guidance
noise_pred_posi = self.dit( inference_callback = lambda prompt_emb_posi: self.dit(
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs,
) )
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
noise_pred_nega = self.dit( noise_pred_nega = self.dit(
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs,
) )
@@ -127,6 +139,9 @@ class SD3ImagePipeline(BasePipeline):
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image # Decode image
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
# offload all models
self.load_models_to_device([])
return image return image

View File

@@ -25,6 +25,7 @@ class SDImagePipeline(BasePipeline):
self.controlnet: MultiControlNetManager = None self.controlnet: MultiControlNetManager = None
self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
self.ipadapter: SDIpAdapter = None self.ipadapter: SDIpAdapter = None
self.model_names = ['text_encoder', 'unet', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter_image_encoder', 'ipadapter']
def denoising_model(self): def denoising_model(self):
@@ -57,9 +58,9 @@ class SDImagePipeline(BasePipeline):
@staticmethod @staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]): def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], device=None):
pipe = SDImagePipeline( pipe = SDImagePipeline(
device=model_manager.device, device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype, torch_dtype=model_manager.torch_dtype,
) )
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes=[]) pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes=[])
@@ -90,6 +91,9 @@ class SDImagePipeline(BasePipeline):
def __call__( def __call__(
self, self,
prompt, prompt,
local_prompts=[],
masks=[],
mask_scales=[],
negative_prompt="", negative_prompt="",
cfg_scale=7.5, cfg_scale=7.5,
clip_skip=1, clip_skip=1,
@@ -104,9 +108,12 @@ class SDImagePipeline(BasePipeline):
tiled=False, tiled=False,
tile_size=64, tile_size=64,
tile_stride=32, tile_stride=32,
seed=None,
progress_bar_cmd=tqdm, progress_bar_cmd=tqdm,
progress_bar_st=None, progress_bar_st=None,
): ):
height, width = self.check_resize_height_width(height, width)
# Tiler parameters # Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
@@ -115,20 +122,25 @@ class SDImagePipeline(BasePipeline):
# Prepare latent tensors # Prepare latent tensors
if input_image is not None: if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs) latents = self.encode_image(image, **tiler_kwargs)
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else: else:
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) latents = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# Encode prompts # Encode prompts
self.load_models_to_device(['text_encoder'])
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True) prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False) prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, positive=True) for prompt_local in local_prompts]
# IP-Adapter # IP-Adapter
if ipadapter_images is not None: if ipadapter_images is not None:
self.load_models_to_device(['ipadapter_image_encoder'])
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images) ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
self.load_models_to_device(['ipadapter'])
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)} ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))} ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
else: else:
@@ -136,6 +148,7 @@ class SDImagePipeline(BasePipeline):
# Prepare ControlNets # Prepare ControlNets
if controlnet_image is not None: if controlnet_image is not None:
self.load_models_to_device(['controlnet'])
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype) controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
controlnet_image = controlnet_image.unsqueeze(1) controlnet_image = controlnet_image.unsqueeze(1)
controlnet_kwargs = {"controlnet_frames": controlnet_image} controlnet_kwargs = {"controlnet_frames": controlnet_image}
@@ -143,16 +156,18 @@ class SDImagePipeline(BasePipeline):
controlnet_kwargs = {"controlnet_frames": None} controlnet_kwargs = {"controlnet_frames": None}
# Denoise # Denoise
self.load_models_to_device(['controlnet', 'unet'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device) timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance # Classifier-free guidance
noise_pred_posi = lets_dance( inference_callback = lambda prompt_emb_posi: lets_dance(
self.unet, motion_modules=None, controlnet=self.controlnet, self.unet, motion_modules=None, controlnet=self.controlnet,
sample=latents, timestep=timestep, sample=latents, timestep=timestep,
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi, **prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
device=self.device, device=self.device,
) )
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
noise_pred_nega = lets_dance( noise_pred_nega = lets_dance(
self.unet, motion_modules=None, controlnet=self.controlnet, self.unet, motion_modules=None, controlnet=self.controlnet,
sample=latents, timestep=timestep, **prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega, sample=latents, timestep=timestep, **prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
@@ -168,6 +183,9 @@ class SDImagePipeline(BasePipeline):
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image # Decode image
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
# offload all models
self.load_models_to_device([])
return image return image

View File

@@ -166,9 +166,12 @@ class SDVideoPipeline(SDImagePipeline):
tiled=False, tiled=False,
tile_size=64, tile_size=64,
tile_stride=32, tile_stride=32,
seed=None,
progress_bar_cmd=tqdm, progress_bar_cmd=tqdm,
progress_bar_st=None, progress_bar_st=None,
): ):
height, width = self.check_resize_height_width(height, width)
# Tiler parameters, batch size ... # Tiler parameters, batch size ...
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
other_kwargs = { other_kwargs = {
@@ -182,9 +185,9 @@ class SDVideoPipeline(SDImagePipeline):
# Prepare latent tensors # Prepare latent tensors
if self.motion_modules is None: if self.motion_modules is None:
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1) noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
else: else:
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype) noise = self.generate_noise((num_frames, 4, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype)
if input_frames is None or denoising_strength == 1.0: if input_frames is None or denoising_strength == 1.0:
latents = noise latents = noise
else: else:

View File

@@ -9,6 +9,7 @@ from .dancer import lets_dance_xl
from typing import List from typing import List
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from einops import repeat
@@ -25,9 +26,10 @@ class SDXLImagePipeline(BasePipeline):
self.unet: SDXLUNet = None self.unet: SDXLUNet = None
self.vae_decoder: SDXLVAEDecoder = None self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None self.vae_encoder: SDXLVAEEncoder = None
# self.controlnet: MultiControlNetManager = None (TODO) self.controlnet: MultiControlNetManager = None
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
self.ipadapter: SDXLIpAdapter = None self.ipadapter: SDXLIpAdapter = None
self.model_names = ['text_encoder', 'text_encoder_2', 'text_encoder_kolors', 'unet', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter_image_encoder', 'ipadapter']
def denoising_model(self): def denoising_model(self):
@@ -43,7 +45,16 @@ class SDXLImagePipeline(BasePipeline):
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder") self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder") self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
# ControlNets (TODO) # ControlNets
controlnet_units = []
for config in controlnet_config_units:
controlnet_unit = ControlNetUnit(
Annotator(config.processor_id, device=self.device),
model_manager.fetch_model("sdxl_controlnet", config.model_path),
config.scale
)
controlnet_units.append(controlnet_unit)
self.controlnet = MultiControlNetManager(controlnet_units)
# IP-Adapters # IP-Adapters
self.ipadapter = model_manager.fetch_model("sdxl_ipadapter") self.ipadapter = model_manager.fetch_model("sdxl_ipadapter")
@@ -61,9 +72,9 @@ class SDXLImagePipeline(BasePipeline):
@staticmethod @staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]): def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], device=None):
pipe = SDXLImagePipeline( pipe = SDXLImagePipeline(
device=model_manager.device, device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype, torch_dtype=model_manager.torch_dtype,
) )
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes) pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
@@ -93,13 +104,17 @@ class SDXLImagePipeline(BasePipeline):
def prepare_extra_input(self, latents=None): def prepare_extra_input(self, latents=None):
height, width = latents.shape[2] * 8, latents.shape[3] * 8 height, width = latents.shape[2] * 8, latents.shape[3] * 8
return {"add_time_id": torch.tensor([height, width, 0, 0, height, width], device=self.device)} add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device).repeat(latents.shape[0])
return {"add_time_id": add_time_id}
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
prompt, prompt,
local_prompts=[],
masks=[],
mask_scales=[],
negative_prompt="", negative_prompt="",
cfg_scale=7.5, cfg_scale=7.5,
clip_skip=1, clip_skip=1,
@@ -116,9 +131,12 @@ class SDXLImagePipeline(BasePipeline):
tiled=False, tiled=False,
tile_size=64, tile_size=64,
tile_stride=32, tile_stride=32,
seed=None,
progress_bar_cmd=tqdm, progress_bar_cmd=tqdm,
progress_bar_st=None, progress_bar_st=None,
): ):
height, width = self.check_resize_height_width(height, width)
# Tiler parameters # Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
@@ -127,16 +145,19 @@ class SDXLImagePipeline(BasePipeline):
# Prepare latent tensors # Prepare latent tensors
if input_image is not None: if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs) latents = self.encode_image(image, **tiler_kwargs)
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else: else:
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) latents = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# Encode prompts # Encode prompts
self.load_models_to_device(['text_encoder', 'text_encoder_2', 'text_encoder_kolors'])
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=False) prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=False)
prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) for prompt_local in local_prompts]
# IP-Adapter # IP-Adapter
if ipadapter_images is not None: if ipadapter_images is not None:
@@ -144,32 +165,43 @@ class SDXLImagePipeline(BasePipeline):
self.ipadapter.set_less_adapter() self.ipadapter.set_less_adapter()
else: else:
self.ipadapter.set_full_adapter() self.ipadapter.set_full_adapter()
self.load_models_to_device(['ipadapter_image_encoder'])
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images) ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
self.load_models_to_device(['ipadapter'])
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)} ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))} ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
else: else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}} ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
# Prepare ControlNets (TODO) # Prepare ControlNets
if controlnet_image is not None:
self.load_models_to_device(['controlnet'])
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
controlnet_image = controlnet_image.unsqueeze(1)
controlnet_kwargs = {"controlnet_frames": controlnet_image}
else:
controlnet_kwargs = {"controlnet_frames": None} controlnet_kwargs = {"controlnet_frames": None}
# Prepare extra input # Prepare extra input
extra_input = self.prepare_extra_input(latents) extra_input = self.prepare_extra_input(latents)
# Denoise # Denoise
self.load_models_to_device(['controlnet', 'unet'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device) timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance # Classifier-free guidance
noise_pred_posi = lets_dance_xl( inference_callback = lambda prompt_emb_posi: lets_dance_xl(
self.unet, motion_modules=None, controlnet=None, self.unet, motion_modules=None, controlnet=self.controlnet,
sample=latents, timestep=timestep, **extra_input, sample=latents, timestep=timestep, **extra_input,
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi, **prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
device=self.device, device=self.device,
) )
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
if cfg_scale != 1.0: if cfg_scale != 1.0:
noise_pred_nega = lets_dance_xl( noise_pred_nega = lets_dance_xl(
self.unet, motion_modules=None, controlnet=None, self.unet, motion_modules=None, controlnet=self.controlnet,
sample=latents, timestep=timestep, **extra_input, sample=latents, timestep=timestep, **extra_input,
**prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega, **prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
device=self.device, device=self.device,
@@ -186,6 +218,9 @@ class SDXLImagePipeline(BasePipeline):
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image # Decode image
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
# offload all models
self.load_models_to_device([])
return image return image

View File

@@ -120,9 +120,12 @@ class SDXLVideoPipeline(SDXLImagePipeline):
tiled=False, tiled=False,
tile_size=64, tile_size=64,
tile_stride=32, tile_stride=32,
seed=None,
progress_bar_cmd=tqdm, progress_bar_cmd=tqdm,
progress_bar_st=None, progress_bar_st=None,
): ):
height, width = self.check_resize_height_width(height, width)
# Tiler parameters, batch size ... # Tiler parameters, batch size ...
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
@@ -131,9 +134,9 @@ class SDXLVideoPipeline(SDXLImagePipeline):
# Prepare latent tensors # Prepare latent tensors
if self.motion_modules is None: if self.motion_modules is None:
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1) noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
else: else:
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype) noise = self.generate_noise((num_frames, 4, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype)
if input_frames is None or denoising_strength == 1.0: if input_frames is None or denoising_strength == 1.0:
latents = noise latents = noise
else: else:

View File

@@ -0,0 +1,209 @@
from ..models import ModelManager
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder
from ..models.stepvideo_text_encoder import STEP1TextEncoder
from ..models.stepvideo_dit import StepVideoModel
from ..models.stepvideo_vae import StepVideoVAE
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
from ..prompters import StepVideoPrompter
import torch
from einops import rearrange
import numpy as np
from PIL import Image
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
from transformers.models.bert.modeling_bert import BertEmbeddings
from ..models.stepvideo_dit import RMSNorm
from ..models.stepvideo_vae import CausalConv, CausalConvAfterNorm, Upsample2D, BaseGroupNorm
class StepVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler(sigma_min=0.0, extra_one_step=True, shift=13.0, reverse_sigmas=True, num_train_timesteps=1)
self.prompter = StepVideoPrompter()
self.text_encoder_1: HunyuanDiTCLIPTextEncoder = None
self.text_encoder_2: STEP1TextEncoder = None
self.dit: StepVideoModel = None
self.vae: StepVideoVAE = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae']
def enable_vram_management(self, num_persistent_param_in_dit=None):
dtype = next(iter(self.text_encoder_1.parameters())).dtype
enable_vram_management(
self.text_encoder_1,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
BertEmbeddings: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=torch.float32,
computation_device=self.device,
),
)
dtype = next(iter(self.text_encoder_2.parameters())).dtype
enable_vram_management(
self.text_encoder_2,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
RMSNorm: AutoWrappedModule,
torch.nn.Embedding: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
dtype = next(iter(self.dit.parameters())).dtype
enable_vram_management(
self.dit,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv2d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
RMSNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
dtype = next(iter(self.vae.parameters())).dtype
enable_vram_management(
self.vae,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
CausalConv: AutoWrappedModule,
CausalConvAfterNorm: AutoWrappedModule,
Upsample2D: AutoWrappedModule,
BaseGroupNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
self.enable_cpu_offload()
def fetch_models(self, model_manager: ModelManager):
self.text_encoder_1 = model_manager.fetch_model("hunyuan_dit_clip_text_encoder")
self.text_encoder_2 = model_manager.fetch_model("stepvideo_text_encoder_2")
self.dit = model_manager.fetch_model("stepvideo_dit")
self.vae = model_manager.fetch_model("stepvideo_vae")
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
@staticmethod
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
if device is None: device = model_manager.device
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
pipe = StepVideoPipeline(device=device, torch_dtype=torch_dtype)
pipe.fetch_models(model_manager)
return pipe
def encode_prompt(self, prompt, positive=True):
clip_embeds, llm_embeds, llm_mask = self.prompter.encode_prompt(prompt, device=self.device, positive=positive)
clip_embeds = clip_embeds.to(dtype=self.torch_dtype, device=self.device)
llm_embeds = llm_embeds.to(dtype=self.torch_dtype, device=self.device)
llm_mask = llm_mask.to(dtype=self.torch_dtype, device=self.device)
return {"encoder_hidden_states_2": clip_embeds, "encoder_hidden_states": llm_embeds, "encoder_attention_mask": llm_mask}
def tensor2video(self, frames):
frames = rearrange(frames, "C T H W -> T H W C")
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
frames = [Image.fromarray(frame) for frame in frames]
return frames
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
input_video=None,
denoising_strength=1.0,
seed=None,
rand_device="cpu",
height=544,
width=992,
num_frames=204,
cfg_scale=9.0,
num_inference_steps=30,
tiled=True,
tile_size=(34, 34),
tile_stride=(16, 16),
smooth_scale=0.6,
progress_bar_cmd=lambda x: x,
progress_bar_st=None,
):
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Initialize noise
latents = self.generate_noise((1, max(num_frames//17*3, 1), 64, height//16, width//16), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device)
# Encode prompts
self.load_models_to_device(["text_encoder_1", "text_encoder_2"])
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
# Denoise
self.load_models_to_device(["dit"])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
# Inference
noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi)
if cfg_scale != 1.0:
noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# Scheduler
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# Decode
self.load_models_to_device(['vae'])
frames = self.vae.decode(latents, device=self.device, smooth_scale=smooth_scale, **tiler_kwargs)
self.load_models_to_device([])
frames = self.tensor2video(frames[0])
return frames

View File

@@ -49,9 +49,9 @@ class SVDVideoPipeline(BasePipeline):
return image_emb return image_emb
def encode_image_with_vae(self, image, noise_aug_strength): def encode_image_with_vae(self, image, noise_aug_strength, seed=None):
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
noise = torch.randn(image.shape, device="cpu", dtype=self.torch_dtype).to(self.device) noise = self.generate_noise(image.shape, seed=seed, device=self.device, dtype=self.torch_dtype)
image = image + noise_aug_strength * noise image = image + noise_aug_strength * noise
image_emb = self.vae_encoder(image) / self.vae_encoder.scaling_factor image_emb = self.vae_encoder(image) / self.vae_encoder.scaling_factor
return image_emb return image_emb
@@ -126,14 +126,17 @@ class SVDVideoPipeline(BasePipeline):
num_inference_steps=20, num_inference_steps=20,
post_normalize=True, post_normalize=True,
contrast_enhance_scale=1.2, contrast_enhance_scale=1.2,
seed=None,
progress_bar_cmd=tqdm, progress_bar_cmd=tqdm,
progress_bar_st=None, progress_bar_st=None,
): ):
height, width = self.check_resize_height_width(height, width)
# Prepare scheduler # Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength) self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
# Prepare latent tensors # Prepare latent tensors
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).to(self.device) noise = self.generate_noise((num_frames, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
if denoising_strength == 1.0: if denoising_strength == 1.0:
latents = noise.clone() latents = noise.clone()
else: else:
@@ -147,7 +150,7 @@ class SVDVideoPipeline(BasePipeline):
# Encode image # Encode image
image_emb_clip_posi = self.encode_image_with_clip(input_image) image_emb_clip_posi = self.encode_image_with_clip(input_image)
image_emb_clip_nega = torch.zeros_like(image_emb_clip_posi) image_emb_clip_nega = torch.zeros_like(image_emb_clip_posi)
image_emb_vae_posi = repeat(self.encode_image_with_vae(input_image, noise_aug_strength), "B C H W -> (B T) C H W", T=num_frames) image_emb_vae_posi = repeat(self.encode_image_with_vae(input_image, noise_aug_strength, seed=seed), "B C H W -> (B T) C H W", T=num_frames)
image_emb_vae_nega = torch.zeros_like(image_emb_vae_posi) image_emb_vae_nega = torch.zeros_like(image_emb_vae_posi)
# Prepare classifier-free guidance # Prepare classifier-free guidance

View File

@@ -1,6 +1,12 @@
from .prompt_refiners import Translator, BeautifulPrompt from .prompt_refiners import Translator, BeautifulPrompt, QwenPrompt
from .sd_prompter import SDPrompter from .sd_prompter import SDPrompter
from .sdxl_prompter import SDXLPrompter from .sdxl_prompter import SDXLPrompter
from .sd3_prompter import SD3Prompter from .sd3_prompter import SD3Prompter
from .hunyuan_dit_prompter import HunyuanDiTPrompter from .hunyuan_dit_prompter import HunyuanDiTPrompter
from .kolors_prompter import KolorsPrompter from .kolors_prompter import KolorsPrompter
from .flux_prompter import FluxPrompter
from .omost import OmostPromter
from .cog_prompter import CogPrompter
from .hunyuan_video_prompter import HunyuanVideoPrompter
from .stepvideo_prompter import StepVideoPrompter
from .wanx_prompter import WanXPrompter

View File

@@ -37,15 +37,21 @@ def tokenize_long_prompt(tokenizer, prompt, max_length=None):
class BasePrompter: class BasePrompter:
def __init__(self, refiners=[]): def __init__(self):
self.refiners = refiners self.refiners = []
self.extenders = []
def load_prompt_refiners(self, model_nameger: ModelManager, refiner_classes=[]): def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):
for refiner_class in refiner_classes: for refiner_class in refiner_classes:
refiner = refiner_class.from_model_manager(model_nameger) refiner = refiner_class.from_model_manager(model_manager)
self.refiners.append(refiner) self.refiners.append(refiner)
def load_prompt_extenders(self,model_manager:ModelManager,extender_classes=[]):
for extender_class in extender_classes:
extender = extender_class.from_model_manager(model_manager)
self.extenders.append(extender)
@torch.no_grad() @torch.no_grad()
def process_prompt(self, prompt, positive=True): def process_prompt(self, prompt, positive=True):
@@ -55,3 +61,10 @@ class BasePrompter:
for refiner in self.refiners: for refiner in self.refiners:
prompt = refiner(prompt, positive=positive) prompt = refiner(prompt, positive=positive)
return prompt return prompt
@torch.no_grad()
def extend_prompt(self, prompt:str, positive=True):
extended_prompt = dict(prompt=prompt)
for extender in self.extenders:
extended_prompt = extender(extended_prompt)
return extended_prompt

View File

@@ -0,0 +1,46 @@
from .base_prompter import BasePrompter
from ..models.flux_text_encoder import FluxTextEncoder2
from transformers import T5TokenizerFast
import os
class CogPrompter(BasePrompter):
def __init__(
self,
tokenizer_path=None
):
if tokenizer_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_path = os.path.join(base_path, "tokenizer_configs/cog/tokenizer")
super().__init__()
self.tokenizer = T5TokenizerFast.from_pretrained(tokenizer_path)
self.text_encoder: FluxTextEncoder2 = None
def fetch_models(self, text_encoder: FluxTextEncoder2 = None):
self.text_encoder = text_encoder
def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device):
input_ids = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True,
).input_ids.to(device)
prompt_emb = text_encoder(input_ids)
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return prompt_emb
def encode_prompt(
self,
prompt,
positive=True,
device="cuda"
):
prompt = self.process_prompt(prompt, positive=positive)
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder, self.tokenizer, 226, device)
return prompt_emb

View File

@@ -0,0 +1,74 @@
from .base_prompter import BasePrompter
from ..models.flux_text_encoder import FluxTextEncoder2
from ..models.sd3_text_encoder import SD3TextEncoder1
from transformers import CLIPTokenizer, T5TokenizerFast
import os, torch
class FluxPrompter(BasePrompter):
def __init__(
self,
tokenizer_1_path=None,
tokenizer_2_path=None
):
if tokenizer_1_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_1_path = os.path.join(base_path, "tokenizer_configs/flux/tokenizer_1")
if tokenizer_2_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/flux/tokenizer_2")
super().__init__()
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
self.tokenizer_2 = T5TokenizerFast.from_pretrained(tokenizer_2_path)
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: FluxTextEncoder2 = None
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: FluxTextEncoder2 = None):
self.text_encoder_1 = text_encoder_1
self.text_encoder_2 = text_encoder_2
def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device):
input_ids = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True
).input_ids.to(device)
pooled_prompt_emb, _ = text_encoder(input_ids)
return pooled_prompt_emb
def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device):
input_ids = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True,
).input_ids.to(device)
prompt_emb = text_encoder(input_ids)
return prompt_emb
def encode_prompt(
self,
prompt,
positive=True,
device="cuda",
t5_sequence_length=512,
):
prompt = self.process_prompt(prompt, positive=positive)
# CLIP
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
# T5
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, t5_sequence_length, device)
# text_ids
text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype)
return prompt_emb, pooled_prompt_emb, text_ids

View File

@@ -0,0 +1,143 @@
from .base_prompter import BasePrompter
from ..models.sd3_text_encoder import SD3TextEncoder1
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
from transformers import CLIPTokenizer, LlamaTokenizerFast
import os, torch
PROMPT_TEMPLATE_ENCODE = (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
PROMPT_TEMPLATE_ENCODE_VIDEO = (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
PROMPT_TEMPLATE = {
"dit-llm-encode": {
"template": PROMPT_TEMPLATE_ENCODE,
"crop_start": 36,
},
"dit-llm-encode-video": {
"template": PROMPT_TEMPLATE_ENCODE_VIDEO,
"crop_start": 95,
},
}
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
class HunyuanVideoPrompter(BasePrompter):
def __init__(
self,
tokenizer_1_path=None,
tokenizer_2_path=None,
):
if tokenizer_1_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_1_path = os.path.join(
base_path, "tokenizer_configs/hunyuan_video/tokenizer_1")
if tokenizer_2_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_2_path = os.path.join(
base_path, "tokenizer_configs/hunyuan_video/tokenizer_2")
super().__init__()
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
self.tokenizer_2 = LlamaTokenizerFast.from_pretrained(tokenizer_2_path, padding_side='right')
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: HunyuanVideoLLMEncoder = None
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode']
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video']
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: HunyuanVideoLLMEncoder = None):
self.text_encoder_1 = text_encoder_1
self.text_encoder_2 = text_encoder_2
def apply_text_to_template(self, text, template):
assert isinstance(template, str)
if isinstance(text, list):
return [self.apply_text_to_template(text_) for text_ in text]
elif isinstance(text, str):
# Will send string to tokenizer. Used for llm
return template.format(text)
else:
raise TypeError(f"Unsupported prompt type: {type(text)}")
def encode_prompt_using_clip(self, prompt, max_length, device):
tokenized_result = self.tokenizer_1(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True
)
input_ids = tokenized_result.input_ids.to(device)
attention_mask = tokenized_result.attention_mask.to(device)
return self.text_encoder_1(input_ids=input_ids, extra_mask=attention_mask)[0]
def encode_prompt_using_llm(self,
prompt,
max_length,
device,
crop_start,
hidden_state_skip_layer=2,
use_attention_mask=True):
max_length += crop_start
inputs = self.tokenizer_2(prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True)
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)
last_hidden_state = self.text_encoder_2(input_ids, attention_mask, hidden_state_skip_layer)
# crop out
if crop_start > 0:
last_hidden_state = last_hidden_state[:, crop_start:]
attention_mask = (attention_mask[:, crop_start:] if use_attention_mask else None)
return last_hidden_state, attention_mask
def encode_prompt(self,
prompt,
positive=True,
device="cuda",
clip_sequence_length=77,
llm_sequence_length=256,
data_type='video',
use_template=True,
hidden_state_skip_layer=2,
use_attention_mask=True):
prompt = self.process_prompt(prompt, positive=positive)
# apply template
if use_template:
template = self.prompt_template_video if data_type == 'video' else self.prompt_template
prompt_formated = self.apply_text_to_template(prompt, template['template'])
else:
prompt_formated = prompt
# Text encoder
if data_type == 'video':
crop_start = self.prompt_template_video.get("crop_start", 0)
else:
crop_start = self.prompt_template.get("crop_start", 0)
# CLIP
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device)
# LLM
prompt_emb, attention_mask = self.encode_prompt_using_llm(
prompt_formated, llm_sequence_length, device, crop_start,
hidden_state_skip_layer, use_attention_mask)
return prompt_emb, pooled_prompt_emb, attention_mask

View File

@@ -245,6 +245,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None, pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None, return_attention_mask: Optional[bool] = None,
padding_side: Optional[str] = None,
) -> dict: ) -> dict:
""" """
Pad encoded inputs (on left/right and up to predefined length or max length in the batch) Pad encoded inputs (on left/right and up to predefined length or max length in the batch)

View File

@@ -0,0 +1,356 @@
import os
import re
from typing import Dict, List
import torch
from PIL import Image
from torchvision import transforms
from transformers import AutoTokenizer
from huggingface_hub import snapshot_download
import numpy as np
def crop_arr(pil_image, max_image_size):
while min(*pil_image.size) >= 2 * max_image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
if max(*pil_image.size) > max_image_size:
scale = max_image_size / max(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
if min(*pil_image.size) < 16:
scale = 16 / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y1 = (arr.shape[0] % 16) // 2
crop_y2 = arr.shape[0] % 16 - crop_y1
crop_x1 = (arr.shape[1] % 16) // 2
crop_x2 = arr.shape[1] % 16 - crop_x1
arr = arr[crop_y1:arr.shape[0]-crop_y2, crop_x1:arr.shape[1]-crop_x2]
return Image.fromarray(arr)
class OmniGenPrompter:
def __init__(self,
text_tokenizer,
max_image_size: int=1024):
self.text_tokenizer = text_tokenizer
self.max_image_size = max_image_size
self.image_transform = transforms.Compose([
transforms.Lambda(lambda pil_image: crop_arr(pil_image, max_image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
self.collator = OmniGenCollator()
self.separate_collator = OmniGenSeparateCollator()
@classmethod
def from_pretrained(cls, model_name):
if not os.path.exists(model_name):
cache_folder = os.getenv('HF_HUB_CACHE')
model_name = snapshot_download(repo_id=model_name,
cache_dir=cache_folder,
allow_patterns="*.json")
text_tokenizer = AutoTokenizer.from_pretrained(model_name)
return cls(text_tokenizer)
def process_image(self, image):
return self.image_transform(image)
def process_multi_modal_prompt(self, text, input_images):
text = self.add_prefix_instruction(text)
if input_images is None or len(input_images) == 0:
model_inputs = self.text_tokenizer(text)
return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}
pattern = r"<\|image_\d+\|>"
prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)]
for i in range(1, len(prompt_chunks)):
if prompt_chunks[i][0] == 1:
prompt_chunks[i] = prompt_chunks[i][1:]
image_tags = re.findall(pattern, text)
image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
unique_image_ids = sorted(list(set(image_ids)))
assert unique_image_ids == list(range(1, len(unique_image_ids)+1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
# total images must be the same as the number of image tags
assert len(unique_image_ids) == len(input_images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
input_images = [input_images[x-1] for x in image_ids]
all_input_ids = []
img_inx = []
idx = 0
for i in range(len(prompt_chunks)):
all_input_ids.extend(prompt_chunks[i])
if i != len(prompt_chunks) -1:
start_inx = len(all_input_ids)
size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16
img_inx.append([start_inx, start_inx+size])
all_input_ids.extend([0]*size)
return {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx}
def add_prefix_instruction(self, prompt):
user_prompt = '<|user|>\n'
generation_prompt = 'Generate an image according to the following instructions\n'
assistant_prompt = '<|assistant|>\n<|diffusion|>'
prompt_suffix = "<|end|>\n"
prompt = f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}"
return prompt
def __call__(self,
instructions: List[str],
input_images: List[List[str]] = None,
height: int = 1024,
width: int = 1024,
negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",
use_img_cfg: bool = True,
separate_cfg_input: bool = False,
use_input_image_size_as_output: bool=False,
) -> Dict:
if input_images is None:
use_img_cfg = False
if isinstance(instructions, str):
instructions = [instructions]
input_images = [input_images]
input_data = []
for i in range(len(instructions)):
cur_instruction = instructions[i]
cur_input_images = None if input_images is None else input_images[i]
if cur_input_images is not None and len(cur_input_images) > 0:
cur_input_images = [self.process_image(x) for x in cur_input_images]
else:
cur_input_images = None
assert "<img><|image_1|></img>" not in cur_instruction
mllm_input = self.process_multi_modal_prompt(cur_instruction, cur_input_images)
neg_mllm_input, img_cfg_mllm_input = None, None
neg_mllm_input = self.process_multi_modal_prompt(negative_prompt, None)
if use_img_cfg:
if cur_input_images is not None and len(cur_input_images) >= 1:
img_cfg_prompt = [f"<img><|image_{i+1}|></img>" for i in range(len(cur_input_images))]
img_cfg_mllm_input = self.process_multi_modal_prompt(" ".join(img_cfg_prompt), cur_input_images)
else:
img_cfg_mllm_input = neg_mllm_input
if use_input_image_size_as_output:
input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [mllm_input['pixel_values'][0].size(-2), mllm_input['pixel_values'][0].size(-1)]))
else:
input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
if separate_cfg_input:
return self.separate_collator(input_data)
return self.collator(input_data)
class OmniGenCollator:
def __init__(self, pad_token_id=2, hidden_size=3072):
self.pad_token_id = pad_token_id
self.hidden_size = hidden_size
def create_position(self, attention_mask, num_tokens_for_output_images):
position_ids = []
text_length = attention_mask.size(-1)
img_length = max(num_tokens_for_output_images)
for mask in attention_mask:
temp_l = torch.sum(mask)
temp_position = [0]*(text_length-temp_l) + [i for i in range(temp_l+img_length+1)] # we add a time embedding into the sequence, so add one more token
position_ids.append(temp_position)
return torch.LongTensor(position_ids)
def create_mask(self, attention_mask, num_tokens_for_output_images):
extended_mask = []
padding_images = []
text_length = attention_mask.size(-1)
img_length = max(num_tokens_for_output_images)
seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token
inx = 0
for mask in attention_mask:
temp_l = torch.sum(mask)
pad_l = text_length - temp_l
temp_mask = torch.tril(torch.ones(size=(temp_l+1, temp_l+1)))
image_mask = torch.zeros(size=(temp_l+1, img_length))
temp_mask = torch.cat([temp_mask, image_mask], dim=-1)
image_mask = torch.ones(size=(img_length, temp_l+img_length+1))
temp_mask = torch.cat([temp_mask, image_mask], dim=0)
if pad_l > 0:
pad_mask = torch.zeros(size=(temp_l+1+img_length, pad_l))
temp_mask = torch.cat([pad_mask, temp_mask], dim=-1)
pad_mask = torch.ones(size=(pad_l, seq_len))
temp_mask = torch.cat([pad_mask, temp_mask], dim=0)
true_img_length = num_tokens_for_output_images[inx]
pad_img_length = img_length - true_img_length
if pad_img_length > 0:
temp_mask[:, -pad_img_length:] = 0
temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size))
else:
temp_padding_imgs = None
extended_mask.append(temp_mask.unsqueeze(0))
padding_images.append(temp_padding_imgs)
inx += 1
return torch.cat(extended_mask, dim=0), padding_images
def adjust_attention_for_input_images(self, attention_mask, image_sizes):
for b_inx in image_sizes.keys():
for start_inx, end_inx in image_sizes[b_inx]:
attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1
return attention_mask
def pad_input_ids(self, input_ids, image_sizes):
max_l = max([len(x) for x in input_ids])
padded_ids = []
attention_mask = []
new_image_sizes = []
for i in range(len(input_ids)):
temp_ids = input_ids[i]
temp_l = len(temp_ids)
pad_l = max_l - temp_l
if pad_l == 0:
attention_mask.append([1]*max_l)
padded_ids.append(temp_ids)
else:
attention_mask.append([0]*pad_l+[1]*temp_l)
padded_ids.append([self.pad_token_id]*pad_l+temp_ids)
if i in image_sizes:
new_inx = []
for old_inx in image_sizes[i]:
new_inx.append([x+pad_l for x in old_inx])
image_sizes[i] = new_inx
return torch.LongTensor(padded_ids), torch.LongTensor(attention_mask), image_sizes
def process_mllm_input(self, mllm_inputs, target_img_size):
num_tokens_for_output_images = []
for img_size in target_img_size:
num_tokens_for_output_images.append(img_size[0]*img_size[1]//16//16)
pixel_values, image_sizes = [], {}
b_inx = 0
for x in mllm_inputs:
if x['pixel_values'] is not None:
pixel_values.extend(x['pixel_values'])
for size in x['image_sizes']:
if b_inx not in image_sizes:
image_sizes[b_inx] = [size]
else:
image_sizes[b_inx].append(size)
b_inx += 1
pixel_values = [x.unsqueeze(0) for x in pixel_values]
input_ids = [x['input_ids'] for x in mllm_inputs]
padded_input_ids, attention_mask, image_sizes = self.pad_input_ids(input_ids, image_sizes)
position_ids = self.create_position(attention_mask, num_tokens_for_output_images)
attention_mask, padding_images = self.create_mask(attention_mask, num_tokens_for_output_images)
attention_mask = self.adjust_attention_for_input_images(attention_mask, image_sizes)
return padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes
def __call__(self, features):
mllm_inputs = [f[0] for f in features]
cfg_mllm_inputs = [f[1] for f in features]
img_cfg_mllm_input = [f[2] for f in features]
target_img_size = [f[3] for f in features]
if img_cfg_mllm_input[0] is not None:
mllm_inputs = mllm_inputs + cfg_mllm_inputs + img_cfg_mllm_input
target_img_size = target_img_size + target_img_size + target_img_size
else:
mllm_inputs = mllm_inputs + cfg_mllm_inputs
target_img_size = target_img_size + target_img_size
all_padded_input_ids, all_position_ids, all_attention_mask, all_padding_images, all_pixel_values, all_image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
data = {"input_ids": all_padded_input_ids,
"attention_mask": all_attention_mask,
"position_ids": all_position_ids,
"input_pixel_values": all_pixel_values,
"input_image_sizes": all_image_sizes,
"padding_images": all_padding_images,
}
return data
class OmniGenSeparateCollator(OmniGenCollator):
def __call__(self, features):
mllm_inputs = [f[0] for f in features]
cfg_mllm_inputs = [f[1] for f in features]
img_cfg_mllm_input = [f[2] for f in features]
target_img_size = [f[3] for f in features]
all_padded_input_ids, all_attention_mask, all_position_ids, all_pixel_values, all_image_sizes, all_padding_images = [], [], [], [], [], []
padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
all_padded_input_ids.append(padded_input_ids)
all_attention_mask.append(attention_mask)
all_position_ids.append(position_ids)
all_pixel_values.append(pixel_values)
all_image_sizes.append(image_sizes)
all_padding_images.append(padding_images)
if cfg_mllm_inputs[0] is not None:
padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(cfg_mllm_inputs, target_img_size)
all_padded_input_ids.append(padded_input_ids)
all_attention_mask.append(attention_mask)
all_position_ids.append(position_ids)
all_pixel_values.append(pixel_values)
all_image_sizes.append(image_sizes)
all_padding_images.append(padding_images)
if img_cfg_mllm_input[0] is not None:
padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(img_cfg_mllm_input, target_img_size)
all_padded_input_ids.append(padded_input_ids)
all_attention_mask.append(attention_mask)
all_position_ids.append(position_ids)
all_pixel_values.append(pixel_values)
all_image_sizes.append(image_sizes)
all_padding_images.append(padding_images)
data = {"input_ids": all_padded_input_ids,
"attention_mask": all_attention_mask,
"position_ids": all_position_ids,
"input_pixel_values": all_pixel_values,
"input_image_sizes": all_image_sizes,
"padding_images": all_padding_images,
}
return data

View File

@@ -0,0 +1,323 @@
from transformers import AutoTokenizer, TextIteratorStreamer
import difflib
import torch
import numpy as np
import re
from ..models.model_manager import ModelManager
from PIL import Image
valid_colors = { # r, g, b
'aliceblue': (240, 248, 255), 'antiquewhite': (250, 235, 215), 'aqua': (0, 255, 255),
'aquamarine': (127, 255, 212), 'azure': (240, 255, 255), 'beige': (245, 245, 220),
'bisque': (255, 228, 196), 'black': (0, 0, 0), 'blanchedalmond': (255, 235, 205), 'blue': (0, 0, 255),
'blueviolet': (138, 43, 226), 'brown': (165, 42, 42), 'burlywood': (222, 184, 135),
'cadetblue': (95, 158, 160), 'chartreuse': (127, 255, 0), 'chocolate': (210, 105, 30),
'coral': (255, 127, 80), 'cornflowerblue': (100, 149, 237), 'cornsilk': (255, 248, 220),
'crimson': (220, 20, 60), 'cyan': (0, 255, 255), 'darkblue': (0, 0, 139), 'darkcyan': (0, 139, 139),
'darkgoldenrod': (184, 134, 11), 'darkgray': (169, 169, 169), 'darkgrey': (169, 169, 169),
'darkgreen': (0, 100, 0), 'darkkhaki': (189, 183, 107), 'darkmagenta': (139, 0, 139),
'darkolivegreen': (85, 107, 47), 'darkorange': (255, 140, 0), 'darkorchid': (153, 50, 204),
'darkred': (139, 0, 0), 'darksalmon': (233, 150, 122), 'darkseagreen': (143, 188, 143),
'darkslateblue': (72, 61, 139), 'darkslategray': (47, 79, 79), 'darkslategrey': (47, 79, 79),
'darkturquoise': (0, 206, 209), 'darkviolet': (148, 0, 211), 'deeppink': (255, 20, 147),
'deepskyblue': (0, 191, 255), 'dimgray': (105, 105, 105), 'dimgrey': (105, 105, 105),
'dodgerblue': (30, 144, 255), 'firebrick': (178, 34, 34), 'floralwhite': (255, 250, 240),
'forestgreen': (34, 139, 34), 'fuchsia': (255, 0, 255), 'gainsboro': (220, 220, 220),
'ghostwhite': (248, 248, 255), 'gold': (255, 215, 0), 'goldenrod': (218, 165, 32),
'gray': (128, 128, 128), 'grey': (128, 128, 128), 'green': (0, 128, 0), 'greenyellow': (173, 255, 47),
'honeydew': (240, 255, 240), 'hotpink': (255, 105, 180), 'indianred': (205, 92, 92),
'indigo': (75, 0, 130), 'ivory': (255, 255, 240), 'khaki': (240, 230, 140), 'lavender': (230, 230, 250),
'lavenderblush': (255, 240, 245), 'lawngreen': (124, 252, 0), 'lemonchiffon': (255, 250, 205),
'lightblue': (173, 216, 230), 'lightcoral': (240, 128, 128), 'lightcyan': (224, 255, 255),
'lightgoldenrodyellow': (250, 250, 210), 'lightgray': (211, 211, 211), 'lightgrey': (211, 211, 211),
'lightgreen': (144, 238, 144), 'lightpink': (255, 182, 193), 'lightsalmon': (255, 160, 122),
'lightseagreen': (32, 178, 170), 'lightskyblue': (135, 206, 250), 'lightslategray': (119, 136, 153),
'lightslategrey': (119, 136, 153), 'lightsteelblue': (176, 196, 222), 'lightyellow': (255, 255, 224),
'lime': (0, 255, 0), 'limegreen': (50, 205, 50), 'linen': (250, 240, 230), 'magenta': (255, 0, 255),
'maroon': (128, 0, 0), 'mediumaquamarine': (102, 205, 170), 'mediumblue': (0, 0, 205),
'mediumorchid': (186, 85, 211), 'mediumpurple': (147, 112, 219), 'mediumseagreen': (60, 179, 113),
'mediumslateblue': (123, 104, 238), 'mediumspringgreen': (0, 250, 154),
'mediumturquoise': (72, 209, 204), 'mediumvioletred': (199, 21, 133), 'midnightblue': (25, 25, 112),
'mintcream': (245, 255, 250), 'mistyrose': (255, 228, 225), 'moccasin': (255, 228, 181),
'navajowhite': (255, 222, 173), 'navy': (0, 0, 128), 'navyblue': (0, 0, 128),
'oldlace': (253, 245, 230), 'olive': (128, 128, 0), 'olivedrab': (107, 142, 35),
'orange': (255, 165, 0), 'orangered': (255, 69, 0), 'orchid': (218, 112, 214),
'palegoldenrod': (238, 232, 170), 'palegreen': (152, 251, 152), 'paleturquoise': (175, 238, 238),
'palevioletred': (219, 112, 147), 'papayawhip': (255, 239, 213), 'peachpuff': (255, 218, 185),
'peru': (205, 133, 63), 'pink': (255, 192, 203), 'plum': (221, 160, 221), 'powderblue': (176, 224, 230),
'purple': (128, 0, 128), 'rebeccapurple': (102, 51, 153), 'red': (255, 0, 0),
'rosybrown': (188, 143, 143), 'royalblue': (65, 105, 225), 'saddlebrown': (139, 69, 19),
'salmon': (250, 128, 114), 'sandybrown': (244, 164, 96), 'seagreen': (46, 139, 87),
'seashell': (255, 245, 238), 'sienna': (160, 82, 45), 'silver': (192, 192, 192),
'skyblue': (135, 206, 235), 'slateblue': (106, 90, 205), 'slategray': (112, 128, 144),
'slategrey': (112, 128, 144), 'snow': (255, 250, 250), 'springgreen': (0, 255, 127),
'steelblue': (70, 130, 180), 'tan': (210, 180, 140), 'teal': (0, 128, 128), 'thistle': (216, 191, 216),
'tomato': (255, 99, 71), 'turquoise': (64, 224, 208), 'violet': (238, 130, 238),
'wheat': (245, 222, 179), 'white': (255, 255, 255), 'whitesmoke': (245, 245, 245),
'yellow': (255, 255, 0), 'yellowgreen': (154, 205, 50)
}
valid_locations = { # x, y in 90*90
'in the center': (45, 45),
'on the left': (15, 45),
'on the right': (75, 45),
'on the top': (45, 15),
'on the bottom': (45, 75),
'on the top-left': (15, 15),
'on the top-right': (75, 15),
'on the bottom-left': (15, 75),
'on the bottom-right': (75, 75)
}
valid_offsets = { # x, y in 90*90
'no offset': (0, 0),
'slightly to the left': (-10, 0),
'slightly to the right': (10, 0),
'slightly to the upper': (0, -10),
'slightly to the lower': (0, 10),
'slightly to the upper-left': (-10, -10),
'slightly to the upper-right': (10, -10),
'slightly to the lower-left': (-10, 10),
'slightly to the lower-right': (10, 10)}
valid_areas = { # w, h in 90*90
"a small square area": (50, 50),
"a small vertical area": (40, 60),
"a small horizontal area": (60, 40),
"a medium-sized square area": (60, 60),
"a medium-sized vertical area": (50, 80),
"a medium-sized horizontal area": (80, 50),
"a large square area": (70, 70),
"a large vertical area": (60, 90),
"a large horizontal area": (90, 60)
}
def safe_str(x):
return x.strip(',. ') + '.'
def closest_name(input_str, options):
input_str = input_str.lower()
closest_match = difflib.get_close_matches(input_str, list(options.keys()), n=1, cutoff=0.5)
assert isinstance(closest_match, list) and len(closest_match) > 0, f'The value [{input_str}] is not valid!'
result = closest_match[0]
if result != input_str:
print(f'Automatically corrected [{input_str}] -> [{result}].')
return result
class Canvas:
@staticmethod
def from_bot_response(response: str):
matched = re.search(r'```python\n(.*?)\n```', response, re.DOTALL)
assert matched, 'Response does not contain codes!'
code_content = matched.group(1)
assert 'canvas = Canvas()' in code_content, 'Code block must include valid canvas var!'
local_vars = {'Canvas': Canvas}
exec(code_content, {}, local_vars)
canvas = local_vars.get('canvas', None)
assert isinstance(canvas, Canvas), 'Code block must produce valid canvas var!'
return canvas
def __init__(self):
self.components = []
self.color = None
self.record_tags = True
self.prefixes = []
self.suffixes = []
return
def set_global_description(self, description: str, detailed_descriptions: list, tags: str,
HTML_web_color_name: str):
assert isinstance(description, str), 'Global description is not valid!'
assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \
'Global detailed_descriptions is not valid!'
assert isinstance(tags, str), 'Global tags is not valid!'
HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
self.color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8)
self.prefixes = [description]
self.suffixes = detailed_descriptions
if self.record_tags:
self.suffixes = self.suffixes + [tags]
self.prefixes = [safe_str(x) for x in self.prefixes]
self.suffixes = [safe_str(x) for x in self.suffixes]
return
def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str,
detailed_descriptions: list, tags: str, atmosphere: str, style: str,
quality_meta: str, HTML_web_color_name: str):
assert isinstance(description, str), 'Local description is wrong!'
assert isinstance(distance_to_viewer, (int, float)) and distance_to_viewer > 0, \
f'The distance_to_viewer for [{description}] is not positive float number!'
assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \
f'The detailed_descriptions for [{description}] is not valid!'
assert isinstance(tags, str), f'The tags for [{description}] is not valid!'
assert isinstance(atmosphere, str), f'The atmosphere for [{description}] is not valid!'
assert isinstance(style, str), f'The style for [{description}] is not valid!'
assert isinstance(quality_meta, str), f'The quality_meta for [{description}] is not valid!'
location = closest_name(location, valid_locations)
offset = closest_name(offset, valid_offsets)
area = closest_name(area, valid_areas)
HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
xb, yb = valid_locations[location]
xo, yo = valid_offsets[offset]
w, h = valid_areas[area]
rect = (yb + yo - h // 2, yb + yo + h // 2, xb + xo - w // 2, xb + xo + w // 2)
rect = [max(0, min(90, i)) for i in rect]
color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8)
prefixes = self.prefixes + [description]
suffixes = detailed_descriptions
if self.record_tags:
suffixes = suffixes + [tags, atmosphere, style, quality_meta]
prefixes = [safe_str(x) for x in prefixes]
suffixes = [safe_str(x) for x in suffixes]
self.components.append(dict(
rect=rect,
distance_to_viewer=distance_to_viewer,
color=color,
prefixes=prefixes,
suffixes=suffixes,
location=location,
))
return
def process(self):
# sort components
self.components = sorted(self.components, key=lambda x: x['distance_to_viewer'], reverse=True)
# compute initial latent
# print(self.color)
initial_latent = np.zeros(shape=(90, 90, 3), dtype=np.float32) + self.color
for component in self.components:
a, b, c, d = component['rect']
initial_latent[a:b, c:d] = 0.7 * component['color'] + 0.3 * initial_latent[a:b, c:d]
initial_latent = initial_latent.clip(0, 255).astype(np.uint8)
# compute conditions
bag_of_conditions = [
dict(mask=np.ones(shape=(90, 90), dtype=np.float32), prefixes=self.prefixes, suffixes=self.suffixes,location= "full")
]
for i, component in enumerate(self.components):
a, b, c, d = component['rect']
m = np.zeros(shape=(90, 90), dtype=np.float32)
m[a:b, c:d] = 1.0
bag_of_conditions.append(dict(
mask = m,
prefixes = component['prefixes'],
suffixes = component['suffixes'],
location = component['location'],
))
return dict(
initial_latent = initial_latent,
bag_of_conditions = bag_of_conditions,
)
class OmostPromter(torch.nn.Module):
def __init__(self,model = None,tokenizer = None, template = "",device="cpu"):
super().__init__()
self.model=model
self.tokenizer = tokenizer
self.device = device
if template == "":
template = r'''You are a helpful AI assistant to compose images using the below python class `Canvas`:
```python
class Canvas:
def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str, HTML_web_color_name: str):
pass
def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str, detailed_descriptions: list[str], tags: str, atmosphere: str, style: str, quality_meta: str, HTML_web_color_name: str):
assert location in ["in the center", "on the left", "on the right", "on the top", "on the bottom", "on the top-left", "on the top-right", "on the bottom-left", "on the bottom-right"]
assert offset in ["no offset", "slightly to the left", "slightly to the right", "slightly to the upper", "slightly to the lower", "slightly to the upper-left", "slightly to the upper-right", "slightly to the lower-left", "slightly to the lower-right"]
assert area in ["a small square area", "a small vertical area", "a small horizontal area", "a medium-sized square area", "a medium-sized vertical area", "a medium-sized horizontal area", "a large square area", "a large vertical area", "a large horizontal area"]
assert distance_to_viewer > 0
pass
```'''
self.template = template
@staticmethod
def from_model_manager(model_manager: ModelManager):
model, model_path = model_manager.fetch_model("omost_prompt", require_model_path=True)
tokenizer = AutoTokenizer.from_pretrained(model_path)
omost = OmostPromter(
model= model,
tokenizer = tokenizer,
device = model_manager.device
)
return omost
def __call__(self,prompt_dict:dict):
raw_prompt=prompt_dict["prompt"]
conversation = [{"role": "system", "content": self.template}]
conversation.append({"role": "user", "content": raw_prompt})
input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True).to(self.device)
streamer = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
attention_mask = torch.ones(input_ids.shape, dtype=torch.bfloat16, device=self.device)
generate_kwargs = dict(
input_ids = input_ids,
streamer = streamer,
# stopping_criteria=stopping_criteria,
# max_new_tokens=max_new_tokens,
do_sample = True,
attention_mask = attention_mask,
pad_token_id = self.tokenizer.eos_token_id,
# temperature=temperature,
# top_p=top_p,
)
self.model.generate(**generate_kwargs)
outputs = []
for text in streamer:
outputs.append(text)
llm_outputs = "".join(outputs)
canvas = Canvas.from_bot_response(llm_outputs)
canvas_output = canvas.process()
prompts = [" ".join(_["prefixes"]+_["suffixes"][:2]) for _ in canvas_output["bag_of_conditions"]]
canvas_output["prompt"] = prompts[0]
canvas_output["prompts"] = prompts[1:]
raw_masks = [_["mask"] for _ in canvas_output["bag_of_conditions"]]
masks=[]
for mask in raw_masks:
mask[mask>0.5]=255
mask = np.stack([mask] * 3, axis=-1).astype("uint8")
masks.append(Image.fromarray(mask))
canvas_output["masks"] = masks
prompt_dict.update(canvas_output)
print(f"Your prompt is extended by Omost:\n")
cnt = 0
for component,pmt in zip(canvas_output["bag_of_conditions"],prompts):
loc = component["location"]
cnt += 1
print(f"Component {cnt} - Location : {loc}\nPrompt:{pmt}\n")
return prompt_dict

View File

@@ -1,8 +1,7 @@
from transformers import AutoTokenizer from transformers import AutoTokenizer
from ..models.model_manager import ModelManager from ..models.model_manager import ModelManager
import torch import torch
from .omost import OmostPromter
class BeautifulPrompt(torch.nn.Module): class BeautifulPrompt(torch.nn.Module):
def __init__(self, tokenizer_path=None, model=None, template=""): def __init__(self, tokenizer_path=None, model=None, template=""):
@@ -13,8 +12,8 @@ class BeautifulPrompt(torch.nn.Module):
@staticmethod @staticmethod
def from_model_manager(model_nameger: ModelManager): def from_model_manager(model_manager: ModelManager):
model, model_path = model_nameger.fetch_model("beautiful_prompt", require_model_path=True) model, model_path = model_manager.fetch_model("beautiful_prompt", require_model_path=True)
template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:' template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
if model_path.endswith("v2"): if model_path.endswith("v2"):
template = """Converts a simple image description into a prompt. \ template = """Converts a simple image description into a prompt. \
@@ -55,6 +54,60 @@ but make sure there is a correlation between the input and output.\n\
class QwenPrompt(torch.nn.Module):
# This class leverages the open-source Qwen model to translate Chinese prompts into English,
# with an integrated optimization mechanism for enhanced translation quality.
def __init__(self, tokenizer_path=None, model=None, system_prompt=""):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.model = model
self.system_prompt = system_prompt
@staticmethod
def from_model_manager(model_nameger: ModelManager):
model, model_path = model_nameger.fetch_model("qwen_prompt", require_model_path=True)
system_prompt = """You are an English image describer. Here are some example image styles:\n\n1. Extreme close-up: Clear focus on a single object with a blurred background, highlighted under natural sunlight.\n2. Vintage: A photograph of a historical scene, using techniques such as Daguerreotype or cyanotype.\n3. Anime: A stylized cartoon image, emphasizing hyper-realistic portraits and luminous brushwork.\n4. Candid: A natural, unposed shot capturing spontaneous moments, often with cinematic qualities.\n5. Landscape: A photorealistic image of natural scenery, such as a sunrise over the sea.\n6. Design: Colorful and detailed illustrations, often in the style of 2D game art or botanical illustrations.\n7. Urban: An ultrarealistic scene in a modern setting, possibly a cityscape viewed from indoors.\n\nYour task is to translate a given Chinese image description into a concise and precise English description. Ensure that the imagery is vivid and descriptive, and include stylistic elements to enrich the description.\nPlease note the following points:\n\n1. Capture the essence and mood of the Chinese description without including direct phrases or words from the examples provided.\n2. You should add appropriate words to make the images described in the prompt more aesthetically pleasing. If the Chinese description does not specify a style, you need to add some stylistic descriptions based on the essence of the Chinese text.\n3. The generated English description should not exceed 200 words.\n\n"""
qwen_prompt = QwenPrompt(
tokenizer_path=model_path,
model=model,
system_prompt=system_prompt
)
return qwen_prompt
def __call__(self, raw_prompt, positive=True, **kwargs):
if positive:
messages = [{
'role': 'system',
'content': self.system_prompt
}, {
'role': 'user',
'content': raw_prompt
}]
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
generated_ids = self.model.generate(
model_inputs.input_ids,
max_new_tokens=512
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
prompt = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(f"Your prompt is refined by Qwen: {prompt}")
return prompt
else:
return raw_prompt
class Translator(torch.nn.Module): class Translator(torch.nn.Module):
def __init__(self, tokenizer_path=None, model=None): def __init__(self, tokenizer_path=None, model=None):
super().__init__() super().__init__()
@@ -63,8 +116,8 @@ class Translator(torch.nn.Module):
@staticmethod @staticmethod
def from_model_manager(model_nameger: ModelManager): def from_model_manager(model_manager: ModelManager):
model, model_path = model_nameger.fetch_model("translator", require_model_path=True) model, model_path = model_manager.fetch_model("translator", require_model_path=True)
translator = Translator(tokenizer_path=model_path, model=model) translator = Translator(tokenizer_path=model_path, model=model)
return translator return translator

View File

@@ -67,7 +67,8 @@ class SD3Prompter(BasePrompter):
self, self,
prompt, prompt,
positive=True, positive=True,
device="cuda" device="cuda",
t5_sequence_length=77,
): ):
prompt = self.process_prompt(prompt, positive=positive) prompt = self.process_prompt(prompt, positive=positive)
@@ -77,9 +78,9 @@ class SD3Prompter(BasePrompter):
# T5 # T5
if self.text_encoder_3 is None: if self.text_encoder_3 is None:
prompt_emb_3 = torch.zeros((prompt_emb_1.shape[0], 256, 4096), dtype=prompt_emb_1.dtype, device=device) prompt_emb_3 = torch.zeros((prompt_emb_1.shape[0], t5_sequence_length, 4096), dtype=prompt_emb_1.dtype, device=device)
else: else:
prompt_emb_3 = self.encode_prompt_using_t5(prompt, self.text_encoder_3, self.tokenizer_3, 256, device) prompt_emb_3 = self.encode_prompt_using_t5(prompt, self.text_encoder_3, self.tokenizer_3, t5_sequence_length, device)
prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16 prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16
# Merge # Merge

View File

@@ -1,5 +1,5 @@
from .base_prompter import BasePrompter, tokenize_long_prompt from .base_prompter import BasePrompter, tokenize_long_prompt
from ..models.model_manager import ModelManager, load_state_dict, search_for_embeddings from ..models.utils import load_state_dict, search_for_embeddings
from ..models import SDTextEncoder from ..models import SDTextEncoder
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
import torch, os import torch, os

View File

@@ -0,0 +1,56 @@
from .base_prompter import BasePrompter
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder
from ..models.stepvideo_text_encoder import STEP1TextEncoder
from transformers import BertTokenizer
import os, torch
class StepVideoPrompter(BasePrompter):
def __init__(
self,
tokenizer_1_path=None,
):
if tokenizer_1_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_1_path = os.path.join(
base_path, "tokenizer_configs/hunyuan_dit/tokenizer")
super().__init__()
self.tokenizer_1 = BertTokenizer.from_pretrained(tokenizer_1_path)
def fetch_models(self, text_encoder_1: HunyuanDiTCLIPTextEncoder = None, text_encoder_2: STEP1TextEncoder = None):
self.text_encoder_1 = text_encoder_1
self.text_encoder_2 = text_encoder_2
def encode_prompt_using_clip(self, prompt, max_length, device):
text_inputs = self.tokenizer_1(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
prompt_embeds = self.text_encoder_1(
text_inputs.input_ids.to(device),
attention_mask=text_inputs.attention_mask.to(device),
)
return prompt_embeds
def encode_prompt_using_llm(self, prompt, max_length, device):
y, y_mask = self.text_encoder_2(prompt, max_length=max_length, device=device)
return y, y_mask
def encode_prompt(self,
prompt,
positive=True,
device="cuda"):
prompt = self.process_prompt(prompt, positive=positive)
clip_embeds = self.encode_prompt_using_clip(prompt, max_length=77, device=device)
llm_embeds, llm_mask = self.encode_prompt_using_llm(prompt, max_length=320, device=device)
llm_mask = torch.nn.functional.pad(llm_mask, (clip_embeds.shape[1], 0), value=1)
return clip_embeds, llm_embeds, llm_mask

View File

@@ -0,0 +1,103 @@
from .base_prompter import BasePrompter
from ..models.wanx_text_encoder import WanXTextEncoder
from transformers import AutoTokenizer
import os, torch
import ftfy
import html
import string
import regex as re
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
def canonicalize(text, keep_punctuation_exact_string=None):
text = text.replace('_', ' ')
if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join(
part.translate(str.maketrans('', '', string.punctuation))
for part in text.split(keep_punctuation_exact_string))
else:
text = text.translate(str.maketrans('', '', string.punctuation))
text = text.lower()
text = re.sub(r'\s+', ' ', text)
return text.strip()
class HuggingfaceTokenizer:
def __init__(self, name, seq_len=None, clean=None, **kwargs):
assert clean in (None, 'whitespace', 'lower', 'canonicalize')
self.name = name
self.seq_len = seq_len
self.clean = clean
# init tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
self.vocab_size = self.tokenizer.vocab_size
def __call__(self, sequence, **kwargs):
return_mask = kwargs.pop('return_mask', False)
# arguments
_kwargs = {'return_tensors': 'pt'}
if self.seq_len is not None:
_kwargs.update({
'padding': 'max_length',
'truncation': True,
'max_length': self.seq_len
})
_kwargs.update(**kwargs)
# tokenization
if isinstance(sequence, str):
sequence = [sequence]
if self.clean:
sequence = [self._clean(u) for u in sequence]
ids = self.tokenizer(sequence, **_kwargs)
# output
if return_mask:
return ids.input_ids, ids.attention_mask
else:
return ids.input_ids
def _clean(self, text):
if self.clean == 'whitespace':
text = whitespace_clean(basic_clean(text))
elif self.clean == 'lower':
text = whitespace_clean(basic_clean(text)).lower()
elif self.clean == 'canonicalize':
text = canonicalize(basic_clean(text))
return text
class WanXPrompter(BasePrompter):
def __init__(self, tokenizer_path=None, text_len=512):
if tokenizer_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_path = os.path.join(
base_path, "tokenizer_configs/hunyuan_dit/tokenizer")
super().__init__()
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean='whitespace')
self.text_encoder = None
def fetch_models(self, text_encoder: WanXTextEncoder = None):
self.text_encoder = text_encoder
def encode_prompt(self, prompt, device="cuda"):
ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
ids = ids.to(device)
mask = mask.to(device)
seq_lens = mask.gt(0).sum(dim=1).long()
prompt_emb = self.text_encoder(ids, mask)
prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)]
return prompt_emb

View File

@@ -10,7 +10,7 @@ class ContinuousODEScheduler():
self.set_timesteps(num_inference_steps) self.set_timesteps(num_inference_steps)
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0): def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, **kwargs):
ramp = torch.linspace(1-denoising_strength, 1, num_inference_steps) ramp = torch.linspace(1-denoising_strength, 1, num_inference_steps)
min_inv_rho = torch.pow(torch.tensor((self.sigma_min,)), (1 / self.rho)) min_inv_rho = torch.pow(torch.tensor((self.sigma_min,)), (1 / self.rho))
max_inv_rho = torch.pow(torch.tensor((self.sigma_max,)), (1 / self.rho)) max_inv_rho = torch.pow(torch.tensor((self.sigma_max,)), (1 / self.rho))

View File

@@ -3,7 +3,7 @@ import torch, math
class EnhancedDDIMScheduler(): class EnhancedDDIMScheduler():
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon"): def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon", rescale_zero_terminal_snr=False):
self.num_train_timesteps = num_train_timesteps self.num_train_timesteps = num_train_timesteps
if beta_schedule == "scaled_linear": if beta_schedule == "scaled_linear":
betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32)) betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
@@ -11,12 +11,34 @@ class EnhancedDDIMScheduler():
betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
else: else:
raise NotImplementedError(f"{beta_schedule} is not implemented") raise NotImplementedError(f"{beta_schedule} is not implemented")
self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0).tolist() self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0)
if rescale_zero_terminal_snr:
self.alphas_cumprod = self.rescale_zero_terminal_snr(self.alphas_cumprod)
self.alphas_cumprod = self.alphas_cumprod.tolist()
self.set_timesteps(10) self.set_timesteps(10)
self.prediction_type = prediction_type self.prediction_type = prediction_type
def set_timesteps(self, num_inference_steps, denoising_strength=1.0): def rescale_zero_terminal_snr(self, alphas_cumprod):
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt.square() # Revert sqrt
return alphas_bar
def set_timesteps(self, num_inference_steps, denoising_strength=1.0, **kwargs):
# The timesteps are aligned to 999...0, which is different from other implementations, # The timesteps are aligned to 999...0, which is different from other implementations,
# but I think this implementation is more reasonable in theory. # but I think this implementation is more reasonable in theory.
max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0) max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
@@ -77,3 +99,7 @@ class EnhancedDDIMScheduler():
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])]) sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return target return target
def training_weight(self, timestep):
return 1.0

View File

@@ -4,19 +4,35 @@ import torch
class FlowMatchScheduler(): class FlowMatchScheduler():
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002): def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
self.num_train_timesteps = num_train_timesteps self.num_train_timesteps = num_train_timesteps
self.shift = shift self.shift = shift
self.sigma_max = sigma_max self.sigma_max = sigma_max
self.sigma_min = sigma_min self.sigma_min = sigma_min
self.inverse_timesteps = inverse_timesteps
self.extra_one_step = extra_one_step
self.reverse_sigmas = reverse_sigmas
self.set_timesteps(num_inference_steps) self.set_timesteps(num_inference_steps)
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0): def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False):
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
if self.extra_one_step:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
else:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps) self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
if self.inverse_timesteps:
self.sigmas = torch.flip(self.sigmas, dims=[0])
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas) self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
if self.reverse_sigmas:
self.sigmas = 1 - self.sigmas
self.timesteps = self.sigmas * self.num_train_timesteps self.timesteps = self.sigmas * self.num_train_timesteps
if training:
x = self.timesteps
y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
y_shifted = y - y.min()
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
self.linear_timesteps_weights = bsmntw_weighing
def step(self, model_output, timestep, sample, to_final=False): def step(self, model_output, timestep, sample, to_final=False):
@@ -25,7 +41,7 @@ class FlowMatchScheduler():
timestep_id = torch.argmin((self.timesteps - timestep).abs()) timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id] sigma = self.sigmas[timestep_id]
if to_final or timestep_id + 1 >= len(self.timesteps): if to_final or timestep_id + 1 >= len(self.timesteps):
sigma_ = 0 sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
else: else:
sigma_ = self.sigmas[timestep_id + 1] sigma_ = self.sigmas[timestep_id + 1]
prev_sample = sample + model_output * (sigma_ - sigma) prev_sample = sample + model_output * (sigma_ - sigma)
@@ -33,8 +49,12 @@ class FlowMatchScheduler():
def return_to_timestep(self, timestep, sample, sample_stablized): def return_to_timestep(self, timestep, sample, sample_stablized):
# This scheduler doesn't support this function. if isinstance(timestep, torch.Tensor):
pass timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
model_output = (sample - sample_stablized) / sigma
return model_output
def add_noise(self, original_samples, noise, timestep): def add_noise(self, original_samples, noise, timestep):
@@ -49,3 +69,9 @@ class FlowMatchScheduler():
def training_target(self, sample, noise, timestep): def training_target(self, sample, noise, timestep):
target = noise - sample target = noise - sample
return target return target
def training_weight(self, timestep):
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
weights = self.linear_timesteps_weights[timestep_id]
return weights

View File

@@ -0,0 +1,102 @@
{
"<extra_id_0>": 32099,
"<extra_id_10>": 32089,
"<extra_id_11>": 32088,
"<extra_id_12>": 32087,
"<extra_id_13>": 32086,
"<extra_id_14>": 32085,
"<extra_id_15>": 32084,
"<extra_id_16>": 32083,
"<extra_id_17>": 32082,
"<extra_id_18>": 32081,
"<extra_id_19>": 32080,
"<extra_id_1>": 32098,
"<extra_id_20>": 32079,
"<extra_id_21>": 32078,
"<extra_id_22>": 32077,
"<extra_id_23>": 32076,
"<extra_id_24>": 32075,
"<extra_id_25>": 32074,
"<extra_id_26>": 32073,
"<extra_id_27>": 32072,
"<extra_id_28>": 32071,
"<extra_id_29>": 32070,
"<extra_id_2>": 32097,
"<extra_id_30>": 32069,
"<extra_id_31>": 32068,
"<extra_id_32>": 32067,
"<extra_id_33>": 32066,
"<extra_id_34>": 32065,
"<extra_id_35>": 32064,
"<extra_id_36>": 32063,
"<extra_id_37>": 32062,
"<extra_id_38>": 32061,
"<extra_id_39>": 32060,
"<extra_id_3>": 32096,
"<extra_id_40>": 32059,
"<extra_id_41>": 32058,
"<extra_id_42>": 32057,
"<extra_id_43>": 32056,
"<extra_id_44>": 32055,
"<extra_id_45>": 32054,
"<extra_id_46>": 32053,
"<extra_id_47>": 32052,
"<extra_id_48>": 32051,
"<extra_id_49>": 32050,
"<extra_id_4>": 32095,
"<extra_id_50>": 32049,
"<extra_id_51>": 32048,
"<extra_id_52>": 32047,
"<extra_id_53>": 32046,
"<extra_id_54>": 32045,
"<extra_id_55>": 32044,
"<extra_id_56>": 32043,
"<extra_id_57>": 32042,
"<extra_id_58>": 32041,
"<extra_id_59>": 32040,
"<extra_id_5>": 32094,
"<extra_id_60>": 32039,
"<extra_id_61>": 32038,
"<extra_id_62>": 32037,
"<extra_id_63>": 32036,
"<extra_id_64>": 32035,
"<extra_id_65>": 32034,
"<extra_id_66>": 32033,
"<extra_id_67>": 32032,
"<extra_id_68>": 32031,
"<extra_id_69>": 32030,
"<extra_id_6>": 32093,
"<extra_id_70>": 32029,
"<extra_id_71>": 32028,
"<extra_id_72>": 32027,
"<extra_id_73>": 32026,
"<extra_id_74>": 32025,
"<extra_id_75>": 32024,
"<extra_id_76>": 32023,
"<extra_id_77>": 32022,
"<extra_id_78>": 32021,
"<extra_id_79>": 32020,
"<extra_id_7>": 32092,
"<extra_id_80>": 32019,
"<extra_id_81>": 32018,
"<extra_id_82>": 32017,
"<extra_id_83>": 32016,
"<extra_id_84>": 32015,
"<extra_id_85>": 32014,
"<extra_id_86>": 32013,
"<extra_id_87>": 32012,
"<extra_id_88>": 32011,
"<extra_id_89>": 32010,
"<extra_id_8>": 32091,
"<extra_id_90>": 32009,
"<extra_id_91>": 32008,
"<extra_id_92>": 32007,
"<extra_id_93>": 32006,
"<extra_id_94>": 32005,
"<extra_id_95>": 32004,
"<extra_id_96>": 32003,
"<extra_id_97>": 32002,
"<extra_id_98>": 32001,
"<extra_id_99>": 32000,
"<extra_id_9>": 32090
}

View File

@@ -0,0 +1,125 @@
{
"additional_special_tokens": [
"<extra_id_0>",
"<extra_id_1>",
"<extra_id_2>",
"<extra_id_3>",
"<extra_id_4>",
"<extra_id_5>",
"<extra_id_6>",
"<extra_id_7>",
"<extra_id_8>",
"<extra_id_9>",
"<extra_id_10>",
"<extra_id_11>",
"<extra_id_12>",
"<extra_id_13>",
"<extra_id_14>",
"<extra_id_15>",
"<extra_id_16>",
"<extra_id_17>",
"<extra_id_18>",
"<extra_id_19>",
"<extra_id_20>",
"<extra_id_21>",
"<extra_id_22>",
"<extra_id_23>",
"<extra_id_24>",
"<extra_id_25>",
"<extra_id_26>",
"<extra_id_27>",
"<extra_id_28>",
"<extra_id_29>",
"<extra_id_30>",
"<extra_id_31>",
"<extra_id_32>",
"<extra_id_33>",
"<extra_id_34>",
"<extra_id_35>",
"<extra_id_36>",
"<extra_id_37>",
"<extra_id_38>",
"<extra_id_39>",
"<extra_id_40>",
"<extra_id_41>",
"<extra_id_42>",
"<extra_id_43>",
"<extra_id_44>",
"<extra_id_45>",
"<extra_id_46>",
"<extra_id_47>",
"<extra_id_48>",
"<extra_id_49>",
"<extra_id_50>",
"<extra_id_51>",
"<extra_id_52>",
"<extra_id_53>",
"<extra_id_54>",
"<extra_id_55>",
"<extra_id_56>",
"<extra_id_57>",
"<extra_id_58>",
"<extra_id_59>",
"<extra_id_60>",
"<extra_id_61>",
"<extra_id_62>",
"<extra_id_63>",
"<extra_id_64>",
"<extra_id_65>",
"<extra_id_66>",
"<extra_id_67>",
"<extra_id_68>",
"<extra_id_69>",
"<extra_id_70>",
"<extra_id_71>",
"<extra_id_72>",
"<extra_id_73>",
"<extra_id_74>",
"<extra_id_75>",
"<extra_id_76>",
"<extra_id_77>",
"<extra_id_78>",
"<extra_id_79>",
"<extra_id_80>",
"<extra_id_81>",
"<extra_id_82>",
"<extra_id_83>",
"<extra_id_84>",
"<extra_id_85>",
"<extra_id_86>",
"<extra_id_87>",
"<extra_id_88>",
"<extra_id_89>",
"<extra_id_90>",
"<extra_id_91>",
"<extra_id_92>",
"<extra_id_93>",
"<extra_id_94>",
"<extra_id_95>",
"<extra_id_96>",
"<extra_id_97>",
"<extra_id_98>",
"<extra_id_99>"
],
"eos_token": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

Binary file not shown.

View File

@@ -0,0 +1,940 @@
{
"add_prefix_space": true,
"added_tokens_decoder": {
"0": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"1": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"2": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32000": {
"content": "<extra_id_99>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32001": {
"content": "<extra_id_98>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32002": {
"content": "<extra_id_97>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32003": {
"content": "<extra_id_96>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32004": {
"content": "<extra_id_95>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32005": {
"content": "<extra_id_94>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32006": {
"content": "<extra_id_93>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32007": {
"content": "<extra_id_92>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32008": {
"content": "<extra_id_91>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32009": {
"content": "<extra_id_90>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32010": {
"content": "<extra_id_89>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32011": {
"content": "<extra_id_88>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32012": {
"content": "<extra_id_87>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32013": {
"content": "<extra_id_86>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32014": {
"content": "<extra_id_85>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32015": {
"content": "<extra_id_84>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32016": {
"content": "<extra_id_83>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32017": {
"content": "<extra_id_82>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32018": {
"content": "<extra_id_81>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32019": {
"content": "<extra_id_80>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32020": {
"content": "<extra_id_79>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32021": {
"content": "<extra_id_78>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32022": {
"content": "<extra_id_77>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32023": {
"content": "<extra_id_76>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32024": {
"content": "<extra_id_75>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32025": {
"content": "<extra_id_74>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32026": {
"content": "<extra_id_73>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32027": {
"content": "<extra_id_72>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32028": {
"content": "<extra_id_71>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32029": {
"content": "<extra_id_70>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32030": {
"content": "<extra_id_69>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32031": {
"content": "<extra_id_68>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32032": {
"content": "<extra_id_67>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32033": {
"content": "<extra_id_66>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32034": {
"content": "<extra_id_65>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32035": {
"content": "<extra_id_64>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32036": {
"content": "<extra_id_63>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32037": {
"content": "<extra_id_62>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32038": {
"content": "<extra_id_61>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32039": {
"content": "<extra_id_60>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32040": {
"content": "<extra_id_59>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32041": {
"content": "<extra_id_58>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32042": {
"content": "<extra_id_57>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32043": {
"content": "<extra_id_56>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32044": {
"content": "<extra_id_55>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32045": {
"content": "<extra_id_54>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32046": {
"content": "<extra_id_53>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32047": {
"content": "<extra_id_52>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32048": {
"content": "<extra_id_51>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32049": {
"content": "<extra_id_50>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32050": {
"content": "<extra_id_49>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32051": {
"content": "<extra_id_48>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32052": {
"content": "<extra_id_47>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32053": {
"content": "<extra_id_46>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32054": {
"content": "<extra_id_45>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32055": {
"content": "<extra_id_44>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32056": {
"content": "<extra_id_43>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32057": {
"content": "<extra_id_42>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32058": {
"content": "<extra_id_41>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32059": {
"content": "<extra_id_40>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32060": {
"content": "<extra_id_39>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32061": {
"content": "<extra_id_38>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32062": {
"content": "<extra_id_37>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32063": {
"content": "<extra_id_36>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32064": {
"content": "<extra_id_35>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32065": {
"content": "<extra_id_34>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32066": {
"content": "<extra_id_33>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32067": {
"content": "<extra_id_32>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32068": {
"content": "<extra_id_31>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32069": {
"content": "<extra_id_30>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32070": {
"content": "<extra_id_29>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32071": {
"content": "<extra_id_28>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32072": {
"content": "<extra_id_27>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32073": {
"content": "<extra_id_26>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32074": {
"content": "<extra_id_25>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32075": {
"content": "<extra_id_24>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32076": {
"content": "<extra_id_23>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32077": {
"content": "<extra_id_22>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32078": {
"content": "<extra_id_21>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32079": {
"content": "<extra_id_20>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32080": {
"content": "<extra_id_19>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32081": {
"content": "<extra_id_18>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32082": {
"content": "<extra_id_17>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32083": {
"content": "<extra_id_16>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32084": {
"content": "<extra_id_15>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32085": {
"content": "<extra_id_14>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32086": {
"content": "<extra_id_13>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32087": {
"content": "<extra_id_12>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32088": {
"content": "<extra_id_11>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32089": {
"content": "<extra_id_10>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32090": {
"content": "<extra_id_9>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32091": {
"content": "<extra_id_8>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32092": {
"content": "<extra_id_7>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32093": {
"content": "<extra_id_6>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32094": {
"content": "<extra_id_5>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32095": {
"content": "<extra_id_4>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32096": {
"content": "<extra_id_3>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32097": {
"content": "<extra_id_2>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32098": {
"content": "<extra_id_1>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32099": {
"content": "<extra_id_0>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
}
},
"additional_special_tokens": [
"<extra_id_0>",
"<extra_id_1>",
"<extra_id_2>",
"<extra_id_3>",
"<extra_id_4>",
"<extra_id_5>",
"<extra_id_6>",
"<extra_id_7>",
"<extra_id_8>",
"<extra_id_9>",
"<extra_id_10>",
"<extra_id_11>",
"<extra_id_12>",
"<extra_id_13>",
"<extra_id_14>",
"<extra_id_15>",
"<extra_id_16>",
"<extra_id_17>",
"<extra_id_18>",
"<extra_id_19>",
"<extra_id_20>",
"<extra_id_21>",
"<extra_id_22>",
"<extra_id_23>",
"<extra_id_24>",
"<extra_id_25>",
"<extra_id_26>",
"<extra_id_27>",
"<extra_id_28>",
"<extra_id_29>",
"<extra_id_30>",
"<extra_id_31>",
"<extra_id_32>",
"<extra_id_33>",
"<extra_id_34>",
"<extra_id_35>",
"<extra_id_36>",
"<extra_id_37>",
"<extra_id_38>",
"<extra_id_39>",
"<extra_id_40>",
"<extra_id_41>",
"<extra_id_42>",
"<extra_id_43>",
"<extra_id_44>",
"<extra_id_45>",
"<extra_id_46>",
"<extra_id_47>",
"<extra_id_48>",
"<extra_id_49>",
"<extra_id_50>",
"<extra_id_51>",
"<extra_id_52>",
"<extra_id_53>",
"<extra_id_54>",
"<extra_id_55>",
"<extra_id_56>",
"<extra_id_57>",
"<extra_id_58>",
"<extra_id_59>",
"<extra_id_60>",
"<extra_id_61>",
"<extra_id_62>",
"<extra_id_63>",
"<extra_id_64>",
"<extra_id_65>",
"<extra_id_66>",
"<extra_id_67>",
"<extra_id_68>",
"<extra_id_69>",
"<extra_id_70>",
"<extra_id_71>",
"<extra_id_72>",
"<extra_id_73>",
"<extra_id_74>",
"<extra_id_75>",
"<extra_id_76>",
"<extra_id_77>",
"<extra_id_78>",
"<extra_id_79>",
"<extra_id_80>",
"<extra_id_81>",
"<extra_id_82>",
"<extra_id_83>",
"<extra_id_84>",
"<extra_id_85>",
"<extra_id_86>",
"<extra_id_87>",
"<extra_id_88>",
"<extra_id_89>",
"<extra_id_90>",
"<extra_id_91>",
"<extra_id_92>",
"<extra_id_93>",
"<extra_id_94>",
"<extra_id_95>",
"<extra_id_96>",
"<extra_id_97>",
"<extra_id_98>",
"<extra_id_99>"
],
"clean_up_tokenization_spaces": true,
"eos_token": "</s>",
"extra_ids": 100,
"legacy": true,
"model_max_length": 226,
"pad_token": "<pad>",
"sp_model_kwargs": {},
"tokenizer_class": "T5Tokenizer",
"unk_token": "<unk>"
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,30 @@
{
"bos_token": {
"content": "<|startoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

View File

@@ -0,0 +1,30 @@
{
"add_prefix_space": false,
"added_tokens_decoder": {
"49406": {
"content": "<|startoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": true
},
"49407": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"bos_token": "<|startoftext|>",
"clean_up_tokenization_spaces": true,
"do_lower_case": true,
"eos_token": "<|endoftext|>",
"errors": "replace",
"model_max_length": 77,
"pad_token": "<|endoftext|>",
"tokenizer_class": "CLIPTokenizer",
"unk_token": "<|endoftext|>"
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,125 @@
{
"additional_special_tokens": [
"<extra_id_0>",
"<extra_id_1>",
"<extra_id_2>",
"<extra_id_3>",
"<extra_id_4>",
"<extra_id_5>",
"<extra_id_6>",
"<extra_id_7>",
"<extra_id_8>",
"<extra_id_9>",
"<extra_id_10>",
"<extra_id_11>",
"<extra_id_12>",
"<extra_id_13>",
"<extra_id_14>",
"<extra_id_15>",
"<extra_id_16>",
"<extra_id_17>",
"<extra_id_18>",
"<extra_id_19>",
"<extra_id_20>",
"<extra_id_21>",
"<extra_id_22>",
"<extra_id_23>",
"<extra_id_24>",
"<extra_id_25>",
"<extra_id_26>",
"<extra_id_27>",
"<extra_id_28>",
"<extra_id_29>",
"<extra_id_30>",
"<extra_id_31>",
"<extra_id_32>",
"<extra_id_33>",
"<extra_id_34>",
"<extra_id_35>",
"<extra_id_36>",
"<extra_id_37>",
"<extra_id_38>",
"<extra_id_39>",
"<extra_id_40>",
"<extra_id_41>",
"<extra_id_42>",
"<extra_id_43>",
"<extra_id_44>",
"<extra_id_45>",
"<extra_id_46>",
"<extra_id_47>",
"<extra_id_48>",
"<extra_id_49>",
"<extra_id_50>",
"<extra_id_51>",
"<extra_id_52>",
"<extra_id_53>",
"<extra_id_54>",
"<extra_id_55>",
"<extra_id_56>",
"<extra_id_57>",
"<extra_id_58>",
"<extra_id_59>",
"<extra_id_60>",
"<extra_id_61>",
"<extra_id_62>",
"<extra_id_63>",
"<extra_id_64>",
"<extra_id_65>",
"<extra_id_66>",
"<extra_id_67>",
"<extra_id_68>",
"<extra_id_69>",
"<extra_id_70>",
"<extra_id_71>",
"<extra_id_72>",
"<extra_id_73>",
"<extra_id_74>",
"<extra_id_75>",
"<extra_id_76>",
"<extra_id_77>",
"<extra_id_78>",
"<extra_id_79>",
"<extra_id_80>",
"<extra_id_81>",
"<extra_id_82>",
"<extra_id_83>",
"<extra_id_84>",
"<extra_id_85>",
"<extra_id_86>",
"<extra_id_87>",
"<extra_id_88>",
"<extra_id_89>",
"<extra_id_90>",
"<extra_id_91>",
"<extra_id_92>",
"<extra_id_93>",
"<extra_id_94>",
"<extra_id_95>",
"<extra_id_96>",
"<extra_id_97>",
"<extra_id_98>",
"<extra_id_99>"
],
"eos_token": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,940 @@
{
"add_prefix_space": true,
"added_tokens_decoder": {
"0": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"1": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"2": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32000": {
"content": "<extra_id_99>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32001": {
"content": "<extra_id_98>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32002": {
"content": "<extra_id_97>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32003": {
"content": "<extra_id_96>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32004": {
"content": "<extra_id_95>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32005": {
"content": "<extra_id_94>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32006": {
"content": "<extra_id_93>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32007": {
"content": "<extra_id_92>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32008": {
"content": "<extra_id_91>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32009": {
"content": "<extra_id_90>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32010": {
"content": "<extra_id_89>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32011": {
"content": "<extra_id_88>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32012": {
"content": "<extra_id_87>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32013": {
"content": "<extra_id_86>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32014": {
"content": "<extra_id_85>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32015": {
"content": "<extra_id_84>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32016": {
"content": "<extra_id_83>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32017": {
"content": "<extra_id_82>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32018": {
"content": "<extra_id_81>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32019": {
"content": "<extra_id_80>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32020": {
"content": "<extra_id_79>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32021": {
"content": "<extra_id_78>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32022": {
"content": "<extra_id_77>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32023": {
"content": "<extra_id_76>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32024": {
"content": "<extra_id_75>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32025": {
"content": "<extra_id_74>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32026": {
"content": "<extra_id_73>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32027": {
"content": "<extra_id_72>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32028": {
"content": "<extra_id_71>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32029": {
"content": "<extra_id_70>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32030": {
"content": "<extra_id_69>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32031": {
"content": "<extra_id_68>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32032": {
"content": "<extra_id_67>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32033": {
"content": "<extra_id_66>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32034": {
"content": "<extra_id_65>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32035": {
"content": "<extra_id_64>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32036": {
"content": "<extra_id_63>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32037": {
"content": "<extra_id_62>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32038": {
"content": "<extra_id_61>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32039": {
"content": "<extra_id_60>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32040": {
"content": "<extra_id_59>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32041": {
"content": "<extra_id_58>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32042": {
"content": "<extra_id_57>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32043": {
"content": "<extra_id_56>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32044": {
"content": "<extra_id_55>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32045": {
"content": "<extra_id_54>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32046": {
"content": "<extra_id_53>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32047": {
"content": "<extra_id_52>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32048": {
"content": "<extra_id_51>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32049": {
"content": "<extra_id_50>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32050": {
"content": "<extra_id_49>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32051": {
"content": "<extra_id_48>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32052": {
"content": "<extra_id_47>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32053": {
"content": "<extra_id_46>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32054": {
"content": "<extra_id_45>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32055": {
"content": "<extra_id_44>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32056": {
"content": "<extra_id_43>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32057": {
"content": "<extra_id_42>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32058": {
"content": "<extra_id_41>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32059": {
"content": "<extra_id_40>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32060": {
"content": "<extra_id_39>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32061": {
"content": "<extra_id_38>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32062": {
"content": "<extra_id_37>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32063": {
"content": "<extra_id_36>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32064": {
"content": "<extra_id_35>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32065": {
"content": "<extra_id_34>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32066": {
"content": "<extra_id_33>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32067": {
"content": "<extra_id_32>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32068": {
"content": "<extra_id_31>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32069": {
"content": "<extra_id_30>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32070": {
"content": "<extra_id_29>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32071": {
"content": "<extra_id_28>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32072": {
"content": "<extra_id_27>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32073": {
"content": "<extra_id_26>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32074": {
"content": "<extra_id_25>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32075": {
"content": "<extra_id_24>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32076": {
"content": "<extra_id_23>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32077": {
"content": "<extra_id_22>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32078": {
"content": "<extra_id_21>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32079": {
"content": "<extra_id_20>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32080": {
"content": "<extra_id_19>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32081": {
"content": "<extra_id_18>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32082": {
"content": "<extra_id_17>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32083": {
"content": "<extra_id_16>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32084": {
"content": "<extra_id_15>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32085": {
"content": "<extra_id_14>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32086": {
"content": "<extra_id_13>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32087": {
"content": "<extra_id_12>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32088": {
"content": "<extra_id_11>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32089": {
"content": "<extra_id_10>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32090": {
"content": "<extra_id_9>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32091": {
"content": "<extra_id_8>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32092": {
"content": "<extra_id_7>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32093": {
"content": "<extra_id_6>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32094": {
"content": "<extra_id_5>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32095": {
"content": "<extra_id_4>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32096": {
"content": "<extra_id_3>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32097": {
"content": "<extra_id_2>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32098": {
"content": "<extra_id_1>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32099": {
"content": "<extra_id_0>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"additional_special_tokens": [
"<extra_id_0>",
"<extra_id_1>",
"<extra_id_2>",
"<extra_id_3>",
"<extra_id_4>",
"<extra_id_5>",
"<extra_id_6>",
"<extra_id_7>",
"<extra_id_8>",
"<extra_id_9>",
"<extra_id_10>",
"<extra_id_11>",
"<extra_id_12>",
"<extra_id_13>",
"<extra_id_14>",
"<extra_id_15>",
"<extra_id_16>",
"<extra_id_17>",
"<extra_id_18>",
"<extra_id_19>",
"<extra_id_20>",
"<extra_id_21>",
"<extra_id_22>",
"<extra_id_23>",
"<extra_id_24>",
"<extra_id_25>",
"<extra_id_26>",
"<extra_id_27>",
"<extra_id_28>",
"<extra_id_29>",
"<extra_id_30>",
"<extra_id_31>",
"<extra_id_32>",
"<extra_id_33>",
"<extra_id_34>",
"<extra_id_35>",
"<extra_id_36>",
"<extra_id_37>",
"<extra_id_38>",
"<extra_id_39>",
"<extra_id_40>",
"<extra_id_41>",
"<extra_id_42>",
"<extra_id_43>",
"<extra_id_44>",
"<extra_id_45>",
"<extra_id_46>",
"<extra_id_47>",
"<extra_id_48>",
"<extra_id_49>",
"<extra_id_50>",
"<extra_id_51>",
"<extra_id_52>",
"<extra_id_53>",
"<extra_id_54>",
"<extra_id_55>",
"<extra_id_56>",
"<extra_id_57>",
"<extra_id_58>",
"<extra_id_59>",
"<extra_id_60>",
"<extra_id_61>",
"<extra_id_62>",
"<extra_id_63>",
"<extra_id_64>",
"<extra_id_65>",
"<extra_id_66>",
"<extra_id_67>",
"<extra_id_68>",
"<extra_id_69>",
"<extra_id_70>",
"<extra_id_71>",
"<extra_id_72>",
"<extra_id_73>",
"<extra_id_74>",
"<extra_id_75>",
"<extra_id_76>",
"<extra_id_77>",
"<extra_id_78>",
"<extra_id_79>",
"<extra_id_80>",
"<extra_id_81>",
"<extra_id_82>",
"<extra_id_83>",
"<extra_id_84>",
"<extra_id_85>",
"<extra_id_86>",
"<extra_id_87>",
"<extra_id_88>",
"<extra_id_89>",
"<extra_id_90>",
"<extra_id_91>",
"<extra_id_92>",
"<extra_id_93>",
"<extra_id_94>",
"<extra_id_95>",
"<extra_id_96>",
"<extra_id_97>",
"<extra_id_98>",
"<extra_id_99>"
],
"clean_up_tokenization_spaces": true,
"eos_token": "</s>",
"extra_ids": 100,
"legacy": true,
"model_max_length": 512,
"pad_token": "<pad>",
"sp_model_kwargs": {},
"tokenizer_class": "T5Tokenizer",
"unk_token": "<unk>"
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,30 @@
{
"bos_token": {
"content": "<|startoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

View File

@@ -0,0 +1,30 @@
{
"add_prefix_space": false,
"added_tokens_decoder": {
"49406": {
"content": "<|startoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": true
},
"49407": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"bos_token": "<|startoftext|>",
"clean_up_tokenization_spaces": true,
"do_lower_case": true,
"eos_token": "<|endoftext|>",
"errors": "replace",
"model_max_length": 77,
"pad_token": "<|endoftext|>",
"tokenizer_class": "CLIPTokenizer",
"unk_token": "<|endoftext|>"
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,30 @@
{
"bos_token": {
"content": "<|begin_of_text|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "<|end_of_text|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -3,6 +3,7 @@ from peft import LoraConfig, inject_adapter_in_model
import torch, os import torch, os
from ..data.simple_text_image import TextImageDataset from ..data.simple_text_image import TextImageDataset
from modelscope.hub.api import HubApi from modelscope.hub.api import HubApi
from ..models.utils import load_state_dict
@@ -11,11 +12,14 @@ class LightningModelForT2ILoRA(pl.LightningModule):
self, self,
learning_rate=1e-4, learning_rate=1e-4,
use_gradient_checkpointing=True, use_gradient_checkpointing=True,
state_dict_converter=None,
): ):
super().__init__() super().__init__()
# Set parameters # Set parameters
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.use_gradient_checkpointing = use_gradient_checkpointing self.use_gradient_checkpointing = use_gradient_checkpointing
self.state_dict_converter = state_dict_converter
self.lora_alpha = None
def load_models(self): def load_models(self):
@@ -30,12 +34,16 @@ class LightningModelForT2ILoRA(pl.LightningModule):
self.pipe.denoising_model().train() self.pipe.denoising_model().train()
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out"): def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", pretrained_lora_path=None, state_dict_converter=None):
# Add LoRA to UNet # Add LoRA to UNet
self.lora_alpha = lora_alpha
if init_lora_weights == "kaiming":
init_lora_weights = True
lora_config = LoraConfig( lora_config = LoraConfig(
r=lora_rank, r=lora_rank,
lora_alpha=lora_alpha, lora_alpha=lora_alpha,
init_lora_weights="gaussian", init_lora_weights=init_lora_weights,
target_modules=lora_target_modules.split(","), target_modules=lora_target_modules.split(","),
) )
model = inject_adapter_in_model(lora_config, model) model = inject_adapter_in_model(lora_config, model)
@@ -44,6 +52,17 @@ class LightningModelForT2ILoRA(pl.LightningModule):
if param.requires_grad: if param.requires_grad:
param.data = param.to(torch.float32) param.data = param.to(torch.float32)
# Lora pretrained lora weights
if pretrained_lora_path is not None:
state_dict = load_state_dict(pretrained_lora_path)
if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
all_keys = [i for i, _ in model.named_parameters()]
num_updated_keys = len(all_keys) - len(missing_keys)
num_unexpected_keys = len(unexpected_keys)
print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
# Data # Data
@@ -52,6 +71,9 @@ class LightningModelForT2ILoRA(pl.LightningModule):
# Prepare input parameters # Prepare input parameters
self.pipe.device = self.device self.pipe.device = self.device
prompt_emb = self.pipe.encode_prompt(text, positive=True) prompt_emb = self.pipe.encode_prompt(text, positive=True)
if "latents" in batch:
latents = batch["latents"].to(dtype=self.pipe.torch_dtype, device=self.device)
else:
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device)) latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
@@ -65,7 +87,8 @@ class LightningModelForT2ILoRA(pl.LightningModule):
noisy_latents, timestep=timestep, **prompt_emb, **extra_input, noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
use_gradient_checkpointing=self.use_gradient_checkpointing use_gradient_checkpointing=self.use_gradient_checkpointing
) )
loss = torch.nn.functional.mse_loss(noise_pred, training_target) loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
# Record log # Record log
self.log("train_loss", loss, prog_bar=True) self.log("train_loss", loss, prog_bar=True)
@@ -83,9 +106,13 @@ class LightningModelForT2ILoRA(pl.LightningModule):
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.denoising_model().named_parameters())) trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.denoising_model().named_parameters()))
trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
state_dict = self.pipe.denoising_model().state_dict() state_dict = self.pipe.denoising_model().state_dict()
lora_state_dict = {}
for name, param in state_dict.items(): for name, param in state_dict.items():
if name in trainable_param_names: if name in trainable_param_names:
checkpoint[name] = param lora_state_dict[name] = param
if self.state_dict_converter is not None:
lora_state_dict = self.state_dict_converter(lora_state_dict, alpha=self.lora_alpha)
checkpoint.update(lora_state_dict)
@@ -152,7 +179,7 @@ def add_general_parsers(parser):
"--precision", "--precision",
type=str, type=str,
default="16-mixed", default="16-mixed",
choices=["32", "16", "16-mixed"], choices=["32", "16", "16-mixed", "bf16"],
help="Training precision", help="Training precision",
) )
parser.add_argument( parser.add_argument(
@@ -173,6 +200,13 @@ def add_general_parsers(parser):
default=4.0, default=4.0,
help="The weight of the LoRA update matrices.", help="The weight of the LoRA update matrices.",
) )
parser.add_argument(
"--init_lora_weights",
type=str,
default="kaiming",
choices=["gaussian", "kaiming"],
help="The initializing method of LoRA weight.",
)
parser.add_argument( parser.add_argument(
"--use_gradient_checkpointing", "--use_gradient_checkpointing",
default=False, default=False,
@@ -210,6 +244,12 @@ def add_general_parsers(parser):
default=None, default=None,
help="Access key on ModelScope (https://www.modelscope.cn/). Required if you want to upload the model to ModelScope.", help="Access key on ModelScope (https://www.modelscope.cn/). Required if you want to upload the model to ModelScope.",
) )
parser.add_argument(
"--pretrained_lora_path",
type=str,
default=None,
help="Pretrained LoRA path. Required if the training is resumed.",
)
return parser return parser

View File

@@ -0,0 +1 @@
from .layers import *

View File

@@ -0,0 +1,95 @@
import torch, copy
from ..models.utils import init_weights_on_device
def cast_to(weight, dtype, device):
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight)
return r
class AutoWrappedModule(torch.nn.Module):
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
super().__init__()
self.module = module.to(dtype=offload_dtype, device=offload_device)
self.offload_dtype = offload_dtype
self.offload_device = offload_device
self.onload_dtype = onload_dtype
self.onload_device = onload_device
self.computation_dtype = computation_dtype
self.computation_device = computation_device
self.state = 0
def offload(self):
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.module.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.module.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def forward(self, *args, **kwargs):
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
module = self.module
else:
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
return module(*args, **kwargs)
class AutoWrappedLinear(torch.nn.Linear):
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
with init_weights_on_device(device=torch.device("meta")):
super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
self.weight = module.weight
self.bias = module.bias
self.offload_dtype = offload_dtype
self.offload_device = offload_device
self.onload_dtype = onload_dtype
self.onload_device = onload_device
self.computation_dtype = computation_dtype
self.computation_device = computation_device
self.state = 0
def offload(self):
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def forward(self, x, *args, **kwargs):
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
weight, bias = self.weight, self.bias
else:
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
return torch.nn.functional.linear(x, weight, bias)
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
for name, module in model.named_children():
for source_module, target_module in module_map.items():
if isinstance(module, source_module):
num_param = sum(p.numel() for p in module.parameters())
if max_num_param is not None and total_num_param + num_param > max_num_param:
module_config_ = overflow_module_config
else:
module_config_ = module_config
module_ = target_module(module, **module_config_)
setattr(model, name, module_)
total_num_param += num_param
break
else:
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param)
return total_num_param
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None):
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0)
model.vram_management_enabled = True

View File

@@ -1,21 +0,0 @@
name: DiffSynthStudio
channels:
- pytorch
- nvidia
- defaults
dependencies:
- python=3.9.16
- pip=23.0.1
- cudatoolkit
- pytorch
- cupy
- pip:
- transformers
- controlnet-aux==0.0.7
- streamlit
- streamlit-drawable-canvas
- imageio
- imageio[ffmpeg]
- safetensors
- einops
- sentencepiece

43
examples/ArtAug/README.md Normal file
View File

@@ -0,0 +1,43 @@
# FLUX Aesthetics Enhancement LoRA
## Introduction
This is a LoRA model trained for FLUX.1-dev, which enhances the aesthetic quality of images generated by the model. The improvements include, but are not limited to: rich details, beautiful lighting and shadows, aesthetic composition, and clear visuals. This model does not require any trigger words.
* Paper: https://arxiv.org/abs/2412.12888
* Github: https://github.com/modelscope/DiffSynth-Studio
* Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
* Demo: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (Coming soon)
## Methodology
![workflow](https://github.com/user-attachments/assets/cee969af-d49f-4480-911c-bedc1c095f9b)
The ArtAug project is inspired by reasoning approaches like GPT-o1, which rely on model interaction and self-correction. We developed a framework aimed at enhancing the capabilities of image generation models through interaction with image understanding models. The training process of ArtAug consists of the following steps:
1. **Synthesis-Understanding Interaction**: After generating an image using the image generation model, we employ a multimodal large language model (Qwen2-VL-72B) to analyze the image content and provide suggestions for modifications, which then lead to the regeneration of a higher quality image.
2. **Data Generation and Filtering**: Interactive generation involves long inference times and sometimes produce poor image content. Therefore, we generate a large batch of image pairs offline, filter them, and use them for subsequent training.
3. **Differential Training**: We apply differential training techniques to train a LoRA model, enabling it to learn the differences between images before and after enhancement, rather than directly training on the dataset of enhanced images.
4. **Iterative Enhancement**: The trained LoRA model is fused into the base model, and the entire process is repeated multiple times with the fused model until the interaction algorithm no longer provides significant enhancements. The LoRA models produced in each iteration are combined to produce this final model.
This model integrates the aesthetic understanding of Qwen2-VL-72B into FLUX.1[dev], leading to an improvement in the quality of generated images.
## Usage
Please see [./artaug_flux.py](./artaug_flux.py) for more details.
Since this model is encapsulated in the universal FLUX LoRA format, it can be loaded by most LoRA loaders, allowing you to integrate this LoRA model into your own workflow.
## Examples
|FLUX.1-dev|FLUX.1-dev + ArtAug LoRA|
|-|-|
|![image_1_base](https://github.com/user-attachments/assets/e1d5c505-b423-45fe-be01-25c2758f5417)|![image_1_enhance](https://github.com/user-attachments/assets/335908e3-d0bd-41c2-9d99-d10528a2d719)|
|![image_2_base](https://github.com/user-attachments/assets/7f38e8d4-3c62-492e-bd96-be60f0855037)|![image_2_enhance](https://github.com/user-attachments/assets/ae3a1daf-7a7c-44fd-bdbc-1d2a83bc3de3)|
|![image_3_base](https://github.com/user-attachments/assets/e2ae4879-9202-45d6-9df7-fbcbd2093d19)|![image_3_enhance](https://github.com/user-attachments/assets/4df6e5b9-65de-408b-88c6-51db39aad801)|
|![image_4_base](https://github.com/user-attachments/assets/dbc65387-60df-4a18-b1bb-45eaa5be5c1d)|![image_4_enhance](https://github.com/user-attachments/assets/fc19860d-3e28-468b-b013-8745255ac6db)|
|![image_5_base](https://github.com/user-attachments/assets/bb65c1ba-c0c6-4d3b-b3ef-bdbbb5f03a48)|![image_5_enhance](https://github.com/user-attachments/assets/03570c62-9a0b-428f-8c86-6e01c1421202)|
|![image_6_base](https://github.com/user-attachments/assets/18e9a4e7-2afd-4ca9-bc49-7736042c25dc)|![image_6_enhance](https://github.com/user-attachments/assets/aa73571f-098a-4e65-9eda-b9729ba379cd)|

View File

@@ -0,0 +1,14 @@
import torch
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
lora_path = download_customized_models(
model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1",
origin_file_path="merged_lora.safetensors",
local_dir="models/lora"
)[0]
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
model_manager.load_lora(lora_path, lora_alpha=1.0)
pipe = FluxImagePipeline.from_model_manager(model_manager)
image = pipe(prompt="a house", seed=0)
image.save("image_artaug.jpg")

View File

@@ -0,0 +1,91 @@
# ControlNet
We provide extensive ControlNet support. Taking the FLUX model as an example, we support many different ControlNet models that can be freely combined, even if their structures differ. Additionally, ControlNet models are compatible with high-resolution refinement and partition control techniques, enabling very powerful controllable image generation.
These examples are in [`flux_controlnet.py`](./flux_controlnet.py).
## Canny/Depth/Normal: Structure Control
Structural control is the most fundamental capability of the ControlNet model. By using Canny to extract edge information, or by utilizing depth maps and normal maps, we can extract the structure of an image, which can then serve as control information during the image generation process.
Model link: https://modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha
For example, if we generate an image of a cat and use a model like InstantX/FLUX.1-dev-Controlnet-Union-alpha that supports multiple control conditions, we can simultaneously enable both Canny and Depth controls to transform the environment into a twilight setting.
|![image_5](https://github.com/user-attachments/assets/19d2abc4-36ae-4163-a8da-df5732d1a737)|![image_6](https://github.com/user-attachments/assets/28378271-3782-484c-bd51-3d3311dd85c6)|
|-|-|
The control strength of ControlNet for structure can be adjusted. For example, in the case below, when we move the girl from summer to winter, we can appropriately lower the control strength of ControlNet so that the model will adapt to the content of the image and change her into warm clothes.
|![image_7](https://github.com/user-attachments/assets/a7b8555b-bfd9-4e92-aa77-16bca81b07e3)|![image_8](https://github.com/user-attachments/assets/a1bab36b-6cce-4f29-8233-4cb824b524a8)|
|-|-|
## Upscaler/Tile/Blur: High-Resolution Image Synthesis
There are many ControlNet models that support high definition, such as:
Model link: https://modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler, https://modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha, https://modelscope.cn/models/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro
These models can transform blurry, noisy low-quality images into clear ones. In DiffSynth-Studio, the native high-resolution patch processing technology supported by the framework can overcome the resolution limitations of the models, enabling image generation at resolutions of 2048 or even higher, significantly enhancing the capabilities of these models. In the example below, we can see that in the high-definition image enlarged to 2048 resolution, the cat's fur is rendered in exquisite detail, and the skin texture of the characters is delicate and realistic.
|![image_1](https://github.com/user-attachments/assets/9038158a-118c-4ad7-ab01-22865f6a06fc)|![image_2](https://github.com/user-attachments/assets/88583a33-cd74-4cb9-8fd4-c6e14c0ada0c)|
|-|-|
|![image_3](https://github.com/user-attachments/assets/13061ecf-bb57-448a-82c6-7e4655c9cd85)|![image_4](https://github.com/user-attachments/assets/0b7ae80f-de58-4d1d-a49c-ad17e7631bdc)|
|-|-|
## Inpaint: Image Restoration
The Inpaint ControlNet model can repaint specific areas in an image. For example, we can put sunglasses on a cat.
Model link: https://modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta
|![image_9](https://github.com/user-attachments/assets/babddad0-2d67-4624-b77a-c953250ebdab)|![mask_9](https://github.com/user-attachments/assets/d5bc2878-1817-457a-bdfa-200f955233d3)|![image_10](https://github.com/user-attachments/assets/e3197f2c-190b-4522-83ab-a2e0451b39f6)|
|-|-|-|
However, we noticed that the head movements of the cat have changed. If we want to preserve the original structural features, we can use the Canny, Depth, and Normal models. DiffSynth-Studio provides seamless support for ControlNet of different structures. By using a Normal ControlNet, we can ensure that the structure of the image remains unchanged during local redrawing.
Model link: https://modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Surface-Normals
|![image_11](https://github.com/user-attachments/assets/c028e6fc-5125-4cba-b35a-b6211c2e6600)|![mask_11](https://github.com/user-attachments/assets/1928ee9a-7594-4c6e-9c71-5bd0b043d8f4)|![image_12](https://github.com/user-attachments/assets/97b3b9e1-f821-405e-971b-9e1c31a209aa)|
|-|-|-|
## MultiControlNet+MultiDiffusion: Fine-Grained Control
DiffSynth-Studio not only supports the simultaneous activation of multiple ControlNet structures, but also allows for the partitioned control of content within an image using different prompts. Additionally, it supports the chunk processing of ultra-high-resolution large images, enabling us to achieve extremely detailed high-level control. Next, we will showcase the creative process behind a beautiful image.
First, use the prompt "a beautiful Asian woman and a cat on a bed. The woman wears a dress" to generate a cat and a young girl.
![image_13](https://github.com/user-attachments/assets/8da006e4-0e68-4fa5-b407-31ef5dbe8e5a)
Then, enable Inpaint ControlNet and Canny ControlNet.
Model link: https://modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta, https://modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha
We control the image using two component.
|Prompt: an orange cat, highly detailed|Prompt: a girl wearing a red camisole|
|-|-|
|![mask_13_1](https://github.com/user-attachments/assets/188530a0-913c-48db-a7f1-62f0384bfdc3)|![mask_13_2](https://github.com/user-attachments/assets/99c4d0d5-8cc3-47a0-8e56-ceb37db4dfdc)|
Generate!
![image_14](https://github.com/user-attachments/assets/f5b9d3dd-a690-4597-91a8-a019c6fc2523)
The background is a bit blurry, so we use deblurring LoRA for image-to-image generation.
Model link: https://modelscope.cn/models/LiblibAI/FLUX.1-dev-LoRA-AntiBlur
![image_15](https://github.com/user-attachments/assets/32ed2667-2260-4d80-aaa9-4435d6920a2a)
The entire image is much clearer now. Next, let's use the high-definition model to increase the resolution to 4096*4096!
Model link: https://modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler
![image_17](https://github.com/user-attachments/assets/1a688a12-1544-4973-8aca-aa3a23cb34c1)
Zoom in to see details.
![image_17_cropped](https://github.com/user-attachments/assets/461a1fbc-9ffa-4da5-80fd-e1af9667c804)
Enjoy!

Some files were not shown because too many files have changed in this diff Show More