mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 23:08:13 +00:00
Compare commits
938 Commits
v1.0.0
...
cache_lear
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a87910bc65 | ||
|
|
f48662e863 | ||
|
|
8d8bfc7f54 | ||
|
|
8e15dcd289 | ||
|
|
586ac9d8a6 | ||
|
|
288bbc7128 | ||
|
|
5002ac74dc | ||
|
|
863a6ba597 | ||
|
|
b08bc1470d | ||
|
|
94b57e9677 | ||
|
|
3fb037d33a | ||
|
|
b3b63fef3e | ||
|
|
f6d85f3c2e | ||
|
|
2f22e598b7 | ||
|
|
888caf8b88 | ||
|
|
b6e39c97af | ||
|
|
02124c4034 | ||
|
|
fddc98ff16 | ||
|
|
0dfcd25cf3 | ||
|
|
ff10fde47f | ||
|
|
dc94614c80 | ||
|
|
e56a4d5730 | ||
|
|
3f8468893a | ||
|
|
1b47e1dc22 | ||
|
|
b0bf78e915 | ||
|
|
abdf66d09e | ||
|
|
27b1fe240b | ||
|
|
1635897516 | ||
|
|
8d172127cd | ||
|
|
fccb1ecdd7 | ||
|
|
c0f7e1db7c | ||
|
|
53890bafa4 | ||
|
|
6886f7ba35 | ||
|
|
afd48cd706 | ||
|
|
24b68c2392 | ||
|
|
280ff7cca6 | ||
|
|
b4b62e2f7c | ||
|
|
051b957adb | ||
|
|
ca9b5e64ea | ||
|
|
6d1be405b9 | ||
|
|
25c3a3d3e2 | ||
|
|
49bc84f78e | ||
|
|
25a9e75030 | ||
|
|
2a7ac73eb5 | ||
|
|
f4f991d409 | ||
|
|
a781138413 | ||
|
|
91a5623976 | ||
|
|
28cd355aba | ||
|
|
005389fca7 | ||
|
|
a6282056eb | ||
|
|
21a6eb8e2f | ||
|
|
98ab238340 | ||
|
|
2070bbd925 | ||
|
|
1c8a0f8317 | ||
|
|
9f07d65ebb | ||
|
|
5f1d5adfce | ||
|
|
4f23caa55f | ||
|
|
b4f6a4de6c | ||
|
|
53fe42af1b | ||
|
|
ee9a3b4405 | ||
|
|
b1a2782ad7 | ||
|
|
8d303b47e9 | ||
|
|
00da4b6c4f | ||
|
|
22695e9be0 | ||
|
|
3140199c96 | ||
|
|
98290190ec | ||
|
|
3f4de2cc7f | ||
|
|
8d0df403ca | ||
|
|
4e9db263b0 | ||
|
|
d12bf71bcc | ||
|
|
35e0776022 | ||
|
|
ffb7a138f7 | ||
|
|
548304667f | ||
|
|
273143136c | ||
|
|
030ebe649a | ||
|
|
90921d2293 | ||
|
|
b61131c693 | ||
|
|
37fbb3248a | ||
|
|
d13f533f42 | ||
|
|
b3cc652dea | ||
|
|
d879d66c62 | ||
|
|
848bfd6993 | ||
|
|
269da09f6e | ||
|
|
e30514a00c | ||
|
|
3743b1307c | ||
|
|
a835df984c | ||
|
|
3e4b47e424 | ||
|
|
dd8d902624 | ||
|
|
a8b340c098 | ||
|
|
88497b5c13 | ||
|
|
1e90c72d94 | ||
|
|
3dd82a738e | ||
|
|
8ad2d9884b | ||
|
|
70f531b724 | ||
|
|
37c2868b61 | ||
|
|
a18e6233b5 | ||
|
|
2336d5f6b3 | ||
|
|
b6ccb362b9 | ||
|
|
ae52d93694 | ||
|
|
ad91d41601 | ||
|
|
dce77ec4d1 | ||
|
|
5c0b07d939 | ||
|
|
19e429d889 | ||
|
|
209a350c0f | ||
|
|
a3c2744a43 | ||
|
|
55e8346da3 | ||
|
|
b7979b2633 | ||
|
|
c90aaa2798 | ||
|
|
0c617d5d9e | ||
|
|
fd87b72754 | ||
|
|
db75508ba0 | ||
|
|
acba342a63 | ||
|
|
d16877e695 | ||
|
|
e99cdcf3b8 | ||
|
|
a236a17f17 | ||
|
|
03e530dc39 | ||
|
|
6be244233a | ||
|
|
544c391936 | ||
|
|
f4d06ce3fc | ||
|
|
ffedb9eb52 | ||
|
|
381067515c | ||
|
|
00f2d1aa5d | ||
|
|
8cc3bece6d | ||
|
|
f4bf592064 | ||
|
|
3235393fb5 | ||
|
|
3b662da31e | ||
|
|
19ce3048c1 | ||
|
|
de0aa946f7 | ||
|
|
f376202a49 | ||
|
|
a13ecfc46b | ||
|
|
10a1853eda | ||
|
|
0efab85674 | ||
|
|
f45a0ffd02 | ||
|
|
8ba528a8f6 | ||
|
|
dd479e5bff | ||
|
|
bac39b1cd2 | ||
|
|
c1c9a4853b | ||
|
|
3ee5f53a36 | ||
|
|
32449a6aa0 | ||
|
|
a6884f6b3a | ||
|
|
b078666640 | ||
|
|
7604ca1e52 | ||
|
|
62c3d406d9 | ||
|
|
5745c9f200 | ||
|
|
86829120c2 | ||
|
|
60ac96525b | ||
|
|
07b1f5702f | ||
|
|
507e7e5d36 | ||
|
|
ab8580f77e | ||
|
|
6454259853 | ||
|
|
9cc1697d4d | ||
|
|
c758769a02 | ||
|
|
a5935e973a | ||
|
|
9834d72e4d | ||
|
|
01234e59c0 | ||
|
|
8f1d10fb43 | ||
|
|
20e1aaf908 | ||
|
|
c6722b3f56 | ||
|
|
11315d7a40 | ||
|
|
68d97a9844 | ||
|
|
4629d4cf9e | ||
|
|
3cb5cec906 | ||
|
|
b7e16b9034 | ||
|
|
83d1e7361f | ||
|
|
1547c3f786 | ||
|
|
bfaaf12bf4 | ||
|
|
47545e1aab | ||
|
|
7c6905a432 | ||
|
|
2883bc1b76 | ||
|
|
78d8842ddf | ||
|
|
5821a664a0 | ||
|
|
ab9aa1a087 | ||
|
|
a4d34d9f3d | ||
|
|
127cc9007a | ||
|
|
e1f5db5f5c | ||
|
|
e316fb717f | ||
|
|
64c5139502 | ||
|
|
5da9611a74 | ||
|
|
733750d01b | ||
|
|
edc95359d0 | ||
|
|
f2d0241e26 | ||
|
|
7b5d7f4af5 | ||
|
|
1fa9a6c60c | ||
|
|
51efa128d3 | ||
|
|
421c6a5fce | ||
|
|
864080d8f2 | ||
|
|
ba372dd295 | ||
|
|
1ceb02f673 | ||
|
|
30f93161fb | ||
|
|
3ee3cc3104 | ||
|
|
c2218f5c73 | ||
|
|
72af7122b3 | ||
|
|
afd101f345 | ||
|
|
1313f4dd63 | ||
|
|
8332ecebb7 | ||
|
|
401d7d74a5 | ||
|
|
b8d7d55568 | ||
|
|
a30ed9093f | ||
|
|
b73e713028 | ||
|
|
e0eabaa426 | ||
|
|
538017177a | ||
|
|
30292d9411 | ||
|
|
b168d7aa8b | ||
|
|
8ea45b0daa | ||
|
|
0a1c172a00 | ||
|
|
77fac2a03f | ||
|
|
084bc2fc78 | ||
|
|
c63d474b60 | ||
|
|
7540568156 | ||
|
|
c5d426c254 | ||
|
|
a36f2f6032 | ||
|
|
ed256ef8be | ||
|
|
15079a6cb8 | ||
|
|
c084d6377b | ||
|
|
e9bc42f233 | ||
|
|
0d6de58af9 | ||
|
|
acbf932974 | ||
|
|
9d64ed7042 | ||
|
|
0b4b337e9a | ||
|
|
99908d9a1c | ||
|
|
73ced7a46d | ||
|
|
32b8b9b51e | ||
|
|
f6534a5b63 | ||
|
|
034c9b6c60 | ||
|
|
76335e0fe5 | ||
|
|
c0b589d934 | ||
|
|
833ba1e1fa | ||
|
|
7a5974d964 | ||
|
|
b0abdaffb4 | ||
|
|
e9f29bc402 | ||
|
|
1a7f482fbd | ||
|
|
3a0d51d100 | ||
|
|
bffdb901ed | ||
|
|
d93e8738cd | ||
|
|
7e5ce5d5c9 | ||
|
|
7aef554d83 | ||
|
|
090074e395 | ||
|
|
2dcdeefca8 | ||
|
|
452a6ca5cf | ||
|
|
d6cf20ef33 | ||
|
|
efdd6a59b6 | ||
|
|
42ec7b08eb | ||
|
|
d049fb6d1d | ||
|
|
144365b07d | ||
|
|
cb8de6be1b | ||
|
|
8c13362dcf | ||
|
|
c13fd7e0ee | ||
|
|
958ebf1352 | ||
|
|
b6da77e468 | ||
|
|
260e32217f | ||
|
|
5cee326f92 | ||
|
|
1d240994e7 | ||
|
|
a0bae07825 | ||
|
|
ff71720297 | ||
|
|
dea85643e6 | ||
|
|
6a46f32afe | ||
|
|
4641d0f360 | ||
|
|
826bab5962 | ||
|
|
5b6d112c15 | ||
|
|
febdaf6067 | ||
|
|
0a78bb9d38 | ||
|
|
9cea10cc69 | ||
|
|
caa17da5b9 | ||
|
|
fdeb363fa2 | ||
|
|
4147473c81 | ||
|
|
8a0bd7c377 | ||
|
|
b541b9bed2 | ||
|
|
419d47c195 | ||
|
|
ac2e859960 | ||
|
|
6663dca015 | ||
|
|
86e509ad31 | ||
|
|
8fcfa1dd2d | ||
|
|
2b7a2548b4 | ||
|
|
f0916e6bae | ||
|
|
822e80ec2f | ||
|
|
04e39f7de5 | ||
|
|
ce0b948655 | ||
|
|
c795e35142 | ||
|
|
f7c01f1367 | ||
|
|
cb49f0283f | ||
|
|
6a45815b23 | ||
|
|
8dae8d7bc8 | ||
|
|
f6418004bb | ||
|
|
c4b97cd591 | ||
|
|
b6d1ff01e0 | ||
|
|
0d81626fe7 | ||
|
|
e3f47a799b | ||
|
|
e014cad820 | ||
|
|
89bf3ce5cf | ||
|
|
3ebe118f23 | ||
|
|
7f719cefe6 | ||
|
|
46bd05b54d | ||
|
|
613dafbd09 | ||
|
|
952933eeb1 | ||
|
|
c0172e70b1 | ||
|
|
6ab426e641 | ||
|
|
d0467a7e8d | ||
|
|
36838a05ee | ||
|
|
5e6f9f89f1 | ||
|
|
2dad9a319c | ||
|
|
9ec0652339 | ||
|
|
7e348083ae | ||
|
|
29b12b2f4e | ||
|
|
b3f57ed920 | ||
|
|
c9fea729d8 | ||
|
|
9d0683df25 | ||
|
|
838b8109b1 | ||
|
|
3a9621f6da | ||
|
|
fff2c89360 | ||
|
|
ce61bef2b0 | ||
|
|
123f6dbadb | ||
|
|
f9ce261a0e | ||
|
|
d93de98a21 | ||
|
|
ad1da43476 | ||
|
|
398b1dbd7a | ||
|
|
9f6922bba9 | ||
|
|
f11a91e610 | ||
|
|
7ed09bb78d | ||
|
|
ac931856d5 | ||
|
|
2d09318236 | ||
|
|
7dc49bd036 | ||
|
|
4d16bdf853 | ||
|
|
01a1f48f70 | ||
|
|
6a9d875d65 | ||
|
|
f1c96d31b4 | ||
|
|
aafcca8d77 | ||
|
|
bf369cad4d | ||
|
|
024fdad76d | ||
|
|
e1c2eda5f5 | ||
|
|
0b574cc0c2 | ||
|
|
3212c83398 | ||
|
|
49f9a11eb3 | ||
|
|
fa36739f01 | ||
|
|
42e9764b60 | ||
|
|
f7f5c07570 | ||
|
|
ec1a936624 | ||
|
|
6e6136586c | ||
|
|
34766863f8 | ||
|
|
1d76d5e828 | ||
|
|
250540a398 | ||
|
|
46f3c38c37 | ||
|
|
9a8982efb1 | ||
|
|
3c815cce4b | ||
|
|
39d199c8bb | ||
|
|
f5506d1e13 | ||
|
|
166a8734fe | ||
|
|
b2273ec568 | ||
|
|
89c4e3bdb6 | ||
|
|
051ebf3439 | ||
|
|
7cfadc2ca8 | ||
|
|
32cf5d32ce | ||
|
|
4f7c3b6a1e | ||
|
|
57128dc89f | ||
|
|
d20680baae | ||
|
|
970403f78e | ||
|
|
bee2a969e5 | ||
|
|
2803ffcb38 | ||
|
|
d3224e1fdc | ||
|
|
3c2f85606f | ||
|
|
1f25ad416b | ||
|
|
d0b9b25db7 | ||
|
|
ef09db69cd | ||
|
|
84ede171fd | ||
|
|
6f4e38276e | ||
|
|
a3b67436a6 | ||
|
|
829ca3414b | ||
|
|
3915bc3ee6 | ||
|
|
4299c999b5 | ||
|
|
6bae70eee0 | ||
|
|
6452edb738 | ||
|
|
bc739c78cd | ||
|
|
2feaeb1a64 | ||
|
|
09360cf4f5 | ||
|
|
26461c1963 | ||
|
|
0412fc7232 | ||
|
|
8d2f6ad32e | ||
|
|
1625894694 | ||
|
|
c35f2d8bda | ||
|
|
a8ee7ec9ef | ||
|
|
46d390cf8a | ||
|
|
6b8e3880ff | ||
|
|
c1c3be2420 | ||
|
|
b2554db100 | ||
|
|
b63f81c6e3 | ||
|
|
cb2caa3a36 | ||
|
|
f0ea049faa | ||
|
|
0954e8a017 | ||
|
|
e4178e2501 | ||
|
|
0b860abf1b | ||
|
|
8c558b3526 | ||
|
|
aef982a53c | ||
|
|
db124fa6bc | ||
|
|
2ed3860085 | ||
|
|
87ab7d020b | ||
|
|
03c8fd5e61 | ||
|
|
9c51623fc2 | ||
|
|
8ec545d70c | ||
|
|
79fa8607dc | ||
|
|
7df48fc2b5 | ||
|
|
8ef91b3672 | ||
|
|
2860470b4e | ||
|
|
c125728ce0 | ||
|
|
63eaa9e7ea | ||
|
|
158567ca20 | ||
|
|
de4e2703ca | ||
|
|
9e683bfe25 | ||
|
|
0befa05014 | ||
|
|
283f35447a | ||
|
|
c35414a652 | ||
|
|
68aafab09e | ||
|
|
29663b25a6 | ||
|
|
2861ec4d9f | ||
|
|
729c512c66 | ||
|
|
2af3a6f6a2 | ||
|
|
05dba91f79 | ||
|
|
b8f05bb342 | ||
|
|
5f68727ad3 | ||
|
|
bba44173d2 | ||
|
|
9015d08927 | ||
|
|
1dfa32f0ae | ||
|
|
c98e31fee3 | ||
|
|
f3d2470e84 | ||
|
|
4ad6bd4e23 | ||
|
|
3aed244c6f | ||
|
|
783c435d88 | ||
|
|
cd1ba7281b | ||
|
|
970ff12ff5 | ||
|
|
2827b60330 | ||
|
|
b3df7e5e21 | ||
|
|
c18b5a0c71 | ||
|
|
b9f7d08219 | ||
|
|
11ea986e67 | ||
|
|
b06066f25b | ||
|
|
0b3400bca3 | ||
|
|
0d509241c0 | ||
|
|
ebeda32215 | ||
|
|
ff95c56884 | ||
|
|
2871535f3b | ||
|
|
e3c5d2540b | ||
|
|
22705a44b4 | ||
|
|
43a8d9768c | ||
|
|
dbee3a1ae0 | ||
|
|
f1f00c4255 | ||
|
|
c05b1a2fd0 | ||
|
|
55951590f5 | ||
|
|
1384de0353 | ||
|
|
05c6b49b90 | ||
|
|
d19fcc8c04 | ||
|
|
af6b1d4246 | ||
|
|
cbd10fb27d | ||
|
|
836fa5c957 | ||
|
|
dc066aca2d | ||
|
|
44f6ffbf56 | ||
|
|
0a24d0819f | ||
|
|
f0106cd48c | ||
|
|
dee4075380 | ||
|
|
a692389df0 | ||
|
|
629e9be4ce | ||
|
|
3a3d9010b8 | ||
|
|
a25334b352 | ||
|
|
00279a8375 | ||
|
|
89397c755a | ||
|
|
77676b5cea | ||
|
|
0f4b08daa3 | ||
|
|
63b2c51e11 | ||
|
|
8a9dbbd3ba | ||
|
|
22d28665fe | ||
|
|
1363a0559f | ||
|
|
9cb887015b | ||
|
|
789dade026 | ||
|
|
9bb51fe879 | ||
|
|
d9c812818d | ||
|
|
c8e9a96196 | ||
|
|
6143af4654 | ||
|
|
9458e382b0 | ||
|
|
4f2d9226cf | ||
|
|
f688a469b1 | ||
|
|
c8ea3b3356 | ||
|
|
6e9472b470 | ||
|
|
a5c03c5272 | ||
|
|
8068ac2592 | ||
|
|
5f80e7ac5e | ||
|
|
157e0be49d | ||
|
|
3dbe271aab | ||
|
|
44e2eecdf1 | ||
|
|
8c226e83a6 | ||
|
|
009f26bb40 | ||
|
|
fcf2fbc07f | ||
|
|
b603acd36a | ||
|
|
6c8bb6438b | ||
|
|
8072d3839d | ||
|
|
c8ad643374 | ||
|
|
31f9df5e62 | ||
|
|
e2f415524a | ||
|
|
3eb7e7530e | ||
|
|
916aa54595 | ||
|
|
6ddbd43f7b | ||
|
|
a37a83ecc3 | ||
|
|
f2a0d0c85f | ||
|
|
93194f44e8 | ||
|
|
c4e5033532 | ||
|
|
cc6cd26733 | ||
|
|
1113d305d1 | ||
|
|
6d5f8b7423 | ||
|
|
1b3c204d20 | ||
|
|
1788d50f0a | ||
|
|
e7a21dbf0b | ||
|
|
3b3e1e4d44 | ||
|
|
24426e3a32 | ||
|
|
31369bab15 | ||
|
|
551721658b | ||
|
|
46f052375f | ||
|
|
c2d35a2157 | ||
|
|
4c052e42bc | ||
|
|
a88613555d | ||
|
|
c164519ef1 | ||
|
|
afff5ffb21 | ||
|
|
a8481fd5e1 | ||
|
|
8584e50309 | ||
|
|
9f3e02f167 | ||
|
|
7ad9b9aecc | ||
|
|
b6a111d3a2 | ||
|
|
bd6f2695a9 | ||
|
|
6eecc9d442 | ||
|
|
35269783d7 | ||
|
|
9534a78167 | ||
|
|
830b1b7202 | ||
|
|
436a91e0c9 | ||
|
|
40760ab88b | ||
|
|
8badd63a2d | ||
|
|
b1afff1728 | ||
|
|
6e977e1181 | ||
|
|
62f6ca2b8a | ||
|
|
4e00c109e3 | ||
|
|
8f10a9c353 | ||
|
|
a3a35acc7e | ||
|
|
675eefa07e | ||
|
|
dbef6122e9 | ||
|
|
d150bcf622 | ||
|
|
451aab0116 | ||
|
|
3edf3583b1 | ||
|
|
ef2a7abad4 | ||
|
|
32f630ff5f | ||
|
|
109a0a0d49 | ||
|
|
4f01b37a2a | ||
|
|
cc6306136c | ||
|
|
419ace37f3 | ||
|
|
ccf24c363f | ||
|
|
b7a1ac6671 | ||
|
|
e54c0a8468 | ||
|
|
5f4cb32255 | ||
|
|
7b6cf39618 | ||
|
|
bf81de0c88 | ||
|
|
b36cad6929 | ||
|
|
b161bd6dfd | ||
|
|
538cfcbb77 | ||
|
|
a4105d2c0e | ||
|
|
553b341f5f | ||
|
|
e9e24b8cf1 | ||
|
|
1b693d0028 | ||
|
|
a4c3c07229 | ||
|
|
6b24748c80 | ||
|
|
8f2f8646eb | ||
|
|
e3ac438f5a | ||
|
|
b731628112 | ||
|
|
0dc56d9dcc | ||
|
|
b925b402e2 | ||
|
|
61d9653536 | ||
|
|
53f01e72e6 | ||
|
|
55e5e373dd | ||
|
|
4a0921ada1 | ||
|
|
5129d3dc52 | ||
|
|
ee9bab80f2 | ||
|
|
cd8884c9ef | ||
|
|
46744362de | ||
|
|
0f0cdc3afc | ||
|
|
a33c63af87 | ||
|
|
3cc9764bc9 | ||
|
|
f6c6e3c640 | ||
|
|
60a9db706e | ||
|
|
a98700feb2 | ||
|
|
5418ca781e | ||
|
|
71eee780fb | ||
|
|
4864453e0a | ||
|
|
c5a32f76c2 | ||
|
|
c4ed3d3e4b | ||
|
|
803ddcccc7 | ||
|
|
4cd51fecf2 | ||
|
|
3b0211a547 | ||
|
|
e88328d152 | ||
|
|
52896fa8dd | ||
|
|
c7035ad911 | ||
|
|
070811e517 | ||
|
|
7e010d88a5 | ||
|
|
4e43d4d461 | ||
|
|
d7efe7e539 | ||
|
|
633f789c47 | ||
|
|
88607f404e | ||
|
|
6d405b669c | ||
|
|
d0fed6ba72 | ||
|
|
64eaa0d76a | ||
|
|
3dc28f428f | ||
|
|
3c8a3fe2e1 | ||
|
|
e28c246bcc | ||
|
|
04d03500ff | ||
|
|
54081bdcbb | ||
|
|
d8b250607a | ||
|
|
1e58e6ef82 | ||
|
|
42cb7d96bb | ||
|
|
39890f023f | ||
|
|
e425753f79 | ||
|
|
ca40074d72 | ||
|
|
1fd3d67379 | ||
|
|
3acd9c73be | ||
|
|
32422b49ee | ||
|
|
5c4d3185fb | ||
|
|
762bcbee58 | ||
|
|
6b411ada16 | ||
|
|
a25bd74d8b | ||
|
|
fb5fc09bad | ||
|
|
3fdba19e02 | ||
|
|
4bec2983a9 | ||
|
|
03ea27893f | ||
|
|
718b45f2af | ||
|
|
63a79eeb2a | ||
|
|
e757013a14 | ||
|
|
a05f647633 | ||
|
|
7604be0301 | ||
|
|
945b43492e | ||
|
|
b548d7caf2 | ||
|
|
6e316fd825 | ||
|
|
84fb61aaaf | ||
|
|
50a9946b57 | ||
|
|
384d1a8198 | ||
|
|
a58c193d0c | ||
|
|
34a5ef8c15 | ||
|
|
41e3e4e157 | ||
|
|
e576d71908 | ||
|
|
906aadbf1b | ||
|
|
bf0bf2d5ba | ||
|
|
fe0fff1399 | ||
|
|
50fceb84d2 | ||
|
|
100da41034 | ||
|
|
c382237833 | ||
|
|
98ac191750 | ||
|
|
2f73dbe7a3 | ||
|
|
a66203a391 | ||
|
|
fab61f614b | ||
|
|
6b67a11ad6 | ||
|
|
91f77d268c | ||
|
|
eb4d5187d8 | ||
|
|
ee4b02247c | ||
|
|
da8e1fe7e4 | ||
|
|
3db824c281 | ||
|
|
df2ecafd3f | ||
|
|
217652d28e | ||
|
|
f64c766dcd | ||
|
|
076fd85556 | ||
|
|
c7912ed827 | ||
|
|
e63f9d6993 | ||
|
|
d80ef3a677 | ||
|
|
852c3d831f | ||
|
|
ceb92ee7aa | ||
|
|
3a75026176 | ||
|
|
6a92b08244 | ||
|
|
38bc785ea9 | ||
|
|
a466fdca8f | ||
|
|
f9f49e3c78 | ||
|
|
61a30673c2 | ||
|
|
a48822ec00 | ||
|
|
b6c3d2b74a | ||
|
|
5006c2176c | ||
|
|
d3d3556ff6 | ||
|
|
6fa8dbe077 | ||
|
|
a57749ef60 | ||
|
|
b5c1d33e58 | ||
|
|
34a9f82865 | ||
|
|
18dc6cb962 | ||
|
|
490d420d82 | ||
|
|
0aca943a39 | ||
|
|
c760208614 | ||
|
|
fad7aea58a | ||
|
|
b42eb1444c | ||
|
|
25a247dd3f | ||
|
|
7792017a02 | ||
|
|
0219e8d2f3 | ||
|
|
1d309a14a3 | ||
|
|
7df73ceaaf | ||
|
|
0dbb3d333f | ||
|
|
1419bec53d | ||
|
|
cf12723c89 | ||
|
|
4268f5466b | ||
|
|
b9f5a00d98 | ||
|
|
7d44dc99fb | ||
|
|
b20de1b44d | ||
|
|
366ee0f542 | ||
|
|
bed770248b | ||
|
|
020560d2b5 | ||
|
|
af7d305f00 | ||
|
|
427232cbc0 | ||
|
|
2899283c01 | ||
|
|
9cff769fbd | ||
|
|
23e33273f1 | ||
|
|
f191353cf4 | ||
|
|
66a094fc84 | ||
|
|
3681adc5ac | ||
|
|
4449faaa01 | ||
|
|
991ba162bd | ||
|
|
77d0f4d297 | ||
|
|
a834371d50 | ||
|
|
acda7d891a | ||
|
|
7434ec8fcd | ||
|
|
0699212665 | ||
|
|
f47de78b59 | ||
|
|
5fdc8039ec | ||
|
|
46d4616e23 | ||
|
|
2e597335be | ||
|
|
d346300162 | ||
|
|
1df7387f1b | ||
|
|
75d62a02d1 | ||
|
|
9db26879df | ||
|
|
7beac7972e | ||
|
|
72cac18d3e | ||
|
|
9f8112ec34 | ||
|
|
d9fad821b2 | ||
|
|
c0889c2564 | ||
|
|
913591c13e | ||
|
|
aaf13d6e4a | ||
|
|
90c07fec61 | ||
|
|
cc6c3c0807 | ||
|
|
ce2476ab9b | ||
|
|
9e70c49317 | ||
|
|
bf1c99645b | ||
|
|
c2478ff284 | ||
|
|
a60bf3cd5f | ||
|
|
34231907d0 | ||
|
|
840dab58cd | ||
|
|
d5ceca0663 | ||
|
|
8cf3422688 | ||
|
|
6f743fc4b6 | ||
|
|
991b133bff | ||
|
|
3b010043de | ||
|
|
088ea29e6e | ||
|
|
b8b135ff73 | ||
|
|
2872fdaf48 | ||
|
|
9853f83454 | ||
|
|
fd6e661203 | ||
|
|
c087f68d74 | ||
|
|
b6620f3dde | ||
|
|
3228c3e085 | ||
|
|
6cc5fd6d1e | ||
|
|
4f6d5e7074 | ||
|
|
6a999e1127 | ||
|
|
e3d89cec0c | ||
|
|
1b6e96a820 | ||
|
|
e38ccf4c2f | ||
|
|
010c801081 | ||
|
|
edc9272e55 | ||
|
|
405ca6be33 | ||
|
|
c06ea2271a | ||
|
|
0692e8b1e1 | ||
|
|
aa23356420 | ||
|
|
00a610e5ad | ||
|
|
2e39dcc0d3 | ||
|
|
03d3a26f6f | ||
|
|
309fa9cf51 | ||
|
|
65aab8adea | ||
|
|
3d48b287a3 | ||
|
|
29cebf0bec | ||
|
|
95a0f0bedc | ||
|
|
77e0617861 | ||
|
|
469a0405a1 | ||
|
|
46f191ffe7 | ||
|
|
ec7ac20def | ||
|
|
3f410b0b77 | ||
|
|
8e06cac0df | ||
|
|
e5099f4e74 | ||
|
|
447adef472 | ||
|
|
a849b05e5a | ||
|
|
b048f1b1de | ||
|
|
f7848f9560 | ||
|
|
236b56d285 | ||
|
|
42a717054a | ||
|
|
263166768e | ||
|
|
7a45b7efa7 | ||
|
|
54ed532e3e | ||
|
|
05e2028c5d | ||
|
|
79249063b8 | ||
|
|
31ebec7a72 | ||
|
|
919d399fdb | ||
|
|
32a7a1487d | ||
|
|
8c2671ce40 | ||
|
|
5d1005a7c8 | ||
|
|
b84f906964 | ||
|
|
7c0520d029 | ||
|
|
9d09121fbc | ||
|
|
7f2a5424d4 | ||
|
|
00830f0ecd | ||
|
|
fd7737af7d | ||
|
|
f2130c4c25 | ||
|
|
4f40683fd8 | ||
|
|
5fc9e53eec | ||
|
|
27e3cea285 | ||
|
|
ee770fa68f | ||
|
|
9cb4aa16eb | ||
|
|
92d990629f | ||
|
|
ba58f1bc0b | ||
|
|
02fcfd530f | ||
|
|
095e8a3de8 | ||
|
|
e17ad83fb5 | ||
|
|
e7c41151ec | ||
|
|
7f4ba62d4f | ||
|
|
71b17a3a53 | ||
|
|
d46b8b8fd7 | ||
|
|
a671070a28 | ||
|
|
4600d5351b | ||
|
|
75bba5b8e5 | ||
|
|
8d1d1536d3 | ||
|
|
a7050a185b | ||
|
|
d345541c2d | ||
|
|
bd028e4c66 | ||
|
|
d6f4fb67cc | ||
|
|
4378b540cf | ||
|
|
39ddb7c3e3 | ||
|
|
344cbd3286 | ||
|
|
d4ba173b53 | ||
|
|
c56ce656b2 | ||
|
|
9377214518 | ||
|
|
900a1c095f | ||
|
|
7e97a96840 | ||
|
|
69f272d7ba | ||
|
|
a653554bd9 | ||
|
|
6a25006544 | ||
|
|
8cfe4820f6 | ||
|
|
c8021d4224 | ||
|
|
3a64cc27b5 | ||
|
|
2edc485ec1 | ||
|
|
a6d6553cee | ||
|
|
45feef9413 | ||
|
|
105fe3961c | ||
|
|
d381c7b186 | ||
|
|
5e8334c0bf | ||
|
|
2ea8a16afb | ||
|
|
aa054db1c7 | ||
|
|
07d70a6a56 | ||
|
|
747572e62c | ||
|
|
72ed76e89e | ||
|
|
a403cb04f3 | ||
|
|
ed71184854 | ||
|
|
dfbf43e463 | ||
|
|
7d7d72dcfe | ||
|
|
540c036988 | ||
|
|
58f89ceec9 | ||
|
|
4e3a184199 | ||
|
|
22e4ae99e8 | ||
|
|
75ab786afc | ||
|
|
e5c72ba1f2 | ||
|
|
66873d7d64 | ||
|
|
a0d1d5bcea | ||
|
|
fa0fa95bb6 | ||
|
|
41ea2f811a | ||
|
|
ec352cfce2 | ||
|
|
aade874241 | ||
|
|
c01eb653d7 | ||
|
|
892f80c265 | ||
|
|
2e487a2c55 | ||
|
|
a34e3ba338 | ||
|
|
c414f4cb12 | ||
|
|
d91c603875 | ||
|
|
7f899dcfca | ||
|
|
5f12fd4346 | ||
|
|
a7197f846b | ||
|
|
ac81fa7a9f | ||
|
|
091df1f1e7 | ||
|
|
a9fbfa108f | ||
|
|
44a8bf4143 | ||
|
|
3da8aa257b | ||
|
|
884dd749a0 | ||
|
|
c697591d6e | ||
|
|
0b706e03e7 | ||
|
|
447e75cd06 | ||
|
|
7f76c8809c | ||
|
|
cde1f81df6 | ||
|
|
c21ed1e478 | ||
|
|
a8cb4a21d1 | ||
|
|
0b9e673fa2 | ||
|
|
d242af8e22 | ||
|
|
76bd931d79 | ||
|
|
995f3374f1 | ||
|
|
1887885274 | ||
|
|
ce43cf412d | ||
|
|
d1712f0594 | ||
|
|
416b73b8c0 | ||
|
|
4654aa0cab | ||
|
|
6f9d8f465a | ||
|
|
e5e55345dc | ||
|
|
8d6eb6d41a | ||
|
|
1118e67cec | ||
|
|
d70cd04b15 | ||
|
|
3d1db23224 | ||
|
|
a488810693 | ||
|
|
0b066d3cb4 | ||
|
|
d154bee18a | ||
|
|
3a8694b642 | ||
|
|
fe485b3fa1 | ||
|
|
e70eaa6a31 | ||
|
|
27ef67306d | ||
|
|
547aca3db2 | ||
|
|
5f7360e2ce | ||
|
|
23f9675218 | ||
|
|
ef1e82076c | ||
|
|
65d4588cc7 | ||
|
|
0488f90c8f | ||
|
|
03d91f6618 | ||
|
|
ae5e4b67dc | ||
|
|
a6c6e33d88 | ||
|
|
79d9bf7109 | ||
|
|
66e1b382cd | ||
|
|
66f1ff43e9 | ||
|
|
d6d14859e3 | ||
|
|
4478bb9bbe | ||
|
|
a6aaf9da2a | ||
|
|
aa908ae0c2 | ||
|
|
778a2d8f84 | ||
|
|
508baabf9a | ||
|
|
80aa4d8e19 | ||
|
|
99e11112a7 | ||
|
|
1116e6dbc7 | ||
|
|
d1ac96c1ab | ||
|
|
abe88c899e | ||
|
|
b1709fcbdb | ||
|
|
ec877bf490 | ||
|
|
a8f1812acf | ||
|
|
6877b460c4 | ||
|
|
f189f9f1be | ||
|
|
6f79fd6d77 | ||
|
|
60d7bb52d6 | ||
|
|
65a2a0643a |
BIN
.github/workflows/logo.gif
vendored
Normal file
BIN
.github/workflows/logo.gif
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 146 KiB |
4
.github/workflows/publish.yaml
vendored
4
.github/workflows/publish.yaml
vendored
@@ -20,9 +20,9 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
- name: Install wheel
|
- name: Install wheel
|
||||||
run: pip install wheel && pip install -r requirements.txt
|
run: pip install wheel==0.44.0 && pip install -r requirements.txt
|
||||||
- name: Build DiffSynth
|
- name: Build DiffSynth
|
||||||
run: python setup.py sdist bdist_wheel
|
run: python -m build
|
||||||
- name: Publish package to PyPI
|
- name: Publish package to PyPI
|
||||||
run: |
|
run: |
|
||||||
pip install twine
|
pip install twine
|
||||||
|
|||||||
175
.gitignore
vendored
Normal file
175
.gitignore
vendored
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
/data
|
||||||
|
/models
|
||||||
|
/scripts
|
||||||
|
/diffusers
|
||||||
|
*.pkl
|
||||||
|
*.safetensors
|
||||||
|
*.pth
|
||||||
|
*.ckpt
|
||||||
|
*.pt
|
||||||
|
*.bin
|
||||||
|
*.DS_Store
|
||||||
|
*.msc
|
||||||
|
*.mv
|
||||||
|
log*.txt
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
# Set web page format
|
|
||||||
import streamlit as st
|
|
||||||
st.set_page_config(layout="wide")
|
|
||||||
# Diasble virtual VRAM on windows system
|
|
||||||
import torch
|
|
||||||
torch.cuda.set_per_process_memory_fraction(0.999, 0)
|
|
||||||
|
|
||||||
|
|
||||||
st.markdown("""
|
|
||||||
# DiffSynth Studio
|
|
||||||
|
|
||||||
[Source Code](https://github.com/Artiprocher/DiffSynth-Studio)
|
|
||||||
|
|
||||||
Welcome to DiffSynth Studio.
|
|
||||||
""")
|
|
||||||
953
README_zh.md
Normal file
953
README_zh.md
Normal file
@@ -0,0 +1,953 @@
|
|||||||
|
# DiffSynth-Studio
|
||||||
|
|
||||||
|
<a href="https://github.com/modelscope/DiffSynth-Studio"><img src=".github/workflows/logo.gif" title="Logo" style="max-width:100%;" width="55" /></a> <a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></p>
|
||||||
|
|
||||||
|
[](https://pypi.org/project/DiffSynth/)
|
||||||
|
[](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)
|
||||||
|
[](https://github.com/modelscope/DiffSynth-Studio/issues)
|
||||||
|
[](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
|
||||||
|
[](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
|
||||||
|
|
||||||
|
[Switch to English](./README.md)
|
||||||
|
|
||||||
|
## 简介
|
||||||
|
|
||||||
|
> DiffSynth-Studio 文档:[中文版](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/)、[English version](https://diffsynth-studio-doc.readthedocs.io/en/latest/)
|
||||||
|
|
||||||
|
欢迎来到 Diffusion 模型的魔法世界!DiffSynth-Studio 是由[魔搭社区](https://www.modelscope.cn/)团队开发和维护的开源 Diffusion 模型引擎。我们期望以框架建设孵化技术创新,凝聚开源社区的力量,探索生成式模型技术的边界!
|
||||||
|
|
||||||
|
DiffSynth 目前包括两个开源项目:
|
||||||
|
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): 聚焦于激进的技术探索,面向学术界,提供更前沿的模型能力支持。
|
||||||
|
* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): 聚焦于稳定的模型部署,面向工业界,提供更高的计算性能与更稳定的功能。
|
||||||
|
|
||||||
|
[DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) 与 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 是魔搭社区 AIGC 专区的核心引擎,欢迎体验我们精心打造的产品化功能:
|
||||||
|
|
||||||
|
* 魔搭社区 AIGC 专区 (面向中国用户): https://modelscope.cn/aigc/home
|
||||||
|
* ModelScope Civision (for global users): https://modelscope.ai/civision/home
|
||||||
|
|
||||||
|
我们相信,一个完善的开源代码框架能够降低技术探索的门槛,我们基于这个代码库搞出了不少[有意思的技术](#创新成果)。或许你也有许多天马行空的构想,借助 DiffSynth-Studio,你可以快速实现这些想法。为此,我们为开发者准备了详细的文档,我们希望通过这些文档,帮助开发者理解 Diffusion 模型的原理,更期待与你一同拓展技术的边界。
|
||||||
|
|
||||||
|
## 更新历史
|
||||||
|
|
||||||
|
> DiffSynth-Studio 经历了大版本更新,部分旧功能已停止维护,如需使用旧版功能,请切换到大版本更新前的[最后一个历史版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3)。
|
||||||
|
|
||||||
|
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
|
||||||
|
- **2026年2月26日** 新增对[LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2)音视频生成模型全量微调与LoRA训练支持,详见[文档](docs/zh/Model_Details/LTX-2.md)。
|
||||||
|
|
||||||
|
- **2026年2月10日** 新增对[LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2)音视频生成模型的推理支持,详见[文档](docs/zh/Model_Details/LTX-2.md),后续将推进模型训练的支持。
|
||||||
|
|
||||||
|
- **2026年2月2日** Research Tutorial 的第一篇文档上线,带你从零开始训练一个 0.1B 的小型文生图模型,详见[文档](/docs/zh/Research_Tutorial/train_from_scratch.md)、[模型](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel),我们希望 DiffSynth-Studio 能够成为一个更强大的 Diffusion 模型训练框架。
|
||||||
|
|
||||||
|
- **2026年1月27日** [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) 发布,我们的 [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) 模型同步发布,在[魔搭创空间](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L)可直接体验,详见[文档](/docs/zh/Model_Details/Z-Image.md)。
|
||||||
|
|
||||||
|
- **2026年1月19日** 新增对 [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 和 [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/FLUX2.md)和[示例代码](/examples/flux2/)现已可用。
|
||||||
|
|
||||||
|
- **2026年1月12日** 我们训练并开源了一个文本引导的图层拆分模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)),这一模型输入一张图与一段文本描述,模型会将图像中与文本描述相关的图层拆分出来。更多细节请阅读我们的 blog([中文版](https://modelscope.cn/learn/4938)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-layered-control))。
|
||||||
|
|
||||||
|
- **2025年12月24日** 我们基于 Qwen-Image-Edit-2511 训练了一个 In-Context Editing LoRA 模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)),这个模型可以输入三张图:图A、图B、图C,模型会自行分析图A到图B的变化,并将这样的变化应用到图C,生成图D。更多细节请阅读我们的 blog([中文版](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora))。
|
||||||
|
|
||||||
|
- **2025年12月9日** 我们基于 DiffSynth-Studio 2.0 训练了一个疯狂的模型:[Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)(Image to LoRA)。这一模型以图像为输入,以 LoRA 为输出。尽管这个版本的模型在泛化能力、细节保持能力等方面还有很大改进空间,我们将这些模型开源,以启发更多创新性的研究工作。更多细节,请参考我们的 [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l)。
|
||||||
|
|
||||||
|
- **2025年12月4日** DiffSynth-Studio 2.0 发布!众多新功能上线
|
||||||
|
- [文档](/docs/zh/README.md)上线:我们的文档还在持续优化更新中
|
||||||
|
- [显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)模块升级,支持 Layer 级别的 Disk Offload,同时释放内存与显存
|
||||||
|
- 新模型支持
|
||||||
|
- Z-Image Turbo: [模型](https://www.modelscope.ai/models/Tongyi-MAI/Z-Image-Turbo)、[文档](/docs/zh/Model_Details/Z-Image.md)、[代码](/examples/z_image/)
|
||||||
|
- FLUX.2-dev: [模型](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)、[文档](/docs/zh/Model_Details/FLUX2.md)、[代码](/examples/flux2/)
|
||||||
|
- 训练框架升级
|
||||||
|
- [拆分训练](/docs/zh/Training/Split_Training.md):支持自动化地将训练过程拆分为数据处理和训练两阶段(即使训练的是 ControlNet 或其他任意模型),在数据处理阶段进行文本编码、VAE 编码等不需要梯度回传的计算,在训练阶段处理其他计算。速度更快,显存需求更少。
|
||||||
|
- [差分 LoRA 训练](/docs/zh/Training/Differential_LoRA.md):这是我们曾在 [ArtAug](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) 中使用的训练技术,目前已可用于任意模型的 LoRA 训练。
|
||||||
|
- [FP8 训练](/docs/zh/Training/FP8_Precision.md):FP8 在训练中支持应用到任意非训练模型,即梯度关闭或者梯度仅影响 LoRA 权重的模型。
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>更多</summary>
|
||||||
|
|
||||||
|
- **2025年11月4日** 支持了 [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) 模型,该模型基于 Wan 2.1 训练,支持根据参考视频生成相应的动作。
|
||||||
|
|
||||||
|
- **2025年10月30日** 支持了 [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) 模型,该模型支持文生视频、图生视频、视频续写。这个模型在本项目中沿用 Wan 的框架进行推理和训练。
|
||||||
|
|
||||||
|
- **2025年10月27日** 支持了 [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) 模型,Wan 模型生态再添一员。
|
||||||
|
|
||||||
|
- **2025年9月23日** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) 发布!本模型由我们与淘天体验设计团队联合研发并开源。模型基于 Qwen-Image 构建,专为电商海报场景设计,支持精确的分区布局控制。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)。
|
||||||
|
|
||||||
|
- **2025年9月9日** 我们的训练框架支持了多种训练模式,目前已适配 Qwen-Image,除标准 SFT 训练模式外,已支持 Direct Distill,请参考[我们的示例代码](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)。这项功能是实验性的,我们将会继续完善已支持更全面的模型训练功能。
|
||||||
|
|
||||||
|
- **2025年8月28日** 我们支持了Wan2.2-S2V,一个音频驱动的电影级视频生成模型。请参见[./examples/wanvideo/](./examples/wanvideo/)。
|
||||||
|
|
||||||
|
- **2025年8月21日** [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) 发布!相比于 V1 版本,训练数据集变为 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset),因此,生成的图像更符合 Qwen-Image 本身的图像分布和风格。 请参考[我们的示例代码](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)。
|
||||||
|
|
||||||
|
- **2025年8月21日** 我们开源了 [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) 结构控制 LoRA 模型,采用 In Context 的技术路线,支持多种类别的结构控制条件,包括 canny, depth, lineart, softedge, normal, openpose。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)。
|
||||||
|
|
||||||
|
- **2025年8月20日** 我们开源了 [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) 模型,提升了 Qwen-Image-Edit 对低分辨率图像输入的编辑效果。请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)
|
||||||
|
|
||||||
|
- **2025年8月19日** 🔥 Qwen-Image-Edit 开源,欢迎图像编辑模型新成员!
|
||||||
|
|
||||||
|
- **2025年8月18日** 我们训练并开源了 Qwen-Image 的图像重绘 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)。
|
||||||
|
|
||||||
|
- **2025年8月15日** 我们开源了 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset) 数据集。这是一个使用 Qwen-Image 模型生成的图像数据集,共包含 160,000 张`1024 x 1024`图像。它包括通用、英文文本渲染和中文文本渲染子集。我们为每张图像提供了图像描述、实体和结构控制图像的标注。开发者可以使用这个数据集来训练 Qwen-Image 模型的 ControlNet 和 EliGen 等模型,我们旨在通过开源推动技术发展!
|
||||||
|
|
||||||
|
- **2025年8月13日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)。
|
||||||
|
|
||||||
|
- **2025年8月12日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)。
|
||||||
|
|
||||||
|
- **2025年8月11日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA),沿用了与 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) 相同的训练流程,但模型结构修改为了 LoRA,因此能够更好地与其他开源生态模型兼容。
|
||||||
|
|
||||||
|
- **2025年8月7日** 我们开源了 Qwen-Image 的实体控制 LoRA 模型 [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)。Qwen-Image-EliGen 能够实现实体级可控的文生图。技术细节请参见[论文](https://arxiv.org/abs/2501.01097)。训练数据集:[EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)。
|
||||||
|
|
||||||
|
- **2025年8月5日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full),实现了约 5 倍加速。
|
||||||
|
|
||||||
|
- **2025年8月4日** 🔥 Qwen-Image 开源,欢迎图像生成模型家族新成员!
|
||||||
|
|
||||||
|
- **2025年8月1日** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) 开源,这是一个专注于美学摄影的文生图模型。我们第一时间提供了全方位支持,包括低显存逐层 offload、LoRA 训练、全量训练。详细信息请参考 [./examples/flux/](./examples/flux/)。
|
||||||
|
|
||||||
|
- **2025年7月28日** Wan 2.2 开源,我们第一时间提供了全方位支持,包括低显存逐层 offload、FP8 量化、序列并行、LoRA 训练、全量训练。详细信息请参考 [./examples/wanvideo/](./examples/wanvideo/)。
|
||||||
|
|
||||||
|
- **2025年7月11日** 我们提出 Nexus-Gen,一个将大语言模型(LLM)的语言推理能力与扩散模型的图像生成能力相结合的统一框架。该框架支持无缝的图像理解、生成和编辑任务。
|
||||||
|
- 论文: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
|
||||||
|
- Github 仓库: https://github.com/modelscope/Nexus-Gen
|
||||||
|
- 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
|
||||||
|
- 训练数据集: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
|
||||||
|
- 在线体验: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
|
||||||
|
|
||||||
|
- **2025年6月15日** ModelScope 官方评测框架 [EvalScope](https://github.com/modelscope/evalscope) 现已支持文生图生成评测。请参考[最佳实践](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html)指南进行尝试。
|
||||||
|
|
||||||
|
- **2025年3月25日** 我们的新开源项目 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 现已开源!专注于稳定的模型部署,面向工业界,提供更好的工程支持、更高的计算性能和更稳定的功能。
|
||||||
|
|
||||||
|
- **2025年3月31日** 我们支持 InfiniteYou,一种用于 FLUX 的人脸特征保留方法。更多细节请参考 [./examples/InfiniteYou/](./examples/InfiniteYou/)。
|
||||||
|
|
||||||
|
- **2025年3月13日** 我们支持 HunyuanVideo-I2V,即腾讯开源的 HunyuanVideo 的图像到视频生成版本。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。
|
||||||
|
|
||||||
|
- **2025年2月25日** 我们支持 Wan-Video,这是阿里巴巴开源的一系列最先进的视频合成模型。详见 [./examples/wanvideo/](./examples/wanvideo/)。
|
||||||
|
|
||||||
|
- **2025年2月17日** 我们支持 [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)!先进的视频合成模型!详见 [./examples/stepvideo](./examples/stepvideo/)。
|
||||||
|
|
||||||
|
- **2024年12月31日** 我们提出 EliGen,一种用于精确实体级别控制的文本到图像生成的新框架,并辅以修复融合管道,将其能力扩展到图像修复任务。EliGen 可以无缝集成现有的社区模型,如 IP-Adapter 和 In-Context LoRA,提升其通用性。更多详情,请见 [./examples/EntityControl](./examples/EntityControl/)。
|
||||||
|
- 论文: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||||
|
- 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
||||||
|
- 在线体验: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||||
|
- 训练数据集: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||||
|
|
||||||
|
- **2024年12月19日** 我们为 HunyuanVideo 实现了高级显存管理,使得在 24GB 显存下可以生成分辨率为 129x720x1280 的视频,或在仅 6GB 显存下生成分辨率为 129x512x384 的视频。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。
|
||||||
|
|
||||||
|
- **2024年12月18日** 我们提出 ArtAug,一种通过合成-理解交互来改进文生图模型的方法。我们以 LoRA 格式为 FLUX.1-dev 训练了一个 ArtAug 增强模块。该模型将 Qwen2-VL-72B 的美学理解融入 FLUX.1-dev,从而提升了生成图像的质量。
|
||||||
|
- 论文: https://arxiv.org/abs/2412.12888
|
||||||
|
- 示例: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug
|
||||||
|
- 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
|
||||||
|
- 演示: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (即将上线)
|
||||||
|
|
||||||
|
- **2024年10月25日** 我们提供了广泛的 FLUX ControlNet 支持。该项目支持许多不同的 ControlNet 模型,并且可以自由组合,即使它们的结构不同。此外,ControlNet 模型兼容高分辨率优化和分区控制技术,能够实现非常强大的可控图像生成。详见 [`./examples/ControlNet/`](./examples/ControlNet/)。
|
||||||
|
|
||||||
|
- **2024年10月8日** 我们发布了基于 CogVideoX-5B 和 ExVideo 的扩展 LoRA。您可以从 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 或 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 下载此模型。
|
||||||
|
|
||||||
|
- **2024年8月22日** 本项目现已支持 CogVideoX-5B。详见 [此处](/examples/video_synthesis/)。我们为这个文生视频模型提供了几个有趣的功能,包括:
|
||||||
|
- 文本到视频
|
||||||
|
- 视频编辑
|
||||||
|
- 自我超分
|
||||||
|
- 视频插帧
|
||||||
|
|
||||||
|
- **2024年8月22日** 我们实现了一个有趣的画笔功能,支持所有文生图模型。现在,您可以在 AI 的辅助下使用画笔创作惊艳的图像了!
|
||||||
|
- 在我们的 [WebUI](#usage-in-webui) 中使用它。
|
||||||
|
|
||||||
|
- **2024年8月21日** DiffSynth-Studio 现已支持 FLUX。
|
||||||
|
- 启用 CFG 和高分辨率修复以提升视觉质量。详见 [此处](/examples/image_synthesis/README.md)
|
||||||
|
- LoRA、ControlNet 和其他附加模型将很快推出。
|
||||||
|
|
||||||
|
- **2024年6月21日** 我们提出 ExVideo,一种旨在增强视频生成模型能力的后训练微调技术。我们将 Stable Video Diffusion 进行了扩展,实现了长达 128 帧的长视频生成。
|
||||||
|
- [项目页面](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
||||||
|
- 源代码已在此仓库中发布。详见 [`examples/ExVideo`](./examples/ExVideo/)。
|
||||||
|
- 模型已发布于 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) 和 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1)。
|
||||||
|
- 技术报告已发布于 [arXiv](https://arxiv.org/abs/2406.14130)。
|
||||||
|
- 您可以在此 [演示](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1) 中试用 ExVideo!
|
||||||
|
|
||||||
|
- **2024年6月13日** DiffSynth Studio 已迁移至 ModelScope。开发团队也从“我”转变为“我们”。当然,我仍会参与后续的开发和维护工作。
|
||||||
|
|
||||||
|
- **2024年1月29日** 我们提出 Diffutoon,这是一个出色的卡通着色解决方案。
|
||||||
|
- [项目页面](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
|
||||||
|
- 源代码已在此项目中发布。
|
||||||
|
- 技术报告(IJCAI 2024)已发布于 [arXiv](https://arxiv.org/abs/2401.16224)。
|
||||||
|
|
||||||
|
- **2023年12月8日** 我们决定启动一个新项目,旨在释放扩散模型的潜力,尤其是在视频合成方面。该项目的开发工作正式开始。
|
||||||
|
|
||||||
|
- **2023年11月15日** 我们提出 FastBlend,一种强大的视频去闪烁算法。
|
||||||
|
- sd-webui 扩展已发布于 [GitHub](https://github.com/Artiprocher/sd-webui-fastblend)。
|
||||||
|
- 演示视频已在 Bilibili 上展示,包含三个任务:
|
||||||
|
- [视频去闪烁](https://www.bilibili.com/video/BV1d94y1W7PE)
|
||||||
|
- [视频插帧](https://www.bilibili.com/video/BV1Lw411m71p)
|
||||||
|
- [图像驱动的视频渲染](https://www.bilibili.com/video/BV1RB4y1Z7LF)
|
||||||
|
- 技术报告已发布于 [arXiv](https://arxiv.org/abs/2311.09265)。
|
||||||
|
- 其他用户开发的非官方 ComfyUI 扩展已发布于 [GitHub](https://github.com/AInseven/ComfyUI-fastblend)。
|
||||||
|
|
||||||
|
- **2023年10月1日** 我们发布了该项目的早期版本,名为 FastSDXL。这是构建一个扩散引擎的初步尝试。
|
||||||
|
- 源代码已发布于 [GitHub](https://github.com/Artiprocher/FastSDXL)。
|
||||||
|
- FastSDXL 包含一个可训练的 OLSS 调度器,以提高效率。
|
||||||
|
- OLSS 的原始仓库位于 [此处](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler)。
|
||||||
|
- 技术报告(CIKM 2023)已发布于 [arXiv](https://arxiv.org/abs/2305.14677)。
|
||||||
|
- 演示视频已发布于 [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj)。
|
||||||
|
- 由于 OLSS 需要额外训练,我们未在本项目中实现它。
|
||||||
|
|
||||||
|
- **2023年8月29日** 我们提出 DiffSynth,一个视频合成框架。
|
||||||
|
- [项目页面](https://ecnu-cilab.github.io/DiffSynth.github.io/)。
|
||||||
|
- 源代码已发布在 [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth)。
|
||||||
|
- 技术报告(ECML PKDD 2024)已发布于 [arXiv](https://arxiv.org/abs/2308.03463)。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 安装
|
||||||
|
|
||||||
|
从源码安装(推荐):
|
||||||
|
|
||||||
|
```
|
||||||
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
更多安装方式,以及非 NVIDIA GPU 的安装,请参考[安装文档](/docs/zh/Pipeline_Usage/Setup.md)。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 基础框架
|
||||||
|
|
||||||
|
DiffSynth-Studio 为主流 Diffusion 模型(包括 FLUX、Wan 等)重新设计了推理和训练流水线,能够实现高效的显存管理、灵活的模型训练。
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>环境变量配置</summary>
|
||||||
|
|
||||||
|
> 在进行模型推理和训练前,可通过[环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md)配置模型下载源等。
|
||||||
|
>
|
||||||
|
> 本项目默认从魔搭社区下载模型。对于非中国区域的用户,可以通过以下配置从魔搭社区的国际站下载模型:
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> import os
|
||||||
|
> os.environ["MODELSCOPE_DOMAIN"] = "www.modelscope.ai"
|
||||||
|
> ```
|
||||||
|
>
|
||||||
|
> 如需从其他站点下载,请修改[环境变量 DIFFSYNTH_DOWNLOAD_SOURCE](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_download_source)。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 图像生成模型
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
#### Z-Image:[/docs/zh/Model_Details/Z-Image.md](/docs/zh/Model_Details/Z-Image.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) 模型并进行推理。FP8 精度量化会导致明显的图像质量劣化,因此不建议在 Z-Image Turbo 模型上开启任何量化,仅建议开启 CPU Offload,最低 8G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = ZImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
|
||||||
|
image = pipe(prompt=prompt, seed=42, rand_device="cuda")
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
Z-Image 的示例代码位于:[/examples/z_image/](/examples/z_image/)
|
||||||
|
|
||||||
|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|
||||||
|
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|
||||||
|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|
||||||
|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|
||||||
|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|
||||||
|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### FLUX.2: [/docs/zh/Model_Details/FLUX2.md](/docs/zh/Model_Details/FLUX2.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 10G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": "disk",
|
||||||
|
"offload_device": "disk",
|
||||||
|
"onload_dtype": torch.float8_e4m3fn,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float8_e4m3fn,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = Flux2ImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
prompt = "High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene."
|
||||||
|
image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
FLUX.2 的示例代码位于:[/examples/flux2/](/examples/flux2/)
|
||||||
|
|
||||||
|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|
||||||
|
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|
||||||
|
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|
||||||
|
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|
||||||
|
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### Qwen-Image: [/docs/zh/Model_Details/Qwen-Image.md](/docs/zh/Model_Details/Qwen-Image.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": "disk",
|
||||||
|
"offload_device": "disk",
|
||||||
|
"onload_dtype": torch.float8_e4m3fn,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float8_e4m3fn,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = QwenImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||||
|
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>模型血缘</summary>
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
graph LR;
|
||||||
|
Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit;
|
||||||
|
Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509;
|
||||||
|
Qwen/Qwen-Image-->EliGen-Series;
|
||||||
|
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen;
|
||||||
|
DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2;
|
||||||
|
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster;
|
||||||
|
Qwen/Qwen-Image-->Distill-Series;
|
||||||
|
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full;
|
||||||
|
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA;
|
||||||
|
Qwen/Qwen-Image-->ControlNet-Series;
|
||||||
|
ControlNet-Series-->Blockwise-ControlNet-Series;
|
||||||
|
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny;
|
||||||
|
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth;
|
||||||
|
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint;
|
||||||
|
ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union;
|
||||||
|
Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix;
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/)
|
||||||
|
|
||||||
|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
||||||
|
|[Qwen/Qwen-Image-2512](https://www.modelscope.cn/models/Qwen/Qwen-Image-2512)|[code](/examples/qwen_image/model_inference/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-2512.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-2512.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-2512.py)|
|
||||||
|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||||
|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
||||||
|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
||||||
|
|[FireRedTeam/FireRed-Image-Edit-1.0](https://www.modelscope.cn/models/FireRedTeam/FireRed-Image-Edit-1.0)|[code](/examples/qwen_image/model_inference/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_inference_low_vram/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_training/full/FireRed-Image-Edit-1.0.sh)|[code](/examples/qwen_image/model_training/validate_full/FireRed-Image-Edit-1.0.py)|[code](/examples/qwen_image/model_training/lora/FireRed-Image-Edit-1.0.sh)|[code](/examples/qwen_image/model_training/validate_lora/FireRed-Image-Edit-1.0.py)|
|
||||||
|
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|
||||||
|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|
||||||
|
|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### FLUX.1: [/docs/zh/Model_Details/FLUX.md](/docs/zh/Model_Details/FLUX.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float8_e4m3fn,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float8_e4m3fn,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float8_e4m3fn,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = FluxImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 1,
|
||||||
|
)
|
||||||
|
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||||
|
image = pipe(prompt=prompt, seed=0)
|
||||||
|
image.save("image.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>模型血缘</summary>
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
graph LR;
|
||||||
|
FLUX.1-Series-->black-forest-labs/FLUX.1-dev;
|
||||||
|
FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev;
|
||||||
|
FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev;
|
||||||
|
black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series;
|
||||||
|
FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta;
|
||||||
|
FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha;
|
||||||
|
FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler;
|
||||||
|
black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter;
|
||||||
|
black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou;
|
||||||
|
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen;
|
||||||
|
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev;
|
||||||
|
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev;
|
||||||
|
black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview;
|
||||||
|
black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit;
|
||||||
|
Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit;
|
||||||
|
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2;
|
||||||
|
Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2;
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
FLUX.1 的示例代码位于:[/examples/flux/](/examples/flux/)
|
||||||
|
|
||||||
|
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|-|
|
||||||
|
|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py)|
|
||||||
|
|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|
|
||||||
|
|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|
||||||
|
|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
|
||||||
|
|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
|
||||||
|
|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
|
||||||
|
|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
|
||||||
|
|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
|
||||||
|
|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|
|
||||||
|
|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
|
||||||
|
|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
|
||||||
|
|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](/examples/flux/model_inference/Step1X-Edit.py)|[code](/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](/examples/flux/model_training/full/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_lora/Step1X-Edit.py)|
|
||||||
|
|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](/examples/flux/model_inference/FLEX.2-preview.py)|[code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py)|
|
||||||
|
|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](/examples/flux/model_training/full/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_lora/Nexus-Gen.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 视频生成模型
|
||||||
|
|
||||||
|
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
||||||
|
|
||||||
|
#### LTX-2: [/docs/zh/Model_Details/LTX-2.md](/docs/zh/Model_Details/LTX-2.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8GB 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float8_e5m2,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float8_e5m2,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float8_e5m2,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
|
||||||
|
Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
|
||||||
|
For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
|
||||||
|
and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
|
||||||
|
We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
|
||||||
|
and avoid redundant memory usage when users only want to use part of the model.
|
||||||
|
"""
|
||||||
|
# use the repackaged modelconfig from "DiffSynth-Studio/LTX-2-Repackage" to avoid redundant model loading
|
||||||
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# use the following modelconfig if you want to initialize model from offical checkpoints from "Lightricks/LTX-2"
|
||||||
|
# pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
# torch_dtype=torch.bfloat16,
|
||||||
|
# device="cuda",
|
||||||
|
# model_configs=[
|
||||||
|
# ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||||
|
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
|
# ],
|
||||||
|
# tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
# stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||||
|
# vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
# )
|
||||||
|
|
||||||
|
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
|
||||||
|
negative_prompt = (
|
||||||
|
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||||
|
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||||
|
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||||
|
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||||
|
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||||
|
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||||
|
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||||
|
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||||
|
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||||
|
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||||
|
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
)
|
||||||
|
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||||
|
video, audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
seed=43,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames=num_frames,
|
||||||
|
tiled=True,
|
||||||
|
use_two_stage_pipeline=True,
|
||||||
|
)
|
||||||
|
write_video_audio_ltx2(
|
||||||
|
video=video,
|
||||||
|
audio=audio,
|
||||||
|
output_path='ltx2_twostage.mp4',
|
||||||
|
fps=24,
|
||||||
|
audio_sample_rate=24000,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
LTX-2 的示例代码位于:[/examples/ltx2/](/examples/ltx2/)
|
||||||
|
|
||||||
|
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py)|
|
||||||
|
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-In.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-In.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Out.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Left.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Dolly-Right.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Up.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Up.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Jib-Down.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Jib-Down.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2-19b-LoRA-Camera-Control-Static](https://www.modelscope.cn/models/Lightricks/LTX-2-19b-LoRA-Camera-Control-Static)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-Camera-Control-Static.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-Camera-Control-Static.py)|-|-|-|-|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### Wan: [/docs/zh/Model_Details/Wan.md](/docs/zh/Model_Details/Wan.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.utils.data import save_video, VideoData
|
||||||
|
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": "disk",
|
||||||
|
"offload_device": "disk",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = WanVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
video = pipe(
|
||||||
|
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||||
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
seed=0, tiled=True,
|
||||||
|
)
|
||||||
|
save_video(video, "video.mp4", fps=15, quality=5)
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>模型血缘</summary>
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
graph LR;
|
||||||
|
Wan-Series-->Wan2.1-Series;
|
||||||
|
Wan-Series-->Wan2.2-Series;
|
||||||
|
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B;
|
||||||
|
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P;
|
||||||
|
Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P;
|
||||||
|
Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview;
|
||||||
|
iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B;
|
||||||
|
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series;
|
||||||
|
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP;
|
||||||
|
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series;
|
||||||
|
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP;
|
||||||
|
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control;
|
||||||
|
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series;
|
||||||
|
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control;
|
||||||
|
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP;
|
||||||
|
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series;
|
||||||
|
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control;
|
||||||
|
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP;
|
||||||
|
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera;
|
||||||
|
Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->meituan-longcat/LongCat-Video;
|
||||||
|
Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B;
|
||||||
|
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B;
|
||||||
|
Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B;
|
||||||
|
Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B;
|
||||||
|
Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B;
|
||||||
|
Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series;
|
||||||
|
Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B;
|
||||||
|
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP;
|
||||||
|
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control;
|
||||||
|
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera;
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/)
|
||||||
|
|
||||||
|
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|
|
||||||
|
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||||
|
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||||
|
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||||
|
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|
||||||
|
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|
||||||
|
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|
||||||
|
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
||||||
|
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|
||||||
|
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|
||||||
|
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
||||||
|
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
|
||||||
|
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](/examples/wanvideo/model_inference/LongCat-Video.py)|[code](/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
|
||||||
|
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
||||||
|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
|
||||||
|
|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|
||||||
|
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
||||||
|
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
||||||
|
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 创新成果
|
||||||
|
|
||||||
|
DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Spectral Evolution Search: 用于奖励对齐图像生成的高效推理阶段缩放</summary>
|
||||||
|
|
||||||
|
- 论文:[Spectral Evolution Search: Efficient Inference-Time Scaling for Reward-Aligned Image Generation
|
||||||
|
](https://arxiv.org/abs/2602.03208)
|
||||||
|
- 代码样例:coming soon
|
||||||
|
|
||||||
|
|FLUX.1-dev|FLUX.1-dev + SES|Qwen-Image|Qwen-Image + SES|
|
||||||
|
|-|-|-|-|
|
||||||
|
|||||
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>VIRAL:基于DiT模型的类比视觉上下文推理</summary>
|
||||||
|
|
||||||
|
- 论文:[VIRAL: Visual In-Context Reasoning via Analogy in Diffusion Transformers
|
||||||
|
](https://arxiv.org/abs/2602.03210)
|
||||||
|
- 代码样例:[/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-ICEdit.py)
|
||||||
|
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)
|
||||||
|
|
||||||
|
|Example 1|Example 2|Query|Output|
|
||||||
|
|-|-|-|-|
|
||||||
|
|||||
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>AttriCtrl: 图像生成模型的属性强度控制</summary>
|
||||||
|
|
||||||
|
- 论文:[AttriCtrl: Fine-Grained Control of Aesthetic Attribute Intensity in Diffusion Models
|
||||||
|
](https://arxiv.org/abs/2508.02151)
|
||||||
|
- 代码样例:[/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py](/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py)
|
||||||
|
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev)
|
||||||
|
|
||||||
|
|brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9|
|
||||||
|
|-|-|-|-|-|
|
||||||
|
||||||
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>AutoLoRA: 自动化的 LoRA 检索和融合</summary>
|
||||||
|
|
||||||
|
- 论文:[AutoLoRA: Automatic LoRA Retrieval and Fine-Grained Gated Fusion for Text-to-Image Generation
|
||||||
|
](https://arxiv.org/abs/2508.02107)
|
||||||
|
- 代码样例:[/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)
|
||||||
|
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)
|
||||||
|
|
||||||
|
||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)|
|
||||||
|
|-|-|-|-|-|
|
||||||
|
|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |||||
|
||||||
|
|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |||||
|
||||||
|
|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |||||
|
||||||
|
|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |||||
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Nexus-Gen: 统一架构的图像理解、生成、编辑</summary>
|
||||||
|
|
||||||
|
- 详细页面:https://github.com/modelscope/Nexus-Gen
|
||||||
|
- 论文:[Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
|
||||||
|
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
|
||||||
|
- 数据集:[ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
|
||||||
|
- 在线体验:[ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>ArtAug: 图像生成模型的美学提升</summary>
|
||||||
|
|
||||||
|
- 详细页面:[./examples/ArtAug/](./examples/ArtAug/)
|
||||||
|
- 论文:[ArtAug: Enhancing Text-to-Image Generation through Synthesis-Understanding Interaction](https://arxiv.org/abs/2412.12888)
|
||||||
|
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
|
||||||
|
- 在线体验:[ModelScope AIGC Tab](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0)
|
||||||
|
|
||||||
|
|FLUX.1-dev|FLUX.1-dev + ArtAug LoRA|
|
||||||
|
|-|-|
|
||||||
|
|||
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>EliGen: 精准的图像分区控制</summary>
|
||||||
|
|
||||||
|
- 论文:[EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||||
|
- 代码样例:[/examples/flux/model_inference/FLUX.1-dev-EliGen.py](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)
|
||||||
|
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
||||||
|
- 在线体验:[ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||||
|
- 数据集:[EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||||
|
|
||||||
|
|实体控制区域|生成图像|
|
||||||
|
|-|-|
|
||||||
|
|||
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>ExVideo: 视频生成模型的扩展训练</summary>
|
||||||
|
|
||||||
|
- 项目页面:[Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
||||||
|
- 论文:[ExVideo: Extending Video Diffusion Models via Parameter-Efficient Post-Tuning](https://arxiv.org/abs/2406.14130)
|
||||||
|
- 代码样例:请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/ExVideo)查看
|
||||||
|
- 模型:[ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||||
|
|
||||||
|
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Diffutoon: 高分辨率动漫风格视频渲染</summary>
|
||||||
|
|
||||||
|
- 项目页面:[Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
|
||||||
|
- 论文:[Diffutoon: High-Resolution Editable Toon Shading via Diffusion Models](https://arxiv.org/abs/2401.16224)
|
||||||
|
- 代码样例:请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/Diffutoon)查看
|
||||||
|
|
||||||
|
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>DiffSynth: 本项目的初代版本</summary>
|
||||||
|
|
||||||
|
- 项目页面:[Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/)
|
||||||
|
- 论文:[DiffSynth: Latent In-Iteration Deflickering for Realistic Video Synthesis](https://arxiv.org/abs/2308.03463)
|
||||||
|
- 代码样例:请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/diffsynth)查看
|
||||||
|
|
||||||
|
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
||||||
|
|
||||||
|
</details>
|
||||||
@@ -1,6 +1 @@
|
|||||||
from .data import *
|
from .core import *
|
||||||
from .models import *
|
|
||||||
from .prompters import *
|
|
||||||
from .schedulers import *
|
|
||||||
from .pipelines import *
|
|
||||||
from .controlnets import *
|
|
||||||
|
|||||||
@@ -0,0 +1,2 @@
|
|||||||
|
from .model_configs import MODEL_CONFIGS
|
||||||
|
from .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS
|
||||||
|
|||||||
@@ -1,243 +0,0 @@
|
|||||||
from typing_extensions import Literal, TypeAlias
|
|
||||||
|
|
||||||
from ..models.sd_text_encoder import SDTextEncoder
|
|
||||||
from ..models.sd_unet import SDUNet
|
|
||||||
from ..models.sd_vae_encoder import SDVAEEncoder
|
|
||||||
from ..models.sd_vae_decoder import SDVAEDecoder
|
|
||||||
|
|
||||||
from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
|
||||||
from ..models.sdxl_unet import SDXLUNet
|
|
||||||
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
|
|
||||||
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
|
|
||||||
|
|
||||||
from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
|
||||||
from ..models.sd3_dit import SD3DiT
|
|
||||||
from ..models.sd3_vae_decoder import SD3VAEDecoder
|
|
||||||
from ..models.sd3_vae_encoder import SD3VAEEncoder
|
|
||||||
|
|
||||||
from ..models.sd_controlnet import SDControlNet
|
|
||||||
|
|
||||||
from ..models.sd_motion import SDMotionModel
|
|
||||||
from ..models.sdxl_motion import SDXLMotionModel
|
|
||||||
|
|
||||||
from ..models.svd_image_encoder import SVDImageEncoder
|
|
||||||
from ..models.svd_unet import SVDUNet
|
|
||||||
from ..models.svd_vae_decoder import SVDVAEDecoder
|
|
||||||
from ..models.svd_vae_encoder import SVDVAEEncoder
|
|
||||||
|
|
||||||
from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
|
||||||
from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
|
||||||
|
|
||||||
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
|
||||||
from ..models.hunyuan_dit import HunyuanDiT
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
model_loader_configs = [
|
|
||||||
# These configs are provided for detecting model type automatically.
|
|
||||||
# The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
|
|
||||||
(None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
|
|
||||||
(None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
|
|
||||||
(None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
|
|
||||||
(None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
|
|
||||||
(None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
|
|
||||||
(None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
|
|
||||||
(None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
|
|
||||||
(None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
|
|
||||||
(None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
|
|
||||||
(None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
|
|
||||||
(None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
|
|
||||||
(None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
|
|
||||||
(None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
|
|
||||||
(None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
|
|
||||||
(None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
|
||||||
(None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
|
||||||
(None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
|
|
||||||
(None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
|
|
||||||
(None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
|
|
||||||
(None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
|
|
||||||
(None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
|
|
||||||
(None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
|
|
||||||
(None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
|
|
||||||
(None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
|
|
||||||
]
|
|
||||||
huggingface_model_loader_configs = [
|
|
||||||
# These configs are provided for detecting model type automatically.
|
|
||||||
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name)
|
|
||||||
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder"),
|
|
||||||
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator"),
|
|
||||||
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt"),
|
|
||||||
]
|
|
||||||
patch_model_loader_configs = [
|
|
||||||
# These configs are provided for detecting model type automatically.
|
|
||||||
# The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
|
|
||||||
("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
|
|
||||||
]
|
|
||||||
|
|
||||||
preset_models_on_huggingface = {
|
|
||||||
"HunyuanDiT": [
|
|
||||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
|
||||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
|
||||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
|
||||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
|
||||||
],
|
|
||||||
"stable-video-diffusion-img2vid-xt": [
|
|
||||||
("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
|
||||||
],
|
|
||||||
"ExVideo-SVD-128f-v1": [
|
|
||||||
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
|
||||||
],
|
|
||||||
}
|
|
||||||
preset_models_on_modelscope = {
|
|
||||||
# Hunyuan DiT
|
|
||||||
"HunyuanDiT": [
|
|
||||||
("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
|
||||||
("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
|
||||||
("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
|
||||||
("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
|
||||||
],
|
|
||||||
# Stable Video Diffusion
|
|
||||||
"stable-video-diffusion-img2vid-xt": [
|
|
||||||
("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
|
||||||
],
|
|
||||||
# ExVideo
|
|
||||||
"ExVideo-SVD-128f-v1": [
|
|
||||||
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
|
||||||
],
|
|
||||||
# Stable Diffusion
|
|
||||||
"StableDiffusion_v15": [
|
|
||||||
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
|
||||||
],
|
|
||||||
"DreamShaper_8": [
|
|
||||||
("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
|
||||||
],
|
|
||||||
"AingDiffusion_v12": [
|
|
||||||
("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
|
|
||||||
],
|
|
||||||
"Flat2DAnimerge_v45Sharp": [
|
|
||||||
("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
|
|
||||||
],
|
|
||||||
# Textual Inversion
|
|
||||||
"TextualInversion_VeryBadImageNegative_v1.3": [
|
|
||||||
("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
|
||||||
],
|
|
||||||
# Stable Diffusion XL
|
|
||||||
"StableDiffusionXL_v1": [
|
|
||||||
("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
|
||||||
],
|
|
||||||
"BluePencilXL_v200": [
|
|
||||||
("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
|
||||||
],
|
|
||||||
"StableDiffusionXL_Turbo": [
|
|
||||||
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
|
||||||
],
|
|
||||||
# Stable Diffusion 3
|
|
||||||
"StableDiffusion3": [
|
|
||||||
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
|
||||||
],
|
|
||||||
"StableDiffusion3_without_T5": [
|
|
||||||
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
|
||||||
],
|
|
||||||
# ControlNet
|
|
||||||
"ControlNet_v11f1p_sd15_depth": [
|
|
||||||
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
|
||||||
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
|
||||||
],
|
|
||||||
"ControlNet_v11p_sd15_softedge": [
|
|
||||||
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
|
||||||
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
|
|
||||||
],
|
|
||||||
"ControlNet_v11f1e_sd15_tile": [
|
|
||||||
("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
|
||||||
],
|
|
||||||
"ControlNet_v11p_sd15_lineart": [
|
|
||||||
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
|
||||||
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
|
||||||
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
|
|
||||||
],
|
|
||||||
# AnimateDiff
|
|
||||||
"AnimateDiff_v2": [
|
|
||||||
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
|
||||||
],
|
|
||||||
"AnimateDiff_xl_beta": [
|
|
||||||
("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
|
||||||
],
|
|
||||||
# RIFE
|
|
||||||
"RIFE": [
|
|
||||||
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
|
|
||||||
],
|
|
||||||
# Beautiful Prompt
|
|
||||||
"BeautifulPrompt": [
|
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
],
|
|
||||||
# Translator
|
|
||||||
"opus-mt-zh-en": [
|
|
||||||
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
],
|
|
||||||
# IP-Adapter
|
|
||||||
"IP-Adapter-SD": [
|
|
||||||
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
|
||||||
("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
|
||||||
],
|
|
||||||
"IP-Adapter-SDXL": [
|
|
||||||
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
|
||||||
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
|
||||||
],
|
|
||||||
# Kolors
|
|
||||||
"Kolors": [
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
|
||||||
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
|
||||||
],
|
|
||||||
"SDXL-vae-fp16-fix": [
|
|
||||||
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
|
||||||
],
|
|
||||||
}
|
|
||||||
Preset_model_id: TypeAlias = Literal[
|
|
||||||
"HunyuanDiT",
|
|
||||||
"stable-video-diffusion-img2vid-xt",
|
|
||||||
"ExVideo-SVD-128f-v1",
|
|
||||||
"StableDiffusion_v15",
|
|
||||||
"DreamShaper_8",
|
|
||||||
"AingDiffusion_v12",
|
|
||||||
"Flat2DAnimerge_v45Sharp",
|
|
||||||
"TextualInversion_VeryBadImageNegative_v1.3",
|
|
||||||
"StableDiffusionXL_v1",
|
|
||||||
"BluePencilXL_v200",
|
|
||||||
"StableDiffusionXL_Turbo",
|
|
||||||
"ControlNet_v11f1p_sd15_depth",
|
|
||||||
"ControlNet_v11p_sd15_softedge",
|
|
||||||
"ControlNet_v11f1e_sd15_tile",
|
|
||||||
"ControlNet_v11p_sd15_lineart",
|
|
||||||
"AnimateDiff_v2",
|
|
||||||
"AnimateDiff_xl_beta",
|
|
||||||
"RIFE",
|
|
||||||
"BeautifulPrompt",
|
|
||||||
"opus-mt-zh-en",
|
|
||||||
"IP-Adapter-SD",
|
|
||||||
"IP-Adapter-SDXL",
|
|
||||||
"StableDiffusion3",
|
|
||||||
"StableDiffusion3_without_T5",
|
|
||||||
"Kolors",
|
|
||||||
"SDXL-vae-fp16-fix",
|
|
||||||
]
|
|
||||||
722
diffsynth/configs/model_configs.py
Normal file
722
diffsynth/configs/model_configs.py
Normal file
@@ -0,0 +1,722 @@
|
|||||||
|
qwen_image_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors")
|
||||||
|
"model_hash": "0319a1cb19835fb510907dd3367c95ff",
|
||||||
|
"model_name": "qwen_image_dit",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "8004730443f55db63092006dd9f7110e",
|
||||||
|
"model_name": "qwen_image_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.qwen_image_text_encoder.QwenImageTextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "ed4ea5824d55ec3107b09815e318123a",
|
||||||
|
"model_name": "qwen_image_vae",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "073bce9cf969e317e5662cd570c3e79c",
|
||||||
|
"model_name": "qwen_image_blockwise_controlnet",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "a9e54e480a628f0b956a688a81c33bab",
|
||||||
|
"model_name": "qwen_image_blockwise_controlnet",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
|
||||||
|
"extra_kwargs": {"additional_in_dim": 4},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors")
|
||||||
|
"model_hash": "469c78b61e3e31bc9eec0d0af3d3f2f8",
|
||||||
|
"model_name": "siglip2_image_encoder",
|
||||||
|
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors")
|
||||||
|
"model_hash": "5722b5c873720009de96422993b15682",
|
||||||
|
"model_name": "dinov3_image_encoder",
|
||||||
|
"model_class": "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example:
|
||||||
|
"model_hash": "a166c33455cdbd89c0888a3645ca5c0f",
|
||||||
|
"model_name": "qwen_image_image2lora_coarse",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example:
|
||||||
|
"model_hash": "a5476e691767a4da6d3a6634a10f7408",
|
||||||
|
"model_name": "qwen_image_image2lora_fine",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
|
||||||
|
"extra_kwargs": {"residual_length": 37*37+7, "residual_mid_dim": 64}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example:
|
||||||
|
"model_hash": "0aad514690602ecaff932c701cb4b0bb",
|
||||||
|
"model_name": "qwen_image_image2lora_style",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
|
||||||
|
"extra_kwargs": {"compress_dim": 64, "use_residual": False}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "8dc8cda05de16c73afa755e2c1ce2839",
|
||||||
|
"model_name": "qwen_image_dit",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
|
||||||
|
"extra_kwargs": {"use_layer3d_rope": True, "use_additional_t_cond": True}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "44b39ddc499e027cfb24f7878d7416b9",
|
||||||
|
"model_name": "qwen_image_vae",
|
||||||
|
"model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
|
||||||
|
"extra_kwargs": {"image_channels": 4}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
wan_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="krea/krea-realtime-video", origin_file_pattern="krea-realtime-video-14b.safetensors")
|
||||||
|
"model_hash": "5ec04e02b42d2580483ad69f4e76346a",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth")
|
||||||
|
"model_hash": "9c8818c2cbea55eca56c7b447df170da",
|
||||||
|
"model_name": "wan_video_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.wan_video_text_encoder.WanTextEncoder",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth")
|
||||||
|
"model_hash": "ccc42284ea13e1ad04693284c7a09be6",
|
||||||
|
"model_name": "wan_video_vae",
|
||||||
|
"model_class": "diffsynth.models.wan_video_vae.WanVideoVAE",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="meituan-longcat/LongCat-Video", origin_file_pattern="dit/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "8b27900f680d7251ce44e2dc8ae1ffef",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "5f90e66a0672219f12d9a626c8c21f61",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTFromDiffusers"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "5f90e66a0672219f12d9a626c8c21f61",
|
||||||
|
"model_name": "wan_video_vap",
|
||||||
|
"model_class": "diffsynth.models.wan_video_mot.MotWanModel",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_mot.WanVideoMotStateDictConverter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth")
|
||||||
|
"model_hash": "5941c53e207d62f20f9025686193c40b",
|
||||||
|
"model_name": "wan_video_image_encoder",
|
||||||
|
"model_class": "diffsynth.models.wan_video_image_encoder.WanImageEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_image_encoder.WanImageEncoderStateDictConverter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "dbd5ec76bbf977983f972c151d545389",
|
||||||
|
"model_name": "wan_video_motion_controller",
|
||||||
|
"model_class": "diffsynth.models.wan_video_motion_controller.WanMotionControllerModel",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "9269f8db9040a9d860eaca435be61814",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "3ef3b1f8e1dab83d5b71fd7b617f859f",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_image_pos_emb': True}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "349723183fc063b2bfc10bb2835cf677",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "6d6ccde6845b95ad9114ab993d917893",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "efa44cddf936c70abd0ea28b6cbe946c",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "6bfcfb3b342cb286ce886889d519a77e",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "ac6a5aa74f4a0aab6f64eb9a72f19901",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "70ddad9d3a133785da5ea371aae09504",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': True}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "b61c605c2adbd23124d152ed28e049ae",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "26bde73488a92e64cc20b0a7485b9e5b",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "aafcfd9672c3a2456dc46e1cb6e52c70",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "a61453409b67cd3246cf0c3bebad47ba",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "a61453409b67cd3246cf0c3bebad47ba",
|
||||||
|
"model_name": "wan_video_vace",
|
||||||
|
"model_class": "diffsynth.models.wan_video_vace.VaceWanModel",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "7a513e1f257a861512b1afd387a8ecd9",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "7a513e1f257a861512b1afd387a8ecd9",
|
||||||
|
"model_name": "wan_video_vace",
|
||||||
|
"model_class": "diffsynth.models.wan_video_vace.VaceWanModel",
|
||||||
|
"extra_kwargs": {'vace_layers': (0, 5, 10, 15, 20, 25, 30, 35), 'vace_in_dim': 96, 'patch_size': (1, 2, 2), 'has_image_input': False, 'dim': 5120, 'num_heads': 40, 'ffn_dim': 13824, 'eps': 1e-06},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "31fa352acb8a1b1d33cd8764273d80a2",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "31fa352acb8a1b1d33cd8764273d80a2",
|
||||||
|
"model_name": "wan_video_animate_adapter",
|
||||||
|
"model_class": "diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_animate_adapter.WanAnimateAdapterStateDictConverter"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "47dbeab5e560db3180adf51dc0232fb1",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24, 'require_clip_embedding': False}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "2267d489f0ceb9f21836532952852ee5",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 52, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True, 'require_clip_embedding': False},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "5b013604280dd715f8457c6ed6d6a626",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'require_clip_embedding': False}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "966cffdcc52f9c46c391768b27637614",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit_s2v.WanS2VModel",
|
||||||
|
"extra_kwargs": {'dim': 5120, 'in_dim': 16, 'ffn_dim': 13824, 'out_dim': 16, 'text_dim': 4096, 'freq_dim': 256, 'eps': 1e-06, 'patch_size': (1, 2, 2), 'num_heads': 40, 'num_layers': 40, 'cond_dim': 16, 'audio_dim': 1024, 'num_audio_token': 4}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||||
|
"model_hash": "1f5ab7703c6fc803fdded85ff040c316",
|
||||||
|
"model_name": "wan_video_dit",
|
||||||
|
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||||
|
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 3072, 'ffn_dim': 14336, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 48, 'num_heads': 24, 'num_layers': 30, 'eps': 1e-06, 'seperated_timestep': True, 'require_clip_embedding': False, 'require_vae_embedding': False, 'fuse_vae_embedding_in_latents': True}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth")
|
||||||
|
"model_hash": "e1de6c02cdac79f8b739f4d3698cd216",
|
||||||
|
"model_name": "wan_video_vae",
|
||||||
|
"model_class": "diffsynth.models.wan_video_vae.WanVideoVAE38",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors")
|
||||||
|
"model_hash": "06be60f3a4526586d8431cd038a71486",
|
||||||
|
"model_name": "wans2v_audio_encoder",
|
||||||
|
"model_class": "diffsynth.models.wav2vec.WanS2VAudioEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
flux_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors")
|
||||||
|
"model_hash": "a29710fea6dddb0314663ee823598e50",
|
||||||
|
"model_name": "flux_dit",
|
||||||
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Supported due to historical reasons.
|
||||||
|
"model_hash": "605c56eab23e9e2af863ad8f0813a25d",
|
||||||
|
"model_name": "flux_dit",
|
||||||
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverterFromDiffusers",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors")
|
||||||
|
"model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
|
||||||
|
"model_name": "flux_text_encoder_clip",
|
||||||
|
"model_class": "diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_clip.FluxTextEncoderClipStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors")
|
||||||
|
"model_hash": "22540b49eaedbc2f2784b2091a234c7c",
|
||||||
|
"model_name": "flux_text_encoder_t5",
|
||||||
|
"model_class": "diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_t5.FluxTextEncoderT5StateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors")
|
||||||
|
"model_hash": "21ea55f476dfc4fd135587abb59dfe5d",
|
||||||
|
"model_name": "flux_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors")
|
||||||
|
"model_hash": "21ea55f476dfc4fd135587abb59dfe5d",
|
||||||
|
"model_name": "flux_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors")
|
||||||
|
"model_hash": "d02f41c13549fa5093d3521f62a5570a",
|
||||||
|
"model_name": "flux_dit",
|
||||||
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
|
"extra_kwargs": {'input_dim': 196, 'num_blocks': 8},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors")
|
||||||
|
"model_hash": "0629116fce1472503a66992f96f3eb1a",
|
||||||
|
"model_name": "flux_value_controller",
|
||||||
|
"model_class": "diffsynth.models.flux_value_control.SingleValueEncoder",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "52357cb26250681367488a8954c271e8",
|
||||||
|
"model_name": "flux_controlnet",
|
||||||
|
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
||||||
|
"extra_kwargs": {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "78d18b9101345ff695f312e7e62538c0",
|
||||||
|
"model_name": "flux_controlnet",
|
||||||
|
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
||||||
|
"extra_kwargs": {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "b001c89139b5f053c715fe772362dd2a",
|
||||||
|
"model_name": "flux_controlnet",
|
||||||
|
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
||||||
|
"extra_kwargs": {"num_single_blocks": 0},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin")
|
||||||
|
"model_hash": "c07c0f04f5ff55e86b4e937c7a40d481",
|
||||||
|
"model_name": "infiniteyou_image_projector",
|
||||||
|
"model_class": "diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_infiniteyou.FluxInfiniteYouImageProjectorStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors")
|
||||||
|
"model_hash": "7f9583eb8ba86642abb9a21a4b2c9e16",
|
||||||
|
"model_name": "flux_controlnet",
|
||||||
|
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
||||||
|
"extra_kwargs": {"num_joint_blocks": 4, "num_single_blocks": 10},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "77c2e4dd2440269eb33bfaa0d004f6ab",
|
||||||
|
"model_name": "flux_lora_encoder",
|
||||||
|
"model_class": "diffsynth.models.flux_lora_encoder.FluxLoRAEncoder",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "30143afb2dea73d1ac580e0787628f8c",
|
||||||
|
"model_name": "flux_lora_patcher",
|
||||||
|
"model_class": "diffsynth.models.flux_lora_patcher.FluxLoraPatcher",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors")
|
||||||
|
"model_hash": "2bd19e845116e4f875a0a048e27fc219",
|
||||||
|
"model_name": "nexus_gen_llm",
|
||||||
|
"model_class": "diffsynth.models.nexus_gen.NexusGenAutoregressiveModel",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen.NexusGenAutoregressiveModelStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin")
|
||||||
|
"model_hash": "63c969fd37cce769a90aa781fbff5f81",
|
||||||
|
"model_name": "nexus_gen_editing_adapter",
|
||||||
|
"model_class": "diffsynth.models.nexus_gen_projector.NexusGenImageEmbeddingMerger",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenMergerStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin")
|
||||||
|
"model_hash": "63c969fd37cce769a90aa781fbff5f81",
|
||||||
|
"model_name": "flux_dit",
|
||||||
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin")
|
||||||
|
"model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d",
|
||||||
|
"model_name": "nexus_gen_generation_adapter",
|
||||||
|
"model_class": "diffsynth.models.nexus_gen_projector.NexusGenAdapter",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenAdapterStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin")
|
||||||
|
"model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d",
|
||||||
|
"model_name": "flux_dit",
|
||||||
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin")
|
||||||
|
"model_hash": "4daaa66cc656a8fe369908693dad0a35",
|
||||||
|
"model_name": "flux_ipadapter",
|
||||||
|
"model_class": "diffsynth.models.flux_ipadapter.FluxIpAdapter",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.FluxIpAdapterStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "04d8c1e20a1f1b25f7434f111992a33f",
|
||||||
|
"model_name": "siglip_vision_model",
|
||||||
|
"model_class": "diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.SiglipStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
|
||||||
|
"model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50",
|
||||||
|
"model_name": "step1x_connector",
|
||||||
|
"model_class": "diffsynth.models.step1x_connector.Qwen2Connector",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.step1x_connector.Qwen2ConnectorStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
|
||||||
|
"model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50",
|
||||||
|
"model_name": "flux_dit",
|
||||||
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||||
|
"extra_kwargs": {"disable_guidance_embedder": True},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="MAILAND/majicflus_v1", origin_file_pattern="majicflus_v134.safetensors")
|
||||||
|
"model_hash": "3394f306c4cbf04334b712bf5aaed95f",
|
||||||
|
"model_name": "flux_dit",
|
||||||
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
flux2_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors")
|
||||||
|
"model_hash": "28fca3d8e5bf2a2d1271748a773f6757",
|
||||||
|
"model_name": "flux2_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.flux2_text_encoder.Flux2TextEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux2_text_encoder.Flux2TextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors")
|
||||||
|
"model_hash": "d38e1d5c5aec3b0a11e79327ac6e3b0f",
|
||||||
|
"model_name": "flux2_dit",
|
||||||
|
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "c54288e3ee12ca215898840682337b95",
|
||||||
|
"model_name": "flux2_vae",
|
||||||
|
"model_class": "diffsynth.models.flux2_vae.Flux2VAE",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors")
|
||||||
|
"model_hash": "3bde7b817fec8143028b6825a63180df",
|
||||||
|
"model_name": "flux2_dit",
|
||||||
|
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
|
||||||
|
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 7680, "num_attention_heads": 24, "num_layers": 5, "num_single_layers": 20}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors")
|
||||||
|
"model_hash": "9195f3ea256fcd0ae6d929c203470754",
|
||||||
|
"model_name": "z_image_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
||||||
|
"extra_kwargs": {"model_size": "8B"},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors")
|
||||||
|
"model_hash": "39c6fc48f07bebecedbbaa971ff466c8",
|
||||||
|
"model_name": "flux2_dit",
|
||||||
|
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
|
||||||
|
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 12288, "num_attention_heads": 32, "num_layers": 8, "num_single_layers": 24}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
z_image_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors")
|
||||||
|
"model_hash": "fc3a8a1247fe185ce116ccbe0e426c28",
|
||||||
|
"model_name": "z_image_dit",
|
||||||
|
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors")
|
||||||
|
"model_hash": "0f050f62a88876fea6eae0a18dac5a2e",
|
||||||
|
"model_name": "z_image_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
|
||||||
|
"model_name": "flux_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverterDiffusers",
|
||||||
|
"extra_kwargs": {"use_conv_attention": False},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
|
||||||
|
"model_name": "flux_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers",
|
||||||
|
"extra_kwargs": {"use_conv_attention": False},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors")
|
||||||
|
"model_hash": "aa3563718e5c3ecde3dfbb020ca61180",
|
||||||
|
"model_name": "z_image_dit",
|
||||||
|
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
|
||||||
|
"extra_kwargs": {"siglip_feat_dim": 1152},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors")
|
||||||
|
"model_hash": "89d48e420f45cff95115a9f3e698d44a",
|
||||||
|
"model_name": "siglip_vision_model_428m",
|
||||||
|
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors")
|
||||||
|
"model_hash": "1677708d40029ab380a95f6c731a57d7",
|
||||||
|
"model_name": "z_image_controlnet",
|
||||||
|
"model_class": "diffsynth.models.z_image_controlnet.ZImageControlNet",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ???
|
||||||
|
"model_hash": "9510cb8cd1dd34ee0e4f111c24905510",
|
||||||
|
"model_name": "z_image_image2lora_style",
|
||||||
|
"model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
|
||||||
|
"extra_kwargs": {"compress_dim": 128},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "1392adecee344136041e70553f875f31",
|
||||||
|
"model_name": "z_image_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
||||||
|
"extra_kwargs": {"model_size": "0.6B"},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
|
||||||
|
Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
|
||||||
|
For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
|
||||||
|
and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
|
||||||
|
We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
|
||||||
|
and avoid redundant memory usage when users only want to use part of the model.
|
||||||
|
"""
|
||||||
|
ltx2_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_dit",
|
||||||
|
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors")
|
||||||
|
"model_hash": "c567aaa37d5ed7454c73aa6024458661",
|
||||||
|
"model_name": "ltx2_dit",
|
||||||
|
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_video_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors")
|
||||||
|
"model_hash": "7f7e904a53260ec0351b05f32153754b",
|
||||||
|
"model_name": "ltx2_video_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_video_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors")
|
||||||
|
"model_hash": "dc6029ca2825147872b45e35a2dc3a97",
|
||||||
|
"model_name": "ltx2_video_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_audio_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors")
|
||||||
|
"model_hash": "7d7823dde8f1ea0b50fb07ac329dd4cb",
|
||||||
|
"model_name": "ltx2_audio_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_audio_vocoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors")
|
||||||
|
"model_hash": "f471360f6b24bef702ab73133d9f8bb9",
|
||||||
|
"model_name": "ltx2_audio_vocoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_audio_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_encoder.safetensors")
|
||||||
|
"model_hash": "29338f3b95e7e312a3460a482e4f4554",
|
||||||
|
"model_name": "ltx2_audio_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_text_encoder_post_modules",
|
||||||
|
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors")
|
||||||
|
"model_hash": "981629689c8be92a712ab3c5eb4fc3f6",
|
||||||
|
"model_name": "ltx2_text_encoder_post_modules",
|
||||||
|
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors")
|
||||||
|
"model_hash": "33917f31c4a79196171154cca39f165e",
|
||||||
|
"model_name": "ltx2_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "c79c458c6e99e0e14d47e676761732d2",
|
||||||
|
"model_name": "ltx2_latent_upsampler",
|
||||||
|
"model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series
|
||||||
246
diffsynth/configs/vram_management_module_maps.py
Normal file
246
diffsynth/configs/vram_management_module_maps.py
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
flux_general_vram_config = {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.general_modules.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.flux_lora_encoder.LoRALayerBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.flux_lora_patcher.LoraMerger": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
}
|
||||||
|
|
||||||
|
VRAM_MANAGEMENT_MODULE_MAPS = {
|
||||||
|
"diffsynth.models.qwen_image_dit.QwenImageDiT": {
|
||||||
|
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionPatchEmbed": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.qwen_image_vae.QwenImageVAE": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.qwen_image_vae.QwenImageRMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.qwen_image_controlnet.BlockWiseControlBlock": {
|
||||||
|
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
},
|
||||||
|
"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder": {
|
||||||
|
"transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
},
|
||||||
|
"diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder": {
|
||||||
|
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTLayerScale": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTRopePositionEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
},
|
||||||
|
"diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter": {
|
||||||
|
"diffsynth.models.wan_video_animate_adapter.FaceEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_animate_adapter.EqualLinear": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_animate_adapter.ConvLayer": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_animate_adapter.FusedLeakyReLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_animate_adapter.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_dit_s2v.WanS2VModel": {
|
||||||
|
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_dit_s2v.WanS2VDiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_dit_s2v.CausalAudioEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_dit.WanModel": {
|
||||||
|
"diffsynth.models.wan_video_dit.MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
|
||||||
|
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_image_encoder.WanImageEncoder": {
|
||||||
|
"diffsynth.models.wan_video_image_encoder.VisionTransformer": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_mot.MotWanModel": {
|
||||||
|
"diffsynth.models.wan_video_mot.MotWanAttentionBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_motion_controller.WanMotionControllerModel": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_text_encoder.WanTextEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_text_encoder.T5RelativeEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_text_encoder.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_vace.VaceWanModel": {
|
||||||
|
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_vae.WanVideoVAE": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wan_video_vae.WanVideoVAE38": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.wav2vec.WanS2VAudioEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.longcat_video_dit.RMSNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.longcat_video_dit.LayerNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.flux_dit.FluxDiT": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"diffsynth.models.flux_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip": flux_general_vram_config,
|
||||||
|
"diffsynth.models.flux_vae.FluxVAEEncoder": flux_general_vram_config,
|
||||||
|
"diffsynth.models.flux_vae.FluxVAEDecoder": flux_general_vram_config,
|
||||||
|
"diffsynth.models.flux_controlnet.FluxControlNet": flux_general_vram_config,
|
||||||
|
"diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector": flux_general_vram_config,
|
||||||
|
"diffsynth.models.flux_ipadapter.FluxIpAdapter": flux_general_vram_config,
|
||||||
|
"diffsynth.models.flux_lora_patcher.FluxLoraPatcher": flux_general_vram_config,
|
||||||
|
"diffsynth.models.step1x_connector.Qwen2Connector": flux_general_vram_config,
|
||||||
|
"diffsynth.models.flux_lora_encoder.FluxLoRAEncoder": flux_general_vram_config,
|
||||||
|
"diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.t5.modeling_t5.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.t5.modeling_t5.T5DenseActDense": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.t5.modeling_t5.T5DenseGatedActDense": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M": {
|
||||||
|
"transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.siglip.modeling_siglip.SiglipEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.MultiheadAttention": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.flux2_dit.Flux2DiT": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.flux2_text_encoder.Flux2TextEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.mistral.modeling_mistral.MistralRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.flux2_vae.Flux2VAE": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.z_image_text_encoder.ZImageTextEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.z_image_dit.ZImageDiT": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.z_image_controlnet.ZImageControlNet": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
},
|
||||||
|
"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M": {
|
||||||
|
"transformers.models.siglip2.modeling_siglip2.Siglip2VisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.siglip2.modeling_siglip2.Siglip2MultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_dit.LTXModel": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler": {
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_video_vae.LTX2VideoEncoder": {
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_video_vae.LTX2VideoDecoder": {
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder": {
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_audio_vae.LTX2Vocoder": {
|
||||||
|
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.ltx2_text_encoder.Embeddings1DConnector": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_text_encoder.LTX2TextEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"transformers.models.gemma3.modeling_gemma3.Gemma3MultiModalProjector": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
}
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
|
|
||||||
from .processors import Annotator
|
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from .processors import Processor_id
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetConfigUnit:
|
|
||||||
def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
|
|
||||||
self.processor_id = processor_id
|
|
||||||
self.model_path = model_path
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetUnit:
|
|
||||||
def __init__(self, processor, model, scale=1.0):
|
|
||||||
self.processor = processor
|
|
||||||
self.model = model
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
|
|
||||||
class MultiControlNetManager:
|
|
||||||
def __init__(self, controlnet_units=[]):
|
|
||||||
self.processors = [unit.processor for unit in controlnet_units]
|
|
||||||
self.models = [unit.model for unit in controlnet_units]
|
|
||||||
self.scales = [unit.scale for unit in controlnet_units]
|
|
||||||
|
|
||||||
def process_image(self, image, processor_id=None):
|
|
||||||
if processor_id is None:
|
|
||||||
processed_image = [processor(image) for processor in self.processors]
|
|
||||||
else:
|
|
||||||
processed_image = [self.processors[processor_id](image)]
|
|
||||||
processed_image = torch.concat([
|
|
||||||
torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
|
|
||||||
for image_ in processed_image
|
|
||||||
], dim=0)
|
|
||||||
return processed_image
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
sample, timestep, encoder_hidden_states, conditionings,
|
|
||||||
tiled=False, tile_size=64, tile_stride=32
|
|
||||||
):
|
|
||||||
res_stack = None
|
|
||||||
for conditioning, model, scale in zip(conditionings, self.models, self.scales):
|
|
||||||
res_stack_ = model(
|
|
||||||
sample, timestep, encoder_hidden_states, conditioning,
|
|
||||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
|
||||||
)
|
|
||||||
res_stack_ = [res * scale for res in res_stack_]
|
|
||||||
if res_stack is None:
|
|
||||||
res_stack = res_stack_
|
|
||||||
else:
|
|
||||||
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
|
||||||
return res_stack
|
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
from typing_extensions import Literal, TypeAlias
|
|
||||||
import warnings
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
from controlnet_aux.processor import (
|
|
||||||
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
Processor_id: TypeAlias = Literal[
|
|
||||||
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
|
|
||||||
]
|
|
||||||
|
|
||||||
class Annotator:
|
|
||||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'):
|
|
||||||
if processor_id == "canny":
|
|
||||||
self.processor = CannyDetector()
|
|
||||||
elif processor_id == "depth":
|
|
||||||
self.processor = MidasDetector.from_pretrained(model_path).to(device)
|
|
||||||
elif processor_id == "softedge":
|
|
||||||
self.processor = HEDdetector.from_pretrained(model_path).to(device)
|
|
||||||
elif processor_id == "lineart":
|
|
||||||
self.processor = LineartDetector.from_pretrained(model_path).to(device)
|
|
||||||
elif processor_id == "lineart_anime":
|
|
||||||
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
|
|
||||||
elif processor_id == "openpose":
|
|
||||||
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
|
|
||||||
elif processor_id == "tile":
|
|
||||||
self.processor = None
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported processor_id: {processor_id}")
|
|
||||||
|
|
||||||
self.processor_id = processor_id
|
|
||||||
self.detect_resolution = detect_resolution
|
|
||||||
|
|
||||||
def __call__(self, image):
|
|
||||||
width, height = image.size
|
|
||||||
if self.processor_id == "openpose":
|
|
||||||
kwargs = {
|
|
||||||
"include_body": True,
|
|
||||||
"include_hand": True,
|
|
||||||
"include_face": True
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
kwargs = {}
|
|
||||||
if self.processor is not None:
|
|
||||||
detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
|
|
||||||
image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
|
|
||||||
image = image.resize((width, height))
|
|
||||||
return image
|
|
||||||
|
|
||||||
6
diffsynth/core/__init__.py
Normal file
6
diffsynth/core/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from .attention import *
|
||||||
|
from .data import *
|
||||||
|
from .gradient import *
|
||||||
|
from .loader import *
|
||||||
|
from .vram import *
|
||||||
|
from .device import *
|
||||||
1
diffsynth/core/attention/__init__.py
Normal file
1
diffsynth/core/attention/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .attention import attention_forward
|
||||||
121
diffsynth/core/attention/attention.py
Normal file
121
diffsynth/core/attention/attention.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
import torch, os
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import flash_attn_interface
|
||||||
|
FLASH_ATTN_3_AVAILABLE = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
FLASH_ATTN_3_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import flash_attn
|
||||||
|
FLASH_ATTN_2_AVAILABLE = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
FLASH_ATTN_2_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sageattention import sageattn
|
||||||
|
SAGE_ATTN_AVAILABLE = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
SAGE_ATTN_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import xformers.ops as xops
|
||||||
|
XFORMERS_AVAILABLE = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
XFORMERS_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_attention_priority():
|
||||||
|
if os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION') is not None:
|
||||||
|
return os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION').lower()
|
||||||
|
elif FLASH_ATTN_3_AVAILABLE:
|
||||||
|
return "flash_attention_3"
|
||||||
|
elif FLASH_ATTN_2_AVAILABLE:
|
||||||
|
return "flash_attention_2"
|
||||||
|
elif SAGE_ATTN_AVAILABLE:
|
||||||
|
return "sage_attention"
|
||||||
|
elif XFORMERS_AVAILABLE:
|
||||||
|
return "xformers"
|
||||||
|
else:
|
||||||
|
return "torch"
|
||||||
|
|
||||||
|
|
||||||
|
ATTENTION_IMPLEMENTATION = initialize_attention_priority()
|
||||||
|
|
||||||
|
|
||||||
|
def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", required_in_pattern="b n s d", dims=None):
|
||||||
|
dims = {} if dims is None else dims
|
||||||
|
if q_pattern != required_in_pattern:
|
||||||
|
q = rearrange(q, f"{q_pattern} -> {required_in_pattern}", **dims)
|
||||||
|
if k_pattern != required_in_pattern:
|
||||||
|
k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims)
|
||||||
|
if v_pattern != required_in_pattern:
|
||||||
|
v = rearrange(v, f"{v_pattern} -> {required_in_pattern}", **dims)
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
|
||||||
|
def rearrange_out(out: torch.Tensor, out_pattern="b n s d", required_out_pattern="b n s d", dims=None):
|
||||||
|
dims = {} if dims is None else dims
|
||||||
|
if out_pattern != required_out_pattern:
|
||||||
|
out = rearrange(out, f"{required_out_pattern} -> {out_pattern}", **dims)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None):
|
||||||
|
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
|
||||||
|
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||||
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale)
|
||||||
|
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attention_3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
||||||
|
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
||||||
|
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||||
|
out = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=scale)
|
||||||
|
if isinstance(out, tuple):
|
||||||
|
out = out[0]
|
||||||
|
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
||||||
|
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
||||||
|
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||||
|
out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale)
|
||||||
|
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def sage_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
||||||
|
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
|
||||||
|
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||||
|
out = sageattn(q, k, v, sm_scale=scale)
|
||||||
|
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def xformers_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
||||||
|
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
||||||
|
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||||
|
out = xops.memory_efficient_attention(q, k, v, scale=scale)
|
||||||
|
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, compatibility_mode=False):
|
||||||
|
if compatibility_mode or (attn_mask is not None):
|
||||||
|
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale)
|
||||||
|
else:
|
||||||
|
if ATTENTION_IMPLEMENTATION == "flash_attention_3":
|
||||||
|
return flash_attention_3(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||||
|
elif ATTENTION_IMPLEMENTATION == "flash_attention_2":
|
||||||
|
return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||||
|
elif ATTENTION_IMPLEMENTATION == "sage_attention":
|
||||||
|
return sage_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||||
|
elif ATTENTION_IMPLEMENTATION == "xformers":
|
||||||
|
return xformers_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||||
|
else:
|
||||||
|
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||||
1
diffsynth/core/data/__init__.py
Normal file
1
diffsynth/core/data/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .unified_dataset import UnifiedDataset
|
||||||
237
diffsynth/core/data/operators.py
Normal file
237
diffsynth/core/data/operators.py
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
import torch, torchvision, imageio, os
|
||||||
|
import imageio.v3 as iio
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class DataProcessingPipeline:
|
||||||
|
def __init__(self, operators=None):
|
||||||
|
self.operators: list[DataProcessingOperator] = [] if operators is None else operators
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
for operator in self.operators:
|
||||||
|
data = operator(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __rshift__(self, pipe):
|
||||||
|
if isinstance(pipe, DataProcessingOperator):
|
||||||
|
pipe = DataProcessingPipeline([pipe])
|
||||||
|
return DataProcessingPipeline(self.operators + pipe.operators)
|
||||||
|
|
||||||
|
|
||||||
|
class DataProcessingOperator:
|
||||||
|
def __call__(self, data):
|
||||||
|
raise NotImplementedError("DataProcessingOperator cannot be called directly.")
|
||||||
|
|
||||||
|
def __rshift__(self, pipe):
|
||||||
|
if isinstance(pipe, DataProcessingOperator):
|
||||||
|
pipe = DataProcessingPipeline([pipe])
|
||||||
|
return DataProcessingPipeline([self]).__rshift__(pipe)
|
||||||
|
|
||||||
|
|
||||||
|
class DataProcessingOperatorRaw(DataProcessingOperator):
|
||||||
|
def __call__(self, data):
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class ToInt(DataProcessingOperator):
|
||||||
|
def __call__(self, data):
|
||||||
|
return int(data)
|
||||||
|
|
||||||
|
|
||||||
|
class ToFloat(DataProcessingOperator):
|
||||||
|
def __call__(self, data):
|
||||||
|
return float(data)
|
||||||
|
|
||||||
|
|
||||||
|
class ToStr(DataProcessingOperator):
|
||||||
|
def __init__(self, none_value=""):
|
||||||
|
self.none_value = none_value
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
if data is None: data = self.none_value
|
||||||
|
return str(data)
|
||||||
|
|
||||||
|
|
||||||
|
class LoadImage(DataProcessingOperator):
|
||||||
|
def __init__(self, convert_RGB=True, convert_RGBA=False):
|
||||||
|
self.convert_RGB = convert_RGB
|
||||||
|
self.convert_RGBA = convert_RGBA
|
||||||
|
|
||||||
|
def __call__(self, data: str):
|
||||||
|
image = Image.open(data)
|
||||||
|
if self.convert_RGB: image = image.convert("RGB")
|
||||||
|
if self.convert_RGBA: image = image.convert("RGBA")
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class ImageCropAndResize(DataProcessingOperator):
|
||||||
|
def __init__(self, height=None, width=None, max_pixels=None, height_division_factor=1, width_division_factor=1):
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
self.max_pixels = max_pixels
|
||||||
|
self.height_division_factor = height_division_factor
|
||||||
|
self.width_division_factor = width_division_factor
|
||||||
|
|
||||||
|
def crop_and_resize(self, image, target_height, target_width):
|
||||||
|
width, height = image.size
|
||||||
|
scale = max(target_width / width, target_height / height)
|
||||||
|
image = torchvision.transforms.functional.resize(
|
||||||
|
image,
|
||||||
|
(round(height*scale), round(width*scale)),
|
||||||
|
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||||
|
)
|
||||||
|
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
|
||||||
|
return image
|
||||||
|
|
||||||
|
def get_height_width(self, image):
|
||||||
|
if self.height is None or self.width is None:
|
||||||
|
width, height = image.size
|
||||||
|
if width * height > self.max_pixels:
|
||||||
|
scale = (width * height / self.max_pixels) ** 0.5
|
||||||
|
height, width = int(height / scale), int(width / scale)
|
||||||
|
height = height // self.height_division_factor * self.height_division_factor
|
||||||
|
width = width // self.width_division_factor * self.width_division_factor
|
||||||
|
else:
|
||||||
|
height, width = self.height, self.width
|
||||||
|
return height, width
|
||||||
|
|
||||||
|
def __call__(self, data: Image.Image):
|
||||||
|
image = self.crop_and_resize(data, *self.get_height_width(data))
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class ToList(DataProcessingOperator):
|
||||||
|
def __call__(self, data):
|
||||||
|
return [data]
|
||||||
|
|
||||||
|
|
||||||
|
class LoadVideo(DataProcessingOperator):
|
||||||
|
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
|
||||||
|
self.num_frames = num_frames
|
||||||
|
self.time_division_factor = time_division_factor
|
||||||
|
self.time_division_remainder = time_division_remainder
|
||||||
|
# frame_processor is build in the video loader for high efficiency.
|
||||||
|
self.frame_processor = frame_processor
|
||||||
|
|
||||||
|
def get_num_frames(self, reader):
|
||||||
|
num_frames = self.num_frames
|
||||||
|
if int(reader.count_frames()) < num_frames:
|
||||||
|
num_frames = int(reader.count_frames())
|
||||||
|
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
||||||
|
num_frames -= 1
|
||||||
|
return num_frames
|
||||||
|
|
||||||
|
def __call__(self, data: str):
|
||||||
|
reader = imageio.get_reader(data)
|
||||||
|
num_frames = self.get_num_frames(reader)
|
||||||
|
frames = []
|
||||||
|
for frame_id in range(num_frames):
|
||||||
|
frame = reader.get_data(frame_id)
|
||||||
|
frame = Image.fromarray(frame)
|
||||||
|
frame = self.frame_processor(frame)
|
||||||
|
frames.append(frame)
|
||||||
|
reader.close()
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
class SequencialProcess(DataProcessingOperator):
|
||||||
|
def __init__(self, operator=lambda x: x):
|
||||||
|
self.operator = operator
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
return [self.operator(i) for i in data]
|
||||||
|
|
||||||
|
|
||||||
|
class LoadGIF(DataProcessingOperator):
|
||||||
|
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
|
||||||
|
self.num_frames = num_frames
|
||||||
|
self.time_division_factor = time_division_factor
|
||||||
|
self.time_division_remainder = time_division_remainder
|
||||||
|
# frame_processor is build in the video loader for high efficiency.
|
||||||
|
self.frame_processor = frame_processor
|
||||||
|
|
||||||
|
def get_num_frames(self, path):
|
||||||
|
num_frames = self.num_frames
|
||||||
|
images = iio.imread(path, mode="RGB")
|
||||||
|
if len(images) < num_frames:
|
||||||
|
num_frames = len(images)
|
||||||
|
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
||||||
|
num_frames -= 1
|
||||||
|
return num_frames
|
||||||
|
|
||||||
|
def __call__(self, data: str):
|
||||||
|
num_frames = self.get_num_frames(data)
|
||||||
|
frames = []
|
||||||
|
images = iio.imread(data, mode="RGB")
|
||||||
|
for img in images:
|
||||||
|
frame = Image.fromarray(img)
|
||||||
|
frame = self.frame_processor(frame)
|
||||||
|
frames.append(frame)
|
||||||
|
if len(frames) >= num_frames:
|
||||||
|
break
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
class RouteByExtensionName(DataProcessingOperator):
|
||||||
|
def __init__(self, operator_map):
|
||||||
|
self.operator_map = operator_map
|
||||||
|
|
||||||
|
def __call__(self, data: str):
|
||||||
|
file_ext_name = data.split(".")[-1].lower()
|
||||||
|
for ext_names, operator in self.operator_map:
|
||||||
|
if ext_names is None or file_ext_name in ext_names:
|
||||||
|
return operator(data)
|
||||||
|
raise ValueError(f"Unsupported file: {data}")
|
||||||
|
|
||||||
|
|
||||||
|
class RouteByType(DataProcessingOperator):
|
||||||
|
def __init__(self, operator_map):
|
||||||
|
self.operator_map = operator_map
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
for dtype, operator in self.operator_map:
|
||||||
|
if dtype is None or isinstance(data, dtype):
|
||||||
|
return operator(data)
|
||||||
|
raise ValueError(f"Unsupported data: {data}")
|
||||||
|
|
||||||
|
|
||||||
|
class LoadTorchPickle(DataProcessingOperator):
|
||||||
|
def __init__(self, map_location="cpu"):
|
||||||
|
self.map_location = map_location
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
return torch.load(data, map_location=self.map_location, weights_only=False)
|
||||||
|
|
||||||
|
|
||||||
|
class ToAbsolutePath(DataProcessingOperator):
|
||||||
|
def __init__(self, base_path=""):
|
||||||
|
self.base_path = base_path
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
return os.path.join(self.base_path, data)
|
||||||
|
|
||||||
|
|
||||||
|
class LoadAudio(DataProcessingOperator):
|
||||||
|
def __init__(self, sr=16000):
|
||||||
|
self.sr = sr
|
||||||
|
def __call__(self, data: str):
|
||||||
|
import librosa
|
||||||
|
input_audio, sample_rate = librosa.load(data, sr=self.sr)
|
||||||
|
return input_audio
|
||||||
|
|
||||||
|
|
||||||
|
class LoadAudioWithTorchaudio(DataProcessingOperator):
|
||||||
|
def __init__(self, duration=5):
|
||||||
|
self.duration = duration
|
||||||
|
|
||||||
|
def __call__(self, data: str):
|
||||||
|
import torchaudio
|
||||||
|
waveform, sample_rate = torchaudio.load(data)
|
||||||
|
target_samples = int(self.duration * sample_rate)
|
||||||
|
current_samples = waveform.shape[-1]
|
||||||
|
if current_samples > target_samples:
|
||||||
|
waveform = waveform[..., :target_samples]
|
||||||
|
elif current_samples < target_samples:
|
||||||
|
padding = target_samples - current_samples
|
||||||
|
waveform = torch.nn.functional.pad(waveform, (0, padding))
|
||||||
|
return waveform, sample_rate
|
||||||
116
diffsynth/core/data/unified_dataset.py
Normal file
116
diffsynth/core/data/unified_dataset.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
from .operators import *
|
||||||
|
import torch, json, pandas
|
||||||
|
|
||||||
|
|
||||||
|
class UnifiedDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_path=None, metadata_path=None,
|
||||||
|
repeat=1,
|
||||||
|
data_file_keys=tuple(),
|
||||||
|
main_data_operator=lambda x: x,
|
||||||
|
special_operator_map=None,
|
||||||
|
max_data_items=None,
|
||||||
|
):
|
||||||
|
self.base_path = base_path
|
||||||
|
self.metadata_path = metadata_path
|
||||||
|
self.repeat = repeat
|
||||||
|
self.data_file_keys = data_file_keys
|
||||||
|
self.main_data_operator = main_data_operator
|
||||||
|
self.cached_data_operator = LoadTorchPickle()
|
||||||
|
self.special_operator_map = {} if special_operator_map is None else special_operator_map
|
||||||
|
self.max_data_items = max_data_items
|
||||||
|
self.data = []
|
||||||
|
self.cached_data = []
|
||||||
|
self.load_from_cache = metadata_path is None
|
||||||
|
self.load_metadata(metadata_path)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def default_image_operator(
|
||||||
|
base_path="",
|
||||||
|
max_pixels=1920*1080, height=None, width=None,
|
||||||
|
height_division_factor=16, width_division_factor=16,
|
||||||
|
):
|
||||||
|
return RouteByType(operator_map=[
|
||||||
|
(str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)),
|
||||||
|
(list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))),
|
||||||
|
])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def default_video_operator(
|
||||||
|
base_path="",
|
||||||
|
max_pixels=1920*1080, height=None, width=None,
|
||||||
|
height_division_factor=16, width_division_factor=16,
|
||||||
|
num_frames=81, time_division_factor=4, time_division_remainder=1,
|
||||||
|
):
|
||||||
|
return RouteByType(operator_map=[
|
||||||
|
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
|
||||||
|
(("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),
|
||||||
|
(("gif",), LoadGIF(
|
||||||
|
num_frames, time_division_factor, time_division_remainder,
|
||||||
|
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
||||||
|
)),
|
||||||
|
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
|
||||||
|
num_frames, time_division_factor, time_division_remainder,
|
||||||
|
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
||||||
|
)),
|
||||||
|
])),
|
||||||
|
])
|
||||||
|
|
||||||
|
def search_for_cached_data_files(self, path):
|
||||||
|
for file_name in os.listdir(path):
|
||||||
|
subpath = os.path.join(path, file_name)
|
||||||
|
if os.path.isdir(subpath):
|
||||||
|
self.search_for_cached_data_files(subpath)
|
||||||
|
elif subpath.endswith(".pth"):
|
||||||
|
self.cached_data.append(subpath)
|
||||||
|
|
||||||
|
def load_metadata(self, metadata_path):
|
||||||
|
if metadata_path is None:
|
||||||
|
print("No metadata_path. Searching for cached data files.")
|
||||||
|
self.search_for_cached_data_files(self.base_path)
|
||||||
|
print(f"{len(self.cached_data)} cached data files found.")
|
||||||
|
elif metadata_path.endswith(".json"):
|
||||||
|
with open(metadata_path, "r") as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
self.data = metadata
|
||||||
|
elif metadata_path.endswith(".jsonl"):
|
||||||
|
metadata = []
|
||||||
|
with open(metadata_path, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
metadata.append(json.loads(line.strip()))
|
||||||
|
self.data = metadata
|
||||||
|
else:
|
||||||
|
metadata = pandas.read_csv(metadata_path)
|
||||||
|
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
||||||
|
|
||||||
|
def __getitem__(self, data_id):
|
||||||
|
if self.load_from_cache:
|
||||||
|
data = self.cached_data[data_id % len(self.cached_data)]
|
||||||
|
data = self.cached_data_operator(data)
|
||||||
|
else:
|
||||||
|
data = self.data[data_id % len(self.data)].copy()
|
||||||
|
for key in self.data_file_keys:
|
||||||
|
if key in data:
|
||||||
|
if key in self.special_operator_map:
|
||||||
|
data[key] = self.special_operator_map[key](data[key])
|
||||||
|
elif key in self.data_file_keys:
|
||||||
|
data[key] = self.main_data_operator(data[key])
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self.max_data_items is not None:
|
||||||
|
return self.max_data_items
|
||||||
|
elif self.load_from_cache:
|
||||||
|
return len(self.cached_data) * self.repeat
|
||||||
|
else:
|
||||||
|
return len(self.data) * self.repeat
|
||||||
|
|
||||||
|
def check_data_equal(self, data1, data2):
|
||||||
|
# Debug only
|
||||||
|
if len(data1) != len(data2):
|
||||||
|
return False
|
||||||
|
for k in data1:
|
||||||
|
if data1[k] != data2[k]:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
2
diffsynth/core/device/__init__.py
Normal file
2
diffsynth/core/device/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name
|
||||||
|
from .npu_compatible_device import IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE
|
||||||
107
diffsynth/core/device/npu_compatible_device.py
Normal file
107
diffsynth/core/device/npu_compatible_device.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
import importlib
|
||||||
|
import torch
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def is_torch_npu_available():
|
||||||
|
return importlib.util.find_spec("torch_npu") is not None
|
||||||
|
|
||||||
|
|
||||||
|
IS_CUDA_AVAILABLE = torch.cuda.is_available()
|
||||||
|
IS_NPU_AVAILABLE = is_torch_npu_available() and torch.npu.is_available()
|
||||||
|
|
||||||
|
if IS_NPU_AVAILABLE:
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
torch.npu.config.allow_internal_format = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_type() -> str:
|
||||||
|
"""Get device type based on current machine, currently only support CPU, CUDA, NPU."""
|
||||||
|
if IS_CUDA_AVAILABLE:
|
||||||
|
device = "cuda"
|
||||||
|
elif IS_NPU_AVAILABLE:
|
||||||
|
device = "npu"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
def get_torch_device() -> Any:
|
||||||
|
"""Get torch attribute based on device type, e.g. torch.cuda or torch.npu"""
|
||||||
|
device_name = get_device_type()
|
||||||
|
|
||||||
|
try:
|
||||||
|
return getattr(torch, device_name)
|
||||||
|
except AttributeError:
|
||||||
|
print(f"Device namespace '{device_name}' not found in torch, try to load 'torch.cuda'.")
|
||||||
|
return torch.cuda
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_id() -> int:
|
||||||
|
"""Get current device id based on device type."""
|
||||||
|
return get_torch_device().current_device()
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_name() -> str:
|
||||||
|
"""Get current device name based on device type."""
|
||||||
|
return f"{get_device_type()}:{get_device_id()}"
|
||||||
|
|
||||||
|
|
||||||
|
def synchronize() -> None:
|
||||||
|
"""Execute torch synchronize operation."""
|
||||||
|
get_torch_device().synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
def empty_cache() -> None:
|
||||||
|
"""Execute torch empty cache operation."""
|
||||||
|
get_torch_device().empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def get_nccl_backend() -> str:
|
||||||
|
"""Return distributed communication backend type based on device type."""
|
||||||
|
if IS_CUDA_AVAILABLE:
|
||||||
|
return "nccl"
|
||||||
|
elif IS_NPU_AVAILABLE:
|
||||||
|
return "hccl"
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"No available distributed communication backend found on device type {get_device_type()}.")
|
||||||
|
|
||||||
|
|
||||||
|
def enable_high_precision_for_bf16():
|
||||||
|
"""
|
||||||
|
Set high accumulation dtype for matmul and reduction.
|
||||||
|
"""
|
||||||
|
if IS_CUDA_AVAILABLE:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
|
||||||
|
|
||||||
|
if IS_NPU_AVAILABLE:
|
||||||
|
torch.npu.matmul.allow_tf32 = False
|
||||||
|
torch.npu.matmul.allow_bf16_reduced_precision_reduction = False
|
||||||
|
|
||||||
|
|
||||||
|
def parse_device_type(device):
|
||||||
|
if isinstance(device, str):
|
||||||
|
if device.startswith("cuda"):
|
||||||
|
return "cuda"
|
||||||
|
elif device.startswith("npu"):
|
||||||
|
return "npu"
|
||||||
|
else:
|
||||||
|
return "cpu"
|
||||||
|
elif isinstance(device, torch.device):
|
||||||
|
return device.type
|
||||||
|
|
||||||
|
|
||||||
|
def parse_nccl_backend(device_type):
|
||||||
|
if device_type == "cuda":
|
||||||
|
return "nccl"
|
||||||
|
elif device_type == "npu":
|
||||||
|
return "hccl"
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"No available distributed communication backend found on device type {device_type}.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_device_type():
|
||||||
|
return get_device_type()
|
||||||
1
diffsynth/core/gradient/__init__.py
Normal file
1
diffsynth/core/gradient/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .gradient_checkpoint import gradient_checkpoint_forward
|
||||||
34
diffsynth/core/gradient/gradient_checkpoint.py
Normal file
34
diffsynth/core/gradient/gradient_checkpoint.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs, **kwargs):
|
||||||
|
return module(*inputs, **kwargs)
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
|
||||||
|
def gradient_checkpoint_forward(
|
||||||
|
model,
|
||||||
|
use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if use_gradient_checkpointing_offload:
|
||||||
|
with torch.autograd.graph.save_on_cpu():
|
||||||
|
model_output = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(model),
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
elif use_gradient_checkpointing:
|
||||||
|
model_output = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(model),
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_output = model(*args, **kwargs)
|
||||||
|
return model_output
|
||||||
3
diffsynth/core/loader/__init__.py
Normal file
3
diffsynth/core/loader/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .file import load_state_dict, hash_state_dict_keys, hash_model_file
|
||||||
|
from .model import load_model, load_model_with_disk_offload
|
||||||
|
from .config import ModelConfig
|
||||||
119
diffsynth/core/loader/config.py
Normal file
119
diffsynth/core/loader/config.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
import torch, glob, os
|
||||||
|
from typing import Optional, Union, Dict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
from huggingface_hub import snapshot_download as hf_snapshot_download
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelConfig:
|
||||||
|
path: Union[str, list[str]] = None
|
||||||
|
model_id: str = None
|
||||||
|
origin_file_pattern: Union[str, list[str]] = None
|
||||||
|
download_source: str = None
|
||||||
|
local_model_path: str = None
|
||||||
|
skip_download: bool = None
|
||||||
|
offload_device: Optional[Union[str, torch.device]] = None
|
||||||
|
offload_dtype: Optional[torch.dtype] = None
|
||||||
|
onload_device: Optional[Union[str, torch.device]] = None
|
||||||
|
onload_dtype: Optional[torch.dtype] = None
|
||||||
|
preparing_device: Optional[Union[str, torch.device]] = None
|
||||||
|
preparing_dtype: Optional[torch.dtype] = None
|
||||||
|
computation_device: Optional[Union[str, torch.device]] = None
|
||||||
|
computation_dtype: Optional[torch.dtype] = None
|
||||||
|
clear_parameters: bool = False
|
||||||
|
state_dict: Dict[str, torch.Tensor] = None
|
||||||
|
|
||||||
|
def check_input(self):
|
||||||
|
if self.path is None and self.model_id is None:
|
||||||
|
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
|
||||||
|
|
||||||
|
def parse_original_file_pattern(self):
|
||||||
|
if self.origin_file_pattern in [None, "", "./"]:
|
||||||
|
return "*"
|
||||||
|
elif self.origin_file_pattern.endswith("/"):
|
||||||
|
return self.origin_file_pattern + "*"
|
||||||
|
else:
|
||||||
|
return self.origin_file_pattern
|
||||||
|
|
||||||
|
def parse_download_source(self):
|
||||||
|
if self.download_source is None:
|
||||||
|
if os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE') is not None:
|
||||||
|
return os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE')
|
||||||
|
else:
|
||||||
|
return "modelscope"
|
||||||
|
else:
|
||||||
|
return self.download_source
|
||||||
|
|
||||||
|
def parse_skip_download(self):
|
||||||
|
if self.skip_download is None:
|
||||||
|
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None:
|
||||||
|
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "true":
|
||||||
|
return True
|
||||||
|
elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "false":
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return self.skip_download
|
||||||
|
|
||||||
|
def download(self):
|
||||||
|
origin_file_pattern = self.parse_original_file_pattern()
|
||||||
|
downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
|
||||||
|
download_source = self.parse_download_source()
|
||||||
|
if download_source.lower() == "modelscope":
|
||||||
|
snapshot_download(
|
||||||
|
self.model_id,
|
||||||
|
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||||
|
allow_file_pattern=origin_file_pattern,
|
||||||
|
ignore_file_pattern=downloaded_files,
|
||||||
|
local_files_only=False
|
||||||
|
)
|
||||||
|
elif download_source.lower() == "huggingface":
|
||||||
|
hf_snapshot_download(
|
||||||
|
self.model_id,
|
||||||
|
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||||
|
allow_patterns=origin_file_pattern,
|
||||||
|
ignore_patterns=downloaded_files,
|
||||||
|
local_files_only=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("`download_source` should be `modelscope` or `huggingface`.")
|
||||||
|
|
||||||
|
def require_downloading(self):
|
||||||
|
if self.path is not None:
|
||||||
|
return False
|
||||||
|
skip_download = self.parse_skip_download()
|
||||||
|
return not skip_download
|
||||||
|
|
||||||
|
def reset_local_model_path(self):
|
||||||
|
if os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') is not None:
|
||||||
|
self.local_model_path = os.environ.get('DIFFSYNTH_MODEL_BASE_PATH')
|
||||||
|
elif self.local_model_path is None:
|
||||||
|
self.local_model_path = "./models"
|
||||||
|
|
||||||
|
def download_if_necessary(self):
|
||||||
|
self.check_input()
|
||||||
|
self.reset_local_model_path()
|
||||||
|
if self.require_downloading():
|
||||||
|
self.download()
|
||||||
|
if self.path is None:
|
||||||
|
if self.origin_file_pattern in [None, "", "./"]:
|
||||||
|
self.path = os.path.join(self.local_model_path, self.model_id)
|
||||||
|
else:
|
||||||
|
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
|
||||||
|
if isinstance(self.path, list) and len(self.path) == 1:
|
||||||
|
self.path = self.path[0]
|
||||||
|
|
||||||
|
def vram_config(self):
|
||||||
|
return {
|
||||||
|
"offload_device": self.offload_device,
|
||||||
|
"offload_dtype": self.offload_dtype,
|
||||||
|
"onload_device": self.onload_device,
|
||||||
|
"onload_dtype": self.onload_dtype,
|
||||||
|
"preparing_device": self.preparing_device,
|
||||||
|
"preparing_dtype": self.preparing_dtype,
|
||||||
|
"computation_device": self.computation_device,
|
||||||
|
"computation_dtype": self.computation_dtype,
|
||||||
|
}
|
||||||
130
diffsynth/core/loader/file.py
Normal file
130
diffsynth/core/loader/file.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
from safetensors import safe_open
|
||||||
|
import torch, hashlib
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict(file_path, torch_dtype=None, device="cpu", pin_memory=False, verbose=0):
|
||||||
|
if isinstance(file_path, list):
|
||||||
|
state_dict = {}
|
||||||
|
for file_path_ in file_path:
|
||||||
|
state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose))
|
||||||
|
else:
|
||||||
|
if verbose >= 1:
|
||||||
|
print(f"Loading file [started]: {file_path}")
|
||||||
|
if file_path.endswith(".safetensors"):
|
||||||
|
state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
||||||
|
else:
|
||||||
|
state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
||||||
|
# If load state dict in CPU memory, `pin_memory=True` will make `model.to("cuda")` faster.
|
||||||
|
if pin_memory:
|
||||||
|
for i in state_dict:
|
||||||
|
state_dict[i] = state_dict[i].pin_memory()
|
||||||
|
if verbose >= 1:
|
||||||
|
print(f"Loading file [done]: {file_path}")
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||||
|
state_dict = {}
|
||||||
|
with safe_open(file_path, framework="pt", device=str(device)) as f:
|
||||||
|
for k in f.keys():
|
||||||
|
state_dict[k] = f.get_tensor(k)
|
||||||
|
if torch_dtype is not None:
|
||||||
|
state_dict[k] = state_dict[k].to(torch_dtype)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
|
||||||
|
state_dict = torch.load(file_path, map_location=device, weights_only=True)
|
||||||
|
if len(state_dict) == 1:
|
||||||
|
if "state_dict" in state_dict:
|
||||||
|
state_dict = state_dict["state_dict"]
|
||||||
|
elif "module" in state_dict:
|
||||||
|
state_dict = state_dict["module"]
|
||||||
|
elif "model_state" in state_dict:
|
||||||
|
state_dict = state_dict["model_state"]
|
||||||
|
if torch_dtype is not None:
|
||||||
|
for i in state_dict:
|
||||||
|
if isinstance(state_dict[i], torch.Tensor):
|
||||||
|
state_dict[i] = state_dict[i].to(torch_dtype)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
||||||
|
keys = []
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if isinstance(key, str):
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
if with_shape:
|
||||||
|
shape = "_".join(map(str, list(value.shape)))
|
||||||
|
keys.append(key + ":" + shape)
|
||||||
|
keys.append(key)
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
||||||
|
keys.sort()
|
||||||
|
keys_str = ",".join(keys)
|
||||||
|
return keys_str
|
||||||
|
|
||||||
|
|
||||||
|
def hash_state_dict_keys(state_dict, with_shape=True):
|
||||||
|
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
||||||
|
keys_str = keys_str.encode(encoding="UTF-8")
|
||||||
|
return hashlib.md5(keys_str).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def load_keys_dict(file_path):
|
||||||
|
if isinstance(file_path, list):
|
||||||
|
state_dict = {}
|
||||||
|
for file_path_ in file_path:
|
||||||
|
state_dict.update(load_keys_dict(file_path_))
|
||||||
|
return state_dict
|
||||||
|
if file_path.endswith(".safetensors"):
|
||||||
|
return load_keys_dict_from_safetensors(file_path)
|
||||||
|
else:
|
||||||
|
return load_keys_dict_from_bin(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
def load_keys_dict_from_safetensors(file_path):
|
||||||
|
keys_dict = {}
|
||||||
|
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||||
|
for k in f.keys():
|
||||||
|
keys_dict[k] = f.get_slice(k).get_shape()
|
||||||
|
return keys_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_state_dict_to_keys_dict(state_dict):
|
||||||
|
keys_dict = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
keys_dict[k] = list(v.shape)
|
||||||
|
else:
|
||||||
|
keys_dict[k] = convert_state_dict_to_keys_dict(v)
|
||||||
|
return keys_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load_keys_dict_from_bin(file_path):
|
||||||
|
state_dict = load_state_dict_from_bin(file_path)
|
||||||
|
keys_dict = convert_state_dict_to_keys_dict(state_dict)
|
||||||
|
return keys_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_keys_dict_to_single_str(state_dict, with_shape=True):
|
||||||
|
keys = []
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if isinstance(key, str):
|
||||||
|
if isinstance(value, dict):
|
||||||
|
keys.append(key + "|" + convert_keys_dict_to_single_str(value, with_shape=with_shape))
|
||||||
|
else:
|
||||||
|
if with_shape:
|
||||||
|
shape = "_".join(map(str, list(value)))
|
||||||
|
keys.append(key + ":" + shape)
|
||||||
|
keys.append(key)
|
||||||
|
keys.sort()
|
||||||
|
keys_str = ",".join(keys)
|
||||||
|
return keys_str
|
||||||
|
|
||||||
|
|
||||||
|
def hash_model_file(path, with_shape=True):
|
||||||
|
keys_dict = load_keys_dict(path)
|
||||||
|
keys_str = convert_keys_dict_to_single_str(keys_dict, with_shape=with_shape)
|
||||||
|
keys_str = keys_str.encode(encoding="UTF-8")
|
||||||
|
return hashlib.md5(keys_str).hexdigest()
|
||||||
105
diffsynth/core/loader/model.py
Normal file
105
diffsynth/core/loader/model.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
from ..vram.initialization import skip_model_initialization
|
||||||
|
from ..vram.disk_map import DiskMap
|
||||||
|
from ..vram.layers import enable_vram_management
|
||||||
|
from .file import load_state_dict
|
||||||
|
import torch
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
from transformers.utils import ContextManagers
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
|
||||||
|
config = {} if config is None else config
|
||||||
|
with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)):
|
||||||
|
model = model_class(**config)
|
||||||
|
# What is `module_map`?
|
||||||
|
# This is a module mapping table for VRAM management.
|
||||||
|
if module_map is not None:
|
||||||
|
devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]]
|
||||||
|
device = [d for d in devices if d != "disk"][0]
|
||||||
|
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
|
||||||
|
dtype = [d for d in dtypes if d != "disk"][0]
|
||||||
|
if vram_config["offload_device"] != "disk":
|
||||||
|
if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype)
|
||||||
|
if state_dict_converter is not None:
|
||||||
|
state_dict = state_dict_converter(state_dict)
|
||||||
|
else:
|
||||||
|
state_dict = {i: state_dict[i] for i in state_dict}
|
||||||
|
model.load_state_dict(state_dict, assign=True)
|
||||||
|
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit)
|
||||||
|
else:
|
||||||
|
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
|
||||||
|
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit)
|
||||||
|
else:
|
||||||
|
# Why do we use `DiskMap`?
|
||||||
|
# Sometimes a model file contains multiple models,
|
||||||
|
# and DiskMap can load only the parameters of a single model,
|
||||||
|
# avoiding the need to load all parameters in the file.
|
||||||
|
if state_dict is not None:
|
||||||
|
pass
|
||||||
|
elif use_disk_map:
|
||||||
|
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
|
||||||
|
else:
|
||||||
|
state_dict = load_state_dict(path, torch_dtype, device)
|
||||||
|
# Why do we use `state_dict_converter`?
|
||||||
|
# Some models are saved in complex formats,
|
||||||
|
# and we need to convert the state dict into the appropriate format.
|
||||||
|
if state_dict_converter is not None:
|
||||||
|
state_dict = state_dict_converter(state_dict)
|
||||||
|
else:
|
||||||
|
state_dict = {i: state_dict[i] for i in state_dict}
|
||||||
|
# Why does DeepSpeed ZeRO Stage 3 need to be handled separately?
|
||||||
|
# Because at this stage, model parameters are partitioned across multiple GPUs.
|
||||||
|
# Loading them directly could lead to excessive GPU memory consumption.
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model
|
||||||
|
_load_state_dict_into_zero3_model(model, state_dict)
|
||||||
|
else:
|
||||||
|
model.load_state_dict(state_dict, assign=True)
|
||||||
|
# Why do we call `to()`?
|
||||||
|
# Because some models override the behavior of `to()`,
|
||||||
|
# especially those from libraries like Transformers.
|
||||||
|
model = model.to(dtype=torch_dtype, device=device)
|
||||||
|
if hasattr(model, "eval"):
|
||||||
|
model = model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None):
|
||||||
|
if isinstance(path, str):
|
||||||
|
path = [path]
|
||||||
|
config = {} if config is None else config
|
||||||
|
with skip_model_initialization():
|
||||||
|
model = model_class(**config)
|
||||||
|
if hasattr(model, "eval"):
|
||||||
|
model = model.eval()
|
||||||
|
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": "disk",
|
||||||
|
"offload_device": "disk",
|
||||||
|
"onload_dtype": "disk",
|
||||||
|
"onload_device": "disk",
|
||||||
|
"preparing_dtype": torch.float8_e4m3fn,
|
||||||
|
"preparing_device": device,
|
||||||
|
"computation_dtype": torch_dtype,
|
||||||
|
"computation_device": device,
|
||||||
|
}
|
||||||
|
enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_init_context(torch_dtype, device):
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
from transformers.modeling_utils import set_zero3_state
|
||||||
|
import deepspeed
|
||||||
|
# Why do we use "deepspeed.zero.Init"?
|
||||||
|
# Weight segmentation of the model can be performed on the CPU side
|
||||||
|
# and loading the segmented weights onto the computing card
|
||||||
|
init_contexts = [deepspeed.zero.Init(remote_device=device, dtype=torch_dtype), set_zero3_state()]
|
||||||
|
else:
|
||||||
|
# Why do we use `skip_model_initialization`?
|
||||||
|
# It skips the random initialization of model parameters,
|
||||||
|
# thereby speeding up model loading and avoiding excessive memory usage.
|
||||||
|
init_contexts = [skip_model_initialization()]
|
||||||
|
|
||||||
|
return init_contexts
|
||||||
30
diffsynth/core/npu_patch/npu_fused_operator.py
Normal file
30
diffsynth/core/npu_patch/npu_fused_operator.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
import torch
|
||||||
|
from ..device.npu_compatible_device import get_device_type
|
||||||
|
try:
|
||||||
|
import torch_npu
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm_forward_npu(self, hidden_states):
|
||||||
|
"npu rms fused operator for RMSNorm.forward from diffsynth\models\general_modules.py"
|
||||||
|
if hidden_states.dtype != self.weight.dtype:
|
||||||
|
hidden_states = hidden_states.to(self.weight.dtype)
|
||||||
|
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.eps)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm_forward_transformers_npu(self, hidden_states):
|
||||||
|
"npu rms fused operator for transformers"
|
||||||
|
if hidden_states.dtype != self.weight.dtype:
|
||||||
|
hidden_states = hidden_states.to(self.weight.dtype)
|
||||||
|
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def rotary_emb_Zimage_npu(self, x_in: torch.Tensor, freqs_cis: torch.Tensor):
|
||||||
|
"npu rope fused operator for Zimage"
|
||||||
|
with torch.amp.autocast(get_device_type(), enabled=False):
|
||||||
|
freqs_cis = freqs_cis.unsqueeze(2)
|
||||||
|
cos, sin = torch.chunk(torch.view_as_real(freqs_cis), 2, dim=-1)
|
||||||
|
cos = cos.expand(-1, -1, -1, -1, 2).flatten(-2)
|
||||||
|
sin = sin.expand(-1, -1, -1, -1, 2).flatten(-2)
|
||||||
|
return torch_npu.npu_rotary_mul(x_in, cos, sin, rotary_mode="interleave").to(x_in)
|
||||||
2
diffsynth/core/vram/__init__.py
Normal file
2
diffsynth/core/vram/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from .initialization import skip_model_initialization
|
||||||
|
from .layers import *
|
||||||
93
diffsynth/core/vram/disk_map.py
Normal file
93
diffsynth/core/vram/disk_map.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
from safetensors import safe_open
|
||||||
|
import torch, os
|
||||||
|
|
||||||
|
|
||||||
|
class SafetensorsCompatibleTensor:
|
||||||
|
def __init__(self, tensor):
|
||||||
|
self.tensor = tensor
|
||||||
|
|
||||||
|
def get_shape(self):
|
||||||
|
return list(self.tensor.shape)
|
||||||
|
|
||||||
|
|
||||||
|
class SafetensorsCompatibleBinaryLoader:
|
||||||
|
def __init__(self, path, device):
|
||||||
|
print("Detected non-safetensors files, which may cause slower loading. It's recommended to convert it to a safetensors file.")
|
||||||
|
self.state_dict = torch.load(path, weights_only=True, map_location=device)
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return self.state_dict.keys()
|
||||||
|
|
||||||
|
def get_tensor(self, name):
|
||||||
|
return self.state_dict[name]
|
||||||
|
|
||||||
|
def get_slice(self, name):
|
||||||
|
return SafetensorsCompatibleTensor(self.state_dict[name])
|
||||||
|
|
||||||
|
|
||||||
|
class DiskMap:
|
||||||
|
|
||||||
|
def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, buffer_size=10**9):
|
||||||
|
self.path = path if isinstance(path, list) else [path]
|
||||||
|
self.device = device
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
if os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE') is not None:
|
||||||
|
self.buffer_size = int(os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE'))
|
||||||
|
else:
|
||||||
|
self.buffer_size = buffer_size
|
||||||
|
self.files = []
|
||||||
|
self.flush_files()
|
||||||
|
self.name_map = {}
|
||||||
|
for file_id, file in enumerate(self.files):
|
||||||
|
for name in file.keys():
|
||||||
|
self.name_map[name] = file_id
|
||||||
|
self.rename_dict = self.fetch_rename_dict(state_dict_converter)
|
||||||
|
|
||||||
|
def flush_files(self):
|
||||||
|
if len(self.files) == 0:
|
||||||
|
for path in self.path:
|
||||||
|
if path.endswith(".safetensors"):
|
||||||
|
self.files.append(safe_open(path, framework="pt", device=str(self.device)))
|
||||||
|
else:
|
||||||
|
self.files.append(SafetensorsCompatibleBinaryLoader(path, device=self.device))
|
||||||
|
else:
|
||||||
|
for i, path in enumerate(self.path):
|
||||||
|
if path.endswith(".safetensors"):
|
||||||
|
self.files[i] = safe_open(path, framework="pt", device=str(self.device))
|
||||||
|
self.num_params = 0
|
||||||
|
|
||||||
|
def __getitem__(self, name):
|
||||||
|
if self.rename_dict is not None: name = self.rename_dict[name]
|
||||||
|
file_id = self.name_map[name]
|
||||||
|
param = self.files[file_id].get_tensor(name)
|
||||||
|
if self.torch_dtype is not None and isinstance(param, torch.Tensor):
|
||||||
|
param = param.to(self.torch_dtype)
|
||||||
|
if isinstance(param, torch.Tensor) and param.device == "cpu":
|
||||||
|
param = param.clone()
|
||||||
|
if isinstance(param, torch.Tensor):
|
||||||
|
self.num_params += param.numel()
|
||||||
|
if self.num_params > self.buffer_size:
|
||||||
|
self.flush_files()
|
||||||
|
return param
|
||||||
|
|
||||||
|
def fetch_rename_dict(self, state_dict_converter):
|
||||||
|
if state_dict_converter is None:
|
||||||
|
return None
|
||||||
|
state_dict = {}
|
||||||
|
for file in self.files:
|
||||||
|
for name in file.keys():
|
||||||
|
state_dict[name] = name
|
||||||
|
state_dict = state_dict_converter(state_dict)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self.rename_dict is not None:
|
||||||
|
return self.rename_dict.__iter__()
|
||||||
|
else:
|
||||||
|
return self.name_map.__iter__()
|
||||||
|
|
||||||
|
def __contains__(self, x):
|
||||||
|
if self.rename_dict is not None:
|
||||||
|
return x in self.rename_dict
|
||||||
|
else:
|
||||||
|
return x in self.name_map
|
||||||
21
diffsynth/core/vram/initialization.py
Normal file
21
diffsynth/core/vram/initialization.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import torch
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def skip_model_initialization(device=torch.device("meta")):
|
||||||
|
|
||||||
|
def register_empty_parameter(module, name, param):
|
||||||
|
old_register_parameter(module, name, param)
|
||||||
|
if param is not None:
|
||||||
|
param_cls = type(module._parameters[name])
|
||||||
|
kwargs = module._parameters[name].__dict__
|
||||||
|
kwargs["requires_grad"] = param.requires_grad
|
||||||
|
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
||||||
|
|
||||||
|
old_register_parameter = torch.nn.Module.register_parameter
|
||||||
|
torch.nn.Module.register_parameter = register_empty_parameter
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
torch.nn.Module.register_parameter = old_register_parameter
|
||||||
479
diffsynth/core/vram/layers.py
Normal file
479
diffsynth/core/vram/layers.py
Normal file
@@ -0,0 +1,479 @@
|
|||||||
|
import torch, copy
|
||||||
|
from typing import Union
|
||||||
|
from .initialization import skip_model_initialization
|
||||||
|
from .disk_map import DiskMap
|
||||||
|
from ..device import parse_device_type, get_device_name, IS_NPU_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
|
class AutoTorchModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
offload_dtype: torch.dtype = None,
|
||||||
|
offload_device: Union[str, torch.device] = None,
|
||||||
|
onload_dtype: torch.dtype = None,
|
||||||
|
onload_device: Union[str, torch.device] = None,
|
||||||
|
preparing_dtype: torch.dtype = None,
|
||||||
|
preparing_device: Union[str, torch.device] = None,
|
||||||
|
computation_dtype: torch.dtype = None,
|
||||||
|
computation_device: Union[str, torch.device] = None,
|
||||||
|
vram_limit: float = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.set_dtype_and_device(
|
||||||
|
offload_dtype,
|
||||||
|
offload_device,
|
||||||
|
onload_dtype,
|
||||||
|
onload_device,
|
||||||
|
preparing_dtype,
|
||||||
|
preparing_device,
|
||||||
|
computation_dtype,
|
||||||
|
computation_device,
|
||||||
|
vram_limit,
|
||||||
|
)
|
||||||
|
self.state = 0
|
||||||
|
self.name = ""
|
||||||
|
self.computation_device_type = parse_device_type(self.computation_device)
|
||||||
|
|
||||||
|
def set_dtype_and_device(
|
||||||
|
self,
|
||||||
|
offload_dtype: torch.dtype = None,
|
||||||
|
offload_device: Union[str, torch.device] = None,
|
||||||
|
onload_dtype: torch.dtype = None,
|
||||||
|
onload_device: Union[str, torch.device] = None,
|
||||||
|
preparing_dtype: torch.dtype = None,
|
||||||
|
preparing_device: Union[str, torch.device] = None,
|
||||||
|
computation_dtype: torch.dtype = None,
|
||||||
|
computation_device: Union[str, torch.device] = None,
|
||||||
|
vram_limit: float = None,
|
||||||
|
):
|
||||||
|
self.offload_dtype = offload_dtype or computation_dtype
|
||||||
|
self.offload_device = offload_device or computation_dtype
|
||||||
|
self.onload_dtype = onload_dtype or computation_dtype
|
||||||
|
self.onload_device = onload_device or computation_dtype
|
||||||
|
self.preparing_dtype = preparing_dtype or computation_dtype
|
||||||
|
self.preparing_device = preparing_device or computation_dtype
|
||||||
|
self.computation_dtype = computation_dtype
|
||||||
|
self.computation_device = computation_device
|
||||||
|
self.vram_limit = vram_limit
|
||||||
|
|
||||||
|
def cast_to(self, weight, dtype, device):
|
||||||
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
|
r.copy_(weight)
|
||||||
|
return r
|
||||||
|
|
||||||
|
def check_free_vram(self):
|
||||||
|
device = self.computation_device if not IS_NPU_AVAILABLE else get_device_name()
|
||||||
|
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)
|
||||||
|
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
|
||||||
|
return used_memory < self.vram_limit
|
||||||
|
|
||||||
|
def offload(self):
|
||||||
|
if self.state != 0:
|
||||||
|
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||||
|
self.state = 0
|
||||||
|
|
||||||
|
def onload(self):
|
||||||
|
if self.state != 1:
|
||||||
|
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||||
|
self.state = 1
|
||||||
|
|
||||||
|
def param_name(self, name):
|
||||||
|
if self.name == "":
|
||||||
|
return name
|
||||||
|
else:
|
||||||
|
return self.name + "." + name
|
||||||
|
|
||||||
|
|
||||||
|
class AutoWrappedModule(AutoTorchModule):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
module: torch.nn.Module,
|
||||||
|
offload_dtype: torch.dtype = None,
|
||||||
|
offload_device: Union[str, torch.device] = None,
|
||||||
|
onload_dtype: torch.dtype = None,
|
||||||
|
onload_device: Union[str, torch.device] = None,
|
||||||
|
preparing_dtype: torch.dtype = None,
|
||||||
|
preparing_device: Union[str, torch.device] = None,
|
||||||
|
computation_dtype: torch.dtype = None,
|
||||||
|
computation_device: Union[str, torch.device] = None,
|
||||||
|
vram_limit: float = None,
|
||||||
|
name: str = "",
|
||||||
|
disk_map: DiskMap = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
offload_dtype,
|
||||||
|
offload_device,
|
||||||
|
onload_dtype,
|
||||||
|
onload_device,
|
||||||
|
preparing_dtype,
|
||||||
|
preparing_device,
|
||||||
|
computation_dtype,
|
||||||
|
computation_device,
|
||||||
|
vram_limit,
|
||||||
|
)
|
||||||
|
self.module = module
|
||||||
|
if offload_dtype == "disk":
|
||||||
|
self.name = name
|
||||||
|
self.disk_map = disk_map
|
||||||
|
self.required_params = [name for name, _ in self.module.named_parameters()]
|
||||||
|
self.disk_offload = True
|
||||||
|
else:
|
||||||
|
self.disk_offload = False
|
||||||
|
|
||||||
|
def load_from_disk(self, torch_dtype, device, copy_module=False):
|
||||||
|
if copy_module:
|
||||||
|
module = copy.deepcopy(self.module)
|
||||||
|
else:
|
||||||
|
module = self.module
|
||||||
|
state_dict = {}
|
||||||
|
for name in self.required_params:
|
||||||
|
param = self.disk_map[self.param_name(name)]
|
||||||
|
param = param.to(dtype=torch_dtype, device=device)
|
||||||
|
state_dict[name] = param
|
||||||
|
module.load_state_dict(state_dict, assign=True)
|
||||||
|
module.to(dtype=torch_dtype, device=device)
|
||||||
|
return module
|
||||||
|
|
||||||
|
def offload_to_disk(self, model: torch.nn.Module):
|
||||||
|
for buf in model.buffers():
|
||||||
|
# If there are some parameters are registed in buffers (not in state dict),
|
||||||
|
# We cannot offload the model.
|
||||||
|
for children in model.children():
|
||||||
|
self.offload_to_disk(children)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
model.to("meta")
|
||||||
|
|
||||||
|
def offload(self):
|
||||||
|
# offload / onload / preparing -> offload
|
||||||
|
if self.state != 0:
|
||||||
|
if self.disk_offload:
|
||||||
|
self.offload_to_disk(self.module)
|
||||||
|
else:
|
||||||
|
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||||
|
self.state = 0
|
||||||
|
|
||||||
|
def onload(self):
|
||||||
|
# offload / onload / preparing -> onload
|
||||||
|
if self.state < 1:
|
||||||
|
if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
|
||||||
|
self.load_from_disk(self.onload_dtype, self.onload_device)
|
||||||
|
elif self.onload_device != "disk":
|
||||||
|
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||||
|
self.state = 1
|
||||||
|
|
||||||
|
def preparing(self):
|
||||||
|
# onload / preparing -> preparing
|
||||||
|
if self.state != 2:
|
||||||
|
if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
|
||||||
|
self.load_from_disk(self.preparing_dtype, self.preparing_device)
|
||||||
|
elif self.preparing_device != "disk":
|
||||||
|
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
|
||||||
|
self.state = 2
|
||||||
|
|
||||||
|
def cast_to(self, module, dtype, device):
|
||||||
|
return copy.deepcopy(module).to(dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def computation(self):
|
||||||
|
# onload / preparing -> computation (temporary)
|
||||||
|
if self.state == 2:
|
||||||
|
torch_dtype, device = self.preparing_dtype, self.preparing_device
|
||||||
|
else:
|
||||||
|
torch_dtype, device = self.onload_dtype, self.onload_device
|
||||||
|
if torch_dtype == self.computation_dtype and device == self.computation_device:
|
||||||
|
module = self.module
|
||||||
|
elif self.disk_offload and device == "disk":
|
||||||
|
module = self.load_from_disk(self.computation_dtype, self.computation_device, copy_module=True)
|
||||||
|
else:
|
||||||
|
module = self.cast_to(self.module, dtype=self.computation_dtype, device=self.computation_device)
|
||||||
|
return module
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
|
||||||
|
self.preparing()
|
||||||
|
module = self.computation()
|
||||||
|
return module(*args, **kwargs)
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
if name in self.__dict__ or name == "module":
|
||||||
|
return super().__getattr__(name)
|
||||||
|
else:
|
||||||
|
return getattr(self.module, name)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoWrappedNonRecurseModule(AutoWrappedModule):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
module: torch.nn.Module,
|
||||||
|
offload_dtype: torch.dtype = None,
|
||||||
|
offload_device: Union[str, torch.device] = None,
|
||||||
|
onload_dtype: torch.dtype = None,
|
||||||
|
onload_device: Union[str, torch.device] = None,
|
||||||
|
preparing_dtype: torch.dtype = None,
|
||||||
|
preparing_device: Union[str, torch.device] = None,
|
||||||
|
computation_dtype: torch.dtype = None,
|
||||||
|
computation_device: Union[str, torch.device] = None,
|
||||||
|
vram_limit: float = None,
|
||||||
|
name: str = "",
|
||||||
|
disk_map: DiskMap = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
module,
|
||||||
|
offload_dtype,
|
||||||
|
offload_device,
|
||||||
|
onload_dtype,
|
||||||
|
onload_device,
|
||||||
|
preparing_dtype,
|
||||||
|
preparing_device,
|
||||||
|
computation_dtype,
|
||||||
|
computation_device,
|
||||||
|
vram_limit,
|
||||||
|
name,
|
||||||
|
disk_map,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
if self.disk_offload:
|
||||||
|
self.required_params = [name for name, _ in self.module.named_parameters(recurse=False)]
|
||||||
|
|
||||||
|
def load_from_disk(self, torch_dtype, device, copy_module=False):
|
||||||
|
if copy_module:
|
||||||
|
module = copy.deepcopy(self.module)
|
||||||
|
else:
|
||||||
|
module = self.module
|
||||||
|
state_dict = {}
|
||||||
|
for name in self.required_params:
|
||||||
|
param = self.disk_map[self.param_name(name)]
|
||||||
|
param = param.to(dtype=torch_dtype, device=device)
|
||||||
|
state_dict[name] = param
|
||||||
|
module.load_state_dict(state_dict, assign=True, strict=False)
|
||||||
|
return module
|
||||||
|
|
||||||
|
def offload_to_disk(self, model: torch.nn.Module):
|
||||||
|
for name in self.required_params:
|
||||||
|
getattr(self, name).to("meta")
|
||||||
|
|
||||||
|
def cast_to(self, module, dtype, device):
|
||||||
|
# Parameter casting is implemented in the model architecture.
|
||||||
|
return module
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
if name in self.__dict__ or name == "module":
|
||||||
|
return super().__getattr__(name)
|
||||||
|
else:
|
||||||
|
return getattr(self.module, name)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
module: torch.nn.Linear,
|
||||||
|
offload_dtype: torch.dtype = None,
|
||||||
|
offload_device: Union[str, torch.device] = None,
|
||||||
|
onload_dtype: torch.dtype = None,
|
||||||
|
onload_device: Union[str, torch.device] = None,
|
||||||
|
preparing_dtype: torch.dtype = None,
|
||||||
|
preparing_device: Union[str, torch.device] = None,
|
||||||
|
computation_dtype: torch.dtype = None,
|
||||||
|
computation_device: Union[str, torch.device] = None,
|
||||||
|
vram_limit: float = None,
|
||||||
|
name: str = "",
|
||||||
|
disk_map: DiskMap = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
with skip_model_initialization():
|
||||||
|
super().__init__(
|
||||||
|
in_features=module.in_features,
|
||||||
|
out_features=module.out_features,
|
||||||
|
bias=module.bias is not None,
|
||||||
|
)
|
||||||
|
self.set_dtype_and_device(
|
||||||
|
offload_dtype,
|
||||||
|
offload_device,
|
||||||
|
onload_dtype,
|
||||||
|
onload_device,
|
||||||
|
preparing_dtype,
|
||||||
|
preparing_device,
|
||||||
|
computation_dtype,
|
||||||
|
computation_device,
|
||||||
|
vram_limit,
|
||||||
|
)
|
||||||
|
self.weight = module.weight
|
||||||
|
self.bias = module.bias
|
||||||
|
self.state = 0
|
||||||
|
self.name = name
|
||||||
|
self.lora_A_weights = []
|
||||||
|
self.lora_B_weights = []
|
||||||
|
self.lora_merger = None
|
||||||
|
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
|
||||||
|
self.computation_device_type = parse_device_type(self.computation_device)
|
||||||
|
|
||||||
|
if offload_dtype == "disk":
|
||||||
|
self.disk_map = disk_map
|
||||||
|
self.disk_offload = True
|
||||||
|
else:
|
||||||
|
self.disk_offload = False
|
||||||
|
|
||||||
|
def fp8_linear(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
bias: torch.Tensor = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
device = input.device
|
||||||
|
origin_dtype = input.dtype
|
||||||
|
origin_shape = input.shape
|
||||||
|
input = input.reshape(-1, origin_shape[-1])
|
||||||
|
|
||||||
|
x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
|
||||||
|
fp8_max = 448.0
|
||||||
|
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
|
||||||
|
# To avoid overflow and ensure numerical compatibility during FP8 computation,
|
||||||
|
# we scale down the input by 2.0 in advance.
|
||||||
|
# This scaling will be compensated later during the final result scaling.
|
||||||
|
if self.computation_dtype == torch.float8_e4m3fnuz:
|
||||||
|
fp8_max = fp8_max / 2.0
|
||||||
|
scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
|
||||||
|
scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
|
||||||
|
input = input / (scale_a + 1e-8)
|
||||||
|
input = input.to(self.computation_dtype)
|
||||||
|
weight = weight.to(self.computation_dtype)
|
||||||
|
bias = bias.to(torch.bfloat16)
|
||||||
|
|
||||||
|
result = torch._scaled_mm(
|
||||||
|
input,
|
||||||
|
weight.T,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b.T,
|
||||||
|
bias=bias,
|
||||||
|
out_dtype=origin_dtype,
|
||||||
|
)
|
||||||
|
new_shape = origin_shape[:-1] + result.shape[-1:]
|
||||||
|
result = result.reshape(new_shape)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def load_from_disk(self, torch_dtype, device, assign=True):
|
||||||
|
weight = self.disk_map[self.name + ".weight"].to(dtype=torch_dtype, device=device)
|
||||||
|
bias = None if self.bias is None else self.disk_map[self.name + ".bias"].to(dtype=torch_dtype, device=device)
|
||||||
|
if assign:
|
||||||
|
state_dict = {"weight": weight}
|
||||||
|
if bias is not None: state_dict["bias"] = bias
|
||||||
|
self.load_state_dict(state_dict, assign=True)
|
||||||
|
return weight, bias
|
||||||
|
|
||||||
|
def offload(self):
|
||||||
|
# offload / onload / preparing -> offload
|
||||||
|
if self.state != 0:
|
||||||
|
if self.disk_offload:
|
||||||
|
self.to("meta")
|
||||||
|
else:
|
||||||
|
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||||
|
self.state = 0
|
||||||
|
|
||||||
|
def onload(self):
|
||||||
|
# offload / onload / preparing -> onload
|
||||||
|
if self.state < 1:
|
||||||
|
if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
|
||||||
|
self.load_from_disk(self.onload_dtype, self.onload_device)
|
||||||
|
elif self.onload_device != "disk":
|
||||||
|
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||||
|
self.state = 1
|
||||||
|
|
||||||
|
def preparing(self):
|
||||||
|
# onload / preparing -> preparing
|
||||||
|
if self.state != 2:
|
||||||
|
if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
|
||||||
|
self.load_from_disk(self.preparing_dtype, self.preparing_device)
|
||||||
|
elif self.preparing_device != "disk":
|
||||||
|
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
|
||||||
|
self.state = 2
|
||||||
|
|
||||||
|
def computation(self):
|
||||||
|
# onload / preparing -> computation (temporary)
|
||||||
|
if self.state == 2:
|
||||||
|
torch_dtype, device = self.preparing_dtype, self.preparing_device
|
||||||
|
else:
|
||||||
|
torch_dtype, device = self.onload_dtype, self.onload_device
|
||||||
|
if torch_dtype == self.computation_dtype and device == self.computation_device:
|
||||||
|
weight, bias = self.weight, self.bias
|
||||||
|
elif self.disk_offload and device == "disk":
|
||||||
|
weight, bias = self.load_from_disk(self.computation_dtype, self.computation_device, assign=False)
|
||||||
|
else:
|
||||||
|
weight = self.cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||||
|
bias = None if self.bias is None else self.cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||||
|
return weight, bias
|
||||||
|
|
||||||
|
def linear_forward(self, x, weight, bias):
|
||||||
|
if self.enable_fp8:
|
||||||
|
out = self.fp8_linear(x, weight, bias)
|
||||||
|
else:
|
||||||
|
out = torch.nn.functional.linear(x, weight, bias)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def lora_forward(self, x, out):
|
||||||
|
if self.lora_merger is None:
|
||||||
|
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
||||||
|
out = out + x @ lora_A.T @ lora_B.T
|
||||||
|
else:
|
||||||
|
lora_output = []
|
||||||
|
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
||||||
|
lora_output.append(x @ lora_A.T @ lora_B.T)
|
||||||
|
lora_output = torch.stack(lora_output)
|
||||||
|
out = self.lora_merger(out, lora_output)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward(self, x, *args, **kwargs):
|
||||||
|
if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
|
||||||
|
self.preparing()
|
||||||
|
weight, bias = self.computation()
|
||||||
|
out = self.linear_forward(x, weight, bias)
|
||||||
|
if len(self.lora_A_weights) > 0:
|
||||||
|
out = self.lora_forward(x, out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, name_prefix="", disk_map=None, **kwargs):
|
||||||
|
if isinstance(model, AutoWrappedNonRecurseModule):
|
||||||
|
model = model.module
|
||||||
|
for name, module in model.named_children():
|
||||||
|
layer_name = name if name_prefix == "" else name_prefix + "." + name
|
||||||
|
for source_module, target_module in module_map.items():
|
||||||
|
if isinstance(module, source_module):
|
||||||
|
module_ = target_module(module, **vram_config, vram_limit=vram_limit, name=layer_name, disk_map=disk_map, **kwargs)
|
||||||
|
if isinstance(module_, AutoWrappedNonRecurseModule):
|
||||||
|
enable_vram_management_recursively(module_, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
|
||||||
|
setattr(model, name, module_)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
enable_vram_management_recursively(module, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def fill_vram_config(model, vram_config):
|
||||||
|
vram_config_ = vram_config.copy()
|
||||||
|
vram_config_["onload_dtype"] = vram_config["computation_dtype"]
|
||||||
|
vram_config_["onload_device"] = vram_config["computation_device"]
|
||||||
|
vram_config_["preparing_dtype"] = vram_config["computation_dtype"]
|
||||||
|
vram_config_["preparing_device"] = vram_config["computation_device"]
|
||||||
|
for k in vram_config:
|
||||||
|
if vram_config[k] != vram_config_[k]:
|
||||||
|
print(f"No fine-grained VRAM configuration is provided for {model.__class__.__name__}. [`onload`, `preparing`, `computation`] will be the same state. `vram_config` is set to {vram_config_}")
|
||||||
|
break
|
||||||
|
return vram_config_
|
||||||
|
|
||||||
|
|
||||||
|
def enable_vram_management(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, disk_map=None, **kwargs):
|
||||||
|
for source_module, target_module in module_map.items():
|
||||||
|
# If no fine-grained VRAM configuration is provided, the entire model will be managed uniformly.
|
||||||
|
if isinstance(model, source_module):
|
||||||
|
vram_config = fill_vram_config(model, vram_config)
|
||||||
|
model = target_module(model, **vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
enable_vram_management_recursively(model, module_map, vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
|
||||||
|
# `vram_management_enabled` is a flag that allows the pipeline to determine whether VRAM management is enabled.
|
||||||
|
model.vram_management_enabled = True
|
||||||
|
return model
|
||||||
@@ -1 +0,0 @@
|
|||||||
from .video import VideoData, save_video, save_frames
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
import torch, os
|
|
||||||
from torchvision import transforms
|
|
||||||
import pandas as pd
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TextImageDataset(torch.utils.data.Dataset):
|
|
||||||
def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
|
|
||||||
self.steps_per_epoch = steps_per_epoch
|
|
||||||
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
|
|
||||||
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
|
|
||||||
self.text = metadata["text"].to_list()
|
|
||||||
self.image_processor = transforms.Compose(
|
|
||||||
[
|
|
||||||
transforms.Resize(max(height, width), interpolation=transforms.InterpolationMode.BILINEAR),
|
|
||||||
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
|
|
||||||
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize([0.5], [0.5]),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
data_id = torch.randint(0, len(self.path), (1,))[0]
|
|
||||||
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
|
||||||
text = self.text[data_id]
|
|
||||||
image = Image.open(self.path[data_id]).convert("RGB")
|
|
||||||
image = self.image_processor(image)
|
|
||||||
return {"text": text, "image": image}
|
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.steps_per_epoch
|
|
||||||
6
diffsynth/diffusion/__init__.py
Normal file
6
diffsynth/diffusion/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from .flow_match import FlowMatchScheduler
|
||||||
|
from .training_module import DiffusionTrainingModule
|
||||||
|
from .logger import ModelLogger
|
||||||
|
from .runner import launch_training_task, launch_data_process_task
|
||||||
|
from .parsers import *
|
||||||
|
from .loss import *
|
||||||
459
diffsynth/diffusion/base_pipeline.py
Normal file
459
diffsynth/diffusion/base_pipeline.py
Normal file
@@ -0,0 +1,459 @@
|
|||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from einops import repeat, reduce
|
||||||
|
from typing import Union
|
||||||
|
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
|
from ..utils.lora import GeneralLoRALoader
|
||||||
|
from ..models.model_loader import ModelPool
|
||||||
|
from ..utils.controlnet import ControlNetInput
|
||||||
|
from ..core.device import get_device_name, IS_NPU_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineUnit:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
seperate_cfg: bool = False,
|
||||||
|
take_over: bool = False,
|
||||||
|
input_params: tuple[str] = None,
|
||||||
|
output_params: tuple[str] = None,
|
||||||
|
input_params_posi: dict[str, str] = None,
|
||||||
|
input_params_nega: dict[str, str] = None,
|
||||||
|
onload_model_names: tuple[str] = None
|
||||||
|
):
|
||||||
|
self.seperate_cfg = seperate_cfg
|
||||||
|
self.take_over = take_over
|
||||||
|
self.input_params = input_params
|
||||||
|
self.output_params = output_params
|
||||||
|
self.input_params_posi = input_params_posi
|
||||||
|
self.input_params_nega = input_params_nega
|
||||||
|
self.onload_model_names = onload_model_names
|
||||||
|
|
||||||
|
def fetch_input_params(self):
|
||||||
|
params = []
|
||||||
|
if self.input_params is not None:
|
||||||
|
for param in self.input_params:
|
||||||
|
params.append(param)
|
||||||
|
if self.input_params_posi is not None:
|
||||||
|
for _, param in self.input_params_posi.items():
|
||||||
|
params.append(param)
|
||||||
|
if self.input_params_nega is not None:
|
||||||
|
for _, param in self.input_params_nega.items():
|
||||||
|
params.append(param)
|
||||||
|
params = sorted(list(set(params)))
|
||||||
|
return params
|
||||||
|
|
||||||
|
def fetch_output_params(self):
|
||||||
|
params = []
|
||||||
|
if self.output_params is not None:
|
||||||
|
for param in self.output_params:
|
||||||
|
params.append(param)
|
||||||
|
return params
|
||||||
|
|
||||||
|
def process(self, pipe, **kwargs) -> dict:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def post_process(self, pipe, **kwargs) -> dict:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class BasePipeline(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device=get_device_type(), torch_dtype=torch.float16,
|
||||||
|
height_division_factor=64, width_division_factor=64,
|
||||||
|
time_division_factor=None, time_division_remainder=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# The device and torch_dtype is used for the storage of intermediate variables, not models.
|
||||||
|
self.device = device
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
self.device_type = parse_device_type(device)
|
||||||
|
# The following parameters are used for shape check.
|
||||||
|
self.height_division_factor = height_division_factor
|
||||||
|
self.width_division_factor = width_division_factor
|
||||||
|
self.time_division_factor = time_division_factor
|
||||||
|
self.time_division_remainder = time_division_remainder
|
||||||
|
# VRAM management
|
||||||
|
self.vram_management_enabled = False
|
||||||
|
# Pipeline Unit Runner
|
||||||
|
self.unit_runner = PipelineUnitRunner()
|
||||||
|
# LoRA Loader
|
||||||
|
self.lora_loader = GeneralLoRALoader
|
||||||
|
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||||
|
if device is not None:
|
||||||
|
self.device = device
|
||||||
|
if dtype is not None:
|
||||||
|
self.torch_dtype = dtype
|
||||||
|
super().to(*args, **kwargs)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def check_resize_height_width(self, height, width, num_frames=None):
|
||||||
|
# Shape check
|
||||||
|
if height % self.height_division_factor != 0:
|
||||||
|
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
||||||
|
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
|
||||||
|
if width % self.width_division_factor != 0:
|
||||||
|
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
||||||
|
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
|
||||||
|
if num_frames is None:
|
||||||
|
return height, width
|
||||||
|
else:
|
||||||
|
if num_frames % self.time_division_factor != self.time_division_remainder:
|
||||||
|
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
|
||||||
|
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
|
||||||
|
return height, width, num_frames
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
|
||||||
|
# Transform a PIL.Image to torch.Tensor
|
||||||
|
image = torch.Tensor(np.array(image, dtype=np.float32))
|
||||||
|
image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
||||||
|
image = image * ((max_value - min_value) / 255) + min_value
|
||||||
|
image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
|
||||||
|
# Transform a list of PIL.Image to torch.Tensor
|
||||||
|
video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
|
||||||
|
video = torch.stack(video, dim=pattern.index("T") // 2)
|
||||||
|
return video
|
||||||
|
|
||||||
|
|
||||||
|
def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
|
||||||
|
# Transform a torch.Tensor to PIL.Image
|
||||||
|
if pattern != "H W C":
|
||||||
|
vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
|
||||||
|
image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
|
||||||
|
image = image.to(device="cpu", dtype=torch.uint8)
|
||||||
|
image = Image.fromarray(image.numpy())
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
|
||||||
|
# Transform a torch.Tensor to list of PIL.Image
|
||||||
|
if pattern != "T H W C":
|
||||||
|
vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
|
||||||
|
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
|
||||||
|
return video
|
||||||
|
|
||||||
|
|
||||||
|
def load_models_to_device(self, model_names):
|
||||||
|
if self.vram_management_enabled:
|
||||||
|
# offload models
|
||||||
|
for name, model in self.named_children():
|
||||||
|
if name not in model_names:
|
||||||
|
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||||
|
if hasattr(model, "offload"):
|
||||||
|
model.offload()
|
||||||
|
else:
|
||||||
|
for module in model.modules():
|
||||||
|
if hasattr(module, "offload"):
|
||||||
|
module.offload()
|
||||||
|
getattr(torch, self.device_type).empty_cache()
|
||||||
|
# onload models
|
||||||
|
for name, model in self.named_children():
|
||||||
|
if name in model_names:
|
||||||
|
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||||
|
if hasattr(model, "onload"):
|
||||||
|
model.onload()
|
||||||
|
else:
|
||||||
|
for module in model.modules():
|
||||||
|
if hasattr(module, "onload"):
|
||||||
|
module.onload()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
|
||||||
|
# Initialize Gaussian noise
|
||||||
|
generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
|
||||||
|
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
|
||||||
|
noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
||||||
|
return noise
|
||||||
|
|
||||||
|
|
||||||
|
def get_vram(self):
|
||||||
|
device = self.device if not IS_NPU_AVAILABLE else get_device_name()
|
||||||
|
return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)
|
||||||
|
|
||||||
|
def get_module(self, model, name):
|
||||||
|
if "." in name:
|
||||||
|
name, suffix = name[:name.index(".")], name[name.index(".") + 1:]
|
||||||
|
if name.isdigit():
|
||||||
|
return self.get_module(model[int(name)], suffix)
|
||||||
|
else:
|
||||||
|
return self.get_module(getattr(model, name), suffix)
|
||||||
|
else:
|
||||||
|
return getattr(model, name)
|
||||||
|
|
||||||
|
def freeze_except(self, model_names):
|
||||||
|
self.eval()
|
||||||
|
self.requires_grad_(False)
|
||||||
|
for name in model_names:
|
||||||
|
module = self.get_module(self, name)
|
||||||
|
if module is None:
|
||||||
|
print(f"No {name} models in the pipeline. We cannot enable training on the model. If this occurs during the data processing stage, it is normal.")
|
||||||
|
continue
|
||||||
|
module.train()
|
||||||
|
module.requires_grad_(True)
|
||||||
|
|
||||||
|
|
||||||
|
def blend_with_mask(self, base, addition, mask):
|
||||||
|
return base * (1 - mask) + addition * mask
|
||||||
|
|
||||||
|
|
||||||
|
def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs):
|
||||||
|
timestep = scheduler.timesteps[progress_id]
|
||||||
|
if inpaint_mask is not None:
|
||||||
|
noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents)
|
||||||
|
noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask)
|
||||||
|
latents_next = scheduler.step(noise_pred, timestep, latents)
|
||||||
|
return latents_next
|
||||||
|
|
||||||
|
|
||||||
|
def split_pipeline_units(self, model_names: list[str]):
|
||||||
|
return PipelineUnitGraph().split_pipeline_units(self.units, model_names)
|
||||||
|
|
||||||
|
|
||||||
|
def flush_vram_management_device(self, device):
|
||||||
|
for module in self.modules():
|
||||||
|
if isinstance(module, AutoTorchModule):
|
||||||
|
module.offload_device = device
|
||||||
|
module.onload_device = device
|
||||||
|
module.preparing_device = device
|
||||||
|
module.computation_device = device
|
||||||
|
|
||||||
|
|
||||||
|
def load_lora(
|
||||||
|
self,
|
||||||
|
module: torch.nn.Module,
|
||||||
|
lora_config: Union[ModelConfig, str] = None,
|
||||||
|
alpha=1,
|
||||||
|
hotload=None,
|
||||||
|
state_dict=None,
|
||||||
|
verbose=1,
|
||||||
|
):
|
||||||
|
if state_dict is None:
|
||||||
|
if isinstance(lora_config, str):
|
||||||
|
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
|
||||||
|
else:
|
||||||
|
lora_config.download_if_necessary()
|
||||||
|
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||||
|
else:
|
||||||
|
lora = state_dict
|
||||||
|
lora_loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device)
|
||||||
|
lora = lora_loader.convert_state_dict(lora)
|
||||||
|
if hotload is None:
|
||||||
|
hotload = hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")
|
||||||
|
if hotload:
|
||||||
|
if not (hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")):
|
||||||
|
raise ValueError("VRAM Management is not enabled. LoRA hotloading is not supported.")
|
||||||
|
updated_num = 0
|
||||||
|
for _, module in module.named_modules():
|
||||||
|
if isinstance(module, AutoWrappedLinear):
|
||||||
|
name = module.name
|
||||||
|
lora_a_name = f'{name}.lora_A.weight'
|
||||||
|
lora_b_name = f'{name}.lora_B.weight'
|
||||||
|
if lora_a_name in lora and lora_b_name in lora:
|
||||||
|
updated_num += 1
|
||||||
|
module.lora_A_weights.append(lora[lora_a_name] * alpha)
|
||||||
|
module.lora_B_weights.append(lora[lora_b_name])
|
||||||
|
if verbose >= 1:
|
||||||
|
print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.")
|
||||||
|
else:
|
||||||
|
lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_lora(self, verbose=1):
|
||||||
|
cleared_num = 0
|
||||||
|
for name, module in self.named_modules():
|
||||||
|
if isinstance(module, AutoWrappedLinear):
|
||||||
|
if hasattr(module, "lora_A_weights"):
|
||||||
|
if len(module.lora_A_weights) > 0:
|
||||||
|
cleared_num += 1
|
||||||
|
module.lora_A_weights.clear()
|
||||||
|
if hasattr(module, "lora_B_weights"):
|
||||||
|
module.lora_B_weights.clear()
|
||||||
|
if verbose >= 1:
|
||||||
|
print(f"{cleared_num} LoRA layers are cleared.")
|
||||||
|
|
||||||
|
|
||||||
|
def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None):
|
||||||
|
model_pool = ModelPool()
|
||||||
|
for model_config in model_configs:
|
||||||
|
model_config.download_if_necessary()
|
||||||
|
vram_config = model_config.vram_config()
|
||||||
|
vram_config["computation_dtype"] = vram_config["computation_dtype"] or self.torch_dtype
|
||||||
|
vram_config["computation_device"] = vram_config["computation_device"] or self.device
|
||||||
|
model_pool.auto_load_model(
|
||||||
|
model_config.path,
|
||||||
|
vram_config=vram_config,
|
||||||
|
vram_limit=vram_limit,
|
||||||
|
clear_parameters=model_config.clear_parameters,
|
||||||
|
state_dict=model_config.state_dict,
|
||||||
|
)
|
||||||
|
return model_pool
|
||||||
|
|
||||||
|
|
||||||
|
def check_vram_management_state(self):
|
||||||
|
vram_management_enabled = False
|
||||||
|
for module in self.children():
|
||||||
|
if hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled"):
|
||||||
|
vram_management_enabled = True
|
||||||
|
return vram_management_enabled
|
||||||
|
|
||||||
|
|
||||||
|
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
|
||||||
|
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||||
|
self.clear_lora(verbose=0)
|
||||||
|
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
|
||||||
|
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
|
||||||
|
if cfg_scale != 1.0:
|
||||||
|
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||||
|
self.clear_lora(verbose=0)
|
||||||
|
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
||||||
|
if isinstance(noise_pred_posi, tuple):
|
||||||
|
# Separately handling different output types of latents, eg. video and audio latents.
|
||||||
|
noise_pred = tuple(
|
||||||
|
n_nega + cfg_scale * (n_posi - n_nega)
|
||||||
|
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
|
else:
|
||||||
|
noise_pred = noise_pred_posi
|
||||||
|
return noise_pred
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineUnitGraph:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def build_edges(self, units: list[PipelineUnit]):
|
||||||
|
# Establish dependencies between units
|
||||||
|
# to search for subsequent related computation units.
|
||||||
|
last_compute_unit_id = {}
|
||||||
|
edges = []
|
||||||
|
for unit_id, unit in enumerate(units):
|
||||||
|
for input_param in unit.fetch_input_params():
|
||||||
|
if input_param in last_compute_unit_id:
|
||||||
|
edges.append((last_compute_unit_id[input_param], unit_id))
|
||||||
|
for output_param in unit.fetch_output_params():
|
||||||
|
last_compute_unit_id[output_param] = unit_id
|
||||||
|
return edges
|
||||||
|
|
||||||
|
def build_chains(self, units: list[PipelineUnit]):
|
||||||
|
# Establish updating chains for each variable
|
||||||
|
# to track their computation process.
|
||||||
|
params = sum([unit.fetch_input_params() + unit.fetch_output_params() for unit in units], [])
|
||||||
|
params = sorted(list(set(params)))
|
||||||
|
chains = {param: [] for param in params}
|
||||||
|
for unit_id, unit in enumerate(units):
|
||||||
|
for param in unit.fetch_output_params():
|
||||||
|
chains[param].append(unit_id)
|
||||||
|
return chains
|
||||||
|
|
||||||
|
def search_direct_unit_ids(self, units: list[PipelineUnit], model_names: list[str]):
|
||||||
|
# Search for units that directly participate in the model's computation.
|
||||||
|
related_unit_ids = []
|
||||||
|
for unit_id, unit in enumerate(units):
|
||||||
|
for model_name in model_names:
|
||||||
|
if unit.onload_model_names is not None and model_name in unit.onload_model_names:
|
||||||
|
related_unit_ids.append(unit_id)
|
||||||
|
break
|
||||||
|
return related_unit_ids
|
||||||
|
|
||||||
|
def search_related_unit_ids(self, edges, start_unit_ids, direction="target"):
|
||||||
|
# Search for subsequent related computation units.
|
||||||
|
related_unit_ids = [unit_id for unit_id in start_unit_ids]
|
||||||
|
while True:
|
||||||
|
neighbors = []
|
||||||
|
for source, target in edges:
|
||||||
|
if direction == "target" and source in related_unit_ids and target not in related_unit_ids:
|
||||||
|
neighbors.append(target)
|
||||||
|
elif direction == "source" and source not in related_unit_ids and target in related_unit_ids:
|
||||||
|
neighbors.append(source)
|
||||||
|
neighbors = sorted(list(set(neighbors)))
|
||||||
|
if len(neighbors) == 0:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
related_unit_ids.extend(neighbors)
|
||||||
|
related_unit_ids = sorted(list(set(related_unit_ids)))
|
||||||
|
return related_unit_ids
|
||||||
|
|
||||||
|
def search_updating_unit_ids(self, units: list[PipelineUnit], chains, related_unit_ids):
|
||||||
|
# If the input parameters of this subgraph are updated outside the subgraph,
|
||||||
|
# search for the units where these updates occur.
|
||||||
|
first_compute_unit_id = {}
|
||||||
|
for unit_id in related_unit_ids:
|
||||||
|
for param in units[unit_id].fetch_input_params():
|
||||||
|
if param not in first_compute_unit_id:
|
||||||
|
first_compute_unit_id[param] = unit_id
|
||||||
|
updating_unit_ids = []
|
||||||
|
for param in first_compute_unit_id:
|
||||||
|
unit_id = first_compute_unit_id[param]
|
||||||
|
chain = chains[param]
|
||||||
|
if unit_id in chain and chain.index(unit_id) != len(chain) - 1:
|
||||||
|
for unit_id_ in chain[chain.index(unit_id) + 1:]:
|
||||||
|
if unit_id_ not in related_unit_ids:
|
||||||
|
updating_unit_ids.append(unit_id_)
|
||||||
|
related_unit_ids.extend(updating_unit_ids)
|
||||||
|
related_unit_ids = sorted(list(set(related_unit_ids)))
|
||||||
|
return related_unit_ids
|
||||||
|
|
||||||
|
def split_pipeline_units(self, units: list[PipelineUnit], model_names: list[str]):
|
||||||
|
# Split the computation graph,
|
||||||
|
# separating all model-related computations.
|
||||||
|
related_unit_ids = self.search_direct_unit_ids(units, model_names)
|
||||||
|
edges = self.build_edges(units)
|
||||||
|
chains = self.build_chains(units)
|
||||||
|
while True:
|
||||||
|
num_related_unit_ids = len(related_unit_ids)
|
||||||
|
related_unit_ids = self.search_related_unit_ids(edges, related_unit_ids, "target")
|
||||||
|
related_unit_ids = self.search_updating_unit_ids(units, chains, related_unit_ids)
|
||||||
|
if len(related_unit_ids) == num_related_unit_ids:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
num_related_unit_ids = len(related_unit_ids)
|
||||||
|
related_units = [units[i] for i in related_unit_ids]
|
||||||
|
unrelated_units = [units[i] for i in range(len(units)) if i not in related_unit_ids]
|
||||||
|
return related_units, unrelated_units
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineUnitRunner:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]:
|
||||||
|
if unit.take_over:
|
||||||
|
# Let the pipeline unit take over this function.
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega)
|
||||||
|
elif unit.seperate_cfg:
|
||||||
|
# Positive side
|
||||||
|
processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
|
||||||
|
if unit.input_params is not None:
|
||||||
|
for name in unit.input_params:
|
||||||
|
processor_inputs[name] = inputs_shared.get(name)
|
||||||
|
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||||
|
inputs_posi.update(processor_outputs)
|
||||||
|
# Negative side
|
||||||
|
if inputs_shared["cfg_scale"] != 1:
|
||||||
|
processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
|
||||||
|
if unit.input_params is not None:
|
||||||
|
for name in unit.input_params:
|
||||||
|
processor_inputs[name] = inputs_shared.get(name)
|
||||||
|
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||||
|
inputs_nega.update(processor_outputs)
|
||||||
|
else:
|
||||||
|
inputs_nega.update(processor_outputs)
|
||||||
|
else:
|
||||||
|
processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params}
|
||||||
|
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||||
|
inputs_shared.update(processor_outputs)
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
236
diffsynth/diffusion/flow_match.py
Normal file
236
diffsynth/diffusion/flow_match.py
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
import torch, math
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
||||||
|
class FlowMatchScheduler():
|
||||||
|
|
||||||
|
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning"] = "FLUX.1"):
|
||||||
|
self.set_timesteps_fn = {
|
||||||
|
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
|
||||||
|
"Wan": FlowMatchScheduler.set_timesteps_wan,
|
||||||
|
"Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image,
|
||||||
|
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
|
||||||
|
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
|
||||||
|
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
|
||||||
|
"Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
|
||||||
|
}.get(template, FlowMatchScheduler.set_timesteps_flux)
|
||||||
|
self.num_train_timesteps = 1000
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
||||||
|
sigma_min = 0.003/1.002
|
||||||
|
sigma_max = 1.0
|
||||||
|
shift = 3 if shift is None else shift
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
|
||||||
|
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_wan(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
||||||
|
sigma_min = 0.0
|
||||||
|
sigma_max = 1.0
|
||||||
|
shift = 5 if shift is None else shift
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||||
|
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _calculate_shift_qwen_image(image_seq_len, base_seq_len=256, max_seq_len=8192, base_shift=0.5, max_shift=0.9):
|
||||||
|
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||||
|
b = base_shift - m * base_seq_len
|
||||||
|
mu = image_seq_len * m + b
|
||||||
|
return mu
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_qwen_image(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
|
||||||
|
sigma_min = 0.0
|
||||||
|
sigma_max = 1.0
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
shift_terminal = 0.02
|
||||||
|
# Sigmas
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||||
|
# Mu
|
||||||
|
if exponential_shift_mu is not None:
|
||||||
|
mu = exponential_shift_mu
|
||||||
|
elif dynamic_shift_len is not None:
|
||||||
|
mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len)
|
||||||
|
else:
|
||||||
|
mu = 0.8
|
||||||
|
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
||||||
|
# Shift terminal
|
||||||
|
one_minus_z = 1 - sigmas
|
||||||
|
scale_factor = one_minus_z[-1] / (1 - shift_terminal)
|
||||||
|
sigmas = 1 - (one_minus_z / scale_factor)
|
||||||
|
# Timesteps
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_qwen_image_lightning(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
|
||||||
|
sigma_min = 0.0
|
||||||
|
sigma_max = 1.0
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
base_shift = math.log(3)
|
||||||
|
max_shift = math.log(3)
|
||||||
|
# Sigmas
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||||
|
# Mu
|
||||||
|
if exponential_shift_mu is not None:
|
||||||
|
mu = exponential_shift_mu
|
||||||
|
elif dynamic_shift_len is not None:
|
||||||
|
mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len, base_shift=base_shift, max_shift=max_shift)
|
||||||
|
else:
|
||||||
|
mu = 0.8
|
||||||
|
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
||||||
|
# Timesteps
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_empirical_mu(image_seq_len, num_steps):
|
||||||
|
a1, b1 = 8.73809524e-05, 1.89833333
|
||||||
|
a2, b2 = 0.00016927, 0.45666666
|
||||||
|
|
||||||
|
if image_seq_len > 4300:
|
||||||
|
mu = a2 * image_seq_len + b2
|
||||||
|
return float(mu)
|
||||||
|
|
||||||
|
m_200 = a2 * image_seq_len + b2
|
||||||
|
m_10 = a1 * image_seq_len + b1
|
||||||
|
|
||||||
|
a = (m_200 - m_10) / 190.0
|
||||||
|
b = m_200 - 200.0 * a
|
||||||
|
mu = a * num_steps + b
|
||||||
|
|
||||||
|
return float(mu)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None):
|
||||||
|
sigma_min = 1 / num_inference_steps
|
||||||
|
sigma_max = 1.0
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
|
||||||
|
if dynamic_shift_len is None:
|
||||||
|
# If you ask me why I set mu=0.8,
|
||||||
|
# I can only say that it yields better training results.
|
||||||
|
mu = 0.8
|
||||||
|
else:
|
||||||
|
mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps)
|
||||||
|
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
|
||||||
|
sigma_min = 0.0
|
||||||
|
sigma_max = 1.0
|
||||||
|
shift = 3 if shift is None else shift
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||||
|
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
if target_timesteps is not None:
|
||||||
|
target_timesteps = target_timesteps.to(dtype=timesteps.dtype, device=timesteps.device)
|
||||||
|
for timestep in target_timesteps:
|
||||||
|
timestep_id = torch.argmin((timesteps - timestep).abs())
|
||||||
|
timesteps[timestep_id] = timestep
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None):
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
if special_case == "stage2":
|
||||||
|
sigmas = torch.Tensor([0.909375, 0.725, 0.421875])
|
||||||
|
elif special_case == "ditilled_stage1":
|
||||||
|
sigmas = torch.Tensor([1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875])
|
||||||
|
else:
|
||||||
|
dynamic_shift_len = dynamic_shift_len or 4096
|
||||||
|
sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image(
|
||||||
|
image_seq_len=dynamic_shift_len,
|
||||||
|
base_seq_len=1024,
|
||||||
|
max_seq_len=4096,
|
||||||
|
base_shift=0.95,
|
||||||
|
max_shift=2.05,
|
||||||
|
)
|
||||||
|
sigma_min = 0.0
|
||||||
|
sigma_max = 1.0
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||||
|
sigmas = math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1))
|
||||||
|
# Shift terminal
|
||||||
|
one_minus_z = 1.0 - sigmas
|
||||||
|
scale_factor = one_minus_z[-1] / (1 - terminal)
|
||||||
|
sigmas = 1.0 - (one_minus_z / scale_factor)
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
def set_training_weight(self):
|
||||||
|
steps = 1000
|
||||||
|
x = self.timesteps
|
||||||
|
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
|
||||||
|
y_shifted = y - y.min()
|
||||||
|
bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
|
||||||
|
if len(self.timesteps) != 1000:
|
||||||
|
# This is an empirical formula.
|
||||||
|
bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
|
||||||
|
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
|
||||||
|
self.linear_timesteps_weights = bsmntw_weighing
|
||||||
|
|
||||||
|
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
|
||||||
|
self.sigmas, self.timesteps = self.set_timesteps_fn(
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
denoising_strength=denoising_strength,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if training:
|
||||||
|
self.set_training_weight()
|
||||||
|
self.training = True
|
||||||
|
else:
|
||||||
|
self.training = False
|
||||||
|
|
||||||
|
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
|
||||||
|
if isinstance(timestep, torch.Tensor):
|
||||||
|
timestep = timestep.cpu()
|
||||||
|
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||||
|
sigma = self.sigmas[timestep_id]
|
||||||
|
if to_final or timestep_id + 1 >= len(self.timesteps):
|
||||||
|
sigma_ = 0
|
||||||
|
else:
|
||||||
|
sigma_ = self.sigmas[timestep_id + 1]
|
||||||
|
prev_sample = sample + model_output * (sigma_ - sigma)
|
||||||
|
return prev_sample
|
||||||
|
|
||||||
|
def return_to_timestep(self, timestep, sample, sample_stablized):
|
||||||
|
if isinstance(timestep, torch.Tensor):
|
||||||
|
timestep = timestep.cpu()
|
||||||
|
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||||
|
sigma = self.sigmas[timestep_id]
|
||||||
|
model_output = (sample - sample_stablized) / sigma
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
def add_noise(self, original_samples, noise, timestep):
|
||||||
|
if isinstance(timestep, torch.Tensor):
|
||||||
|
timestep = timestep.cpu()
|
||||||
|
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||||
|
sigma = self.sigmas[timestep_id]
|
||||||
|
sample = (1 - sigma) * original_samples + sigma * noise
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def training_target(self, sample, noise, timestep):
|
||||||
|
target = noise - sample
|
||||||
|
return target
|
||||||
|
|
||||||
|
def training_weight(self, timestep):
|
||||||
|
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
|
||||||
|
weights = self.linear_timesteps_weights[timestep_id]
|
||||||
|
return weights
|
||||||
43
diffsynth/diffusion/logger.py
Normal file
43
diffsynth/diffusion/logger.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import os, torch
|
||||||
|
from accelerate import Accelerator
|
||||||
|
|
||||||
|
|
||||||
|
class ModelLogger:
|
||||||
|
def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x):
|
||||||
|
self.output_path = output_path
|
||||||
|
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
|
||||||
|
self.state_dict_converter = state_dict_converter
|
||||||
|
self.num_steps = 0
|
||||||
|
|
||||||
|
|
||||||
|
def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None, **kwargs):
|
||||||
|
self.num_steps += 1
|
||||||
|
if save_steps is not None and self.num_steps % save_steps == 0:
|
||||||
|
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id):
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
state_dict = accelerator.get_state_dict(model)
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||||
|
state_dict = self.state_dict_converter(state_dict)
|
||||||
|
os.makedirs(self.output_path, exist_ok=True)
|
||||||
|
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
|
||||||
|
accelerator.save(state_dict, path, safe_serialization=True)
|
||||||
|
|
||||||
|
|
||||||
|
def on_training_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None):
|
||||||
|
if save_steps is not None and self.num_steps % save_steps != 0:
|
||||||
|
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name):
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
state_dict = accelerator.get_state_dict(model)
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||||
|
state_dict = self.state_dict_converter(state_dict)
|
||||||
|
os.makedirs(self.output_path, exist_ok=True)
|
||||||
|
path = os.path.join(self.output_path, file_name)
|
||||||
|
accelerator.save(state_dict, path, safe_serialization=True)
|
||||||
156
diffsynth/diffusion/loss.py
Normal file
156
diffsynth/diffusion/loss.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
from .base_pipeline import BasePipeline
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
||||||
|
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
|
||||||
|
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
|
||||||
|
|
||||||
|
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
||||||
|
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
|
||||||
|
noise = torch.randn_like(inputs["input_latents"])
|
||||||
|
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||||
|
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||||
|
|
||||||
|
if "first_frame_latents" in inputs:
|
||||||
|
inputs["latents"][:, :, 0:1] = inputs["first_frame_latents"]
|
||||||
|
|
||||||
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||||
|
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
|
||||||
|
|
||||||
|
if "first_frame_latents" in inputs:
|
||||||
|
noise_pred = noise_pred[:, :, 1:]
|
||||||
|
training_target = training_target[:, :, 1:]
|
||||||
|
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||||
|
loss = loss * pipe.scheduler.training_weight(timestep)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def FlowMatchSFTAudioVideoLoss(pipe: BasePipeline, **inputs):
|
||||||
|
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
|
||||||
|
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
|
||||||
|
|
||||||
|
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
||||||
|
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
|
||||||
|
# video
|
||||||
|
noise = torch.randn_like(inputs["input_latents"])
|
||||||
|
inputs["video_latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||||
|
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||||
|
|
||||||
|
# audio
|
||||||
|
if inputs.get("audio_input_latents") is not None:
|
||||||
|
audio_noise = torch.randn_like(inputs["audio_input_latents"])
|
||||||
|
inputs["audio_latents"] = pipe.scheduler.add_noise(inputs["audio_input_latents"], audio_noise, timestep)
|
||||||
|
training_target_audio = pipe.scheduler.training_target(inputs["audio_input_latents"], audio_noise, timestep)
|
||||||
|
|
||||||
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||||
|
noise_pred, noise_pred_audio = pipe.model_fn(**models, **inputs, timestep=timestep)
|
||||||
|
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||||
|
loss = loss * pipe.scheduler.training_weight(timestep)
|
||||||
|
if inputs.get("audio_input_latents") is not None:
|
||||||
|
loss_audio = torch.nn.functional.mse_loss(noise_pred_audio.float(), training_target_audio.float())
|
||||||
|
loss_audio = loss_audio * pipe.scheduler.training_weight(timestep)
|
||||||
|
loss = loss + loss_audio
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def DirectDistillLoss(pipe: BasePipeline, **inputs):
|
||||||
|
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
|
||||||
|
pipe.scheduler.training = True
|
||||||
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)
|
||||||
|
inputs["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs)
|
||||||
|
loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float())
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class TrajectoryImitationLoss(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.initialized = False
|
||||||
|
|
||||||
|
def initialize(self, device):
|
||||||
|
import lpips # TODO: remove it
|
||||||
|
self.loss_fn = lpips.LPIPS(net='alex').to(device)
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
def fetch_trajectory(self, pipe: BasePipeline, timesteps_student, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
||||||
|
trajectory = [inputs_shared["latents"].clone()]
|
||||||
|
|
||||||
|
pipe.scheduler.set_timesteps(num_inference_steps, target_timesteps=timesteps_student)
|
||||||
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
noise_pred = pipe.cfg_guided_model_fn(
|
||||||
|
pipe.model_fn, cfg_scale,
|
||||||
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
|
||||||
|
|
||||||
|
trajectory.append(inputs_shared["latents"].clone())
|
||||||
|
return pipe.scheduler.timesteps, trajectory
|
||||||
|
|
||||||
|
def align_trajectory(self, pipe: BasePipeline, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
||||||
|
loss = 0
|
||||||
|
pipe.scheduler.set_timesteps(num_inference_steps, training=True)
|
||||||
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
|
||||||
|
progress_id_teacher = torch.argmin((timesteps_teacher - timestep).abs())
|
||||||
|
inputs_shared["latents"] = trajectory_teacher[progress_id_teacher]
|
||||||
|
|
||||||
|
noise_pred = pipe.cfg_guided_model_fn(
|
||||||
|
pipe.model_fn, cfg_scale,
|
||||||
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
|
||||||
|
sigma = pipe.scheduler.sigmas[progress_id]
|
||||||
|
sigma_ = 0 if progress_id + 1 >= len(pipe.scheduler.timesteps) else pipe.scheduler.sigmas[progress_id + 1]
|
||||||
|
if progress_id + 1 >= len(pipe.scheduler.timesteps):
|
||||||
|
latents_ = trajectory_teacher[-1]
|
||||||
|
else:
|
||||||
|
progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs())
|
||||||
|
latents_ = trajectory_teacher[progress_id_teacher]
|
||||||
|
|
||||||
|
target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma)
|
||||||
|
loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def compute_regularization(self, pipe: BasePipeline, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
||||||
|
inputs_shared["latents"] = trajectory_teacher[0]
|
||||||
|
pipe.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
noise_pred = pipe.cfg_guided_model_fn(
|
||||||
|
pipe.model_fn, cfg_scale,
|
||||||
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
|
||||||
|
|
||||||
|
image_pred = pipe.vae_decoder(inputs_shared["latents"])
|
||||||
|
image_real = pipe.vae_decoder(trajectory_teacher[-1])
|
||||||
|
loss = self.loss_fn(image_pred.float(), image_real.float())
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def forward(self, pipe: BasePipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||||
|
if not self.initialized:
|
||||||
|
self.initialize(pipe.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
pipe.scheduler.set_timesteps(8)
|
||||||
|
timesteps_teacher, trajectory_teacher = self.fetch_trajectory(inputs_shared["teacher"], pipe.scheduler.timesteps, inputs_shared, inputs_posi, inputs_nega, 50, 2)
|
||||||
|
timesteps_teacher = timesteps_teacher.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
loss_1 = self.align_trajectory(pipe, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
|
||||||
|
loss_2 = self.compute_regularization(pipe, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
|
||||||
|
loss = loss_1 + loss_2
|
||||||
|
return loss
|
||||||
70
diffsynth/diffusion/parsers.py
Normal file
70
diffsynth/diffusion/parsers.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
def add_dataset_base_config(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
|
||||||
|
parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
|
||||||
|
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
|
||||||
|
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
|
||||||
|
parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def add_image_size_config(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||||
|
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||||
|
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def add_video_size_config(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||||
|
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||||
|
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
|
||||||
|
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def add_model_config(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
|
||||||
|
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
|
||||||
|
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
||||||
|
parser.add_argument("--fp8_models", default=None, help="Models with FP8 precision, comma-separated.")
|
||||||
|
parser.add_argument("--offload_models", default=None, help="Models with offload, comma-separated. Only used in splited training.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def add_training_config(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
|
||||||
|
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
|
||||||
|
parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
|
||||||
|
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
|
||||||
|
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
|
||||||
|
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def add_output_config(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
|
||||||
|
parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
|
||||||
|
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def add_lora_config(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
|
||||||
|
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
|
||||||
|
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
|
||||||
|
parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.")
|
||||||
|
parser.add_argument("--preset_lora_path", type=str, default=None, help="Path to the preset LoRA checkpoint. If provided, this LoRA will be fused to the base model.")
|
||||||
|
parser.add_argument("--preset_lora_model", type=str, default=None, help="Which model the preset LoRA is fused to.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def add_gradient_config(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
|
||||||
|
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
||||||
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def add_general_config(parser: argparse.ArgumentParser):
|
||||||
|
parser = add_dataset_base_config(parser)
|
||||||
|
parser = add_model_config(parser)
|
||||||
|
parser = add_training_config(parser)
|
||||||
|
parser = add_output_config(parser)
|
||||||
|
parser = add_lora_config(parser)
|
||||||
|
parser = add_gradient_config(parser)
|
||||||
|
return parser
|
||||||
72
diffsynth/diffusion/runner.py
Normal file
72
diffsynth/diffusion/runner.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import os, torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from .training_module import DiffusionTrainingModule
|
||||||
|
from .logger import ModelLogger
|
||||||
|
|
||||||
|
|
||||||
|
def launch_training_task(
|
||||||
|
accelerator: Accelerator,
|
||||||
|
dataset: torch.utils.data.Dataset,
|
||||||
|
model: DiffusionTrainingModule,
|
||||||
|
model_logger: ModelLogger,
|
||||||
|
learning_rate: float = 1e-5,
|
||||||
|
weight_decay: float = 1e-2,
|
||||||
|
num_workers: int = 1,
|
||||||
|
save_steps: int = None,
|
||||||
|
num_epochs: int = 1,
|
||||||
|
args = None,
|
||||||
|
):
|
||||||
|
if args is not None:
|
||||||
|
learning_rate = args.learning_rate
|
||||||
|
weight_decay = args.weight_decay
|
||||||
|
num_workers = args.dataset_num_workers
|
||||||
|
save_steps = args.save_steps
|
||||||
|
num_epochs = args.num_epochs
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
|
||||||
|
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||||
|
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||||
|
model.to(device=accelerator.device)
|
||||||
|
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
||||||
|
|
||||||
|
for epoch_id in range(num_epochs):
|
||||||
|
for data in tqdm(dataloader):
|
||||||
|
with accelerator.accumulate(model):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
if dataset.load_from_cache:
|
||||||
|
loss = model({}, inputs=data)
|
||||||
|
else:
|
||||||
|
loss = model(data)
|
||||||
|
accelerator.backward(loss)
|
||||||
|
optimizer.step()
|
||||||
|
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
|
||||||
|
scheduler.step()
|
||||||
|
if save_steps is None:
|
||||||
|
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||||
|
model_logger.on_training_end(accelerator, model, save_steps)
|
||||||
|
|
||||||
|
|
||||||
|
def launch_data_process_task(
|
||||||
|
accelerator: Accelerator,
|
||||||
|
dataset: torch.utils.data.Dataset,
|
||||||
|
model: DiffusionTrainingModule,
|
||||||
|
model_logger: ModelLogger,
|
||||||
|
num_workers: int = 8,
|
||||||
|
args = None,
|
||||||
|
):
|
||||||
|
if args is not None:
|
||||||
|
num_workers = args.dataset_num_workers
|
||||||
|
|
||||||
|
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||||
|
model.to(device=accelerator.device)
|
||||||
|
model, dataloader = accelerator.prepare(model, dataloader)
|
||||||
|
|
||||||
|
for data_id, data in enumerate(tqdm(dataloader)):
|
||||||
|
with accelerator.accumulate(model):
|
||||||
|
with torch.no_grad():
|
||||||
|
folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
|
||||||
|
os.makedirs(folder, exist_ok=True)
|
||||||
|
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
|
||||||
|
data = model(data)
|
||||||
|
torch.save(data, save_path)
|
||||||
263
diffsynth/diffusion/training_module.py
Normal file
263
diffsynth/diffusion/training_module.py
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
import torch, json, os
|
||||||
|
from ..core import ModelConfig, load_state_dict
|
||||||
|
from ..utils.controlnet import ControlNetInput
|
||||||
|
from peft import LoraConfig, inject_adapter_in_model
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionTrainingModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
for name, model in self.named_children():
|
||||||
|
model.to(*args, **kwargs)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def trainable_modules(self):
|
||||||
|
trainable_modules = filter(lambda p: p.requires_grad, self.parameters())
|
||||||
|
return trainable_modules
|
||||||
|
|
||||||
|
|
||||||
|
def trainable_param_names(self):
|
||||||
|
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters()))
|
||||||
|
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||||
|
return trainable_param_names
|
||||||
|
|
||||||
|
|
||||||
|
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None):
|
||||||
|
if lora_alpha is None:
|
||||||
|
lora_alpha = lora_rank
|
||||||
|
if isinstance(target_modules, list) and len(target_modules) == 1:
|
||||||
|
target_modules = target_modules[0]
|
||||||
|
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
|
||||||
|
model = inject_adapter_in_model(lora_config, model)
|
||||||
|
if upcast_dtype is not None:
|
||||||
|
for param in model.parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
param.data = param.to(upcast_dtype)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def mapping_lora_state_dict(self, state_dict):
|
||||||
|
new_state_dict = {}
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if "lora_A.weight" in key or "lora_B.weight" in key:
|
||||||
|
new_key = key.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight")
|
||||||
|
new_state_dict[new_key] = value
|
||||||
|
elif "lora_A.default.weight" in key or "lora_B.default.weight" in key:
|
||||||
|
new_state_dict[key] = value
|
||||||
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def export_trainable_state_dict(self, state_dict, remove_prefix=None):
|
||||||
|
trainable_param_names = self.trainable_param_names()
|
||||||
|
state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names}
|
||||||
|
if remove_prefix is not None:
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name.startswith(remove_prefix):
|
||||||
|
name = name[len(remove_prefix):]
|
||||||
|
state_dict_[name] = param
|
||||||
|
state_dict = state_dict_
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def transfer_data_to_device(self, data, device, torch_float_dtype=None):
|
||||||
|
if data is None:
|
||||||
|
return data
|
||||||
|
elif isinstance(data, torch.Tensor):
|
||||||
|
data = data.to(device)
|
||||||
|
if torch_float_dtype is not None and data.dtype in [torch.float, torch.float16, torch.bfloat16]:
|
||||||
|
data = data.to(torch_float_dtype)
|
||||||
|
return data
|
||||||
|
elif isinstance(data, tuple):
|
||||||
|
data = tuple(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
|
||||||
|
return data
|
||||||
|
elif isinstance(data, list):
|
||||||
|
data = list(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
|
||||||
|
return data
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
data = {i: self.transfer_data_to_device(data[i], device, torch_float_dtype) for i in data}
|
||||||
|
return data
|
||||||
|
else:
|
||||||
|
return data
|
||||||
|
|
||||||
|
def parse_vram_config(self, fp8=False, offload=False, device="cpu"):
|
||||||
|
if fp8:
|
||||||
|
return {
|
||||||
|
"offload_dtype": torch.float8_e4m3fn,
|
||||||
|
"offload_device": device,
|
||||||
|
"onload_dtype": torch.float8_e4m3fn,
|
||||||
|
"onload_device": device,
|
||||||
|
"preparing_dtype": torch.float8_e4m3fn,
|
||||||
|
"preparing_device": device,
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": device,
|
||||||
|
}
|
||||||
|
elif offload:
|
||||||
|
return {
|
||||||
|
"offload_dtype": "disk",
|
||||||
|
"offload_device": "disk",
|
||||||
|
"onload_dtype": "disk",
|
||||||
|
"onload_device": "disk",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": device,
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": device,
|
||||||
|
"clear_parameters": True,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def parse_model_configs(self, model_paths, model_id_with_origin_paths, fp8_models=None, offload_models=None, device="cpu"):
|
||||||
|
fp8_models = [] if fp8_models is None else fp8_models.split(",")
|
||||||
|
offload_models = [] if offload_models is None else offload_models.split(",")
|
||||||
|
model_configs = []
|
||||||
|
if model_paths is not None:
|
||||||
|
model_paths = json.loads(model_paths)
|
||||||
|
for path in model_paths:
|
||||||
|
vram_config = self.parse_vram_config(
|
||||||
|
fp8=path in fp8_models,
|
||||||
|
offload=path in offload_models,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
model_configs.append(ModelConfig(path=path, **vram_config))
|
||||||
|
if model_id_with_origin_paths is not None:
|
||||||
|
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
||||||
|
for model_id_with_origin_path in model_id_with_origin_paths:
|
||||||
|
vram_config = self.parse_vram_config(
|
||||||
|
fp8=model_id_with_origin_path in fp8_models,
|
||||||
|
offload=model_id_with_origin_path in offload_models,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
config = self.parse_path_or_model_id(model_id_with_origin_path)
|
||||||
|
model_configs.append(ModelConfig(model_id=config.model_id, origin_file_pattern=config.origin_file_pattern, **vram_config))
|
||||||
|
return model_configs
|
||||||
|
|
||||||
|
|
||||||
|
def parse_path_or_model_id(self, model_id_with_origin_path, default_value=None):
|
||||||
|
if model_id_with_origin_path is None:
|
||||||
|
return default_value
|
||||||
|
elif os.path.exists(model_id_with_origin_path):
|
||||||
|
return ModelConfig(path=model_id_with_origin_path)
|
||||||
|
else:
|
||||||
|
if ":" not in model_id_with_origin_path:
|
||||||
|
raise ValueError(f"Failed to parse model config: {model_id_with_origin_path}. This is neither a valid path nor in the format of `model_id/origin_file_pattern`.")
|
||||||
|
split_id = model_id_with_origin_path.rfind(":")
|
||||||
|
model_id = model_id_with_origin_path[:split_id]
|
||||||
|
origin_file_pattern = model_id_with_origin_path[split_id + 1:]
|
||||||
|
return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern)
|
||||||
|
|
||||||
|
|
||||||
|
def auto_detect_lora_target_modules(
|
||||||
|
self,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
search_for_linear=False,
|
||||||
|
linear_detector=lambda x: min(x.weight.shape) >= 512,
|
||||||
|
block_list_detector=lambda x: isinstance(x, torch.nn.ModuleList) and len(x) > 1,
|
||||||
|
name_prefix="",
|
||||||
|
):
|
||||||
|
lora_target_modules = []
|
||||||
|
if search_for_linear:
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
module_name = name_prefix + ["", "."][name_prefix != ""] + name
|
||||||
|
if isinstance(module, torch.nn.Linear) and linear_detector(module):
|
||||||
|
lora_target_modules.append(module_name)
|
||||||
|
else:
|
||||||
|
for name, module in model.named_children():
|
||||||
|
module_name = name_prefix + ["", "."][name_prefix != ""] + name
|
||||||
|
lora_target_modules += self.auto_detect_lora_target_modules(
|
||||||
|
module,
|
||||||
|
search_for_linear=block_list_detector(module),
|
||||||
|
linear_detector=linear_detector,
|
||||||
|
block_list_detector=block_list_detector,
|
||||||
|
name_prefix=module_name,
|
||||||
|
)
|
||||||
|
return lora_target_modules
|
||||||
|
|
||||||
|
|
||||||
|
def parse_lora_target_modules(self, model, lora_target_modules):
|
||||||
|
if lora_target_modules == "":
|
||||||
|
print("No LoRA target modules specified. The framework will automatically search for them.")
|
||||||
|
lora_target_modules = self.auto_detect_lora_target_modules(model)
|
||||||
|
print(f"LoRA will be patched at {lora_target_modules}.")
|
||||||
|
else:
|
||||||
|
lora_target_modules = lora_target_modules.split(",")
|
||||||
|
return lora_target_modules
|
||||||
|
|
||||||
|
|
||||||
|
def switch_pipe_to_training_mode(
|
||||||
|
self,
|
||||||
|
pipe,
|
||||||
|
trainable_models=None,
|
||||||
|
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||||
|
preset_lora_path=None, preset_lora_model=None,
|
||||||
|
task="sft",
|
||||||
|
):
|
||||||
|
# Scheduler
|
||||||
|
pipe.scheduler.set_timesteps(1000, training=True)
|
||||||
|
|
||||||
|
# Freeze untrainable models
|
||||||
|
pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
|
||||||
|
|
||||||
|
# Preset LoRA
|
||||||
|
if preset_lora_path is not None:
|
||||||
|
pipe.load_lora(getattr(pipe, preset_lora_model), preset_lora_path)
|
||||||
|
|
||||||
|
# FP8
|
||||||
|
# FP8 relies on a model-specific memory management scheme.
|
||||||
|
# It is delegated to the subclass.
|
||||||
|
|
||||||
|
# Add LoRA to the base models
|
||||||
|
if lora_base_model is not None and not task.endswith(":data_process"):
|
||||||
|
if (not hasattr(pipe, lora_base_model)) or getattr(pipe, lora_base_model) is None:
|
||||||
|
print(f"No {lora_base_model} models in the pipeline. We cannot patch LoRA on the model. If this occurs during the data processing stage, it is normal.")
|
||||||
|
return
|
||||||
|
model = self.add_lora_to_model(
|
||||||
|
getattr(pipe, lora_base_model),
|
||||||
|
target_modules=self.parse_lora_target_modules(getattr(pipe, lora_base_model), lora_target_modules),
|
||||||
|
lora_rank=lora_rank,
|
||||||
|
upcast_dtype=pipe.torch_dtype,
|
||||||
|
)
|
||||||
|
if lora_checkpoint is not None:
|
||||||
|
state_dict = load_state_dict(lora_checkpoint)
|
||||||
|
state_dict = self.mapping_lora_state_dict(state_dict)
|
||||||
|
load_result = model.load_state_dict(state_dict, strict=False)
|
||||||
|
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
|
||||||
|
if len(load_result[1]) > 0:
|
||||||
|
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
||||||
|
setattr(pipe, lora_base_model, model)
|
||||||
|
|
||||||
|
|
||||||
|
def split_pipeline_units(self, task, pipe, trainable_models=None, lora_base_model=None):
|
||||||
|
models_require_backward = []
|
||||||
|
if trainable_models is not None:
|
||||||
|
models_require_backward += trainable_models.split(",")
|
||||||
|
if lora_base_model is not None:
|
||||||
|
models_require_backward += [lora_base_model]
|
||||||
|
if task.endswith(":data_process"):
|
||||||
|
_, pipe.units = pipe.split_pipeline_units(models_require_backward)
|
||||||
|
elif task.endswith(":train"):
|
||||||
|
pipe.units, _ = pipe.split_pipeline_units(models_require_backward)
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
def parse_extra_inputs(self, data, extra_inputs, inputs_shared):
|
||||||
|
controlnet_keys_map = (
|
||||||
|
("blockwise_controlnet_", "blockwise_controlnet_inputs",),
|
||||||
|
("controlnet_", "controlnet_inputs"),
|
||||||
|
)
|
||||||
|
controlnet_inputs = {}
|
||||||
|
for extra_input in extra_inputs:
|
||||||
|
for prefix, name in controlnet_keys_map:
|
||||||
|
if extra_input.startswith(prefix):
|
||||||
|
if name not in controlnet_inputs:
|
||||||
|
controlnet_inputs[name] = {}
|
||||||
|
controlnet_inputs[name][extra_input.replace(prefix, "")] = data[extra_input]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
inputs_shared[extra_input] = data[extra_input]
|
||||||
|
for name, params in controlnet_inputs.items():
|
||||||
|
inputs_shared[name] = [ControlNetInput(**params)]
|
||||||
|
return inputs_shared
|
||||||
@@ -1,118 +0,0 @@
|
|||||||
import torch
|
|
||||||
from einops import repeat
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualDenseBlock(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, num_feat=64, num_grow_ch=32):
|
|
||||||
super(ResidualDenseBlock, self).__init__()
|
|
||||||
self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
|
|
||||||
self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
|
|
||||||
self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
|
||||||
self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
|
||||||
self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
|
|
||||||
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x1 = self.lrelu(self.conv1(x))
|
|
||||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
|
||||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
|
||||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
|
||||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
|
||||||
return x5 * 0.2 + x
|
|
||||||
|
|
||||||
|
|
||||||
class RRDB(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, num_feat, num_grow_ch=32):
|
|
||||||
super(RRDB, self).__init__()
|
|
||||||
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
|
|
||||||
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
|
|
||||||
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = self.rdb1(x)
|
|
||||||
out = self.rdb2(out)
|
|
||||||
out = self.rdb3(out)
|
|
||||||
return out * 0.2 + x
|
|
||||||
|
|
||||||
|
|
||||||
class RRDBNet(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32):
|
|
||||||
super(RRDBNet, self).__init__()
|
|
||||||
self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
|
||||||
self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
|
|
||||||
self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
|
||||||
# upsample
|
|
||||||
self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
|
||||||
self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
|
||||||
self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
|
||||||
self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
|
||||||
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
feat = x
|
|
||||||
feat = self.conv_first(feat)
|
|
||||||
body_feat = self.conv_body(self.body(feat))
|
|
||||||
feat = feat + body_feat
|
|
||||||
# upsample
|
|
||||||
feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
|
|
||||||
feat = self.lrelu(self.conv_up1(feat))
|
|
||||||
feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
|
|
||||||
feat = self.lrelu(self.conv_up2(feat))
|
|
||||||
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class ESRGAN(torch.nn.Module):
|
|
||||||
def __init__(self, model):
|
|
||||||
super().__init__()
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_pretrained(model_path):
|
|
||||||
model = RRDBNet()
|
|
||||||
state_dict = torch.load(model_path, map_location="cpu")["params_ema"]
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
model.eval()
|
|
||||||
return ESRGAN(model)
|
|
||||||
|
|
||||||
def process_image(self, image):
|
|
||||||
image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
|
|
||||||
return image
|
|
||||||
|
|
||||||
def process_images(self, images):
|
|
||||||
images = [self.process_image(image) for image in images]
|
|
||||||
images = torch.stack(images)
|
|
||||||
return images
|
|
||||||
|
|
||||||
def decode_images(self, images):
|
|
||||||
images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
|
|
||||||
images = [Image.fromarray(image) for image in images]
|
|
||||||
return images
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
|
|
||||||
# Preprocess
|
|
||||||
input_tensor = self.process_images(images)
|
|
||||||
|
|
||||||
# Interpolate
|
|
||||||
output_tensor = []
|
|
||||||
for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
|
|
||||||
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
|
||||||
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
|
||||||
batch_input_tensor = batch_input_tensor.to(
|
|
||||||
device=self.model.conv_first.weight.device,
|
|
||||||
dtype=self.model.conv_first.weight.dtype)
|
|
||||||
batch_output_tensor = self.model(batch_input_tensor)
|
|
||||||
output_tensor.append(batch_output_tensor.cpu())
|
|
||||||
|
|
||||||
# Output
|
|
||||||
output_tensor = torch.concat(output_tensor, dim=0)
|
|
||||||
|
|
||||||
# To images
|
|
||||||
output_images = self.decode_images(output_tensor)
|
|
||||||
return output_images
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
from .runners.fast import TableManager, PyramidPatchMatcher
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
|
||||||
import cupy as cp
|
|
||||||
|
|
||||||
|
|
||||||
class FastBlendSmoother:
|
|
||||||
def __init__(self):
|
|
||||||
self.batch_size = 8
|
|
||||||
self.window_size = 64
|
|
||||||
self.ebsynth_config = {
|
|
||||||
"minimum_patch_size": 5,
|
|
||||||
"threads_per_block": 8,
|
|
||||||
"num_iter": 5,
|
|
||||||
"gpu_id": 0,
|
|
||||||
"guide_weight": 10.0,
|
|
||||||
"initialize": "identity",
|
|
||||||
"tracking_window_size": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_model_manager(model_manager):
|
|
||||||
# TODO: fetch GPU ID from model_manager
|
|
||||||
return FastBlendSmoother()
|
|
||||||
|
|
||||||
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config):
|
|
||||||
frames_guide = [np.array(frame) for frame in frames_guide]
|
|
||||||
frames_style = [np.array(frame) for frame in frames_style]
|
|
||||||
table_manager = TableManager()
|
|
||||||
patch_match_engine = PyramidPatchMatcher(
|
|
||||||
image_height=frames_style[0].shape[0],
|
|
||||||
image_width=frames_style[0].shape[1],
|
|
||||||
channel=3,
|
|
||||||
**ebsynth_config
|
|
||||||
)
|
|
||||||
# left part
|
|
||||||
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4")
|
|
||||||
table_l = table_manager.remapping_table_to_blending_table(table_l)
|
|
||||||
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4")
|
|
||||||
# right part
|
|
||||||
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4")
|
|
||||||
table_r = table_manager.remapping_table_to_blending_table(table_r)
|
|
||||||
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1]
|
|
||||||
# merge
|
|
||||||
frames = []
|
|
||||||
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
|
|
||||||
weight_m = -1
|
|
||||||
weight = weight_l + weight_m + weight_r
|
|
||||||
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
|
|
||||||
frames.append(frame)
|
|
||||||
frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames]
|
|
||||||
return frames
|
|
||||||
|
|
||||||
def __call__(self, rendered_frames, original_frames=None, **kwargs):
|
|
||||||
frames = self.run(
|
|
||||||
original_frames, rendered_frames,
|
|
||||||
self.batch_size, self.window_size, self.ebsynth_config
|
|
||||||
)
|
|
||||||
mempool = cp.get_default_memory_pool()
|
|
||||||
pinned_mempool = cp.get_default_pinned_memory_pool()
|
|
||||||
mempool.free_all_blocks()
|
|
||||||
pinned_mempool.free_all_blocks()
|
|
||||||
return frames
|
|
||||||
@@ -1,397 +0,0 @@
|
|||||||
from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeSingleFrameRunner
|
|
||||||
from .data import VideoData, get_video_fps, save_video, search_for_images
|
|
||||||
import os
|
|
||||||
import gradio as gr
|
|
||||||
|
|
||||||
|
|
||||||
def check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder):
|
|
||||||
frames_guide = VideoData(video_guide, video_guide_folder)
|
|
||||||
frames_style = VideoData(video_style, video_style_folder)
|
|
||||||
message = ""
|
|
||||||
if len(frames_guide) < len(frames_style):
|
|
||||||
message += f"The number of frames mismatches. Only the first {len(frames_guide)} frames of style video will be used.\n"
|
|
||||||
frames_style.set_length(len(frames_guide))
|
|
||||||
elif len(frames_guide) > len(frames_style):
|
|
||||||
message += f"The number of frames mismatches. Only the first {len(frames_style)} frames of guide video will be used.\n"
|
|
||||||
frames_guide.set_length(len(frames_style))
|
|
||||||
height_guide, width_guide = frames_guide.shape()
|
|
||||||
height_style, width_style = frames_style.shape()
|
|
||||||
if height_guide != height_style or width_guide != width_style:
|
|
||||||
message += f"The shape of frames mismatches. The frames in style video will be resized to (height: {height_guide}, width: {width_guide})\n"
|
|
||||||
frames_style.set_shape(height_guide, width_guide)
|
|
||||||
return frames_guide, frames_style, message
|
|
||||||
|
|
||||||
|
|
||||||
def smooth_video(
|
|
||||||
video_guide,
|
|
||||||
video_guide_folder,
|
|
||||||
video_style,
|
|
||||||
video_style_folder,
|
|
||||||
mode,
|
|
||||||
window_size,
|
|
||||||
batch_size,
|
|
||||||
tracking_window_size,
|
|
||||||
output_path,
|
|
||||||
fps,
|
|
||||||
minimum_patch_size,
|
|
||||||
num_iter,
|
|
||||||
guide_weight,
|
|
||||||
initialize,
|
|
||||||
progress = None,
|
|
||||||
):
|
|
||||||
# input
|
|
||||||
frames_guide, frames_style, message = check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder)
|
|
||||||
if len(message) > 0:
|
|
||||||
print(message)
|
|
||||||
# output
|
|
||||||
if output_path == "":
|
|
||||||
if video_style is None:
|
|
||||||
output_path = os.path.join(video_style_folder, "output")
|
|
||||||
else:
|
|
||||||
output_path = os.path.join(os.path.split(video_style)[0], "output")
|
|
||||||
os.makedirs(output_path, exist_ok=True)
|
|
||||||
print("No valid output_path. Your video will be saved here:", output_path)
|
|
||||||
elif not os.path.exists(output_path):
|
|
||||||
os.makedirs(output_path, exist_ok=True)
|
|
||||||
print("Your video will be saved here:", output_path)
|
|
||||||
frames_path = os.path.join(output_path, "frames")
|
|
||||||
video_path = os.path.join(output_path, "video.mp4")
|
|
||||||
os.makedirs(frames_path, exist_ok=True)
|
|
||||||
# process
|
|
||||||
if mode == "Fast" or mode == "Balanced":
|
|
||||||
tracking_window_size = 0
|
|
||||||
ebsynth_config = {
|
|
||||||
"minimum_patch_size": minimum_patch_size,
|
|
||||||
"threads_per_block": 8,
|
|
||||||
"num_iter": num_iter,
|
|
||||||
"gpu_id": 0,
|
|
||||||
"guide_weight": guide_weight,
|
|
||||||
"initialize": initialize,
|
|
||||||
"tracking_window_size": tracking_window_size,
|
|
||||||
}
|
|
||||||
if mode == "Fast":
|
|
||||||
FastModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
|
||||||
elif mode == "Balanced":
|
|
||||||
BalancedModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
|
||||||
elif mode == "Accurate":
|
|
||||||
AccurateModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
|
||||||
# output
|
|
||||||
try:
|
|
||||||
fps = int(fps)
|
|
||||||
except:
|
|
||||||
fps = get_video_fps(video_style) if video_style is not None else 30
|
|
||||||
print("Fps:", fps)
|
|
||||||
print("Saving video...")
|
|
||||||
video_path = save_video(frames_path, video_path, num_frames=len(frames_style), fps=fps)
|
|
||||||
print("Success!")
|
|
||||||
print("Your frames are here:", frames_path)
|
|
||||||
print("Your video is here:", video_path)
|
|
||||||
return output_path, fps, video_path
|
|
||||||
|
|
||||||
|
|
||||||
class KeyFrameMatcher:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def extract_number_from_filename(self, file_name):
|
|
||||||
result = []
|
|
||||||
number = -1
|
|
||||||
for i in file_name:
|
|
||||||
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
|
||||||
if number == -1:
|
|
||||||
number = 0
|
|
||||||
number = number*10 + ord(i) - ord("0")
|
|
||||||
else:
|
|
||||||
if number != -1:
|
|
||||||
result.append(number)
|
|
||||||
number = -1
|
|
||||||
if number != -1:
|
|
||||||
result.append(number)
|
|
||||||
result = tuple(result)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def extract_number_from_filenames(self, file_names):
|
|
||||||
numbers = [self.extract_number_from_filename(file_name) for file_name in file_names]
|
|
||||||
min_length = min(len(i) for i in numbers)
|
|
||||||
for i in range(min_length-1, -1, -1):
|
|
||||||
if len(set(number[i] for number in numbers))==len(file_names):
|
|
||||||
return [number[i] for number in numbers]
|
|
||||||
return list(range(len(file_names)))
|
|
||||||
|
|
||||||
def match_using_filename(self, file_names_a, file_names_b):
|
|
||||||
file_names_b_set = set(file_names_b)
|
|
||||||
matched_file_name = []
|
|
||||||
for file_name in file_names_a:
|
|
||||||
if file_name not in file_names_b_set:
|
|
||||||
matched_file_name.append(None)
|
|
||||||
else:
|
|
||||||
matched_file_name.append(file_name)
|
|
||||||
return matched_file_name
|
|
||||||
|
|
||||||
def match_using_numbers(self, file_names_a, file_names_b):
|
|
||||||
numbers_a = self.extract_number_from_filenames(file_names_a)
|
|
||||||
numbers_b = self.extract_number_from_filenames(file_names_b)
|
|
||||||
numbers_b_dict = {number: file_name for number, file_name in zip(numbers_b, file_names_b)}
|
|
||||||
matched_file_name = []
|
|
||||||
for number in numbers_a:
|
|
||||||
if number in numbers_b_dict:
|
|
||||||
matched_file_name.append(numbers_b_dict[number])
|
|
||||||
else:
|
|
||||||
matched_file_name.append(None)
|
|
||||||
return matched_file_name
|
|
||||||
|
|
||||||
def match_filenames(self, file_names_a, file_names_b):
|
|
||||||
matched_file_name = self.match_using_filename(file_names_a, file_names_b)
|
|
||||||
if sum([i is not None for i in matched_file_name]) > 0:
|
|
||||||
return matched_file_name
|
|
||||||
matched_file_name = self.match_using_numbers(file_names_a, file_names_b)
|
|
||||||
return matched_file_name
|
|
||||||
|
|
||||||
|
|
||||||
def detect_frames(frames_path, keyframes_path):
|
|
||||||
if not os.path.exists(frames_path) and not os.path.exists(keyframes_path):
|
|
||||||
return "Please input the directory of guide video and rendered frames"
|
|
||||||
elif not os.path.exists(frames_path):
|
|
||||||
return "Please input the directory of guide video"
|
|
||||||
elif not os.path.exists(keyframes_path):
|
|
||||||
return "Please input the directory of rendered frames"
|
|
||||||
frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
|
|
||||||
keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
|
|
||||||
if len(frames)==0:
|
|
||||||
return f"No images detected in {frames_path}"
|
|
||||||
if len(keyframes)==0:
|
|
||||||
return f"No images detected in {keyframes_path}"
|
|
||||||
matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
|
|
||||||
max_filename_length = max([len(i) for i in frames])
|
|
||||||
if sum([i is not None for i in matched_keyframes])==0:
|
|
||||||
message = ""
|
|
||||||
for frame, matched_keyframe in zip(frames, matched_keyframes):
|
|
||||||
message += frame + " " * (max_filename_length - len(frame) + 1)
|
|
||||||
message += "--> No matched keyframes\n"
|
|
||||||
else:
|
|
||||||
message = ""
|
|
||||||
for frame, matched_keyframe in zip(frames, matched_keyframes):
|
|
||||||
message += frame + " " * (max_filename_length - len(frame) + 1)
|
|
||||||
if matched_keyframe is None:
|
|
||||||
message += "--> [to be rendered]\n"
|
|
||||||
else:
|
|
||||||
message += f"--> {matched_keyframe}\n"
|
|
||||||
return message
|
|
||||||
|
|
||||||
|
|
||||||
def check_input_for_interpolating(frames_path, keyframes_path):
|
|
||||||
# search for images
|
|
||||||
frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
|
|
||||||
keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
|
|
||||||
# match frames
|
|
||||||
matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
|
|
||||||
file_list = [file_name for file_name in matched_keyframes if file_name is not None]
|
|
||||||
index_style = [i for i, file_name in enumerate(matched_keyframes) if file_name is not None]
|
|
||||||
frames_guide = VideoData(None, frames_path)
|
|
||||||
frames_style = VideoData(None, keyframes_path, file_list=file_list)
|
|
||||||
# match shape
|
|
||||||
message = ""
|
|
||||||
height_guide, width_guide = frames_guide.shape()
|
|
||||||
height_style, width_style = frames_style.shape()
|
|
||||||
if height_guide != height_style or width_guide != width_style:
|
|
||||||
message += f"The shape of frames mismatches. The rendered keyframes will be resized to (height: {height_guide}, width: {width_guide})\n"
|
|
||||||
frames_style.set_shape(height_guide, width_guide)
|
|
||||||
return frames_guide, frames_style, index_style, message
|
|
||||||
|
|
||||||
|
|
||||||
def interpolate_video(
|
|
||||||
frames_path,
|
|
||||||
keyframes_path,
|
|
||||||
output_path,
|
|
||||||
fps,
|
|
||||||
batch_size,
|
|
||||||
tracking_window_size,
|
|
||||||
minimum_patch_size,
|
|
||||||
num_iter,
|
|
||||||
guide_weight,
|
|
||||||
initialize,
|
|
||||||
progress = None,
|
|
||||||
):
|
|
||||||
# input
|
|
||||||
frames_guide, frames_style, index_style, message = check_input_for_interpolating(frames_path, keyframes_path)
|
|
||||||
if len(message) > 0:
|
|
||||||
print(message)
|
|
||||||
# output
|
|
||||||
if output_path == "":
|
|
||||||
output_path = os.path.join(keyframes_path, "output")
|
|
||||||
os.makedirs(output_path, exist_ok=True)
|
|
||||||
print("No valid output_path. Your video will be saved here:", output_path)
|
|
||||||
elif not os.path.exists(output_path):
|
|
||||||
os.makedirs(output_path, exist_ok=True)
|
|
||||||
print("Your video will be saved here:", output_path)
|
|
||||||
output_frames_path = os.path.join(output_path, "frames")
|
|
||||||
output_video_path = os.path.join(output_path, "video.mp4")
|
|
||||||
os.makedirs(output_frames_path, exist_ok=True)
|
|
||||||
# process
|
|
||||||
ebsynth_config = {
|
|
||||||
"minimum_patch_size": minimum_patch_size,
|
|
||||||
"threads_per_block": 8,
|
|
||||||
"num_iter": num_iter,
|
|
||||||
"gpu_id": 0,
|
|
||||||
"guide_weight": guide_weight,
|
|
||||||
"initialize": initialize,
|
|
||||||
"tracking_window_size": tracking_window_size
|
|
||||||
}
|
|
||||||
if len(index_style)==1:
|
|
||||||
InterpolationModeSingleFrameRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
|
|
||||||
else:
|
|
||||||
InterpolationModeRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
|
|
||||||
try:
|
|
||||||
fps = int(fps)
|
|
||||||
except:
|
|
||||||
fps = 30
|
|
||||||
print("Fps:", fps)
|
|
||||||
print("Saving video...")
|
|
||||||
video_path = save_video(output_frames_path, output_video_path, num_frames=len(frames_guide), fps=fps)
|
|
||||||
print("Success!")
|
|
||||||
print("Your frames are here:", output_frames_path)
|
|
||||||
print("Your video is here:", video_path)
|
|
||||||
return output_path, fps, video_path
|
|
||||||
|
|
||||||
|
|
||||||
def on_ui_tabs():
|
|
||||||
with gr.Blocks(analytics_enabled=False) as ui_component:
|
|
||||||
with gr.Tab("Blend"):
|
|
||||||
gr.Markdown("""
|
|
||||||
# Blend
|
|
||||||
|
|
||||||
Given a guide video and a style video, this algorithm will make the style video fluent according to the motion features of the guide video. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/208d902d-6aba-48d7-b7d5-cd120ebd306d) to see the example. Note that this extension doesn't support long videos. Please use short videos (e.g., several seconds). The algorithm is mainly designed for 512*512 resolution. Please use a larger `Minimum patch size` for higher resolution.
|
|
||||||
""")
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
with gr.Tab("Guide video"):
|
|
||||||
video_guide = gr.Video(label="Guide video")
|
|
||||||
with gr.Tab("Guide video (images format)"):
|
|
||||||
video_guide_folder = gr.Textbox(label="Guide video (images format)", value="")
|
|
||||||
with gr.Column():
|
|
||||||
with gr.Tab("Style video"):
|
|
||||||
video_style = gr.Video(label="Style video")
|
|
||||||
with gr.Tab("Style video (images format)"):
|
|
||||||
video_style_folder = gr.Textbox(label="Style video (images format)", value="")
|
|
||||||
with gr.Column():
|
|
||||||
output_path = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of style video")
|
|
||||||
fps = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
|
|
||||||
video_output = gr.Video(label="Output video", interactive=False, show_share_button=True)
|
|
||||||
btn = gr.Button(value="Blend")
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
gr.Markdown("# Settings")
|
|
||||||
mode = gr.Radio(["Fast", "Balanced", "Accurate"], label="Inference mode", value="Fast", interactive=True)
|
|
||||||
window_size = gr.Slider(label="Sliding window size", value=15, minimum=1, maximum=1000, step=1, interactive=True)
|
|
||||||
batch_size = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
|
|
||||||
tracking_window_size = gr.Slider(label="Tracking window size (only for accurate mode)", value=0, minimum=0, maximum=10, step=1, interactive=True)
|
|
||||||
gr.Markdown("## Advanced Settings")
|
|
||||||
minimum_patch_size = gr.Slider(label="Minimum patch size (odd number)", value=5, minimum=5, maximum=99, step=2, interactive=True)
|
|
||||||
num_iter = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
|
|
||||||
guide_weight = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
|
|
||||||
initialize = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
|
|
||||||
with gr.Column():
|
|
||||||
gr.Markdown("""
|
|
||||||
# Reference
|
|
||||||
|
|
||||||
* Output directory: the directory to save the video.
|
|
||||||
* Inference mode
|
|
||||||
|
|
||||||
|Mode|Time|Memory|Quality|Frame by frame output|Description|
|
|
||||||
|-|-|-|-|-|-|
|
|
||||||
|Fast|■|■■■|■■|No|Blend the frames using a tree-like data structure, which requires much RAM but is fast.|
|
|
||||||
|Balanced|■■|■|■■|Yes|Blend the frames naively.|
|
|
||||||
|Accurate|■■■|■|■■■|Yes|Blend the frames and align them together for higher video quality. When [batch size] >= [sliding window size] * 2 + 1, the performance is the best.|
|
|
||||||
|
|
||||||
* Sliding window size: our algorithm will blend the frames in a sliding windows. If the size is n, each frame will be blended with the last n frames and the next n frames. A large sliding window can make the video fluent but sometimes smoggy.
|
|
||||||
* Batch size: a larger batch size makes the program faster but requires more VRAM.
|
|
||||||
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
|
|
||||||
* Advanced settings
|
|
||||||
* Minimum patch size (odd number): the minimum patch size used for patch matching. (Default: 5)
|
|
||||||
* Number of iterations: the number of iterations of patch matching. (Default: 5)
|
|
||||||
* Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
|
|
||||||
* NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
|
|
||||||
""")
|
|
||||||
btn.click(
|
|
||||||
smooth_video,
|
|
||||||
inputs=[
|
|
||||||
video_guide,
|
|
||||||
video_guide_folder,
|
|
||||||
video_style,
|
|
||||||
video_style_folder,
|
|
||||||
mode,
|
|
||||||
window_size,
|
|
||||||
batch_size,
|
|
||||||
tracking_window_size,
|
|
||||||
output_path,
|
|
||||||
fps,
|
|
||||||
minimum_patch_size,
|
|
||||||
num_iter,
|
|
||||||
guide_weight,
|
|
||||||
initialize
|
|
||||||
],
|
|
||||||
outputs=[output_path, fps, video_output]
|
|
||||||
)
|
|
||||||
with gr.Tab("Interpolate"):
|
|
||||||
gr.Markdown("""
|
|
||||||
# Interpolate
|
|
||||||
|
|
||||||
Given a guide video and some rendered keyframes, this algorithm will render the remaining frames. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/3490c5b4-8f67-478f-86de-f9adc2ace16a) to see the example. The algorithm is experimental and is only tested for 512*512 resolution.
|
|
||||||
""")
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
video_guide_folder_ = gr.Textbox(label="Guide video (images format)", value="")
|
|
||||||
with gr.Column():
|
|
||||||
rendered_keyframes_ = gr.Textbox(label="Rendered keyframes (images format)", value="")
|
|
||||||
with gr.Row():
|
|
||||||
detected_frames = gr.Textbox(label="Detected frames", value="Please input the directory of guide video and rendered frames", lines=9, max_lines=9, interactive=False)
|
|
||||||
video_guide_folder_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
|
|
||||||
rendered_keyframes_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
|
|
||||||
with gr.Column():
|
|
||||||
output_path_ = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of rendered keyframes")
|
|
||||||
fps_ = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
|
|
||||||
video_output_ = gr.Video(label="Output video", interactive=False, show_share_button=True)
|
|
||||||
btn_ = gr.Button(value="Interpolate")
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
gr.Markdown("# Settings")
|
|
||||||
batch_size_ = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
|
|
||||||
tracking_window_size_ = gr.Slider(label="Tracking window size", value=0, minimum=0, maximum=10, step=1, interactive=True)
|
|
||||||
gr.Markdown("## Advanced Settings")
|
|
||||||
minimum_patch_size_ = gr.Slider(label="Minimum patch size (odd number, larger is better)", value=15, minimum=5, maximum=99, step=2, interactive=True)
|
|
||||||
num_iter_ = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
|
|
||||||
guide_weight_ = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
|
|
||||||
initialize_ = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
|
|
||||||
with gr.Column():
|
|
||||||
gr.Markdown("""
|
|
||||||
# Reference
|
|
||||||
|
|
||||||
* Output directory: the directory to save the video.
|
|
||||||
* Batch size: a larger batch size makes the program faster but requires more VRAM.
|
|
||||||
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
|
|
||||||
* Advanced settings
|
|
||||||
* Minimum patch size (odd number): the minimum patch size used for patch matching. **This parameter should be larger than that in blending. (Default: 15)**
|
|
||||||
* Number of iterations: the number of iterations of patch matching. (Default: 5)
|
|
||||||
* Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
|
|
||||||
* NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
|
|
||||||
""")
|
|
||||||
btn_.click(
|
|
||||||
interpolate_video,
|
|
||||||
inputs=[
|
|
||||||
video_guide_folder_,
|
|
||||||
rendered_keyframes_,
|
|
||||||
output_path_,
|
|
||||||
fps_,
|
|
||||||
batch_size_,
|
|
||||||
tracking_window_size_,
|
|
||||||
minimum_patch_size_,
|
|
||||||
num_iter_,
|
|
||||||
guide_weight_,
|
|
||||||
initialize_,
|
|
||||||
],
|
|
||||||
outputs=[output_path_, fps_, video_output_]
|
|
||||||
)
|
|
||||||
|
|
||||||
return [(ui_component, "FastBlend", "FastBlend_ui")]
|
|
||||||
@@ -1,119 +0,0 @@
|
|||||||
import cupy as cp
|
|
||||||
|
|
||||||
remapping_kernel = cp.RawKernel(r'''
|
|
||||||
extern "C" __global__
|
|
||||||
void remap(
|
|
||||||
const int height,
|
|
||||||
const int width,
|
|
||||||
const int channel,
|
|
||||||
const int patch_size,
|
|
||||||
const int pad_size,
|
|
||||||
const float* source_style,
|
|
||||||
const int* nnf,
|
|
||||||
float* target_style
|
|
||||||
) {
|
|
||||||
const int r = (patch_size - 1) / 2;
|
|
||||||
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
|
||||||
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
|
||||||
if (x >= height or y >= width) return;
|
|
||||||
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
|
||||||
const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size);
|
|
||||||
const int min_px = x < r ? -x : -r;
|
|
||||||
const int max_px = x + r > height - 1 ? height - 1 - x : r;
|
|
||||||
const int min_py = y < r ? -y : -r;
|
|
||||||
const int max_py = y + r > width - 1 ? width - 1 - y : r;
|
|
||||||
int num = 0;
|
|
||||||
for (int px = min_px; px <= max_px; px++){
|
|
||||||
for (int py = min_py; py <= max_py; py++){
|
|
||||||
const int nid = (x + px) * width + y + py;
|
|
||||||
const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px;
|
|
||||||
const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py;
|
|
||||||
if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue;
|
|
||||||
const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size);
|
|
||||||
num++;
|
|
||||||
for (int c = 0; c < channel; c++){
|
|
||||||
target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (int c = 0; c < channel; c++){
|
|
||||||
target_style[z + pid * channel + c] /= num;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
''', 'remap')
|
|
||||||
|
|
||||||
|
|
||||||
patch_error_kernel = cp.RawKernel(r'''
|
|
||||||
extern "C" __global__
|
|
||||||
void patch_error(
|
|
||||||
const int height,
|
|
||||||
const int width,
|
|
||||||
const int channel,
|
|
||||||
const int patch_size,
|
|
||||||
const int pad_size,
|
|
||||||
const float* source,
|
|
||||||
const int* nnf,
|
|
||||||
const float* target,
|
|
||||||
float* error
|
|
||||||
) {
|
|
||||||
const int r = (patch_size - 1) / 2;
|
|
||||||
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
|
||||||
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
|
||||||
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
|
||||||
if (x >= height or y >= width) return;
|
|
||||||
const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0];
|
|
||||||
const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1];
|
|
||||||
float e = 0;
|
|
||||||
for (int px = -r; px <= r; px++){
|
|
||||||
for (int py = -r; py <= r; py++){
|
|
||||||
const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py;
|
|
||||||
const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py;
|
|
||||||
for (int c = 0; c < channel; c++){
|
|
||||||
const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c];
|
|
||||||
e += diff * diff;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
error[blockIdx.z * height * width + x * width + y] = e;
|
|
||||||
}
|
|
||||||
''', 'patch_error')
|
|
||||||
|
|
||||||
|
|
||||||
pairwise_patch_error_kernel = cp.RawKernel(r'''
|
|
||||||
extern "C" __global__
|
|
||||||
void pairwise_patch_error(
|
|
||||||
const int height,
|
|
||||||
const int width,
|
|
||||||
const int channel,
|
|
||||||
const int patch_size,
|
|
||||||
const int pad_size,
|
|
||||||
const float* source_a,
|
|
||||||
const int* nnf_a,
|
|
||||||
const float* source_b,
|
|
||||||
const int* nnf_b,
|
|
||||||
float* error
|
|
||||||
) {
|
|
||||||
const int r = (patch_size - 1) / 2;
|
|
||||||
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
|
||||||
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
|
||||||
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
|
||||||
if (x >= height or y >= width) return;
|
|
||||||
const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2;
|
|
||||||
const int x_a = nnf_a[z_nnf + 0];
|
|
||||||
const int y_a = nnf_a[z_nnf + 1];
|
|
||||||
const int x_b = nnf_b[z_nnf + 0];
|
|
||||||
const int y_b = nnf_b[z_nnf + 1];
|
|
||||||
float e = 0;
|
|
||||||
for (int px = -r; px <= r; px++){
|
|
||||||
for (int py = -r; py <= r; py++){
|
|
||||||
const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py;
|
|
||||||
const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py;
|
|
||||||
for (int c = 0; c < channel; c++){
|
|
||||||
const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c];
|
|
||||||
e += diff * diff;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
error[blockIdx.z * height * width + x * width + y] = e;
|
|
||||||
}
|
|
||||||
''', 'pairwise_patch_error')
|
|
||||||
@@ -1,146 +0,0 @@
|
|||||||
import imageio, os
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
def read_video(file_name):
|
|
||||||
reader = imageio.get_reader(file_name)
|
|
||||||
video = []
|
|
||||||
for frame in reader:
|
|
||||||
frame = np.array(frame)
|
|
||||||
video.append(frame)
|
|
||||||
reader.close()
|
|
||||||
return video
|
|
||||||
|
|
||||||
|
|
||||||
def get_video_fps(file_name):
|
|
||||||
reader = imageio.get_reader(file_name)
|
|
||||||
fps = reader.get_meta_data()["fps"]
|
|
||||||
reader.close()
|
|
||||||
return fps
|
|
||||||
|
|
||||||
|
|
||||||
def save_video(frames_path, video_path, num_frames, fps):
|
|
||||||
writer = imageio.get_writer(video_path, fps=fps, quality=9)
|
|
||||||
for i in range(num_frames):
|
|
||||||
frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
|
|
||||||
writer.append_data(frame)
|
|
||||||
writer.close()
|
|
||||||
return video_path
|
|
||||||
|
|
||||||
|
|
||||||
class LowMemoryVideo:
|
|
||||||
def __init__(self, file_name):
|
|
||||||
self.reader = imageio.get_reader(file_name)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.reader.count_frames()
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
return np.array(self.reader.get_data(item))
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
self.reader.close()
|
|
||||||
|
|
||||||
|
|
||||||
def split_file_name(file_name):
|
|
||||||
result = []
|
|
||||||
number = -1
|
|
||||||
for i in file_name:
|
|
||||||
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
|
||||||
if number == -1:
|
|
||||||
number = 0
|
|
||||||
number = number*10 + ord(i) - ord("0")
|
|
||||||
else:
|
|
||||||
if number != -1:
|
|
||||||
result.append(number)
|
|
||||||
number = -1
|
|
||||||
result.append(i)
|
|
||||||
if number != -1:
|
|
||||||
result.append(number)
|
|
||||||
result = tuple(result)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def search_for_images(folder):
|
|
||||||
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
|
|
||||||
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
|
|
||||||
file_list = [i[1] for i in sorted(file_list)]
|
|
||||||
file_list = [os.path.join(folder, i) for i in file_list]
|
|
||||||
return file_list
|
|
||||||
|
|
||||||
|
|
||||||
def read_images(folder):
|
|
||||||
file_list = search_for_images(folder)
|
|
||||||
frames = [np.array(Image.open(i)) for i in file_list]
|
|
||||||
return frames
|
|
||||||
|
|
||||||
|
|
||||||
class LowMemoryImageFolder:
|
|
||||||
def __init__(self, folder, file_list=None):
|
|
||||||
if file_list is None:
|
|
||||||
self.file_list = search_for_images(folder)
|
|
||||||
else:
|
|
||||||
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.file_list)
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
return np.array(Image.open(self.file_list[item]))
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class VideoData:
|
|
||||||
def __init__(self, video_file, image_folder, **kwargs):
|
|
||||||
if video_file is not None:
|
|
||||||
self.data_type = "video"
|
|
||||||
self.data = LowMemoryVideo(video_file, **kwargs)
|
|
||||||
elif image_folder is not None:
|
|
||||||
self.data_type = "images"
|
|
||||||
self.data = LowMemoryImageFolder(image_folder, **kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError("Cannot open video or image folder")
|
|
||||||
self.length = None
|
|
||||||
self.height = None
|
|
||||||
self.width = None
|
|
||||||
|
|
||||||
def raw_data(self):
|
|
||||||
frames = []
|
|
||||||
for i in range(self.__len__()):
|
|
||||||
frames.append(self.__getitem__(i))
|
|
||||||
return frames
|
|
||||||
|
|
||||||
def set_length(self, length):
|
|
||||||
self.length = length
|
|
||||||
|
|
||||||
def set_shape(self, height, width):
|
|
||||||
self.height = height
|
|
||||||
self.width = width
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
if self.length is None:
|
|
||||||
return len(self.data)
|
|
||||||
else:
|
|
||||||
return self.length
|
|
||||||
|
|
||||||
def shape(self):
|
|
||||||
if self.height is not None and self.width is not None:
|
|
||||||
return self.height, self.width
|
|
||||||
else:
|
|
||||||
height, width, _ = self.__getitem__(0).shape
|
|
||||||
return height, width
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
frame = self.data.__getitem__(item)
|
|
||||||
height, width, _ = frame.shape
|
|
||||||
if self.height is not None and self.width is not None:
|
|
||||||
if self.height != height or self.width != width:
|
|
||||||
frame = Image.fromarray(frame).resize((self.width, self.height))
|
|
||||||
frame = np.array(frame)
|
|
||||||
return frame
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
pass
|
|
||||||
@@ -1,298 +0,0 @@
|
|||||||
from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel
|
|
||||||
import numpy as np
|
|
||||||
import cupy as cp
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
|
|
||||||
class PatchMatcher:
|
|
||||||
def __init__(
|
|
||||||
self, height, width, channel, minimum_patch_size,
|
|
||||||
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
|
|
||||||
random_search_steps=3, random_search_range=4,
|
|
||||||
use_mean_target_style=False, use_pairwise_patch_error=False,
|
|
||||||
tracking_window_size=0
|
|
||||||
):
|
|
||||||
self.height = height
|
|
||||||
self.width = width
|
|
||||||
self.channel = channel
|
|
||||||
self.minimum_patch_size = minimum_patch_size
|
|
||||||
self.threads_per_block = threads_per_block
|
|
||||||
self.num_iter = num_iter
|
|
||||||
self.gpu_id = gpu_id
|
|
||||||
self.guide_weight = guide_weight
|
|
||||||
self.random_search_steps = random_search_steps
|
|
||||||
self.random_search_range = random_search_range
|
|
||||||
self.use_mean_target_style = use_mean_target_style
|
|
||||||
self.use_pairwise_patch_error = use_pairwise_patch_error
|
|
||||||
self.tracking_window_size = tracking_window_size
|
|
||||||
|
|
||||||
self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1]
|
|
||||||
self.pad_size = self.patch_size_list[0] // 2
|
|
||||||
self.grid = (
|
|
||||||
(height + threads_per_block - 1) // threads_per_block,
|
|
||||||
(width + threads_per_block - 1) // threads_per_block
|
|
||||||
)
|
|
||||||
self.block = (threads_per_block, threads_per_block)
|
|
||||||
|
|
||||||
def pad_image(self, image):
|
|
||||||
return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0)))
|
|
||||||
|
|
||||||
def unpad_image(self, image):
|
|
||||||
return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :]
|
|
||||||
|
|
||||||
def apply_nnf_to_image(self, nnf, source):
|
|
||||||
batch_size = source.shape[0]
|
|
||||||
target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32)
|
|
||||||
remapping_kernel(
|
|
||||||
self.grid + (batch_size,),
|
|
||||||
self.block,
|
|
||||||
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target)
|
|
||||||
)
|
|
||||||
return target
|
|
||||||
|
|
||||||
def get_patch_error(self, source, nnf, target):
|
|
||||||
batch_size = source.shape[0]
|
|
||||||
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
|
|
||||||
patch_error_kernel(
|
|
||||||
self.grid + (batch_size,),
|
|
||||||
self.block,
|
|
||||||
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error)
|
|
||||||
)
|
|
||||||
return error
|
|
||||||
|
|
||||||
def get_pairwise_patch_error(self, source, nnf):
|
|
||||||
batch_size = source.shape[0]//2
|
|
||||||
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
|
|
||||||
source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy()
|
|
||||||
source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy()
|
|
||||||
pairwise_patch_error_kernel(
|
|
||||||
self.grid + (batch_size,),
|
|
||||||
self.block,
|
|
||||||
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error)
|
|
||||||
)
|
|
||||||
error = error.repeat(2, axis=0)
|
|
||||||
return error
|
|
||||||
|
|
||||||
def get_error(self, source_guide, target_guide, source_style, target_style, nnf):
|
|
||||||
error_guide = self.get_patch_error(source_guide, nnf, target_guide)
|
|
||||||
if self.use_mean_target_style:
|
|
||||||
target_style = self.apply_nnf_to_image(nnf, source_style)
|
|
||||||
target_style = target_style.mean(axis=0, keepdims=True)
|
|
||||||
target_style = target_style.repeat(source_guide.shape[0], axis=0)
|
|
||||||
if self.use_pairwise_patch_error:
|
|
||||||
error_style = self.get_pairwise_patch_error(source_style, nnf)
|
|
||||||
else:
|
|
||||||
error_style = self.get_patch_error(source_style, nnf, target_style)
|
|
||||||
error = error_guide * self.guide_weight + error_style
|
|
||||||
return error
|
|
||||||
|
|
||||||
def clamp_bound(self, nnf):
|
|
||||||
nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1)
|
|
||||||
nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1)
|
|
||||||
return nnf
|
|
||||||
|
|
||||||
def random_step(self, nnf, r):
|
|
||||||
batch_size = nnf.shape[0]
|
|
||||||
step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32)
|
|
||||||
upd_nnf = self.clamp_bound(nnf + step)
|
|
||||||
return upd_nnf
|
|
||||||
|
|
||||||
def neighboor_step(self, nnf, d):
|
|
||||||
if d==0:
|
|
||||||
upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1)
|
|
||||||
upd_nnf[:, :, :, 0] += 1
|
|
||||||
elif d==1:
|
|
||||||
upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2)
|
|
||||||
upd_nnf[:, :, :, 1] += 1
|
|
||||||
elif d==2:
|
|
||||||
upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1)
|
|
||||||
upd_nnf[:, :, :, 0] -= 1
|
|
||||||
elif d==3:
|
|
||||||
upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2)
|
|
||||||
upd_nnf[:, :, :, 1] -= 1
|
|
||||||
upd_nnf = self.clamp_bound(upd_nnf)
|
|
||||||
return upd_nnf
|
|
||||||
|
|
||||||
def shift_nnf(self, nnf, d):
|
|
||||||
if d>0:
|
|
||||||
d = min(nnf.shape[0], d)
|
|
||||||
upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
|
|
||||||
else:
|
|
||||||
d = max(-nnf.shape[0], d)
|
|
||||||
upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
|
|
||||||
return upd_nnf
|
|
||||||
|
|
||||||
def track_step(self, nnf, d):
|
|
||||||
if self.use_pairwise_patch_error:
|
|
||||||
upd_nnf = cp.zeros_like(nnf)
|
|
||||||
upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d)
|
|
||||||
upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d)
|
|
||||||
else:
|
|
||||||
upd_nnf = self.shift_nnf(nnf, d)
|
|
||||||
return upd_nnf
|
|
||||||
|
|
||||||
def C(self, n, m):
|
|
||||||
# not used
|
|
||||||
c = 1
|
|
||||||
for i in range(1, n+1):
|
|
||||||
c *= i
|
|
||||||
for i in range(1, m+1):
|
|
||||||
c //= i
|
|
||||||
for i in range(1, n-m+1):
|
|
||||||
c //= i
|
|
||||||
return c
|
|
||||||
|
|
||||||
def bezier_step(self, nnf, r):
|
|
||||||
# not used
|
|
||||||
n = r * 2 - 1
|
|
||||||
upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32)
|
|
||||||
for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))):
|
|
||||||
if d>0:
|
|
||||||
ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
|
|
||||||
elif d<0:
|
|
||||||
ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
|
|
||||||
upd_nnf += ctl_nnf * (self.C(n, i) / 2**n)
|
|
||||||
upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype)
|
|
||||||
return upd_nnf
|
|
||||||
|
|
||||||
def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf):
|
|
||||||
upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf)
|
|
||||||
upd_idx = (upd_err < err)
|
|
||||||
nnf[upd_idx] = upd_nnf[upd_idx]
|
|
||||||
err[upd_idx] = upd_err[upd_idx]
|
|
||||||
return nnf, err
|
|
||||||
|
|
||||||
def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
|
||||||
for d in cp.random.permutation(4):
|
|
||||||
upd_nnf = self.neighboor_step(nnf, d)
|
|
||||||
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
|
||||||
return nnf, err
|
|
||||||
|
|
||||||
def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
|
||||||
for i in range(self.random_search_steps):
|
|
||||||
upd_nnf = self.random_step(nnf, self.random_search_range)
|
|
||||||
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
|
||||||
return nnf, err
|
|
||||||
|
|
||||||
def track(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
|
||||||
for d in range(1, self.tracking_window_size + 1):
|
|
||||||
upd_nnf = self.track_step(nnf, d)
|
|
||||||
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
|
||||||
upd_nnf = self.track_step(nnf, -d)
|
|
||||||
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
|
||||||
return nnf, err
|
|
||||||
|
|
||||||
def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
|
||||||
nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err)
|
|
||||||
nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err)
|
|
||||||
nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err)
|
|
||||||
return nnf, err
|
|
||||||
|
|
||||||
def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
|
|
||||||
with cp.cuda.Device(self.gpu_id):
|
|
||||||
source_guide = self.pad_image(source_guide)
|
|
||||||
target_guide = self.pad_image(target_guide)
|
|
||||||
source_style = self.pad_image(source_style)
|
|
||||||
for it in range(self.num_iter):
|
|
||||||
self.patch_size = self.patch_size_list[it]
|
|
||||||
target_style = self.apply_nnf_to_image(nnf, source_style)
|
|
||||||
err = self.get_error(source_guide, target_guide, source_style, target_style, nnf)
|
|
||||||
nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err)
|
|
||||||
target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style))
|
|
||||||
return nnf, target_style
|
|
||||||
|
|
||||||
|
|
||||||
class PyramidPatchMatcher:
|
|
||||||
def __init__(
|
|
||||||
self, image_height, image_width, channel, minimum_patch_size,
|
|
||||||
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
|
|
||||||
use_mean_target_style=False, use_pairwise_patch_error=False,
|
|
||||||
tracking_window_size=0,
|
|
||||||
initialize="identity"
|
|
||||||
):
|
|
||||||
maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2
|
|
||||||
self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size))
|
|
||||||
self.pyramid_heights = []
|
|
||||||
self.pyramid_widths = []
|
|
||||||
self.patch_matchers = []
|
|
||||||
self.minimum_patch_size = minimum_patch_size
|
|
||||||
self.num_iter = num_iter
|
|
||||||
self.gpu_id = gpu_id
|
|
||||||
self.initialize = initialize
|
|
||||||
for level in range(self.pyramid_level):
|
|
||||||
height = image_height//(2**(self.pyramid_level - 1 - level))
|
|
||||||
width = image_width//(2**(self.pyramid_level - 1 - level))
|
|
||||||
self.pyramid_heights.append(height)
|
|
||||||
self.pyramid_widths.append(width)
|
|
||||||
self.patch_matchers.append(PatchMatcher(
|
|
||||||
height, width, channel, minimum_patch_size=minimum_patch_size,
|
|
||||||
threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight,
|
|
||||||
use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error,
|
|
||||||
tracking_window_size=tracking_window_size
|
|
||||||
))
|
|
||||||
|
|
||||||
def resample_image(self, images, level):
|
|
||||||
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
|
||||||
images = images.get()
|
|
||||||
images_resample = []
|
|
||||||
for image in images:
|
|
||||||
image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
|
|
||||||
images_resample.append(image_resample)
|
|
||||||
images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
|
|
||||||
return images_resample
|
|
||||||
|
|
||||||
def initialize_nnf(self, batch_size):
|
|
||||||
if self.initialize == "random":
|
|
||||||
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
|
|
||||||
nnf = cp.stack([
|
|
||||||
cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32),
|
|
||||||
cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32)
|
|
||||||
], axis=3)
|
|
||||||
elif self.initialize == "identity":
|
|
||||||
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
|
|
||||||
nnf = cp.stack([
|
|
||||||
cp.repeat(cp.arange(height), width).reshape(height, width),
|
|
||||||
cp.tile(cp.arange(width), height).reshape(height, width)
|
|
||||||
], axis=2)
|
|
||||||
nnf = cp.stack([nnf] * batch_size)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
return nnf
|
|
||||||
|
|
||||||
def update_nnf(self, nnf, level):
|
|
||||||
# upscale
|
|
||||||
nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
|
|
||||||
nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
|
|
||||||
nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
|
|
||||||
# check if scale is 2
|
|
||||||
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
|
||||||
if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
|
|
||||||
nnf = nnf.get().astype(np.float32)
|
|
||||||
nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
|
|
||||||
nnf = cp.array(np.stack(nnf), dtype=cp.int32)
|
|
||||||
nnf = self.patch_matchers[level].clamp_bound(nnf)
|
|
||||||
return nnf
|
|
||||||
|
|
||||||
def apply_nnf_to_image(self, nnf, image):
|
|
||||||
with cp.cuda.Device(self.gpu_id):
|
|
||||||
image = self.patch_matchers[-1].pad_image(image)
|
|
||||||
image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image)
|
|
||||||
return image
|
|
||||||
|
|
||||||
def estimate_nnf(self, source_guide, target_guide, source_style):
|
|
||||||
with cp.cuda.Device(self.gpu_id):
|
|
||||||
if not isinstance(source_guide, cp.ndarray):
|
|
||||||
source_guide = cp.array(source_guide, dtype=cp.float32)
|
|
||||||
if not isinstance(target_guide, cp.ndarray):
|
|
||||||
target_guide = cp.array(target_guide, dtype=cp.float32)
|
|
||||||
if not isinstance(source_style, cp.ndarray):
|
|
||||||
source_style = cp.array(source_style, dtype=cp.float32)
|
|
||||||
for level in range(self.pyramid_level):
|
|
||||||
nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level)
|
|
||||||
source_guide_ = self.resample_image(source_guide, level)
|
|
||||||
target_guide_ = self.resample_image(target_guide, level)
|
|
||||||
source_style_ = self.resample_image(source_style, level)
|
|
||||||
nnf, target_style = self.patch_matchers[level].estimate_nnf(
|
|
||||||
source_guide_, target_guide_, source_style_, nnf
|
|
||||||
)
|
|
||||||
return nnf.get(), target_style.get()
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
from .accurate import AccurateModeRunner
|
|
||||||
from .fast import FastModeRunner
|
|
||||||
from .balanced import BalancedModeRunner
|
|
||||||
from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
from ..patch_match import PyramidPatchMatcher
|
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
class AccurateModeRunner:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
|
|
||||||
patch_match_engine = PyramidPatchMatcher(
|
|
||||||
image_height=frames_style[0].shape[0],
|
|
||||||
image_width=frames_style[0].shape[1],
|
|
||||||
channel=3,
|
|
||||||
use_mean_target_style=True,
|
|
||||||
**ebsynth_config
|
|
||||||
)
|
|
||||||
# run
|
|
||||||
n = len(frames_style)
|
|
||||||
for target in tqdm(range(n), desc=desc):
|
|
||||||
l, r = max(target - window_size, 0), min(target + window_size + 1, n)
|
|
||||||
remapped_frames = []
|
|
||||||
for i in range(l, r, batch_size):
|
|
||||||
j = min(i + batch_size, r)
|
|
||||||
source_guide = np.stack([frames_guide[source] for source in range(i, j)])
|
|
||||||
target_guide = np.stack([frames_guide[target]] * (j - i))
|
|
||||||
source_style = np.stack([frames_style[source] for source in range(i, j)])
|
|
||||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
||||||
remapped_frames.append(target_style)
|
|
||||||
frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
|
|
||||||
frame = frame.clip(0, 255).astype("uint8")
|
|
||||||
if save_path is not None:
|
|
||||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
from ..patch_match import PyramidPatchMatcher
|
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
class BalancedModeRunner:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
|
|
||||||
patch_match_engine = PyramidPatchMatcher(
|
|
||||||
image_height=frames_style[0].shape[0],
|
|
||||||
image_width=frames_style[0].shape[1],
|
|
||||||
channel=3,
|
|
||||||
**ebsynth_config
|
|
||||||
)
|
|
||||||
# tasks
|
|
||||||
n = len(frames_style)
|
|
||||||
tasks = []
|
|
||||||
for target in range(n):
|
|
||||||
for source in range(target - window_size, target + window_size + 1):
|
|
||||||
if source >= 0 and source < n and source != target:
|
|
||||||
tasks.append((source, target))
|
|
||||||
# run
|
|
||||||
frames = [(None, 1) for i in range(n)]
|
|
||||||
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
|
||||||
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
|
||||||
source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
|
|
||||||
target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
|
|
||||||
source_style = np.stack([frames_style[source] for source, target in tasks_batch])
|
|
||||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
||||||
for (source, target), result in zip(tasks_batch, target_style):
|
|
||||||
frame, weight = frames[target]
|
|
||||||
if frame is None:
|
|
||||||
frame = frames_style[target]
|
|
||||||
frames[target] = (
|
|
||||||
frame * (weight / (weight + 1)) + result / (weight + 1),
|
|
||||||
weight + 1
|
|
||||||
)
|
|
||||||
if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
|
|
||||||
frame = frame.clip(0, 255).astype("uint8")
|
|
||||||
if save_path is not None:
|
|
||||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
|
||||||
frames[target] = (None, 1)
|
|
||||||
@@ -1,141 +0,0 @@
|
|||||||
from ..patch_match import PyramidPatchMatcher
|
|
||||||
import functools, os
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
class TableManager:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def task_list(self, n):
|
|
||||||
tasks = []
|
|
||||||
max_level = 1
|
|
||||||
while (1<<max_level)<=n:
|
|
||||||
max_level += 1
|
|
||||||
for i in range(n):
|
|
||||||
j = i
|
|
||||||
for level in range(max_level):
|
|
||||||
if i&(1<<level):
|
|
||||||
continue
|
|
||||||
j |= 1<<level
|
|
||||||
if j>=n:
|
|
||||||
break
|
|
||||||
meta_data = {
|
|
||||||
"source": i,
|
|
||||||
"target": j,
|
|
||||||
"level": level + 1
|
|
||||||
}
|
|
||||||
tasks.append(meta_data)
|
|
||||||
tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
|
|
||||||
return tasks
|
|
||||||
|
|
||||||
def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
|
|
||||||
n = len(frames_guide)
|
|
||||||
tasks = self.task_list(n)
|
|
||||||
remapping_table = [[(frames_style[i], 1)] for i in range(n)]
|
|
||||||
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
|
||||||
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
|
||||||
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
|
|
||||||
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
|
|
||||||
source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
|
|
||||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
||||||
for task, result in zip(tasks_batch, target_style):
|
|
||||||
target, level = task["target"], task["level"]
|
|
||||||
if len(remapping_table[target])==level:
|
|
||||||
remapping_table[target].append((result, 1))
|
|
||||||
else:
|
|
||||||
frame, weight = remapping_table[target][level]
|
|
||||||
remapping_table[target][level] = (
|
|
||||||
frame * (weight / (weight + 1)) + result / (weight + 1),
|
|
||||||
weight + 1
|
|
||||||
)
|
|
||||||
return remapping_table
|
|
||||||
|
|
||||||
def remapping_table_to_blending_table(self, table):
|
|
||||||
for i in range(len(table)):
|
|
||||||
for j in range(1, len(table[i])):
|
|
||||||
frame_1, weight_1 = table[i][j-1]
|
|
||||||
frame_2, weight_2 = table[i][j]
|
|
||||||
frame = (frame_1 + frame_2) / 2
|
|
||||||
weight = weight_1 + weight_2
|
|
||||||
table[i][j] = (frame, weight)
|
|
||||||
return table
|
|
||||||
|
|
||||||
def tree_query(self, leftbound, rightbound):
|
|
||||||
node_list = []
|
|
||||||
node_index = rightbound
|
|
||||||
while node_index>=leftbound:
|
|
||||||
node_level = 0
|
|
||||||
while (1<<node_level)&node_index and node_index-(1<<node_level+1)+1>=leftbound:
|
|
||||||
node_level += 1
|
|
||||||
node_list.append((node_index, node_level))
|
|
||||||
node_index -= 1<<node_level
|
|
||||||
return node_list
|
|
||||||
|
|
||||||
def process_window_sum(self, frames_guide, blending_table, patch_match_engine, window_size, batch_size, desc=""):
|
|
||||||
n = len(blending_table)
|
|
||||||
tasks = []
|
|
||||||
frames_result = []
|
|
||||||
for target in range(n):
|
|
||||||
node_list = self.tree_query(max(target-window_size, 0), target)
|
|
||||||
for source, level in node_list:
|
|
||||||
if source!=target:
|
|
||||||
meta_data = {
|
|
||||||
"source": source,
|
|
||||||
"target": target,
|
|
||||||
"level": level
|
|
||||||
}
|
|
||||||
tasks.append(meta_data)
|
|
||||||
else:
|
|
||||||
frames_result.append(blending_table[target][level])
|
|
||||||
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
|
||||||
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
|
||||||
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
|
|
||||||
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
|
|
||||||
source_style = np.stack([blending_table[task["source"]][task["level"]][0] for task in tasks_batch])
|
|
||||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
||||||
for task, frame_2 in zip(tasks_batch, target_style):
|
|
||||||
source, target, level = task["source"], task["target"], task["level"]
|
|
||||||
frame_1, weight_1 = frames_result[target]
|
|
||||||
weight_2 = blending_table[source][level][1]
|
|
||||||
weight = weight_1 + weight_2
|
|
||||||
frame = frame_1 * (weight_1 / weight) + frame_2 * (weight_2 / weight)
|
|
||||||
frames_result[target] = (frame, weight)
|
|
||||||
return frames_result
|
|
||||||
|
|
||||||
|
|
||||||
class FastModeRunner:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, save_path=None):
|
|
||||||
frames_guide = frames_guide.raw_data()
|
|
||||||
frames_style = frames_style.raw_data()
|
|
||||||
table_manager = TableManager()
|
|
||||||
patch_match_engine = PyramidPatchMatcher(
|
|
||||||
image_height=frames_style[0].shape[0],
|
|
||||||
image_width=frames_style[0].shape[1],
|
|
||||||
channel=3,
|
|
||||||
**ebsynth_config
|
|
||||||
)
|
|
||||||
# left part
|
|
||||||
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="Fast Mode Step 1/4")
|
|
||||||
table_l = table_manager.remapping_table_to_blending_table(table_l)
|
|
||||||
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 2/4")
|
|
||||||
# right part
|
|
||||||
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="Fast Mode Step 3/4")
|
|
||||||
table_r = table_manager.remapping_table_to_blending_table(table_r)
|
|
||||||
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 4/4")[::-1]
|
|
||||||
# merge
|
|
||||||
frames = []
|
|
||||||
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
|
|
||||||
weight_m = -1
|
|
||||||
weight = weight_l + weight_m + weight_r
|
|
||||||
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
|
|
||||||
frames.append(frame)
|
|
||||||
frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
|
|
||||||
if save_path is not None:
|
|
||||||
for target, frame in enumerate(frames):
|
|
||||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
|
||||||
@@ -1,121 +0,0 @@
|
|||||||
from ..patch_match import PyramidPatchMatcher
|
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
class InterpolationModeRunner:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_index_dict(self, index_style):
|
|
||||||
index_dict = {}
|
|
||||||
for i, index in enumerate(index_style):
|
|
||||||
index_dict[index] = i
|
|
||||||
return index_dict
|
|
||||||
|
|
||||||
def get_weight(self, l, m, r):
|
|
||||||
weight_l, weight_r = abs(m - r), abs(m - l)
|
|
||||||
if weight_l + weight_r == 0:
|
|
||||||
weight_l, weight_r = 0.5, 0.5
|
|
||||||
else:
|
|
||||||
weight_l, weight_r = weight_l / (weight_l + weight_r), weight_r / (weight_l + weight_r)
|
|
||||||
return weight_l, weight_r
|
|
||||||
|
|
||||||
def get_task_group(self, index_style, n):
|
|
||||||
task_group = []
|
|
||||||
index_style = sorted(index_style)
|
|
||||||
# first frame
|
|
||||||
if index_style[0]>0:
|
|
||||||
tasks = []
|
|
||||||
for m in range(index_style[0]):
|
|
||||||
tasks.append((index_style[0], m, index_style[0]))
|
|
||||||
task_group.append(tasks)
|
|
||||||
# middle frames
|
|
||||||
for l, r in zip(index_style[:-1], index_style[1:]):
|
|
||||||
tasks = []
|
|
||||||
for m in range(l, r):
|
|
||||||
tasks.append((l, m, r))
|
|
||||||
task_group.append(tasks)
|
|
||||||
# last frame
|
|
||||||
tasks = []
|
|
||||||
for m in range(index_style[-1], n):
|
|
||||||
tasks.append((index_style[-1], m, index_style[-1]))
|
|
||||||
task_group.append(tasks)
|
|
||||||
return task_group
|
|
||||||
|
|
||||||
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
|
|
||||||
patch_match_engine = PyramidPatchMatcher(
|
|
||||||
image_height=frames_style[0].shape[0],
|
|
||||||
image_width=frames_style[0].shape[1],
|
|
||||||
channel=3,
|
|
||||||
use_mean_target_style=False,
|
|
||||||
use_pairwise_patch_error=True,
|
|
||||||
**ebsynth_config
|
|
||||||
)
|
|
||||||
# task
|
|
||||||
index_dict = self.get_index_dict(index_style)
|
|
||||||
task_group = self.get_task_group(index_style, len(frames_guide))
|
|
||||||
# run
|
|
||||||
for tasks in task_group:
|
|
||||||
index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
|
|
||||||
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
|
|
||||||
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
|
||||||
source_guide, target_guide, source_style = [], [], []
|
|
||||||
for l, m, r in tasks_batch:
|
|
||||||
# l -> m
|
|
||||||
source_guide.append(frames_guide[l])
|
|
||||||
target_guide.append(frames_guide[m])
|
|
||||||
source_style.append(frames_style[index_dict[l]])
|
|
||||||
# r -> m
|
|
||||||
source_guide.append(frames_guide[r])
|
|
||||||
target_guide.append(frames_guide[m])
|
|
||||||
source_style.append(frames_style[index_dict[r]])
|
|
||||||
source_guide = np.stack(source_guide)
|
|
||||||
target_guide = np.stack(target_guide)
|
|
||||||
source_style = np.stack(source_style)
|
|
||||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
||||||
if save_path is not None:
|
|
||||||
for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
|
|
||||||
weight_l, weight_r = self.get_weight(l, m, r)
|
|
||||||
frame = frame_l * weight_l + frame_r * weight_r
|
|
||||||
frame = frame.clip(0, 255).astype("uint8")
|
|
||||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
|
|
||||||
|
|
||||||
|
|
||||||
class InterpolationModeSingleFrameRunner:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
|
|
||||||
# check input
|
|
||||||
tracking_window_size = ebsynth_config["tracking_window_size"]
|
|
||||||
if tracking_window_size * 2 >= batch_size:
|
|
||||||
raise ValueError("batch_size should be larger than track_window_size * 2")
|
|
||||||
frame_style = frames_style[0]
|
|
||||||
frame_guide = frames_guide[index_style[0]]
|
|
||||||
patch_match_engine = PyramidPatchMatcher(
|
|
||||||
image_height=frame_style.shape[0],
|
|
||||||
image_width=frame_style.shape[1],
|
|
||||||
channel=3,
|
|
||||||
**ebsynth_config
|
|
||||||
)
|
|
||||||
# run
|
|
||||||
frame_id, n = 0, len(frames_guide)
|
|
||||||
for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
|
|
||||||
if i + batch_size > n:
|
|
||||||
l, r = max(n - batch_size, 0), n
|
|
||||||
else:
|
|
||||||
l, r = i, i + batch_size
|
|
||||||
source_guide = np.stack([frame_guide] * (r-l))
|
|
||||||
target_guide = np.stack([frames_guide[i] for i in range(l, r)])
|
|
||||||
source_style = np.stack([frame_style] * (r-l))
|
|
||||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
|
||||||
for i, frame in zip(range(l, r), target_style):
|
|
||||||
if i==frame_id:
|
|
||||||
frame = frame.clip(0, 255).astype("uint8")
|
|
||||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
|
|
||||||
frame_id += 1
|
|
||||||
if r < n and r-frame_id <= tracking_window_size:
|
|
||||||
break
|
|
||||||
@@ -1,242 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
def warp(tenInput, tenFlow, device):
|
|
||||||
backwarp_tenGrid = {}
|
|
||||||
k = (str(tenFlow.device), str(tenFlow.size()))
|
|
||||||
if k not in backwarp_tenGrid:
|
|
||||||
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
|
|
||||||
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
|
||||||
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
|
|
||||||
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
|
||||||
backwarp_tenGrid[k] = torch.cat(
|
|
||||||
[tenHorizontal, tenVertical], 1).to(device)
|
|
||||||
|
|
||||||
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
|
|
||||||
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
|
|
||||||
|
|
||||||
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
|
|
||||||
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
|
|
||||||
|
|
||||||
|
|
||||||
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
|
||||||
return nn.Sequential(
|
|
||||||
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
|
||||||
padding=padding, dilation=dilation, bias=True),
|
|
||||||
nn.PReLU(out_planes)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class IFBlock(nn.Module):
|
|
||||||
def __init__(self, in_planes, c=64):
|
|
||||||
super(IFBlock, self).__init__()
|
|
||||||
self.conv0 = nn.Sequential(conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1),)
|
|
||||||
self.convblock0 = nn.Sequential(conv(c, c), conv(c, c))
|
|
||||||
self.convblock1 = nn.Sequential(conv(c, c), conv(c, c))
|
|
||||||
self.convblock2 = nn.Sequential(conv(c, c), conv(c, c))
|
|
||||||
self.convblock3 = nn.Sequential(conv(c, c), conv(c, c))
|
|
||||||
self.conv1 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 4, 4, 2, 1))
|
|
||||||
self.conv2 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 1, 4, 2, 1))
|
|
||||||
|
|
||||||
def forward(self, x, flow, scale=1):
|
|
||||||
x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
|
||||||
flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
|
|
||||||
feat = self.conv0(torch.cat((x, flow), 1))
|
|
||||||
feat = self.convblock0(feat) + feat
|
|
||||||
feat = self.convblock1(feat) + feat
|
|
||||||
feat = self.convblock2(feat) + feat
|
|
||||||
feat = self.convblock3(feat) + feat
|
|
||||||
flow = self.conv1(feat)
|
|
||||||
mask = self.conv2(feat)
|
|
||||||
flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale
|
|
||||||
mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
|
||||||
return flow, mask
|
|
||||||
|
|
||||||
|
|
||||||
class IFNet(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super(IFNet, self).__init__()
|
|
||||||
self.block0 = IFBlock(7+4, c=90)
|
|
||||||
self.block1 = IFBlock(7+4, c=90)
|
|
||||||
self.block2 = IFBlock(7+4, c=90)
|
|
||||||
self.block_tea = IFBlock(10+4, c=90)
|
|
||||||
|
|
||||||
def forward(self, x, scale_list=[4, 2, 1], training=False):
|
|
||||||
if training == False:
|
|
||||||
channel = x.shape[1] // 2
|
|
||||||
img0 = x[:, :channel]
|
|
||||||
img1 = x[:, channel:]
|
|
||||||
flow_list = []
|
|
||||||
merged = []
|
|
||||||
mask_list = []
|
|
||||||
warped_img0 = img0
|
|
||||||
warped_img1 = img1
|
|
||||||
flow = (x[:, :4]).detach() * 0
|
|
||||||
mask = (x[:, :1]).detach() * 0
|
|
||||||
block = [self.block0, self.block1, self.block2]
|
|
||||||
for i in range(3):
|
|
||||||
f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
|
|
||||||
f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
|
|
||||||
flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
|
|
||||||
mask = mask + (m0 + (-m1)) / 2
|
|
||||||
mask_list.append(mask)
|
|
||||||
flow_list.append(flow)
|
|
||||||
warped_img0 = warp(img0, flow[:, :2], device=x.device)
|
|
||||||
warped_img1 = warp(img1, flow[:, 2:4], device=x.device)
|
|
||||||
merged.append((warped_img0, warped_img1))
|
|
||||||
'''
|
|
||||||
c0 = self.contextnet(img0, flow[:, :2])
|
|
||||||
c1 = self.contextnet(img1, flow[:, 2:4])
|
|
||||||
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
|
||||||
res = tmp[:, 1:4] * 2 - 1
|
|
||||||
'''
|
|
||||||
for i in range(3):
|
|
||||||
mask_list[i] = torch.sigmoid(mask_list[i])
|
|
||||||
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
|
||||||
return flow_list, mask_list[2], merged
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return IFNetStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
class IFNetStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
state_dict_ = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return self.from_diffusers(state_dict)
|
|
||||||
|
|
||||||
|
|
||||||
class RIFEInterpolater:
|
|
||||||
def __init__(self, model, device="cuda"):
|
|
||||||
self.model = model
|
|
||||||
self.device = device
|
|
||||||
# IFNet only does not support float16
|
|
||||||
self.torch_dtype = torch.float32
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_model_manager(model_manager):
|
|
||||||
return RIFEInterpolater(model_manager.RIFE, device=model_manager.device)
|
|
||||||
|
|
||||||
def process_image(self, image):
|
|
||||||
width, height = image.size
|
|
||||||
if width % 32 != 0 or height % 32 != 0:
|
|
||||||
width = (width + 31) // 32
|
|
||||||
height = (height + 31) // 32
|
|
||||||
image = image.resize((width, height))
|
|
||||||
image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
|
|
||||||
return image
|
|
||||||
|
|
||||||
def process_images(self, images):
|
|
||||||
images = [self.process_image(image) for image in images]
|
|
||||||
images = torch.stack(images)
|
|
||||||
return images
|
|
||||||
|
|
||||||
def decode_images(self, images):
|
|
||||||
images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
|
|
||||||
images = [Image.fromarray(image) for image in images]
|
|
||||||
return images
|
|
||||||
|
|
||||||
def add_interpolated_images(self, images, interpolated_images):
|
|
||||||
output_images = []
|
|
||||||
for image, interpolated_image in zip(images, interpolated_images):
|
|
||||||
output_images.append(image)
|
|
||||||
output_images.append(interpolated_image)
|
|
||||||
output_images.append(images[-1])
|
|
||||||
return output_images
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def interpolate_(self, images, scale=1.0):
|
|
||||||
input_tensor = self.process_images(images)
|
|
||||||
input_tensor = torch.cat((input_tensor[:-1], input_tensor[1:]), dim=1)
|
|
||||||
input_tensor = input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
|
||||||
flow, mask, merged = self.model(input_tensor, [4/scale, 2/scale, 1/scale])
|
|
||||||
output_images = self.decode_images(merged[2].cpu())
|
|
||||||
if output_images[0].size != images[0].size:
|
|
||||||
output_images = [image.resize(images[0].size) for image in output_images]
|
|
||||||
return output_images
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1, progress_bar=lambda x:x):
|
|
||||||
# Preprocess
|
|
||||||
processed_images = self.process_images(images)
|
|
||||||
|
|
||||||
for iter in range(num_iter):
|
|
||||||
# Input
|
|
||||||
input_tensor = torch.cat((processed_images[:-1], processed_images[1:]), dim=1)
|
|
||||||
|
|
||||||
# Interpolate
|
|
||||||
output_tensor = []
|
|
||||||
for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
|
|
||||||
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
|
||||||
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
|
||||||
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
|
||||||
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
|
|
||||||
output_tensor.append(merged[2].cpu())
|
|
||||||
|
|
||||||
# Output
|
|
||||||
output_tensor = torch.concat(output_tensor, dim=0).clip(0, 1)
|
|
||||||
processed_images = self.add_interpolated_images(processed_images, output_tensor)
|
|
||||||
processed_images = torch.stack(processed_images)
|
|
||||||
|
|
||||||
# To images
|
|
||||||
output_images = self.decode_images(processed_images)
|
|
||||||
if output_images[0].size != images[0].size:
|
|
||||||
output_images = [image.resize(images[0].size) for image in output_images]
|
|
||||||
return output_images
|
|
||||||
|
|
||||||
|
|
||||||
class RIFESmoother(RIFEInterpolater):
|
|
||||||
def __init__(self, model, device="cuda"):
|
|
||||||
super(RIFESmoother, self).__init__(model, device=device)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_model_manager(model_manager):
|
|
||||||
return RIFESmoother(model_manager.RIFE, device=model_manager.device)
|
|
||||||
|
|
||||||
def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
|
|
||||||
output_tensor = []
|
|
||||||
for batch_id in range(0, input_tensor.shape[0], batch_size):
|
|
||||||
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
|
||||||
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
|
||||||
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
|
||||||
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
|
|
||||||
output_tensor.append(merged[2].cpu())
|
|
||||||
output_tensor = torch.concat(output_tensor, dim=0)
|
|
||||||
return output_tensor
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=1, **kwargs):
|
|
||||||
# Preprocess
|
|
||||||
processed_images = self.process_images(rendered_frames)
|
|
||||||
|
|
||||||
for iter in range(num_iter):
|
|
||||||
# Input
|
|
||||||
input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
|
|
||||||
|
|
||||||
# Interpolate
|
|
||||||
output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
|
|
||||||
|
|
||||||
# Blend
|
|
||||||
input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
|
|
||||||
output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
|
|
||||||
|
|
||||||
# Add to frames
|
|
||||||
processed_images[1:-1] = output_tensor
|
|
||||||
|
|
||||||
# To images
|
|
||||||
output_images = self.decode_images(processed_images)
|
|
||||||
if output_images[0].size != rendered_frames[0].size:
|
|
||||||
output_images = [image.resize(rendered_frames[0].size) for image in output_images]
|
|
||||||
return output_images
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
from .model_manager import *
|
|
||||||
@@ -1,89 +0,0 @@
|
|||||||
import torch
|
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
|
|
||||||
def low_version_attention(query, key, value, attn_bias=None):
|
|
||||||
scale = 1 / query.shape[-1] ** 0.5
|
|
||||||
query = query * scale
|
|
||||||
attn = torch.matmul(query, key.transpose(-2, -1))
|
|
||||||
if attn_bias is not None:
|
|
||||||
attn = attn + attn_bias
|
|
||||||
attn = attn.softmax(-1)
|
|
||||||
return attn @ value
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
|
||||||
super().__init__()
|
|
||||||
dim_inner = head_dim * num_heads
|
|
||||||
kv_dim = kv_dim if kv_dim is not None else q_dim
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = head_dim
|
|
||||||
|
|
||||||
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
|
||||||
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
|
||||||
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
|
||||||
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
|
||||||
|
|
||||||
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
|
|
||||||
batch_size = q.shape[0]
|
|
||||||
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
|
||||||
hidden_states = hidden_states + scale * ip_hidden_states
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
|
||||||
if encoder_hidden_states is None:
|
|
||||||
encoder_hidden_states = hidden_states
|
|
||||||
|
|
||||||
batch_size = encoder_hidden_states.shape[0]
|
|
||||||
|
|
||||||
q = self.to_q(hidden_states)
|
|
||||||
k = self.to_k(encoder_hidden_states)
|
|
||||||
v = self.to_v(encoder_hidden_states)
|
|
||||||
|
|
||||||
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
if qkv_preprocessor is not None:
|
|
||||||
q, k, v = qkv_preprocessor(q, k, v)
|
|
||||||
|
|
||||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
|
||||||
if ipadapter_kwargs is not None:
|
|
||||||
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
||||||
hidden_states = hidden_states.to(q.dtype)
|
|
||||||
|
|
||||||
hidden_states = self.to_out(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
|
||||||
if encoder_hidden_states is None:
|
|
||||||
encoder_hidden_states = hidden_states
|
|
||||||
|
|
||||||
q = self.to_q(hidden_states)
|
|
||||||
k = self.to_k(encoder_hidden_states)
|
|
||||||
v = self.to_v(encoder_hidden_states)
|
|
||||||
|
|
||||||
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
|
|
||||||
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
|
|
||||||
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
|
|
||||||
|
|
||||||
if attn_mask is not None:
|
|
||||||
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
|
|
||||||
else:
|
|
||||||
import xformers.ops as xops
|
|
||||||
hidden_states = xops.memory_efficient_attention(q, k, v)
|
|
||||||
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.to(q.dtype)
|
|
||||||
hidden_states = self.to_out(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
|
||||||
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
|
|
||||||
96
diffsynth/models/dinov3_image_encoder.py
Normal file
96
diffsynth/models/dinov3_image_encoder.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
|
||||||
|
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
|
|
||||||
|
|
||||||
|
class DINOv3ImageEncoder(DINOv3ViTModel):
|
||||||
|
def __init__(self):
|
||||||
|
config = DINOv3ViTConfig(
|
||||||
|
architectures = [
|
||||||
|
"DINOv3ViTModel"
|
||||||
|
],
|
||||||
|
attention_dropout = 0.0,
|
||||||
|
drop_path_rate = 0.0,
|
||||||
|
dtype = "float32",
|
||||||
|
hidden_act = "silu",
|
||||||
|
hidden_size = 4096,
|
||||||
|
image_size = 224,
|
||||||
|
initializer_range = 0.02,
|
||||||
|
intermediate_size = 8192,
|
||||||
|
key_bias = False,
|
||||||
|
layer_norm_eps = 1e-05,
|
||||||
|
layerscale_value = 1.0,
|
||||||
|
mlp_bias = True,
|
||||||
|
model_type = "dinov3_vit",
|
||||||
|
num_attention_heads = 32,
|
||||||
|
num_channels = 3,
|
||||||
|
num_hidden_layers = 40,
|
||||||
|
num_register_tokens = 4,
|
||||||
|
patch_size = 16,
|
||||||
|
pos_embed_jitter = None,
|
||||||
|
pos_embed_rescale = 2.0,
|
||||||
|
pos_embed_shift = None,
|
||||||
|
proj_bias = True,
|
||||||
|
query_bias = False,
|
||||||
|
rope_theta = 100.0,
|
||||||
|
transformers_version = "4.56.1",
|
||||||
|
use_gated_mlp = True,
|
||||||
|
value_bias = False
|
||||||
|
)
|
||||||
|
super().__init__(config)
|
||||||
|
self.processor = DINOv3ViTImageProcessorFast(
|
||||||
|
crop_size = None,
|
||||||
|
data_format = "channels_first",
|
||||||
|
default_to_square = True,
|
||||||
|
device = None,
|
||||||
|
disable_grouping = None,
|
||||||
|
do_center_crop = None,
|
||||||
|
do_convert_rgb = None,
|
||||||
|
do_normalize = True,
|
||||||
|
do_rescale = True,
|
||||||
|
do_resize = True,
|
||||||
|
image_mean = [
|
||||||
|
0.485,
|
||||||
|
0.456,
|
||||||
|
0.406
|
||||||
|
],
|
||||||
|
image_processor_type = "DINOv3ViTImageProcessorFast",
|
||||||
|
image_std = [
|
||||||
|
0.229,
|
||||||
|
0.224,
|
||||||
|
0.225
|
||||||
|
],
|
||||||
|
input_data_format = None,
|
||||||
|
resample = 2,
|
||||||
|
rescale_factor = 0.00392156862745098,
|
||||||
|
return_tensors = None,
|
||||||
|
size = {
|
||||||
|
"height": 224,
|
||||||
|
"width": 224
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
|
||||||
|
inputs = self.processor(images=image, return_tensors="pt")
|
||||||
|
pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device)
|
||||||
|
bool_masked_pos = None
|
||||||
|
head_mask = None
|
||||||
|
|
||||||
|
pixel_values = pixel_values.to(torch_dtype)
|
||||||
|
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||||
|
position_embeddings = self.rope_embeddings(pixel_values)
|
||||||
|
|
||||||
|
for i, layer_module in enumerate(self.layer):
|
||||||
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
|
hidden_states = layer_module(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=layer_head_mask,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence_output = self.norm(hidden_states)
|
||||||
|
pooled_output = sequence_output[:, 0, :]
|
||||||
|
|
||||||
|
return pooled_output
|
||||||
@@ -1,66 +0,0 @@
|
|||||||
from huggingface_hub import hf_hub_download
|
|
||||||
from modelscope import snapshot_download
|
|
||||||
import os, shutil
|
|
||||||
from typing_extensions import Literal, TypeAlias
|
|
||||||
from typing import List
|
|
||||||
from ..configs.model_config import preset_models_on_huggingface, preset_models_on_modelscope, Preset_model_id
|
|
||||||
|
|
||||||
|
|
||||||
def download_from_modelscope(model_id, origin_file_path, local_dir):
|
|
||||||
os.makedirs(local_dir, exist_ok=True)
|
|
||||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
|
||||||
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
|
|
||||||
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
|
|
||||||
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
|
||||||
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
|
|
||||||
if downloaded_file_path != target_file_path:
|
|
||||||
shutil.move(downloaded_file_path, target_file_path)
|
|
||||||
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
|
||||||
|
|
||||||
|
|
||||||
def download_from_huggingface(model_id, origin_file_path, local_dir):
|
|
||||||
os.makedirs(local_dir, exist_ok=True)
|
|
||||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
|
||||||
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
|
|
||||||
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
|
|
||||||
|
|
||||||
|
|
||||||
Preset_model_website: TypeAlias = Literal[
|
|
||||||
"HuggingFace",
|
|
||||||
"ModelScope",
|
|
||||||
]
|
|
||||||
website_to_preset_models = {
|
|
||||||
"HuggingFace": preset_models_on_huggingface,
|
|
||||||
"ModelScope": preset_models_on_modelscope,
|
|
||||||
}
|
|
||||||
website_to_download_fn = {
|
|
||||||
"HuggingFace": download_from_huggingface,
|
|
||||||
"ModelScope": download_from_modelscope,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def download_models(
|
|
||||||
model_id_list: List[Preset_model_id] = [],
|
|
||||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
|
||||||
):
|
|
||||||
print(f"Downloading models: {model_id_list}")
|
|
||||||
downloaded_files = []
|
|
||||||
for model_id in model_id_list:
|
|
||||||
for website in downloading_priority:
|
|
||||||
if model_id in website_to_preset_models[website]:
|
|
||||||
for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]:
|
|
||||||
# Check if the file is downloaded.
|
|
||||||
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
|
||||||
if file_to_download in downloaded_files:
|
|
||||||
continue
|
|
||||||
# Download
|
|
||||||
website_to_download_fn[website](model_id, origin_file_path, local_dir)
|
|
||||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
|
||||||
downloaded_files.append(file_to_download)
|
|
||||||
return downloaded_files
|
|
||||||
1050
diffsynth/models/flux2_dit.py
Normal file
1050
diffsynth/models/flux2_dit.py
Normal file
File diff suppressed because it is too large
Load Diff
58
diffsynth/models/flux2_text_encoder.py
Normal file
58
diffsynth/models/flux2_text_encoder.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
from transformers import Mistral3ForConditionalGeneration, Mistral3Config
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2TextEncoder(Mistral3ForConditionalGeneration):
|
||||||
|
def __init__(self):
|
||||||
|
config = Mistral3Config(**{
|
||||||
|
"architectures": [
|
||||||
|
"Mistral3ForConditionalGeneration"
|
||||||
|
],
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"image_token_index": 10,
|
||||||
|
"model_type": "mistral3",
|
||||||
|
"multimodal_projector_bias": False,
|
||||||
|
"projector_hidden_act": "gelu",
|
||||||
|
"spatial_merge_size": 2,
|
||||||
|
"text_config": {
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"head_dim": 128,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 5120,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 32768,
|
||||||
|
"max_position_embeddings": 131072,
|
||||||
|
"model_type": "mistral",
|
||||||
|
"num_attention_heads": 32,
|
||||||
|
"num_hidden_layers": 40,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"rms_norm_eps": 1e-05,
|
||||||
|
"rope_theta": 1000000000.0,
|
||||||
|
"sliding_window": None,
|
||||||
|
"use_cache": True,
|
||||||
|
"vocab_size": 131072
|
||||||
|
},
|
||||||
|
"transformers_version": "4.57.1",
|
||||||
|
"vision_config": {
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"head_dim": 64,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 1024,
|
||||||
|
"image_size": 1540,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 4096,
|
||||||
|
"model_type": "pixtral",
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_channels": 3,
|
||||||
|
"num_hidden_layers": 24,
|
||||||
|
"patch_size": 14,
|
||||||
|
"rope_theta": 10000.0
|
||||||
|
},
|
||||||
|
"vision_feature_layer": -1
|
||||||
|
})
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
def forward(self, input_ids = None, pixel_values = None, attention_mask = None, position_ids = None, past_key_values = None, inputs_embeds = None, labels = None, use_cache = None, output_attentions = None, output_hidden_states = None, return_dict = None, cache_position = None, logits_to_keep = 0, image_sizes = None, **kwargs):
|
||||||
|
return super().forward(input_ids, pixel_values, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, image_sizes, **kwargs)
|
||||||
|
|
||||||
2322
diffsynth/models/flux2_vae.py
Normal file
2322
diffsynth/models/flux2_vae.py
Normal file
File diff suppressed because it is too large
Load Diff
384
diffsynth/models/flux_controlnet.py
Normal file
384
diffsynth/models/flux_controlnet.py
Normal file
@@ -0,0 +1,384 @@
|
|||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm
|
||||||
|
# from .utils import hash_state_dict_keys, init_weights_on_device
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
def hash_state_dict_keys(state_dict, with_shape=True):
|
||||||
|
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
||||||
|
keys_str = keys_str.encode(encoding="UTF-8")
|
||||||
|
return hashlib.md5(keys_str).hexdigest()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
|
||||||
|
|
||||||
|
old_register_parameter = torch.nn.Module.register_parameter
|
||||||
|
if include_buffers:
|
||||||
|
old_register_buffer = torch.nn.Module.register_buffer
|
||||||
|
|
||||||
|
def register_empty_parameter(module, name, param):
|
||||||
|
old_register_parameter(module, name, param)
|
||||||
|
if param is not None:
|
||||||
|
param_cls = type(module._parameters[name])
|
||||||
|
kwargs = module._parameters[name].__dict__
|
||||||
|
kwargs["requires_grad"] = param.requires_grad
|
||||||
|
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
||||||
|
|
||||||
|
def register_empty_buffer(module, name, buffer, persistent=True):
|
||||||
|
old_register_buffer(module, name, buffer, persistent=persistent)
|
||||||
|
if buffer is not None:
|
||||||
|
module._buffers[name] = module._buffers[name].to(device)
|
||||||
|
|
||||||
|
def patch_tensor_constructor(fn):
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
kwargs["device"] = device
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
if include_buffers:
|
||||||
|
tensor_constructors_to_patch = {
|
||||||
|
torch_function_name: getattr(torch, torch_function_name)
|
||||||
|
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
tensor_constructors_to_patch = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
torch.nn.Module.register_parameter = register_empty_parameter
|
||||||
|
if include_buffers:
|
||||||
|
torch.nn.Module.register_buffer = register_empty_buffer
|
||||||
|
for torch_function_name in tensor_constructors_to_patch.keys():
|
||||||
|
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
torch.nn.Module.register_parameter = old_register_parameter
|
||||||
|
if include_buffers:
|
||||||
|
torch.nn.Module.register_buffer = old_register_buffer
|
||||||
|
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
||||||
|
setattr(torch, torch_function_name, old_torch_function)
|
||||||
|
|
||||||
|
class FluxControlNet(torch.nn.Module):
|
||||||
|
def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
|
||||||
|
super().__init__()
|
||||||
|
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||||
|
self.time_embedder = TimestepEmbeddings(256, 3072)
|
||||||
|
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
||||||
|
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
||||||
|
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||||
|
self.x_embedder = torch.nn.Linear(64, 3072)
|
||||||
|
|
||||||
|
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)])
|
||||||
|
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)])
|
||||||
|
|
||||||
|
self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)])
|
||||||
|
self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)])
|
||||||
|
|
||||||
|
self.mode_dict = mode_dict
|
||||||
|
self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None
|
||||||
|
self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_image_ids(self, latents):
|
||||||
|
batch_size, _, height, width = latents.shape
|
||||||
|
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||||
|
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||||
|
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||||
|
|
||||||
|
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||||
|
|
||||||
|
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
||||||
|
latent_image_ids = latent_image_ids.reshape(
|
||||||
|
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||||
|
)
|
||||||
|
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
||||||
|
|
||||||
|
return latent_image_ids
|
||||||
|
|
||||||
|
|
||||||
|
def patchify(self, hidden_states):
|
||||||
|
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states):
|
||||||
|
if len(res_stack) == 0:
|
||||||
|
return [torch.zeros_like(hidden_states)] * num_blocks
|
||||||
|
interval = (num_blocks + len(res_stack) - 1) // len(res_stack)
|
||||||
|
aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)]
|
||||||
|
return aligned_res_stack
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
controlnet_conditioning,
|
||||||
|
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
|
||||||
|
processor_id=None,
|
||||||
|
tiled=False, tile_size=128, tile_stride=64,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if image_ids is None:
|
||||||
|
image_ids = self.prepare_image_ids(hidden_states)
|
||||||
|
|
||||||
|
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
||||||
|
if self.guidance_embedder is not None:
|
||||||
|
guidance = guidance * 1000
|
||||||
|
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
|
||||||
|
prompt_emb = self.context_embedder(prompt_emb)
|
||||||
|
if self.controlnet_mode_embedder is not None: # Different from FluxDiT
|
||||||
|
processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int)
|
||||||
|
processor_id = repeat(processor_id, "D -> B D", B=1).to(text_ids.device)
|
||||||
|
prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1)
|
||||||
|
text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1)
|
||||||
|
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||||
|
|
||||||
|
hidden_states = self.patchify(hidden_states)
|
||||||
|
hidden_states = self.x_embedder(hidden_states)
|
||||||
|
controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT
|
||||||
|
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT
|
||||||
|
|
||||||
|
controlnet_res_stack = []
|
||||||
|
for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
|
||||||
|
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||||
|
controlnet_res_stack.append(controlnet_block(hidden_states))
|
||||||
|
|
||||||
|
controlnet_single_res_stack = []
|
||||||
|
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||||
|
for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks):
|
||||||
|
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||||
|
controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:]))
|
||||||
|
|
||||||
|
controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:])
|
||||||
|
controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:])
|
||||||
|
|
||||||
|
return controlnet_res_stack, controlnet_single_res_stack
|
||||||
|
|
||||||
|
|
||||||
|
# @staticmethod
|
||||||
|
# def state_dict_converter():
|
||||||
|
# return FluxControlNetStateDictConverter()
|
||||||
|
|
||||||
|
def quantize(self):
|
||||||
|
def cast_to(weight, dtype=None, device=None, copy=False):
|
||||||
|
if device is None or weight.device == device:
|
||||||
|
if not copy:
|
||||||
|
if dtype is None or weight.dtype == dtype:
|
||||||
|
return weight
|
||||||
|
return weight.to(dtype=dtype, copy=copy)
|
||||||
|
|
||||||
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
|
r.copy_(weight)
|
||||||
|
return r
|
||||||
|
|
||||||
|
def cast_weight(s, input=None, dtype=None, device=None):
|
||||||
|
if input is not None:
|
||||||
|
if dtype is None:
|
||||||
|
dtype = input.dtype
|
||||||
|
if device is None:
|
||||||
|
device = input.device
|
||||||
|
weight = cast_to(s.weight, dtype, device)
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||||
|
if input is not None:
|
||||||
|
if dtype is None:
|
||||||
|
dtype = input.dtype
|
||||||
|
if bias_dtype is None:
|
||||||
|
bias_dtype = dtype
|
||||||
|
if device is None:
|
||||||
|
device = input.device
|
||||||
|
bias = None
|
||||||
|
weight = cast_to(s.weight, dtype, device)
|
||||||
|
bias = cast_to(s.bias, bias_dtype, device)
|
||||||
|
return weight, bias
|
||||||
|
|
||||||
|
class quantized_layer:
|
||||||
|
class QLinear(torch.nn.Linear):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self,input,**kwargs):
|
||||||
|
weight,bias= cast_bias_weight(self,input)
|
||||||
|
return torch.nn.functional.linear(input,weight,bias)
|
||||||
|
|
||||||
|
class QRMSNorm(torch.nn.Module):
|
||||||
|
def __init__(self, module):
|
||||||
|
super().__init__()
|
||||||
|
self.module = module
|
||||||
|
|
||||||
|
def forward(self,hidden_states,**kwargs):
|
||||||
|
weight= cast_weight(self.module,hidden_states)
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
|
||||||
|
hidden_states = hidden_states.to(input_dtype) * weight
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
class QEmbedding(torch.nn.Embedding):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self,input,**kwargs):
|
||||||
|
weight= cast_weight(self,input)
|
||||||
|
return torch.nn.functional.embedding(
|
||||||
|
input, weight, self.padding_idx, self.max_norm,
|
||||||
|
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
||||||
|
|
||||||
|
def replace_layer(model):
|
||||||
|
for name, module in model.named_children():
|
||||||
|
if isinstance(module,quantized_layer.QRMSNorm):
|
||||||
|
continue
|
||||||
|
if isinstance(module, torch.nn.Linear):
|
||||||
|
with init_weights_on_device():
|
||||||
|
new_layer = quantized_layer.QLinear(module.in_features,module.out_features)
|
||||||
|
new_layer.weight = module.weight
|
||||||
|
if module.bias is not None:
|
||||||
|
new_layer.bias = module.bias
|
||||||
|
setattr(model, name, new_layer)
|
||||||
|
elif isinstance(module, RMSNorm):
|
||||||
|
if hasattr(module,"quantized"):
|
||||||
|
continue
|
||||||
|
module.quantized= True
|
||||||
|
new_layer = quantized_layer.QRMSNorm(module)
|
||||||
|
setattr(model, name, new_layer)
|
||||||
|
elif isinstance(module,torch.nn.Embedding):
|
||||||
|
rows, cols = module.weight.shape
|
||||||
|
new_layer = quantized_layer.QEmbedding(
|
||||||
|
num_embeddings=rows,
|
||||||
|
embedding_dim=cols,
|
||||||
|
_weight=module.weight,
|
||||||
|
# _freeze=module.freeze,
|
||||||
|
padding_idx=module.padding_idx,
|
||||||
|
max_norm=module.max_norm,
|
||||||
|
norm_type=module.norm_type,
|
||||||
|
scale_grad_by_freq=module.scale_grad_by_freq,
|
||||||
|
sparse=module.sparse)
|
||||||
|
setattr(model, name, new_layer)
|
||||||
|
else:
|
||||||
|
replace_layer(module)
|
||||||
|
|
||||||
|
replace_layer(self)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FluxControlNetStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
hash_value = hash_state_dict_keys(state_dict)
|
||||||
|
global_rename_dict = {
|
||||||
|
"context_embedder": "context_embedder",
|
||||||
|
"x_embedder": "x_embedder",
|
||||||
|
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
||||||
|
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
||||||
|
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
|
||||||
|
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
|
||||||
|
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
||||||
|
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
||||||
|
"norm_out.linear": "final_norm_out.linear",
|
||||||
|
"proj_out": "final_proj_out",
|
||||||
|
}
|
||||||
|
rename_dict = {
|
||||||
|
"proj_out": "proj_out",
|
||||||
|
"norm1.linear": "norm1_a.linear",
|
||||||
|
"norm1_context.linear": "norm1_b.linear",
|
||||||
|
"attn.to_q": "attn.a_to_q",
|
||||||
|
"attn.to_k": "attn.a_to_k",
|
||||||
|
"attn.to_v": "attn.a_to_v",
|
||||||
|
"attn.to_out.0": "attn.a_to_out",
|
||||||
|
"attn.add_q_proj": "attn.b_to_q",
|
||||||
|
"attn.add_k_proj": "attn.b_to_k",
|
||||||
|
"attn.add_v_proj": "attn.b_to_v",
|
||||||
|
"attn.to_add_out": "attn.b_to_out",
|
||||||
|
"ff.net.0.proj": "ff_a.0",
|
||||||
|
"ff.net.2": "ff_a.2",
|
||||||
|
"ff_context.net.0.proj": "ff_b.0",
|
||||||
|
"ff_context.net.2": "ff_b.2",
|
||||||
|
"attn.norm_q": "attn.norm_q_a",
|
||||||
|
"attn.norm_k": "attn.norm_k_a",
|
||||||
|
"attn.norm_added_q": "attn.norm_q_b",
|
||||||
|
"attn.norm_added_k": "attn.norm_k_b",
|
||||||
|
}
|
||||||
|
rename_dict_single = {
|
||||||
|
"attn.to_q": "a_to_q",
|
||||||
|
"attn.to_k": "a_to_k",
|
||||||
|
"attn.to_v": "a_to_v",
|
||||||
|
"attn.norm_q": "norm_q_a",
|
||||||
|
"attn.norm_k": "norm_k_a",
|
||||||
|
"norm.linear": "norm.linear",
|
||||||
|
"proj_mlp": "proj_in_besides_attn",
|
||||||
|
"proj_out": "proj_out",
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name.endswith(".weight") or name.endswith(".bias"):
|
||||||
|
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
||||||
|
prefix = name[:-len(suffix)]
|
||||||
|
if prefix in global_rename_dict:
|
||||||
|
state_dict_[global_rename_dict[prefix] + suffix] = param
|
||||||
|
elif prefix.startswith("transformer_blocks."):
|
||||||
|
names = prefix.split(".")
|
||||||
|
names[0] = "blocks"
|
||||||
|
middle = ".".join(names[2:])
|
||||||
|
if middle in rename_dict:
|
||||||
|
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
||||||
|
state_dict_[name_] = param
|
||||||
|
elif prefix.startswith("single_transformer_blocks."):
|
||||||
|
names = prefix.split(".")
|
||||||
|
names[0] = "single_blocks"
|
||||||
|
middle = ".".join(names[2:])
|
||||||
|
if middle in rename_dict_single:
|
||||||
|
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
|
||||||
|
state_dict_[name_] = param
|
||||||
|
else:
|
||||||
|
state_dict_[name] = param
|
||||||
|
else:
|
||||||
|
state_dict_[name] = param
|
||||||
|
for name in list(state_dict_.keys()):
|
||||||
|
if ".proj_in_besides_attn." in name:
|
||||||
|
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
|
||||||
|
param = torch.concat([
|
||||||
|
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
|
||||||
|
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
|
||||||
|
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
|
||||||
|
state_dict_[name],
|
||||||
|
], dim=0)
|
||||||
|
state_dict_[name_] = param
|
||||||
|
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
|
||||||
|
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
|
||||||
|
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
|
||||||
|
state_dict_.pop(name)
|
||||||
|
for name in list(state_dict_.keys()):
|
||||||
|
for component in ["a", "b"]:
|
||||||
|
if f".{component}_to_q." in name:
|
||||||
|
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
||||||
|
param = torch.concat([
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||||
|
], dim=0)
|
||||||
|
state_dict_[name_] = param
|
||||||
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
||||||
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||||
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
||||||
|
if hash_value == "78d18b9101345ff695f312e7e62538c0":
|
||||||
|
extra_kwargs = {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}}
|
||||||
|
elif hash_value == "b001c89139b5f053c715fe772362dd2a":
|
||||||
|
extra_kwargs = {"num_single_blocks": 0}
|
||||||
|
elif hash_value == "52357cb26250681367488a8954c271e8":
|
||||||
|
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
|
||||||
|
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
|
||||||
|
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
|
||||||
|
elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16":
|
||||||
|
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10}
|
||||||
|
elif hash_value == "43ad5aaa27dd4ee01b832ed16773fa52":
|
||||||
|
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0}
|
||||||
|
else:
|
||||||
|
extra_kwargs = {}
|
||||||
|
return state_dict_, extra_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return self.from_diffusers(state_dict)
|
||||||
395
diffsynth/models/flux_dit.py
Normal file
395
diffsynth/models/flux_dit.py
Normal file
@@ -0,0 +1,395 @@
|
|||||||
|
import torch
|
||||||
|
from .general_modules import TimestepEmbeddings, AdaLayerNorm, RMSNorm
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
|
||||||
|
batch_size, num_tokens = hidden_states.shape[0:2]
|
||||||
|
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
||||||
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, num_tokens, -1)
|
||||||
|
hidden_states = hidden_states + scale * ip_hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class RoPEEmbedding(torch.nn.Module):
|
||||||
|
def __init__(self, dim, theta, axes_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.theta = theta
|
||||||
|
self.axes_dim = axes_dim
|
||||||
|
|
||||||
|
|
||||||
|
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||||
|
assert dim % 2 == 0, "The dimension must be even."
|
||||||
|
|
||||||
|
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||||
|
omega = 1.0 / (theta**scale)
|
||||||
|
|
||||||
|
batch_size, seq_length = pos.shape
|
||||||
|
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||||
|
cos_out = torch.cos(out)
|
||||||
|
sin_out = torch.sin(out)
|
||||||
|
|
||||||
|
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
||||||
|
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
||||||
|
return out.float()
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, ids):
|
||||||
|
n_axes = ids.shape[-1]
|
||||||
|
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
|
||||||
|
return emb.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FluxJointAttention(torch.nn.Module):
|
||||||
|
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.only_out_a = only_out_a
|
||||||
|
|
||||||
|
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
||||||
|
self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
|
||||||
|
|
||||||
|
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
||||||
|
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
||||||
|
self.norm_q_b = RMSNorm(head_dim, eps=1e-6)
|
||||||
|
self.norm_k_b = RMSNorm(head_dim, eps=1e-6)
|
||||||
|
|
||||||
|
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
|
||||||
|
if not only_out_a:
|
||||||
|
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rope(self, xq, xk, freqs_cis):
|
||||||
|
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||||
|
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||||
|
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||||
|
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||||
|
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||||
|
|
||||||
|
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||||
|
batch_size = hidden_states_a.shape[0]
|
||||||
|
|
||||||
|
# Part A
|
||||||
|
qkv_a = self.a_to_qkv(hidden_states_a)
|
||||||
|
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
q_a, k_a, v_a = qkv_a.chunk(3, dim=1)
|
||||||
|
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
|
||||||
|
|
||||||
|
# Part B
|
||||||
|
qkv_b = self.b_to_qkv(hidden_states_b)
|
||||||
|
qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
q_b, k_b, v_b = qkv_b.chunk(3, dim=1)
|
||||||
|
q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b)
|
||||||
|
|
||||||
|
q = torch.concat([q_b, q_a], dim=2)
|
||||||
|
k = torch.concat([k_b, k_a], dim=2)
|
||||||
|
v = torch.concat([v_b, v_a], dim=2)
|
||||||
|
|
||||||
|
q, k = self.apply_rope(q, k, image_rotary_emb)
|
||||||
|
|
||||||
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||||
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
|
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
|
||||||
|
if ipadapter_kwargs_list is not None:
|
||||||
|
hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list)
|
||||||
|
hidden_states_a = self.a_to_out(hidden_states_a)
|
||||||
|
if self.only_out_a:
|
||||||
|
return hidden_states_a
|
||||||
|
else:
|
||||||
|
hidden_states_b = self.b_to_out(hidden_states_b)
|
||||||
|
return hidden_states_a, hidden_states_b
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FluxJointTransformerBlock(torch.nn.Module):
|
||||||
|
def __init__(self, dim, num_attention_heads):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1_a = AdaLayerNorm(dim)
|
||||||
|
self.norm1_b = AdaLayerNorm(dim)
|
||||||
|
|
||||||
|
self.attn = FluxJointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
|
||||||
|
|
||||||
|
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.ff_a = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(dim, dim*4),
|
||||||
|
torch.nn.GELU(approximate="tanh"),
|
||||||
|
torch.nn.Linear(dim*4, dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.ff_b = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(dim, dim*4),
|
||||||
|
torch.nn.GELU(approximate="tanh"),
|
||||||
|
torch.nn.Linear(dim*4, dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||||
|
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
||||||
|
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||||
|
|
||||||
|
# Attention
|
||||||
|
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
|
||||||
|
|
||||||
|
# Part A
|
||||||
|
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||||
|
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
||||||
|
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
||||||
|
|
||||||
|
# Part B
|
||||||
|
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
|
||||||
|
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
|
||||||
|
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
||||||
|
|
||||||
|
return hidden_states_a, hidden_states_b
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FluxSingleAttention(torch.nn.Module):
|
||||||
|
def __init__(self, dim_a, dim_b, num_heads, head_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
|
||||||
|
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
||||||
|
|
||||||
|
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
||||||
|
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rope(self, xq, xk, freqs_cis):
|
||||||
|
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||||
|
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||||
|
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||||
|
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||||
|
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, hidden_states, image_rotary_emb):
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
|
||||||
|
qkv_a = self.a_to_qkv(hidden_states)
|
||||||
|
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
q_a, k_a, v = qkv_a.chunk(3, dim=1)
|
||||||
|
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
|
||||||
|
|
||||||
|
q, k = self.apply_rope(q_a, k_a, image_rotary_emb)
|
||||||
|
|
||||||
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||||
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class AdaLayerNormSingle(torch.nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.silu = torch.nn.SiLU()
|
||||||
|
self.linear = torch.nn.Linear(dim, 3 * dim, bias=True)
|
||||||
|
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x, emb):
|
||||||
|
emb = self.linear(self.silu(emb))
|
||||||
|
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
|
||||||
|
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||||
|
return x, gate_msa
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FluxSingleTransformerBlock(torch.nn.Module):
|
||||||
|
def __init__(self, dim, num_attention_heads):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_attention_heads
|
||||||
|
self.head_dim = dim // num_attention_heads
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
self.norm = AdaLayerNormSingle(dim)
|
||||||
|
self.to_qkv_mlp = torch.nn.Linear(dim, dim * (3 + 4))
|
||||||
|
self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6)
|
||||||
|
self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6)
|
||||||
|
|
||||||
|
self.proj_out = torch.nn.Linear(dim * 5, dim)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rope(self, xq, xk, freqs_cis):
|
||||||
|
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||||
|
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||||
|
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||||
|
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||||
|
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||||
|
|
||||||
|
|
||||||
|
def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
|
||||||
|
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
q, k, v = qkv.chunk(3, dim=1)
|
||||||
|
q, k = self.norm_q_a(q), self.norm_k_a(k)
|
||||||
|
|
||||||
|
q, k = self.apply_rope(q, k, image_rotary_emb)
|
||||||
|
|
||||||
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||||
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
|
if ipadapter_kwargs_list is not None:
|
||||||
|
hidden_states = interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs_list)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||||
|
residual = hidden_states_a
|
||||||
|
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
|
||||||
|
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
|
||||||
|
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
|
||||||
|
|
||||||
|
attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
|
||||||
|
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
|
||||||
|
|
||||||
|
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
||||||
|
hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
|
||||||
|
hidden_states_a = residual + hidden_states_a
|
||||||
|
|
||||||
|
return hidden_states_a, hidden_states_b
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class AdaLayerNormContinuous(torch.nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.silu = torch.nn.SiLU()
|
||||||
|
self.linear = torch.nn.Linear(dim, dim * 2, bias=True)
|
||||||
|
self.norm = torch.nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
|
||||||
|
|
||||||
|
def forward(self, x, conditioning):
|
||||||
|
emb = self.linear(self.silu(conditioning))
|
||||||
|
shift, scale = torch.chunk(emb, 2, dim=1)
|
||||||
|
x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FluxDiT(torch.nn.Module):
|
||||||
|
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
|
||||||
|
super().__init__()
|
||||||
|
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||||
|
self.time_embedder = TimestepEmbeddings(256, 3072)
|
||||||
|
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
||||||
|
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
||||||
|
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||||
|
self.x_embedder = torch.nn.Linear(input_dim, 3072)
|
||||||
|
|
||||||
|
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)])
|
||||||
|
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
|
||||||
|
|
||||||
|
self.final_norm_out = AdaLayerNormContinuous(3072)
|
||||||
|
self.final_proj_out = torch.nn.Linear(3072, 64)
|
||||||
|
|
||||||
|
self.input_dim = input_dim
|
||||||
|
|
||||||
|
|
||||||
|
def patchify(self, hidden_states):
|
||||||
|
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def unpatchify(self, hidden_states, height, width):
|
||||||
|
hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_image_ids(self, latents):
|
||||||
|
batch_size, _, height, width = latents.shape
|
||||||
|
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||||
|
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||||
|
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||||
|
|
||||||
|
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||||
|
|
||||||
|
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
||||||
|
latent_image_ids = latent_image_ids.reshape(
|
||||||
|
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||||
|
)
|
||||||
|
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
||||||
|
|
||||||
|
return latent_image_ids
|
||||||
|
|
||||||
|
|
||||||
|
def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
|
||||||
|
N = len(entity_masks)
|
||||||
|
batch_size = entity_masks[0].shape[0]
|
||||||
|
total_seq_len = N * prompt_seq_len + image_seq_len
|
||||||
|
patched_masks = [self.patchify(entity_masks[i]) for i in range(N)]
|
||||||
|
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
|
||||||
|
|
||||||
|
image_start = N * prompt_seq_len
|
||||||
|
image_end = N * prompt_seq_len + image_seq_len
|
||||||
|
# prompt-image mask
|
||||||
|
for i in range(N):
|
||||||
|
prompt_start = i * prompt_seq_len
|
||||||
|
prompt_end = (i + 1) * prompt_seq_len
|
||||||
|
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
|
||||||
|
image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1)
|
||||||
|
# prompt update with image
|
||||||
|
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
|
||||||
|
# image update with prompt
|
||||||
|
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
|
||||||
|
# prompt-prompt mask
|
||||||
|
for i in range(N):
|
||||||
|
for j in range(N):
|
||||||
|
if i != j:
|
||||||
|
prompt_start_i = i * prompt_seq_len
|
||||||
|
prompt_end_i = (i + 1) * prompt_seq_len
|
||||||
|
prompt_start_j = j * prompt_seq_len
|
||||||
|
prompt_end_j = (j + 1) * prompt_seq_len
|
||||||
|
attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False
|
||||||
|
|
||||||
|
attention_mask = attention_mask.float()
|
||||||
|
attention_mask[attention_mask == 0] = float('-inf')
|
||||||
|
attention_mask[attention_mask == 1] = 0
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, repeat_dim):
|
||||||
|
max_masks = 0
|
||||||
|
attention_mask = None
|
||||||
|
prompt_embs = [prompt_emb]
|
||||||
|
if entity_masks is not None:
|
||||||
|
# entity_masks
|
||||||
|
batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
|
||||||
|
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
||||||
|
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||||
|
# global mask
|
||||||
|
global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||||
|
entity_masks = entity_masks + [global_mask] # append global to last
|
||||||
|
# attention mask
|
||||||
|
attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
|
||||||
|
attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||||
|
attention_mask = attention_mask.unsqueeze(1)
|
||||||
|
# embds: n_masks * b * seq * d
|
||||||
|
local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||||
|
prompt_embs = local_embs + prompt_embs # append global to last
|
||||||
|
prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
|
||||||
|
prompt_emb = torch.cat(prompt_embs, dim=1)
|
||||||
|
|
||||||
|
# positional embedding
|
||||||
|
text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
|
||||||
|
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||||
|
return prompt_emb, image_rotary_emb, attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
|
||||||
|
tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
# (Deprecated) The real forward is in `pipelines.flux_image`.
|
||||||
|
return None
|
||||||
129
diffsynth/models/flux_infiniteyou.py
Normal file
129
diffsynth/models/flux_infiniteyou.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
# FFN
|
||||||
|
def FeedForward(dim, mult=4):
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.LayerNorm(dim),
|
||||||
|
nn.Linear(dim, inner_dim, bias=False),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(inner_dim, dim, bias=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_tensor(x, heads):
|
||||||
|
bs, length, width = x.shape
|
||||||
|
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||||
|
x = x.view(bs, length, heads, -1)
|
||||||
|
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
||||||
|
x = x.reshape(bs, heads, length, -1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PerceiverAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = dim_head**-0.5
|
||||||
|
self.dim_head = dim_head
|
||||||
|
self.heads = heads
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(dim)
|
||||||
|
self.norm2 = nn.LayerNorm(dim)
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||||
|
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||||
|
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x, latents):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): image features
|
||||||
|
shape (b, n1, D)
|
||||||
|
latent (torch.Tensor): latent features
|
||||||
|
shape (b, n2, D)
|
||||||
|
"""
|
||||||
|
x = self.norm1(x)
|
||||||
|
latents = self.norm2(latents)
|
||||||
|
|
||||||
|
b, l, _ = latents.shape
|
||||||
|
|
||||||
|
q = self.to_q(latents)
|
||||||
|
kv_input = torch.cat((x, latents), dim=-2)
|
||||||
|
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||||
|
|
||||||
|
q = reshape_tensor(q, self.heads)
|
||||||
|
k = reshape_tensor(k, self.heads)
|
||||||
|
v = reshape_tensor(v, self.heads)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
||||||
|
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
||||||
|
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||||
|
out = weight @ v
|
||||||
|
|
||||||
|
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||||
|
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
|
class InfiniteYouImageProjector(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim=1280,
|
||||||
|
depth=4,
|
||||||
|
dim_head=64,
|
||||||
|
heads=20,
|
||||||
|
num_queries=8,
|
||||||
|
embedding_dim=512,
|
||||||
|
output_dim=4096,
|
||||||
|
ff_mult=4,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
||||||
|
self.proj_in = nn.Linear(embedding_dim, dim)
|
||||||
|
|
||||||
|
self.proj_out = nn.Linear(dim, output_dim)
|
||||||
|
self.norm_out = nn.LayerNorm(output_dim)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
for _ in range(depth):
|
||||||
|
self.layers.append(
|
||||||
|
nn.ModuleList([
|
||||||
|
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
||||||
|
FeedForward(dim=dim, mult=ff_mult),
|
||||||
|
]))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||||
|
latents = latents.to(dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
x = self.proj_in(x)
|
||||||
|
|
||||||
|
for attn, ff in self.layers:
|
||||||
|
latents = attn(x, latents) + latents
|
||||||
|
latents = ff(latents) + latents
|
||||||
|
|
||||||
|
latents = self.proj_out(latents)
|
||||||
|
return self.norm_out(latents)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return FluxInfiniteYouImageProjectorStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class FluxInfiniteYouImageProjectorStateDictConverter:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict['image_proj']
|
||||||
110
diffsynth/models/flux_ipadapter.py
Normal file
110
diffsynth/models/flux_ipadapter.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
from .general_modules import RMSNorm
|
||||||
|
from transformers import SiglipVisionModel, SiglipVisionConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class SiglipVisionModelSO400M(SiglipVisionModel):
|
||||||
|
def __init__(self):
|
||||||
|
config = SiglipVisionConfig(
|
||||||
|
hidden_size=1152,
|
||||||
|
image_size=384,
|
||||||
|
intermediate_size=4304,
|
||||||
|
model_type="siglip_vision_model",
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_hidden_layers=27,
|
||||||
|
patch_size=14,
|
||||||
|
architectures=["SiglipModel"],
|
||||||
|
initializer_factor=1.0,
|
||||||
|
torch_dtype="float32",
|
||||||
|
transformers_version="4.37.0.dev0"
|
||||||
|
)
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
class MLPProjModel(torch.nn.Module):
|
||||||
|
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.cross_attention_dim = cross_attention_dim
|
||||||
|
self.num_tokens = num_tokens
|
||||||
|
|
||||||
|
self.proj = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
|
||||||
|
torch.nn.GELU(),
|
||||||
|
torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
|
||||||
|
)
|
||||||
|
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
||||||
|
|
||||||
|
def forward(self, id_embeds):
|
||||||
|
x = self.proj(id_embeds)
|
||||||
|
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
||||||
|
x = self.norm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class IpAdapterModule(torch.nn.Module):
|
||||||
|
def __init__(self, num_attention_heads, attention_head_dim, input_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_attention_heads
|
||||||
|
self.head_dim = attention_head_dim
|
||||||
|
output_dim = num_attention_heads * attention_head_dim
|
||||||
|
self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
||||||
|
self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
||||||
|
self.norm_added_k = RMSNorm(attention_head_dim, eps=1e-5, elementwise_affine=False)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
# ip_k
|
||||||
|
ip_k = self.to_k_ip(hidden_states)
|
||||||
|
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
ip_k = self.norm_added_k(ip_k)
|
||||||
|
# ip_v
|
||||||
|
ip_v = self.to_v_ip(hidden_states)
|
||||||
|
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
return ip_k, ip_v
|
||||||
|
|
||||||
|
|
||||||
|
class FluxIpAdapter(torch.nn.Module):
|
||||||
|
def __init__(self, num_attention_heads=24, attention_head_dim=128, cross_attention_dim=4096, num_tokens=128, num_blocks=57):
|
||||||
|
super().__init__()
|
||||||
|
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(num_attention_heads, attention_head_dim, cross_attention_dim) for _ in range(num_blocks)])
|
||||||
|
self.image_proj = MLPProjModel(cross_attention_dim=cross_attention_dim, id_embeddings_dim=1152, num_tokens=num_tokens)
|
||||||
|
self.set_adapter()
|
||||||
|
|
||||||
|
def set_adapter(self):
|
||||||
|
self.call_block_id = {i:i for i in range(len(self.ipadapter_modules))}
|
||||||
|
|
||||||
|
def forward(self, hidden_states, scale=1.0):
|
||||||
|
hidden_states = self.image_proj(hidden_states)
|
||||||
|
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
|
||||||
|
ip_kv_dict = {}
|
||||||
|
for block_id in self.call_block_id:
|
||||||
|
ipadapter_id = self.call_block_id[block_id]
|
||||||
|
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
|
||||||
|
ip_kv_dict[block_id] = {
|
||||||
|
"ip_k": ip_k,
|
||||||
|
"ip_v": ip_v,
|
||||||
|
"scale": scale
|
||||||
|
}
|
||||||
|
return ip_kv_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return FluxIpAdapterStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class FluxIpAdapterStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict["ip_adapter"]:
|
||||||
|
name_ = 'ipadapter_modules.' + name
|
||||||
|
state_dict_[name_] = state_dict["ip_adapter"][name]
|
||||||
|
for name in state_dict["image_proj"]:
|
||||||
|
name_ = "image_proj." + name
|
||||||
|
state_dict_[name_] = state_dict["image_proj"][name]
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return self.from_diffusers(state_dict)
|
||||||
521
diffsynth/models/flux_lora_encoder.py
Normal file
521
diffsynth/models/flux_lora_encoder.py
Normal file
@@ -0,0 +1,521 @@
|
|||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
def low_version_attention(query, key, value, attn_bias=None):
|
||||||
|
scale = 1 / query.shape[-1] ** 0.5
|
||||||
|
query = query * scale
|
||||||
|
attn = torch.matmul(query, key.transpose(-2, -1))
|
||||||
|
if attn_bias is not None:
|
||||||
|
attn = attn + attn_bias
|
||||||
|
attn = attn.softmax(-1)
|
||||||
|
return attn @ value
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||||
|
super().__init__()
|
||||||
|
dim_inner = head_dim * num_heads
|
||||||
|
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
|
||||||
|
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||||
|
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||||
|
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||||
|
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
||||||
|
|
||||||
|
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
|
||||||
|
batch_size = q.shape[0]
|
||||||
|
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
||||||
|
hidden_states = hidden_states + scale * ip_hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
|
||||||
|
batch_size = encoder_hidden_states.shape[0]
|
||||||
|
|
||||||
|
q = self.to_q(hidden_states)
|
||||||
|
k = self.to_k(encoder_hidden_states)
|
||||||
|
v = self.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if qkv_preprocessor is not None:
|
||||||
|
q, k, v = qkv_preprocessor(q, k, v)
|
||||||
|
|
||||||
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
|
if ipadapter_kwargs is not None:
|
||||||
|
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||||
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
|
|
||||||
|
hidden_states = self.to_out(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
|
||||||
|
q = self.to_q(hidden_states)
|
||||||
|
k = self.to_k(encoder_hidden_states)
|
||||||
|
v = self.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||||
|
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||||
|
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
|
||||||
|
else:
|
||||||
|
import xformers.ops as xops
|
||||||
|
hidden_states = xops.memory_efficient_attention(q, k, v)
|
||||||
|
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
|
hidden_states = self.to_out(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
||||||
|
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPEncoderLayer(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
|
||||||
|
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
|
||||||
|
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
|
||||||
|
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
|
||||||
|
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
|
||||||
|
|
||||||
|
self.use_quick_gelu = use_quick_gelu
|
||||||
|
|
||||||
|
def quickGELU(self, x):
|
||||||
|
return x * torch.sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, attn_mask=None):
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
|
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm2(hidden_states)
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
if self.use_quick_gelu:
|
||||||
|
hidden_states = self.quickGELU(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states = torch.nn.functional.gelu(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class SDTextEncoder(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# token_embedding
|
||||||
|
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
||||||
|
|
||||||
|
# position_embeds (This is a fixed tensor)
|
||||||
|
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
||||||
|
|
||||||
|
# encoders
|
||||||
|
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
||||||
|
|
||||||
|
# attn_mask
|
||||||
|
self.attn_mask = self.attention_mask(max_position_embeddings)
|
||||||
|
|
||||||
|
# final_layer_norm
|
||||||
|
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||||
|
|
||||||
|
def attention_mask(self, length):
|
||||||
|
mask = torch.empty(length, length)
|
||||||
|
mask.fill_(float("-inf"))
|
||||||
|
mask.triu_(1)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def forward(self, input_ids, clip_skip=1):
|
||||||
|
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||||
|
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||||
|
for encoder_id, encoder in enumerate(self.encoders):
|
||||||
|
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||||
|
if encoder_id + clip_skip == len(self.encoders):
|
||||||
|
break
|
||||||
|
embeds = self.final_layer_norm(embeds)
|
||||||
|
return embeds
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return SDTextEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class SDTextEncoderStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
rename_dict = {
|
||||||
|
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||||
|
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
||||||
|
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
||||||
|
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
|
||||||
|
}
|
||||||
|
attn_rename_dict = {
|
||||||
|
"self_attn.q_proj": "attn.to_q",
|
||||||
|
"self_attn.k_proj": "attn.to_k",
|
||||||
|
"self_attn.v_proj": "attn.to_v",
|
||||||
|
"self_attn.out_proj": "attn.to_out",
|
||||||
|
"layer_norm1": "layer_norm1",
|
||||||
|
"layer_norm2": "layer_norm2",
|
||||||
|
"mlp.fc1": "fc1",
|
||||||
|
"mlp.fc2": "fc2",
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name in rename_dict:
|
||||||
|
param = state_dict[name]
|
||||||
|
if name == "text_model.embeddings.position_embedding.weight":
|
||||||
|
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||||
|
state_dict_[rename_dict[name]] = param
|
||||||
|
elif name.startswith("text_model.encoder.layers."):
|
||||||
|
param = state_dict[name]
|
||||||
|
names = name.split(".")
|
||||||
|
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
||||||
|
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
||||||
|
state_dict_[name_] = param
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
rename_dict = {
|
||||||
|
"cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds"
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name in rename_dict:
|
||||||
|
param = state_dict[name]
|
||||||
|
if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight":
|
||||||
|
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||||
|
state_dict_[rename_dict[name]] = param
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LoRALayerBlock(torch.nn.Module):
|
||||||
|
def __init__(self, L, dim_in, dim_out):
|
||||||
|
super().__init__()
|
||||||
|
self.x = torch.nn.Parameter(torch.randn(1, L, dim_in))
|
||||||
|
self.layer_norm = torch.nn.LayerNorm(dim_out)
|
||||||
|
|
||||||
|
def forward(self, lora_A, lora_B):
|
||||||
|
x = self.x @ lora_A.T @ lora_B.T
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAEmbedder(torch.nn.Module):
|
||||||
|
def __init__(self, lora_patterns=None, L=1, out_dim=2048):
|
||||||
|
super().__init__()
|
||||||
|
if lora_patterns is None:
|
||||||
|
lora_patterns = self.default_lora_patterns()
|
||||||
|
|
||||||
|
model_dict = {}
|
||||||
|
for lora_pattern in lora_patterns:
|
||||||
|
name, dim = lora_pattern["name"], lora_pattern["dim"]
|
||||||
|
model_dict[name.replace(".", "___")] = LoRALayerBlock(L, dim[0], dim[1])
|
||||||
|
self.model_dict = torch.nn.ModuleDict(model_dict)
|
||||||
|
|
||||||
|
proj_dict = {}
|
||||||
|
for lora_pattern in lora_patterns:
|
||||||
|
layer_type, dim = lora_pattern["type"], lora_pattern["dim"]
|
||||||
|
if layer_type not in proj_dict:
|
||||||
|
proj_dict[layer_type.replace(".", "___")] = torch.nn.Linear(dim[1], out_dim)
|
||||||
|
self.proj_dict = torch.nn.ModuleDict(proj_dict)
|
||||||
|
|
||||||
|
self.lora_patterns = lora_patterns
|
||||||
|
|
||||||
|
|
||||||
|
def default_lora_patterns(self):
|
||||||
|
lora_patterns = []
|
||||||
|
lora_dict = {
|
||||||
|
"attn.a_to_qkv": (3072, 9216), "attn.a_to_out": (3072, 3072), "ff_a.0": (3072, 12288), "ff_a.2": (12288, 3072), "norm1_a.linear": (3072, 18432),
|
||||||
|
"attn.b_to_qkv": (3072, 9216), "attn.b_to_out": (3072, 3072), "ff_b.0": (3072, 12288), "ff_b.2": (12288, 3072), "norm1_b.linear": (3072, 18432),
|
||||||
|
}
|
||||||
|
for i in range(19):
|
||||||
|
for suffix in lora_dict:
|
||||||
|
lora_patterns.append({
|
||||||
|
"name": f"blocks.{i}.{suffix}",
|
||||||
|
"dim": lora_dict[suffix],
|
||||||
|
"type": suffix,
|
||||||
|
})
|
||||||
|
lora_dict = {"to_qkv_mlp": (3072, 21504), "proj_out": (15360, 3072), "norm.linear": (3072, 9216)}
|
||||||
|
for i in range(38):
|
||||||
|
for suffix in lora_dict:
|
||||||
|
lora_patterns.append({
|
||||||
|
"name": f"single_blocks.{i}.{suffix}",
|
||||||
|
"dim": lora_dict[suffix],
|
||||||
|
"type": suffix,
|
||||||
|
})
|
||||||
|
return lora_patterns
|
||||||
|
|
||||||
|
def forward(self, lora):
|
||||||
|
lora_emb = []
|
||||||
|
for lora_pattern in self.lora_patterns:
|
||||||
|
name, layer_type = lora_pattern["name"], lora_pattern["type"]
|
||||||
|
lora_A = lora[name + ".lora_A.weight"]
|
||||||
|
lora_B = lora[name + ".lora_B.weight"]
|
||||||
|
lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B)
|
||||||
|
lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out)
|
||||||
|
lora_emb.append(lora_out)
|
||||||
|
lora_emb = torch.concat(lora_emb, dim=1)
|
||||||
|
return lora_emb
|
||||||
|
|
||||||
|
|
||||||
|
class FluxLoRAEncoder(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim=4096, encoder_intermediate_size=8192, num_encoder_layers=1, num_embeds_per_lora=16, num_special_embeds=1):
|
||||||
|
super().__init__()
|
||||||
|
self.num_embeds_per_lora = num_embeds_per_lora
|
||||||
|
# embedder
|
||||||
|
self.embedder = LoRAEmbedder(L=num_embeds_per_lora, out_dim=embed_dim)
|
||||||
|
|
||||||
|
# encoders
|
||||||
|
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=32, head_dim=128) for _ in range(num_encoder_layers)])
|
||||||
|
|
||||||
|
# special embedding
|
||||||
|
self.special_embeds = torch.nn.Parameter(torch.randn(1, num_special_embeds, embed_dim))
|
||||||
|
self.num_special_embeds = num_special_embeds
|
||||||
|
|
||||||
|
# final layer
|
||||||
|
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||||
|
self.final_linear = torch.nn.Linear(embed_dim, embed_dim)
|
||||||
|
|
||||||
|
def forward(self, lora):
|
||||||
|
lora_embeds = self.embedder(lora)
|
||||||
|
special_embeds = self.special_embeds.to(dtype=lora_embeds.dtype, device=lora_embeds.device)
|
||||||
|
embeds = torch.concat([special_embeds, lora_embeds], dim=1)
|
||||||
|
for encoder_id, encoder in enumerate(self.encoders):
|
||||||
|
embeds = encoder(embeds)
|
||||||
|
embeds = embeds[:, :self.num_special_embeds]
|
||||||
|
embeds = self.final_layer_norm(embeds)
|
||||||
|
embeds = self.final_linear(embeds)
|
||||||
|
return embeds
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return FluxLoRAEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class FluxLoRAEncoderStateDictConverter:
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return state_dict
|
||||||
306
diffsynth/models/flux_lora_patcher.py
Normal file
306
diffsynth/models/flux_lora_patcher.py
Normal file
@@ -0,0 +1,306 @@
|
|||||||
|
import torch, math
|
||||||
|
from ..core.loader import load_state_dict
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
class GeneralLoRALoader:
|
||||||
|
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
||||||
|
self.device = device
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
|
||||||
|
|
||||||
|
def get_name_dict(self, lora_state_dict):
|
||||||
|
lora_name_dict = {}
|
||||||
|
for key in lora_state_dict:
|
||||||
|
if ".lora_B." not in key:
|
||||||
|
continue
|
||||||
|
keys = key.split(".")
|
||||||
|
if len(keys) > keys.index("lora_B") + 2:
|
||||||
|
keys.pop(keys.index("lora_B") + 1)
|
||||||
|
keys.pop(keys.index("lora_B"))
|
||||||
|
if keys[0] == "diffusion_model":
|
||||||
|
keys.pop(0)
|
||||||
|
keys.pop(-1)
|
||||||
|
target_name = ".".join(keys)
|
||||||
|
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
|
||||||
|
return lora_name_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
||||||
|
updated_num = 0
|
||||||
|
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if name in lora_name_dict:
|
||||||
|
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
if len(weight_up.shape) == 4:
|
||||||
|
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||||
|
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||||
|
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||||
|
else:
|
||||||
|
weight_lora = alpha * torch.mm(weight_up, weight_down)
|
||||||
|
state_dict = module.state_dict()
|
||||||
|
state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora
|
||||||
|
module.load_state_dict(state_dict)
|
||||||
|
updated_num += 1
|
||||||
|
print(f"{updated_num} tensors are updated by LoRA.")
|
||||||
|
|
||||||
|
class FluxLoRALoader(GeneralLoRALoader):
|
||||||
|
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
||||||
|
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||||
|
|
||||||
|
self.diffusers_rename_dict = {
|
||||||
|
"transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.default.weight",
|
||||||
|
}
|
||||||
|
|
||||||
|
self.civitai_rename_dict = {
|
||||||
|
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
|
||||||
|
}
|
||||||
|
|
||||||
|
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
||||||
|
super().load(model, state_dict_lora, alpha)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_state_dict(self,state_dict):
|
||||||
|
|
||||||
|
def guess_block_id(name,model_resource):
|
||||||
|
if model_resource == 'civitai':
|
||||||
|
names = name.split("_")
|
||||||
|
for i in names:
|
||||||
|
if i.isdigit():
|
||||||
|
return i, name.replace(f"_{i}_", "_blockid_")
|
||||||
|
if model_resource == 'diffusers':
|
||||||
|
names = name.split(".")
|
||||||
|
for i in names:
|
||||||
|
if i.isdigit():
|
||||||
|
return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def guess_resource(state_dict):
|
||||||
|
for k in state_dict:
|
||||||
|
if "lora_unet_" in k:
|
||||||
|
return 'civitai'
|
||||||
|
elif k.startswith("transformer."):
|
||||||
|
return 'diffusers'
|
||||||
|
else:
|
||||||
|
None
|
||||||
|
|
||||||
|
model_resource = guess_resource(state_dict)
|
||||||
|
if model_resource is None:
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict
|
||||||
|
def guess_alpha(state_dict):
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if ".alpha" in name:
|
||||||
|
for suffix in [".lora_down.weight", ".lora_A.weight"]:
|
||||||
|
name_ = name.replace(".alpha", suffix)
|
||||||
|
if name_ in state_dict:
|
||||||
|
lora_alpha = param.item() / state_dict[name_].shape[0]
|
||||||
|
lora_alpha = math.sqrt(lora_alpha)
|
||||||
|
return lora_alpha
|
||||||
|
|
||||||
|
return 1
|
||||||
|
|
||||||
|
alpha = guess_alpha(state_dict)
|
||||||
|
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
block_id, source_name = guess_block_id(name,model_resource)
|
||||||
|
if alpha != 1:
|
||||||
|
param *= alpha
|
||||||
|
if source_name in rename_dict:
|
||||||
|
target_name = rename_dict[source_name]
|
||||||
|
target_name = target_name.replace(".blockid.", f".{block_id}.")
|
||||||
|
state_dict_[target_name] = param
|
||||||
|
else:
|
||||||
|
state_dict_[name] = param
|
||||||
|
|
||||||
|
if model_resource == 'diffusers':
|
||||||
|
for name in list(state_dict_.keys()):
|
||||||
|
if "single_blocks." in name and ".a_to_q." in name:
|
||||||
|
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
|
||||||
|
if mlp is None:
|
||||||
|
dim = 4
|
||||||
|
if 'lora_A' in name:
|
||||||
|
dim = 1
|
||||||
|
mlp = torch.zeros(dim * state_dict_[name].shape[0],
|
||||||
|
*state_dict_[name].shape[1:],
|
||||||
|
dtype=state_dict_[name].dtype)
|
||||||
|
else:
|
||||||
|
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
|
||||||
|
if 'lora_A' in name:
|
||||||
|
param = torch.concat([
|
||||||
|
state_dict_.pop(name),
|
||||||
|
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||||
|
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||||
|
mlp,
|
||||||
|
], dim=0)
|
||||||
|
elif 'lora_B' in name:
|
||||||
|
d, r = state_dict_[name].shape
|
||||||
|
param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device)
|
||||||
|
param[:d, :r] = state_dict_.pop(name)
|
||||||
|
param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k."))
|
||||||
|
param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v."))
|
||||||
|
param[3*d:, 3*r:] = mlp
|
||||||
|
else:
|
||||||
|
param = torch.concat([
|
||||||
|
state_dict_.pop(name),
|
||||||
|
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||||
|
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||||
|
mlp,
|
||||||
|
], dim=0)
|
||||||
|
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
|
||||||
|
state_dict_[name_] = param
|
||||||
|
for name in list(state_dict_.keys()):
|
||||||
|
for component in ["a", "b"]:
|
||||||
|
if f".{component}_to_q." in name:
|
||||||
|
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
||||||
|
concat_dim = 0
|
||||||
|
if 'lora_A' in name:
|
||||||
|
param = torch.concat([
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||||
|
], dim=0)
|
||||||
|
elif 'lora_B' in name:
|
||||||
|
origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
|
||||||
|
d, r = origin.shape
|
||||||
|
# print(d, r)
|
||||||
|
param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device)
|
||||||
|
param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
|
||||||
|
param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")]
|
||||||
|
param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")]
|
||||||
|
else:
|
||||||
|
param = torch.concat([
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||||
|
], dim=0)
|
||||||
|
state_dict_[name_] = param
|
||||||
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
||||||
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||||
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
class LoraMerger(torch.nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.weight_base = torch.nn.Parameter(torch.randn((dim,)))
|
||||||
|
self.weight_lora = torch.nn.Parameter(torch.randn((dim,)))
|
||||||
|
self.weight_cross = torch.nn.Parameter(torch.randn((dim,)))
|
||||||
|
self.weight_out = torch.nn.Parameter(torch.ones((dim,)))
|
||||||
|
self.bias = torch.nn.Parameter(torch.randn((dim,)))
|
||||||
|
self.activation = torch.nn.Sigmoid()
|
||||||
|
self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5)
|
||||||
|
self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5)
|
||||||
|
|
||||||
|
def forward(self, base_output, lora_outputs):
|
||||||
|
norm_base_output = self.norm_base(base_output)
|
||||||
|
norm_lora_outputs = self.norm_lora(lora_outputs)
|
||||||
|
gate = self.activation(
|
||||||
|
norm_base_output * self.weight_base \
|
||||||
|
+ norm_lora_outputs * self.weight_lora \
|
||||||
|
+ norm_base_output * norm_lora_outputs * self.weight_cross + self.bias
|
||||||
|
)
|
||||||
|
output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0)
|
||||||
|
return output
|
||||||
|
|
||||||
|
class FluxLoraPatcher(torch.nn.Module):
|
||||||
|
def __init__(self, lora_patterns=None):
|
||||||
|
super().__init__()
|
||||||
|
if lora_patterns is None:
|
||||||
|
lora_patterns = self.default_lora_patterns()
|
||||||
|
model_dict = {}
|
||||||
|
for lora_pattern in lora_patterns:
|
||||||
|
name, dim = lora_pattern["name"], lora_pattern["dim"]
|
||||||
|
model_dict[name.replace(".", "___")] = LoraMerger(dim)
|
||||||
|
self.model_dict = torch.nn.ModuleDict(model_dict)
|
||||||
|
|
||||||
|
def default_lora_patterns(self):
|
||||||
|
lora_patterns = []
|
||||||
|
lora_dict = {
|
||||||
|
"attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432,
|
||||||
|
"attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432,
|
||||||
|
}
|
||||||
|
for i in range(19):
|
||||||
|
for suffix in lora_dict:
|
||||||
|
lora_patterns.append({
|
||||||
|
"name": f"blocks.{i}.{suffix}",
|
||||||
|
"dim": lora_dict[suffix]
|
||||||
|
})
|
||||||
|
lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216}
|
||||||
|
for i in range(38):
|
||||||
|
for suffix in lora_dict:
|
||||||
|
lora_patterns.append({
|
||||||
|
"name": f"single_blocks.{i}.{suffix}",
|
||||||
|
"dim": lora_dict[suffix]
|
||||||
|
})
|
||||||
|
return lora_patterns
|
||||||
|
|
||||||
|
def forward(self, base_output, lora_outputs, name):
|
||||||
|
return self.model_dict[name.replace(".", "___")](base_output, lora_outputs)
|
||||||
112
diffsynth/models/flux_text_encoder_clip.py
Normal file
112
diffsynth/models/flux_text_encoder_clip.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||||
|
super().__init__()
|
||||||
|
dim_inner = head_dim * num_heads
|
||||||
|
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
|
||||||
|
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||||
|
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||||
|
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||||
|
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
|
||||||
|
batch_size = encoder_hidden_states.shape[0]
|
||||||
|
|
||||||
|
q = self.to_q(hidden_states)
|
||||||
|
k = self.to_k(encoder_hidden_states)
|
||||||
|
v = self.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||||
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
|
|
||||||
|
hidden_states = self.to_out(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPEncoderLayer(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
|
||||||
|
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
|
||||||
|
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
|
||||||
|
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
|
||||||
|
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
|
||||||
|
|
||||||
|
self.use_quick_gelu = use_quick_gelu
|
||||||
|
|
||||||
|
def quickGELU(self, x):
|
||||||
|
return x * torch.sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, attn_mask=None):
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
|
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm2(hidden_states)
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
if self.use_quick_gelu:
|
||||||
|
hidden_states = self.quickGELU(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states = torch.nn.functional.gelu(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FluxTextEncoderClip(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# token_embedding
|
||||||
|
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
||||||
|
|
||||||
|
# position_embeds (This is a fixed tensor)
|
||||||
|
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
||||||
|
|
||||||
|
# encoders
|
||||||
|
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
||||||
|
|
||||||
|
# attn_mask
|
||||||
|
self.attn_mask = self.attention_mask(max_position_embeddings)
|
||||||
|
|
||||||
|
# final_layer_norm
|
||||||
|
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||||
|
|
||||||
|
def attention_mask(self, length):
|
||||||
|
mask = torch.empty(length, length)
|
||||||
|
mask.fill_(float("-inf"))
|
||||||
|
mask.triu_(1)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def forward(self, input_ids, clip_skip=2, extra_mask=None):
|
||||||
|
embeds = self.token_embedding(input_ids)
|
||||||
|
embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device)
|
||||||
|
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||||
|
if extra_mask is not None:
|
||||||
|
attn_mask[:, extra_mask[0]==0] = float("-inf")
|
||||||
|
for encoder_id, encoder in enumerate(self.encoders):
|
||||||
|
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||||
|
if encoder_id + clip_skip == len(self.encoders):
|
||||||
|
hidden_states = embeds
|
||||||
|
embeds = self.final_layer_norm(embeds)
|
||||||
|
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
|
||||||
|
return pooled_embeds, hidden_states
|
||||||
43
diffsynth/models/flux_text_encoder_t5.py
Normal file
43
diffsynth/models/flux_text_encoder_t5.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import torch
|
||||||
|
from transformers import T5EncoderModel, T5Config
|
||||||
|
|
||||||
|
|
||||||
|
class FluxTextEncoderT5(T5EncoderModel):
|
||||||
|
def __init__(self):
|
||||||
|
config = T5Config(**{
|
||||||
|
"architectures": [
|
||||||
|
"T5EncoderModel"
|
||||||
|
],
|
||||||
|
"classifier_dropout": 0.0,
|
||||||
|
"d_ff": 10240,
|
||||||
|
"d_kv": 64,
|
||||||
|
"d_model": 4096,
|
||||||
|
"decoder_start_token_id": 0,
|
||||||
|
"dense_act_fn": "gelu_new",
|
||||||
|
"dropout_rate": 0.1,
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"eos_token_id": 1,
|
||||||
|
"feed_forward_proj": "gated-gelu",
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"is_encoder_decoder": True,
|
||||||
|
"is_gated_act": True,
|
||||||
|
"layer_norm_epsilon": 1e-06,
|
||||||
|
"model_type": "t5",
|
||||||
|
"num_decoder_layers": 24,
|
||||||
|
"num_heads": 64,
|
||||||
|
"num_layers": 24,
|
||||||
|
"output_past": True,
|
||||||
|
"pad_token_id": 0,
|
||||||
|
"relative_attention_max_distance": 128,
|
||||||
|
"relative_attention_num_buckets": 32,
|
||||||
|
"tie_word_embeddings": False,
|
||||||
|
"transformers_version": "4.57.1",
|
||||||
|
"use_cache": True,
|
||||||
|
"vocab_size": 32128
|
||||||
|
})
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
def forward(self, input_ids):
|
||||||
|
outputs = super().forward(input_ids=input_ids)
|
||||||
|
prompt_emb = outputs.last_hidden_state
|
||||||
|
return prompt_emb
|
||||||
451
diffsynth/models/flux_vae.py
Normal file
451
diffsynth/models/flux_vae.py
Normal file
@@ -0,0 +1,451 @@
|
|||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
|
||||||
|
class TileWorker:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def mask(self, height, width, border_width):
|
||||||
|
# Create a mask with shape (height, width).
|
||||||
|
# The centre area is filled with 1, and the border line is filled with values in range (0, 1].
|
||||||
|
x = torch.arange(height).repeat(width, 1).T
|
||||||
|
y = torch.arange(width).repeat(height, 1)
|
||||||
|
mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
|
||||||
|
mask = (mask / border_width).clip(0, 1)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
|
||||||
|
# Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
|
||||||
|
batch_size, channel, _, _ = model_input.shape
|
||||||
|
model_input = model_input.to(device=tile_device, dtype=tile_dtype)
|
||||||
|
unfold_operator = torch.nn.Unfold(
|
||||||
|
kernel_size=(tile_size, tile_size),
|
||||||
|
stride=(tile_stride, tile_stride)
|
||||||
|
)
|
||||||
|
model_input = unfold_operator(model_input)
|
||||||
|
model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
|
||||||
|
|
||||||
|
return model_input
|
||||||
|
|
||||||
|
|
||||||
|
def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):
|
||||||
|
# Call y=forward_fn(x) for each tile
|
||||||
|
tile_num = model_input.shape[-1]
|
||||||
|
model_output_stack = []
|
||||||
|
|
||||||
|
for tile_id in range(0, tile_num, tile_batch_size):
|
||||||
|
|
||||||
|
# process input
|
||||||
|
tile_id_ = min(tile_id + tile_batch_size, tile_num)
|
||||||
|
x = model_input[:, :, :, :, tile_id: tile_id_]
|
||||||
|
x = x.to(device=inference_device, dtype=inference_dtype)
|
||||||
|
x = rearrange(x, "b c h w n -> (n b) c h w")
|
||||||
|
|
||||||
|
# process output
|
||||||
|
y = forward_fn(x)
|
||||||
|
y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id)
|
||||||
|
y = y.to(device=tile_device, dtype=tile_dtype)
|
||||||
|
model_output_stack.append(y)
|
||||||
|
|
||||||
|
model_output = torch.concat(model_output_stack, dim=-1)
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
|
||||||
|
def io_scale(self, model_output, tile_size):
|
||||||
|
# Determine the size modification happened in forward_fn
|
||||||
|
# We only consider the same scale on height and width.
|
||||||
|
io_scale = model_output.shape[2] / tile_size
|
||||||
|
return io_scale
|
||||||
|
|
||||||
|
|
||||||
|
def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):
|
||||||
|
# The reversed function of tile
|
||||||
|
mask = self.mask(tile_size, tile_size, border_width)
|
||||||
|
mask = mask.to(device=tile_device, dtype=tile_dtype)
|
||||||
|
mask = rearrange(mask, "h w -> 1 1 h w 1")
|
||||||
|
model_output = model_output * mask
|
||||||
|
|
||||||
|
fold_operator = torch.nn.Fold(
|
||||||
|
output_size=(height, width),
|
||||||
|
kernel_size=(tile_size, tile_size),
|
||||||
|
stride=(tile_stride, tile_stride)
|
||||||
|
)
|
||||||
|
mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
|
||||||
|
model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
|
||||||
|
model_output = fold_operator(model_output) / fold_operator(mask)
|
||||||
|
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
|
||||||
|
def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
|
||||||
|
# Prepare
|
||||||
|
inference_device, inference_dtype = model_input.device, model_input.dtype
|
||||||
|
height, width = model_input.shape[2], model_input.shape[3]
|
||||||
|
border_width = int(tile_stride*0.5) if border_width is None else border_width
|
||||||
|
|
||||||
|
# tile
|
||||||
|
model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)
|
||||||
|
|
||||||
|
# inference
|
||||||
|
model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)
|
||||||
|
|
||||||
|
# resize
|
||||||
|
io_scale = self.io_scale(model_output, tile_size)
|
||||||
|
height, width = int(height*io_scale), int(width*io_scale)
|
||||||
|
tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)
|
||||||
|
border_width = int(border_width*io_scale)
|
||||||
|
|
||||||
|
# untile
|
||||||
|
model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)
|
||||||
|
|
||||||
|
# Done!
|
||||||
|
model_output = model_output.to(device=inference_device, dtype=inference_dtype)
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
|
||||||
|
class ConvAttention(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||||
|
super().__init__()
|
||||||
|
dim_inner = head_dim * num_heads
|
||||||
|
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
|
||||||
|
self.to_q = torch.nn.Conv2d(q_dim, dim_inner, kernel_size=(1, 1), bias=bias_q)
|
||||||
|
self.to_k = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv)
|
||||||
|
self.to_v = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv)
|
||||||
|
self.to_out = torch.nn.Conv2d(dim_inner, q_dim, kernel_size=(1, 1), bias=bias_out)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
|
||||||
|
batch_size = encoder_hidden_states.shape[0]
|
||||||
|
|
||||||
|
conv_input = rearrange(hidden_states, "B L C -> B C L 1")
|
||||||
|
q = self.to_q(conv_input)
|
||||||
|
q = rearrange(q[:, :, :, 0], "B C L -> B L C")
|
||||||
|
conv_input = rearrange(encoder_hidden_states, "B L C -> B C L 1")
|
||||||
|
k = self.to_k(conv_input)
|
||||||
|
v = self.to_v(conv_input)
|
||||||
|
k = rearrange(k[:, :, :, 0], "B C L -> B L C")
|
||||||
|
v = rearrange(v[:, :, :, 0], "B C L -> B L C")
|
||||||
|
|
||||||
|
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||||
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
|
|
||||||
|
conv_input = rearrange(hidden_states, "B L C -> B C L 1")
|
||||||
|
hidden_states = self.to_out(conv_input)
|
||||||
|
hidden_states = rearrange(hidden_states[:, :, :, 0], "B C L -> B L C")
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||||
|
super().__init__()
|
||||||
|
dim_inner = head_dim * num_heads
|
||||||
|
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
|
||||||
|
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||||
|
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||||
|
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||||
|
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
|
||||||
|
batch_size = encoder_hidden_states.shape[0]
|
||||||
|
|
||||||
|
q = self.to_q(hidden_states)
|
||||||
|
k = self.to_k(encoder_hidden_states)
|
||||||
|
v = self.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||||
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
|
|
||||||
|
hidden_states = self.to_out(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class VAEAttentionBlock(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, use_conv_attention=True):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
|
||||||
|
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
||||||
|
|
||||||
|
if use_conv_attention:
|
||||||
|
self.transformer_blocks = torch.nn.ModuleList([
|
||||||
|
ConvAttention(
|
||||||
|
inner_dim,
|
||||||
|
num_attention_heads,
|
||||||
|
attention_head_dim,
|
||||||
|
bias_q=True,
|
||||||
|
bias_kv=True,
|
||||||
|
bias_out=True
|
||||||
|
)
|
||||||
|
for d in range(num_layers)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.transformer_blocks = torch.nn.ModuleList([
|
||||||
|
Attention(
|
||||||
|
inner_dim,
|
||||||
|
num_attention_heads,
|
||||||
|
attention_head_dim,
|
||||||
|
bias_q=True,
|
||||||
|
bias_kv=True,
|
||||||
|
bias_out=True
|
||||||
|
)
|
||||||
|
for d in range(num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, hidden_states, time_emb, text_emb, res_stack):
|
||||||
|
batch, _, height, width = hidden_states.shape
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
inner_dim = hidden_states.shape[1]
|
||||||
|
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||||
|
|
||||||
|
for block in self.transformer_blocks:
|
||||||
|
hidden_states = block(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
return hidden_states, time_emb, text_emb, res_stack
|
||||||
|
|
||||||
|
|
||||||
|
class ResnetBlock(torch.nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, temb_channels=None, groups=32, eps=1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||||
|
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
if temb_channels is not None:
|
||||||
|
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||||
|
self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
||||||
|
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.nonlinearity = torch.nn.SiLU()
|
||||||
|
self.conv_shortcut = None
|
||||||
|
if in_channels != out_channels:
|
||||||
|
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
||||||
|
x = hidden_states
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.nonlinearity(x)
|
||||||
|
x = self.conv1(x)
|
||||||
|
if time_emb is not None:
|
||||||
|
emb = self.nonlinearity(time_emb)
|
||||||
|
emb = self.time_emb_proj(emb)[:, :, None, None]
|
||||||
|
x = x + emb
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = self.nonlinearity(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
if self.conv_shortcut is not None:
|
||||||
|
hidden_states = self.conv_shortcut(hidden_states)
|
||||||
|
hidden_states = hidden_states + x
|
||||||
|
return hidden_states, time_emb, text_emb, res_stack
|
||||||
|
|
||||||
|
|
||||||
|
class UpSampler(torch.nn.Module):
|
||||||
|
def __init__(self, channels):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = torch.nn.Conv2d(channels, channels, 3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
||||||
|
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||||
|
hidden_states = self.conv(hidden_states)
|
||||||
|
return hidden_states, time_emb, text_emb, res_stack
|
||||||
|
|
||||||
|
|
||||||
|
class DownSampler(torch.nn.Module):
|
||||||
|
def __init__(self, channels, padding=1, extra_padding=False):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = torch.nn.Conv2d(channels, channels, 3, stride=2, padding=padding)
|
||||||
|
self.extra_padding = extra_padding
|
||||||
|
|
||||||
|
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
||||||
|
if self.extra_padding:
|
||||||
|
hidden_states = torch.nn.functional.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
|
||||||
|
hidden_states = self.conv(hidden_states)
|
||||||
|
return hidden_states, time_emb, text_emb, res_stack
|
||||||
|
|
||||||
|
|
||||||
|
class FluxVAEDecoder(torch.nn.Module):
|
||||||
|
def __init__(self, use_conv_attention=True):
|
||||||
|
super().__init__()
|
||||||
|
self.scaling_factor = 0.3611
|
||||||
|
self.shift_factor = 0.1159
|
||||||
|
self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x
|
||||||
|
|
||||||
|
self.blocks = torch.nn.ModuleList([
|
||||||
|
# UNetMidBlock2D
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
# UpDecoderBlock2D
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
UpSampler(512),
|
||||||
|
# UpDecoderBlock2D
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
UpSampler(512),
|
||||||
|
# UpDecoderBlock2D
|
||||||
|
ResnetBlock(512, 256, eps=1e-6),
|
||||||
|
ResnetBlock(256, 256, eps=1e-6),
|
||||||
|
ResnetBlock(256, 256, eps=1e-6),
|
||||||
|
UpSampler(256),
|
||||||
|
# UpDecoderBlock2D
|
||||||
|
ResnetBlock(256, 128, eps=1e-6),
|
||||||
|
ResnetBlock(128, 128, eps=1e-6),
|
||||||
|
ResnetBlock(128, 128, eps=1e-6),
|
||||||
|
])
|
||||||
|
|
||||||
|
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)
|
||||||
|
self.conv_act = torch.nn.SiLU()
|
||||||
|
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||||
|
hidden_states = TileWorker().tiled_forward(
|
||||||
|
lambda x: self.forward(x),
|
||||||
|
sample,
|
||||||
|
tile_size,
|
||||||
|
tile_stride,
|
||||||
|
tile_device=sample.device,
|
||||||
|
tile_dtype=sample.dtype
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||||
|
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||||
|
if tiled:
|
||||||
|
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
|
||||||
|
# 1. pre-process
|
||||||
|
hidden_states = sample / self.scaling_factor + self.shift_factor
|
||||||
|
hidden_states = self.conv_in(hidden_states)
|
||||||
|
time_emb = None
|
||||||
|
text_emb = None
|
||||||
|
res_stack = None
|
||||||
|
|
||||||
|
# 2. blocks
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||||
|
|
||||||
|
# 3. output
|
||||||
|
hidden_states = self.conv_norm_out(hidden_states)
|
||||||
|
hidden_states = self.conv_act(hidden_states)
|
||||||
|
hidden_states = self.conv_out(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FluxVAEEncoder(torch.nn.Module):
|
||||||
|
def __init__(self, use_conv_attention=True):
|
||||||
|
super().__init__()
|
||||||
|
self.scaling_factor = 0.3611
|
||||||
|
self.shift_factor = 0.1159
|
||||||
|
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
self.blocks = torch.nn.ModuleList([
|
||||||
|
# DownEncoderBlock2D
|
||||||
|
ResnetBlock(128, 128, eps=1e-6),
|
||||||
|
ResnetBlock(128, 128, eps=1e-6),
|
||||||
|
DownSampler(128, padding=0, extra_padding=True),
|
||||||
|
# DownEncoderBlock2D
|
||||||
|
ResnetBlock(128, 256, eps=1e-6),
|
||||||
|
ResnetBlock(256, 256, eps=1e-6),
|
||||||
|
DownSampler(256, padding=0, extra_padding=True),
|
||||||
|
# DownEncoderBlock2D
|
||||||
|
ResnetBlock(256, 512, eps=1e-6),
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
DownSampler(512, padding=0, extra_padding=True),
|
||||||
|
# DownEncoderBlock2D
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
# UNetMidBlock2D
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),
|
||||||
|
ResnetBlock(512, 512, eps=1e-6),
|
||||||
|
])
|
||||||
|
|
||||||
|
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
|
||||||
|
self.conv_act = torch.nn.SiLU()
|
||||||
|
self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||||
|
hidden_states = TileWorker().tiled_forward(
|
||||||
|
lambda x: self.forward(x),
|
||||||
|
sample,
|
||||||
|
tile_size,
|
||||||
|
tile_stride,
|
||||||
|
tile_device=sample.device,
|
||||||
|
tile_dtype=sample.dtype
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||||
|
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||||
|
if tiled:
|
||||||
|
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
|
||||||
|
# 1. pre-process
|
||||||
|
hidden_states = self.conv_in(sample)
|
||||||
|
time_emb = None
|
||||||
|
text_emb = None
|
||||||
|
res_stack = None
|
||||||
|
|
||||||
|
# 2. blocks
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||||
|
|
||||||
|
# 3. output
|
||||||
|
hidden_states = self.conv_norm_out(hidden_states)
|
||||||
|
hidden_states = self.conv_act(hidden_states)
|
||||||
|
hidden_states = self.conv_out(hidden_states)
|
||||||
|
hidden_states = hidden_states[:, :16]
|
||||||
|
hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def encode_video(self, sample, batch_size=8):
|
||||||
|
B = sample.shape[0]
|
||||||
|
hidden_states = []
|
||||||
|
|
||||||
|
for i in range(0, sample.shape[2], batch_size):
|
||||||
|
|
||||||
|
j = min(i + batch_size, sample.shape[2])
|
||||||
|
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
|
||||||
|
|
||||||
|
hidden_states_batch = self(sample_batch)
|
||||||
|
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
|
||||||
|
|
||||||
|
hidden_states.append(hidden_states_batch)
|
||||||
|
|
||||||
|
hidden_states = torch.concat(hidden_states, dim=2)
|
||||||
|
return hidden_states
|
||||||
56
diffsynth/models/flux_value_control.py
Normal file
56
diffsynth/models/flux_value_control.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
import torch
|
||||||
|
from .general_modules import TemporalTimesteps
|
||||||
|
|
||||||
|
|
||||||
|
class MultiValueEncoder(torch.nn.Module):
|
||||||
|
def __init__(self, encoders=()):
|
||||||
|
super().__init__()
|
||||||
|
if not isinstance(encoders, list):
|
||||||
|
encoders = [encoders]
|
||||||
|
self.encoders = torch.nn.ModuleList(encoders)
|
||||||
|
|
||||||
|
def __call__(self, values, dtype):
|
||||||
|
emb = []
|
||||||
|
for encoder, value in zip(self.encoders, values):
|
||||||
|
if value is not None:
|
||||||
|
value = value.unsqueeze(0)
|
||||||
|
emb.append(encoder(value, dtype))
|
||||||
|
emb = torch.concat(emb, dim=0)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class SingleValueEncoder(torch.nn.Module):
|
||||||
|
def __init__(self, dim_in=256, dim_out=4096, prefer_len=32, computation_device=None):
|
||||||
|
super().__init__()
|
||||||
|
self.prefer_len = prefer_len
|
||||||
|
self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
|
||||||
|
self.prefer_value_embedder = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||||
|
)
|
||||||
|
self.positional_embedding = torch.nn.Parameter(
|
||||||
|
torch.randn(self.prefer_len, dim_out)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, value, dtype):
|
||||||
|
value = value * 1000
|
||||||
|
emb = self.prefer_proj(value).to(dtype)
|
||||||
|
emb = self.prefer_value_embedder(emb).squeeze(0)
|
||||||
|
base_embeddings = emb.expand(self.prefer_len, -1)
|
||||||
|
positional_embedding = self.positional_embedding.to(dtype=base_embeddings.dtype, device=base_embeddings.device)
|
||||||
|
learned_embeddings = base_embeddings + positional_embedding
|
||||||
|
return learned_embeddings
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return SingleValueEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class SingleValueEncoderStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return state_dict
|
||||||
146
diffsynth/models/general_modules.py
Normal file
146
diffsynth/models/general_modules.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
import torch, math
|
||||||
|
|
||||||
|
|
||||||
|
def get_timestep_embedding(
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
embedding_dim: int,
|
||||||
|
flip_sin_to_cos: bool = False,
|
||||||
|
downscale_freq_shift: float = 1,
|
||||||
|
scale: float = 1,
|
||||||
|
max_period: int = 10000,
|
||||||
|
computation_device = None,
|
||||||
|
align_dtype_to_timestep = False,
|
||||||
|
):
|
||||||
|
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||||
|
|
||||||
|
half_dim = embedding_dim // 2
|
||||||
|
exponent = -math.log(max_period) * torch.arange(
|
||||||
|
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device if computation_device is None else computation_device
|
||||||
|
)
|
||||||
|
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||||
|
|
||||||
|
emb = torch.exp(exponent)
|
||||||
|
if align_dtype_to_timestep:
|
||||||
|
emb = emb.to(timesteps.dtype)
|
||||||
|
emb = timesteps[:, None].float() * emb[None, :]
|
||||||
|
|
||||||
|
# scale embeddings
|
||||||
|
emb = scale * emb
|
||||||
|
|
||||||
|
# concat sine and cosine embeddings
|
||||||
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||||
|
|
||||||
|
# flip sine and cosine embeddings
|
||||||
|
if flip_sin_to_cos:
|
||||||
|
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||||
|
|
||||||
|
# zero pad
|
||||||
|
if embedding_dim % 2 == 1:
|
||||||
|
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class TemporalTimesteps(torch.nn.Module):
|
||||||
|
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None, scale=1, align_dtype_to_timestep=False):
|
||||||
|
super().__init__()
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.flip_sin_to_cos = flip_sin_to_cos
|
||||||
|
self.downscale_freq_shift = downscale_freq_shift
|
||||||
|
self.computation_device = computation_device
|
||||||
|
self.scale = scale
|
||||||
|
self.align_dtype_to_timestep = align_dtype_to_timestep
|
||||||
|
|
||||||
|
def forward(self, timesteps):
|
||||||
|
t_emb = get_timestep_embedding(
|
||||||
|
timesteps,
|
||||||
|
self.num_channels,
|
||||||
|
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||||
|
downscale_freq_shift=self.downscale_freq_shift,
|
||||||
|
computation_device=self.computation_device,
|
||||||
|
scale=self.scale,
|
||||||
|
align_dtype_to_timestep=self.align_dtype_to_timestep,
|
||||||
|
)
|
||||||
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusersCompatibleTimestepProj(torch.nn.Module):
|
||||||
|
def __init__(self, dim_in, dim_out):
|
||||||
|
super().__init__()
|
||||||
|
self.linear_1 = torch.nn.Linear(dim_in, dim_out)
|
||||||
|
self.act = torch.nn.SiLU()
|
||||||
|
self.linear_2 = torch.nn.Linear(dim_out, dim_out)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.linear_1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.linear_2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbeddings(torch.nn.Module):
|
||||||
|
def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False, use_additional_t_cond=False):
|
||||||
|
super().__init__()
|
||||||
|
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep)
|
||||||
|
if diffusers_compatible_format:
|
||||||
|
self.timestep_embedder = DiffusersCompatibleTimestepProj(dim_in, dim_out)
|
||||||
|
else:
|
||||||
|
self.timestep_embedder = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||||
|
)
|
||||||
|
self.use_additional_t_cond = use_additional_t_cond
|
||||||
|
if use_additional_t_cond:
|
||||||
|
self.addition_t_embedding = torch.nn.Embedding(2, dim_out)
|
||||||
|
|
||||||
|
def forward(self, timestep, dtype, addition_t_cond=None):
|
||||||
|
time_emb = self.time_proj(timestep).to(dtype)
|
||||||
|
time_emb = self.timestep_embedder(time_emb)
|
||||||
|
if addition_t_cond is not None:
|
||||||
|
addition_t_emb = self.addition_t_embedding(addition_t_cond)
|
||||||
|
addition_t_emb = addition_t_emb.to(dtype=dtype)
|
||||||
|
time_emb = time_emb + addition_t_emb
|
||||||
|
return time_emb
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(self, dim, eps, elementwise_affine=True):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
if elementwise_affine:
|
||||||
|
self.weight = torch.nn.Parameter(torch.ones((dim,)))
|
||||||
|
else:
|
||||||
|
self.weight = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||||
|
hidden_states = hidden_states.to(input_dtype)
|
||||||
|
if self.weight is not None:
|
||||||
|
hidden_states = hidden_states * self.weight
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class AdaLayerNorm(torch.nn.Module):
|
||||||
|
def __init__(self, dim, single=False, dual=False):
|
||||||
|
super().__init__()
|
||||||
|
self.single = single
|
||||||
|
self.dual = dual
|
||||||
|
self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual])
|
||||||
|
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
|
||||||
|
def forward(self, x, emb):
|
||||||
|
emb = self.linear(torch.nn.functional.silu(emb))
|
||||||
|
if self.single:
|
||||||
|
scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
|
||||||
|
x = self.norm(x) * (1 + scale) + shift
|
||||||
|
return x
|
||||||
|
elif self.dual:
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2)
|
||||||
|
norm_x = self.norm(x)
|
||||||
|
x = norm_x * (1 + scale_msa) + shift_msa
|
||||||
|
norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2
|
||||||
|
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2
|
||||||
|
else:
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
|
||||||
|
x = self.norm(x) * (1 + scale_msa) + shift_msa
|
||||||
|
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||||
@@ -1,451 +0,0 @@
|
|||||||
from .attention import Attention
|
|
||||||
from einops import repeat, rearrange
|
|
||||||
import math
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanDiTRotaryEmbedding(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, q_norm_shape=88, k_norm_shape=88, rotary_emb_on_k=True):
|
|
||||||
super().__init__()
|
|
||||||
self.q_norm = torch.nn.LayerNorm((q_norm_shape,), elementwise_affine=True, eps=1e-06)
|
|
||||||
self.k_norm = torch.nn.LayerNorm((k_norm_shape,), elementwise_affine=True, eps=1e-06)
|
|
||||||
self.rotary_emb_on_k = rotary_emb_on_k
|
|
||||||
self.k_cache, self.v_cache = [], []
|
|
||||||
|
|
||||||
def reshape_for_broadcast(self, freqs_cis, x):
|
|
||||||
ndim = x.ndim
|
|
||||||
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
|
||||||
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
|
||||||
|
|
||||||
def rotate_half(self, x):
|
|
||||||
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
|
||||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
|
||||||
|
|
||||||
def apply_rotary_emb(self, xq, xk, freqs_cis):
|
|
||||||
xk_out = None
|
|
||||||
cos, sin = self.reshape_for_broadcast(freqs_cis, xq)
|
|
||||||
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
|
||||||
xq_out = (xq.float() * cos + self.rotate_half(xq.float()) * sin).type_as(xq)
|
|
||||||
if xk is not None:
|
|
||||||
xk_out = (xk.float() * cos + self.rotate_half(xk.float()) * sin).type_as(xk)
|
|
||||||
return xq_out, xk_out
|
|
||||||
|
|
||||||
def forward(self, q, k, v, freqs_cis_img, to_cache=False):
|
|
||||||
# norm
|
|
||||||
q = self.q_norm(q)
|
|
||||||
k = self.k_norm(k)
|
|
||||||
|
|
||||||
# RoPE
|
|
||||||
if self.rotary_emb_on_k:
|
|
||||||
q, k = self.apply_rotary_emb(q, k, freqs_cis_img)
|
|
||||||
else:
|
|
||||||
q, _ = self.apply_rotary_emb(q, None, freqs_cis_img)
|
|
||||||
|
|
||||||
if to_cache:
|
|
||||||
self.k_cache.append(k)
|
|
||||||
self.v_cache.append(v)
|
|
||||||
elif len(self.k_cache) > 0 and len(self.v_cache) > 0:
|
|
||||||
k = torch.concat([k] + self.k_cache, dim=2)
|
|
||||||
v = torch.concat([v] + self.v_cache, dim=2)
|
|
||||||
self.k_cache, self.v_cache = [], []
|
|
||||||
return q, k, v
|
|
||||||
|
|
||||||
|
|
||||||
class FP32_Layernorm(torch.nn.LayerNorm):
|
|
||||||
def forward(self, inputs):
|
|
||||||
origin_dtype = inputs.dtype
|
|
||||||
return torch.nn.functional.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).to(origin_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class FP32_SiLU(torch.nn.SiLU):
|
|
||||||
def forward(self, inputs):
|
|
||||||
origin_dtype = inputs.dtype
|
|
||||||
return torch.nn.functional.silu(inputs.float(), inplace=False).to(origin_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanDiTFinalLayer(torch.nn.Module):
|
|
||||||
def __init__(self, final_hidden_size=1408, condition_dim=1408, patch_size=2, out_channels=8):
|
|
||||||
super().__init__()
|
|
||||||
self.norm_final = torch.nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
|
|
||||||
self.linear = torch.nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
|
|
||||||
self.adaLN_modulation = torch.nn.Sequential(
|
|
||||||
FP32_SiLU(),
|
|
||||||
torch.nn.Linear(condition_dim, 2 * final_hidden_size, bias=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
def modulate(self, x, shift, scale):
|
|
||||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
|
||||||
|
|
||||||
def forward(self, hidden_states, condition_emb):
|
|
||||||
shift, scale = self.adaLN_modulation(condition_emb).chunk(2, dim=1)
|
|
||||||
hidden_states = self.modulate(self.norm_final(hidden_states), shift, scale)
|
|
||||||
hidden_states = self.linear(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanDiTBlock(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_dim=1408,
|
|
||||||
condition_dim=1408,
|
|
||||||
num_heads=16,
|
|
||||||
mlp_ratio=4.3637,
|
|
||||||
text_dim=1024,
|
|
||||||
skip_connection=False
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.norm1 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
|
||||||
self.rota1 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads)
|
|
||||||
self.attn1 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
|
|
||||||
self.norm2 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
|
||||||
self.rota2 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads, rotary_emb_on_k=False)
|
|
||||||
self.attn2 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, kv_dim=text_dim, bias_q=True, bias_kv=True, bias_out=True)
|
|
||||||
self.norm3 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
|
||||||
self.modulation = torch.nn.Sequential(FP32_SiLU(), torch.nn.Linear(condition_dim, hidden_dim, bias=True))
|
|
||||||
self.mlp = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(hidden_dim, int(hidden_dim*mlp_ratio), bias=True),
|
|
||||||
torch.nn.GELU(approximate="tanh"),
|
|
||||||
torch.nn.Linear(int(hidden_dim*mlp_ratio), hidden_dim, bias=True)
|
|
||||||
)
|
|
||||||
if skip_connection:
|
|
||||||
self.skip_norm = FP32_Layernorm((hidden_dim * 2,), eps=1e-6, elementwise_affine=True)
|
|
||||||
self.skip_linear = torch.nn.Linear(hidden_dim * 2, hidden_dim, bias=True)
|
|
||||||
else:
|
|
||||||
self.skip_norm, self.skip_linear = None, None
|
|
||||||
|
|
||||||
def forward(self, hidden_states, condition_emb, text_emb, freq_cis_img, residual=None, to_cache=False):
|
|
||||||
# Long Skip Connection
|
|
||||||
if self.skip_norm is not None and self.skip_linear is not None:
|
|
||||||
hidden_states = torch.cat([hidden_states, residual], dim=-1)
|
|
||||||
hidden_states = self.skip_norm(hidden_states)
|
|
||||||
hidden_states = self.skip_linear(hidden_states)
|
|
||||||
|
|
||||||
# Self-Attention
|
|
||||||
shift_msa = self.modulation(condition_emb).unsqueeze(dim=1)
|
|
||||||
attn_input = self.norm1(hidden_states) + shift_msa
|
|
||||||
hidden_states = hidden_states + self.attn1(attn_input, qkv_preprocessor=lambda q, k, v: self.rota1(q, k, v, freq_cis_img, to_cache=to_cache))
|
|
||||||
|
|
||||||
# Cross-Attention
|
|
||||||
attn_input = self.norm3(hidden_states)
|
|
||||||
hidden_states = hidden_states + self.attn2(attn_input, text_emb, qkv_preprocessor=lambda q, k, v: self.rota2(q, k, v, freq_cis_img))
|
|
||||||
|
|
||||||
# FFN Layer
|
|
||||||
mlp_input = self.norm2(hidden_states)
|
|
||||||
hidden_states = hidden_states + self.mlp(mlp_input)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionPool(torch.nn.Module):
|
|
||||||
def __init__(self, spacial_dim, embed_dim, num_heads, output_dim = None):
|
|
||||||
super().__init__()
|
|
||||||
self.positional_embedding = torch.nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
|
|
||||||
self.k_proj = torch.nn.Linear(embed_dim, embed_dim)
|
|
||||||
self.q_proj = torch.nn.Linear(embed_dim, embed_dim)
|
|
||||||
self.v_proj = torch.nn.Linear(embed_dim, embed_dim)
|
|
||||||
self.c_proj = torch.nn.Linear(embed_dim, output_dim or embed_dim)
|
|
||||||
self.num_heads = num_heads
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = x.permute(1, 0, 2) # NLC -> LNC
|
|
||||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
|
||||||
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
|
|
||||||
x, _ = torch.nn.functional.multi_head_attention_forward(
|
|
||||||
query=x[:1], key=x, value=x,
|
|
||||||
embed_dim_to_check=x.shape[-1],
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
q_proj_weight=self.q_proj.weight,
|
|
||||||
k_proj_weight=self.k_proj.weight,
|
|
||||||
v_proj_weight=self.v_proj.weight,
|
|
||||||
in_proj_weight=None,
|
|
||||||
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
|
||||||
bias_k=None,
|
|
||||||
bias_v=None,
|
|
||||||
add_zero_attn=False,
|
|
||||||
dropout_p=0,
|
|
||||||
out_proj_weight=self.c_proj.weight,
|
|
||||||
out_proj_bias=self.c_proj.bias,
|
|
||||||
use_separate_proj_weight=True,
|
|
||||||
training=self.training,
|
|
||||||
need_weights=False
|
|
||||||
)
|
|
||||||
return x.squeeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
class PatchEmbed(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
patch_size=(2, 2),
|
|
||||||
in_chans=4,
|
|
||||||
embed_dim=1408,
|
|
||||||
bias=True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.proj = torch.nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.proj(x)
|
|
||||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
|
|
||||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
|
||||||
if not repeat_only:
|
|
||||||
half = dim // 2
|
|
||||||
freqs = torch.exp(
|
|
||||||
-math.log(max_period)
|
|
||||||
* torch.arange(start=0, end=half, dtype=torch.float32)
|
|
||||||
/ half
|
|
||||||
).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线
|
|
||||||
args = t[:, None].float() * freqs[None]
|
|
||||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
||||||
if dim % 2:
|
|
||||||
embedding = torch.cat(
|
|
||||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
embedding = repeat(t, "b -> b d", d=dim)
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
|
|
||||||
class TimestepEmbedder(torch.nn.Module):
|
|
||||||
def __init__(self, hidden_size=1408, frequency_embedding_size=256):
|
|
||||||
super().__init__()
|
|
||||||
self.mlp = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
|
||||||
torch.nn.SiLU(),
|
|
||||||
torch.nn.Linear(hidden_size, hidden_size, bias=True),
|
|
||||||
)
|
|
||||||
self.frequency_embedding_size = frequency_embedding_size
|
|
||||||
|
|
||||||
def forward(self, t):
|
|
||||||
t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
|
|
||||||
t_emb = self.mlp(t_freq)
|
|
||||||
return t_emb
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanDiT(torch.nn.Module):
|
|
||||||
def __init__(self, num_layers_down=21, num_layers_up=19, in_channels=4, out_channels=8, hidden_dim=1408, text_dim=1024, t5_dim=2048, text_length=77, t5_length=256):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# Embedders
|
|
||||||
self.text_emb_padding = torch.nn.Parameter(torch.randn(text_length + t5_length, text_dim, dtype=torch.float32))
|
|
||||||
self.t5_embedder = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(t5_dim, t5_dim * 4, bias=True),
|
|
||||||
FP32_SiLU(),
|
|
||||||
torch.nn.Linear(t5_dim * 4, text_dim, bias=True),
|
|
||||||
)
|
|
||||||
self.t5_pooler = AttentionPool(t5_length, t5_dim, num_heads=8, output_dim=1024)
|
|
||||||
self.style_embedder = torch.nn.Parameter(torch.randn(hidden_dim))
|
|
||||||
self.patch_embedder = PatchEmbed(in_chans=in_channels)
|
|
||||||
self.timestep_embedder = TimestepEmbedder()
|
|
||||||
self.extra_embedder = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(256 * 6 + 1024 + hidden_dim, hidden_dim * 4),
|
|
||||||
FP32_SiLU(),
|
|
||||||
torch.nn.Linear(hidden_dim * 4, hidden_dim),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Transformer blocks
|
|
||||||
self.num_layers_down = num_layers_down
|
|
||||||
self.num_layers_up = num_layers_up
|
|
||||||
self.blocks = torch.nn.ModuleList(
|
|
||||||
[HunyuanDiTBlock(skip_connection=False) for _ in range(num_layers_down)] + \
|
|
||||||
[HunyuanDiTBlock(skip_connection=True) for _ in range(num_layers_up)]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Output layers
|
|
||||||
self.final_layer = HunyuanDiTFinalLayer()
|
|
||||||
self.out_channels = out_channels
|
|
||||||
|
|
||||||
def prepare_text_emb(self, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5):
|
|
||||||
text_emb_mask = text_emb_mask.bool()
|
|
||||||
text_emb_mask_t5 = text_emb_mask_t5.bool()
|
|
||||||
text_emb_t5 = self.t5_embedder(text_emb_t5)
|
|
||||||
text_emb = torch.cat([text_emb, text_emb_t5], dim=1)
|
|
||||||
text_emb_mask = torch.cat([text_emb_mask, text_emb_mask_t5], dim=-1)
|
|
||||||
text_emb = torch.where(text_emb_mask.unsqueeze(2), text_emb, self.text_emb_padding.to(text_emb))
|
|
||||||
return text_emb
|
|
||||||
|
|
||||||
def prepare_extra_emb(self, text_emb_t5, timestep, size_emb, dtype, batch_size):
|
|
||||||
# Text embedding
|
|
||||||
pooled_text_emb_t5 = self.t5_pooler(text_emb_t5)
|
|
||||||
|
|
||||||
# Timestep embedding
|
|
||||||
timestep_emb = self.timestep_embedder(timestep)
|
|
||||||
|
|
||||||
# Size embedding
|
|
||||||
size_emb = timestep_embedding(size_emb.view(-1), 256).to(dtype)
|
|
||||||
size_emb = size_emb.view(-1, 6 * 256)
|
|
||||||
|
|
||||||
# Style embedding
|
|
||||||
style_emb = repeat(self.style_embedder, "D -> B D", B=batch_size)
|
|
||||||
|
|
||||||
# Concatenate all extra vectors
|
|
||||||
extra_emb = torch.cat([pooled_text_emb_t5, size_emb, style_emb], dim=1)
|
|
||||||
condition_emb = timestep_emb + self.extra_embedder(extra_emb)
|
|
||||||
|
|
||||||
return condition_emb
|
|
||||||
|
|
||||||
def unpatchify(self, x, h, w):
|
|
||||||
return rearrange(x, "B (H W) (P Q C) -> B C (H P) (W Q)", H=h, W=w, P=2, Q=2)
|
|
||||||
|
|
||||||
def build_mask(self, data, is_bound):
|
|
||||||
_, _, H, W = data.shape
|
|
||||||
h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
|
|
||||||
w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
|
|
||||||
border_width = (H + W) // 4
|
|
||||||
pad = torch.ones_like(h) * border_width
|
|
||||||
mask = torch.stack([
|
|
||||||
pad if is_bound[0] else h + 1,
|
|
||||||
pad if is_bound[1] else H - h,
|
|
||||||
pad if is_bound[2] else w + 1,
|
|
||||||
pad if is_bound[3] else W - w
|
|
||||||
]).min(dim=0).values
|
|
||||||
mask = mask.clip(1, border_width)
|
|
||||||
mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
|
|
||||||
mask = rearrange(mask, "H W -> 1 H W")
|
|
||||||
return mask
|
|
||||||
|
|
||||||
def tiled_block_forward(self, block, hidden_states, condition_emb, text_emb, freq_cis_img, residual, torch_dtype, data_device, computation_device, tile_size, tile_stride):
|
|
||||||
B, C, H, W = hidden_states.shape
|
|
||||||
|
|
||||||
weight = torch.zeros((1, 1, H, W), dtype=torch_dtype, device=data_device)
|
|
||||||
values = torch.zeros((B, C, H, W), dtype=torch_dtype, device=data_device)
|
|
||||||
|
|
||||||
# Split tasks
|
|
||||||
tasks = []
|
|
||||||
for h in range(0, H, tile_stride):
|
|
||||||
for w in range(0, W, tile_stride):
|
|
||||||
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
|
|
||||||
continue
|
|
||||||
h_, w_ = h + tile_size, w + tile_size
|
|
||||||
if h_ > H: h, h_ = H - tile_size, H
|
|
||||||
if w_ > W: w, w_ = W - tile_size, W
|
|
||||||
tasks.append((h, h_, w, w_))
|
|
||||||
|
|
||||||
# Run
|
|
||||||
for hl, hr, wl, wr in tasks:
|
|
||||||
hidden_states_batch = hidden_states[:, :, hl:hr, wl:wr].to(computation_device)
|
|
||||||
hidden_states_batch = rearrange(hidden_states_batch, "B C H W -> B (H W) C")
|
|
||||||
if residual is not None:
|
|
||||||
residual_batch = residual[:, :, hl:hr, wl:wr].to(computation_device)
|
|
||||||
residual_batch = rearrange(residual_batch, "B C H W -> B (H W) C")
|
|
||||||
else:
|
|
||||||
residual_batch = None
|
|
||||||
|
|
||||||
# Forward
|
|
||||||
hidden_states_batch = block(hidden_states_batch, condition_emb, text_emb, freq_cis_img, residual_batch).to(data_device)
|
|
||||||
hidden_states_batch = rearrange(hidden_states_batch, "B (H W) C -> B C H W", H=hr-hl)
|
|
||||||
|
|
||||||
mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
|
|
||||||
values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
|
|
||||||
weight[:, :, hl:hr, wl:wr] += mask
|
|
||||||
values /= weight
|
|
||||||
return values
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, hidden_states, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5, timestep, size_emb, freq_cis_img,
|
|
||||||
tiled=False, tile_size=64, tile_stride=32,
|
|
||||||
to_cache=False,
|
|
||||||
use_gradient_checkpointing=False,
|
|
||||||
):
|
|
||||||
# Embeddings
|
|
||||||
text_emb = self.prepare_text_emb(text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5)
|
|
||||||
condition_emb = self.prepare_extra_emb(text_emb_t5, timestep, size_emb, hidden_states.dtype, hidden_states.shape[0])
|
|
||||||
|
|
||||||
# Input
|
|
||||||
height, width = hidden_states.shape[-2], hidden_states.shape[-1]
|
|
||||||
hidden_states = self.patch_embedder(hidden_states)
|
|
||||||
|
|
||||||
# Blocks
|
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return module(*inputs)
|
|
||||||
return custom_forward
|
|
||||||
if tiled:
|
|
||||||
hidden_states = rearrange(hidden_states, "B (H W) C -> B C H W", H=height//2)
|
|
||||||
residuals = []
|
|
||||||
for block_id, block in enumerate(self.blocks):
|
|
||||||
residual = residuals.pop() if block_id >= self.num_layers_down else None
|
|
||||||
hidden_states = self.tiled_block_forward(
|
|
||||||
block, hidden_states, condition_emb, text_emb, freq_cis_img, residual,
|
|
||||||
torch_dtype=hidden_states.dtype, data_device=hidden_states.device, computation_device=hidden_states.device,
|
|
||||||
tile_size=tile_size, tile_stride=tile_stride
|
|
||||||
)
|
|
||||||
if block_id < self.num_layers_down - 2:
|
|
||||||
residuals.append(hidden_states)
|
|
||||||
hidden_states = rearrange(hidden_states, "B C H W -> B (H W) C")
|
|
||||||
else:
|
|
||||||
residuals = []
|
|
||||||
for block_id, block in enumerate(self.blocks):
|
|
||||||
residual = residuals.pop() if block_id >= self.num_layers_down else None
|
|
||||||
if self.training and use_gradient_checkpointing:
|
|
||||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
hidden_states, condition_emb, text_emb, freq_cis_img, residual,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states = block(hidden_states, condition_emb, text_emb, freq_cis_img, residual, to_cache=to_cache)
|
|
||||||
if block_id < self.num_layers_down - 2:
|
|
||||||
residuals.append(hidden_states)
|
|
||||||
|
|
||||||
# Output
|
|
||||||
hidden_states = self.final_layer(hidden_states, condition_emb)
|
|
||||||
hidden_states = self.unpatchify(hidden_states, height//2, width//2)
|
|
||||||
hidden_states, _ = hidden_states.chunk(2, dim=1)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return HunyuanDiTStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanDiTStateDictConverter():
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
name_ = name
|
|
||||||
name_ = name_.replace(".default_modulation.", ".modulation.")
|
|
||||||
name_ = name_.replace(".mlp.fc1.", ".mlp.0.")
|
|
||||||
name_ = name_.replace(".mlp.fc2.", ".mlp.2.")
|
|
||||||
name_ = name_.replace(".attn1.q_norm.", ".rota1.q_norm.")
|
|
||||||
name_ = name_.replace(".attn2.q_norm.", ".rota2.q_norm.")
|
|
||||||
name_ = name_.replace(".attn1.k_norm.", ".rota1.k_norm.")
|
|
||||||
name_ = name_.replace(".attn2.k_norm.", ".rota2.k_norm.")
|
|
||||||
name_ = name_.replace(".q_proj.", ".to_q.")
|
|
||||||
name_ = name_.replace(".out_proj.", ".to_out.")
|
|
||||||
name_ = name_.replace("text_embedding_padding", "text_emb_padding")
|
|
||||||
name_ = name_.replace("mlp_t5.0.", "t5_embedder.0.")
|
|
||||||
name_ = name_.replace("mlp_t5.2.", "t5_embedder.2.")
|
|
||||||
name_ = name_.replace("pooler.", "t5_pooler.")
|
|
||||||
name_ = name_.replace("x_embedder.", "patch_embedder.")
|
|
||||||
name_ = name_.replace("t_embedder.", "timestep_embedder.")
|
|
||||||
name_ = name_.replace("t5_pooler.to_q.", "t5_pooler.q_proj.")
|
|
||||||
name_ = name_.replace("style_embedder.weight", "style_embedder")
|
|
||||||
if ".kv_proj." in name_:
|
|
||||||
param_k = param[:param.shape[0]//2]
|
|
||||||
param_v = param[param.shape[0]//2:]
|
|
||||||
state_dict_[name_.replace(".kv_proj.", ".to_k.")] = param_k
|
|
||||||
state_dict_[name_.replace(".kv_proj.", ".to_v.")] = param_v
|
|
||||||
elif ".Wqkv." in name_:
|
|
||||||
param_q = param[:param.shape[0]//3]
|
|
||||||
param_k = param[param.shape[0]//3:param.shape[0]//3*2]
|
|
||||||
param_v = param[param.shape[0]//3*2:]
|
|
||||||
state_dict_[name_.replace(".Wqkv.", ".to_q.")] = param_q
|
|
||||||
state_dict_[name_.replace(".Wqkv.", ".to_k.")] = param_k
|
|
||||||
state_dict_[name_.replace(".Wqkv.", ".to_v.")] = param_v
|
|
||||||
elif "style_embedder" in name_:
|
|
||||||
state_dict_[name_] = param.squeeze()
|
|
||||||
else:
|
|
||||||
state_dict_[name_] = param
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return self.from_diffusers(state_dict)
|
|
||||||
@@ -1,163 +0,0 @@
|
|||||||
from transformers import BertModel, BertConfig, T5EncoderModel, T5Config
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanDiTCLIPTextEncoder(BertModel):
|
|
||||||
def __init__(self):
|
|
||||||
config = BertConfig(
|
|
||||||
_name_or_path = "",
|
|
||||||
architectures = ["BertModel"],
|
|
||||||
attention_probs_dropout_prob = 0.1,
|
|
||||||
bos_token_id = 0,
|
|
||||||
classifier_dropout = None,
|
|
||||||
directionality = "bidi",
|
|
||||||
eos_token_id = 2,
|
|
||||||
hidden_act = "gelu",
|
|
||||||
hidden_dropout_prob = 0.1,
|
|
||||||
hidden_size = 1024,
|
|
||||||
initializer_range = 0.02,
|
|
||||||
intermediate_size = 4096,
|
|
||||||
layer_norm_eps = 1e-12,
|
|
||||||
max_position_embeddings = 512,
|
|
||||||
model_type = "bert",
|
|
||||||
num_attention_heads = 16,
|
|
||||||
num_hidden_layers = 24,
|
|
||||||
output_past = True,
|
|
||||||
pad_token_id = 0,
|
|
||||||
pooler_fc_size = 768,
|
|
||||||
pooler_num_attention_heads = 12,
|
|
||||||
pooler_num_fc_layers = 3,
|
|
||||||
pooler_size_per_head = 128,
|
|
||||||
pooler_type = "first_token_transform",
|
|
||||||
position_embedding_type = "absolute",
|
|
||||||
torch_dtype = "float32",
|
|
||||||
transformers_version = "4.37.2",
|
|
||||||
type_vocab_size = 2,
|
|
||||||
use_cache = True,
|
|
||||||
vocab_size = 47020
|
|
||||||
)
|
|
||||||
super().__init__(config, add_pooling_layer=False)
|
|
||||||
self.eval()
|
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask, clip_skip=1):
|
|
||||||
input_shape = input_ids.size()
|
|
||||||
|
|
||||||
batch_size, seq_length = input_shape
|
|
||||||
device = input_ids.device
|
|
||||||
|
|
||||||
past_key_values_length = 0
|
|
||||||
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
|
||||||
|
|
||||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
|
||||||
|
|
||||||
embedding_output = self.embeddings(
|
|
||||||
input_ids=input_ids,
|
|
||||||
position_ids=None,
|
|
||||||
token_type_ids=None,
|
|
||||||
inputs_embeds=None,
|
|
||||||
past_key_values_length=0,
|
|
||||||
)
|
|
||||||
encoder_outputs = self.encoder(
|
|
||||||
embedding_output,
|
|
||||||
attention_mask=extended_attention_mask,
|
|
||||||
head_mask=None,
|
|
||||||
encoder_hidden_states=None,
|
|
||||||
encoder_attention_mask=None,
|
|
||||||
past_key_values=None,
|
|
||||||
use_cache=False,
|
|
||||||
output_attentions=False,
|
|
||||||
output_hidden_states=True,
|
|
||||||
return_dict=True,
|
|
||||||
)
|
|
||||||
all_hidden_states = encoder_outputs.hidden_states
|
|
||||||
prompt_emb = all_hidden_states[-clip_skip]
|
|
||||||
if clip_skip > 1:
|
|
||||||
mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std()
|
|
||||||
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
|
||||||
return prompt_emb
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return HunyuanDiTCLIPTextEncoderStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanDiTT5TextEncoder(T5EncoderModel):
|
|
||||||
def __init__(self):
|
|
||||||
config = T5Config(
|
|
||||||
_name_or_path = "../HunyuanDiT/t2i/mt5",
|
|
||||||
architectures = ["MT5ForConditionalGeneration"],
|
|
||||||
classifier_dropout = 0.0,
|
|
||||||
d_ff = 5120,
|
|
||||||
d_kv = 64,
|
|
||||||
d_model = 2048,
|
|
||||||
decoder_start_token_id = 0,
|
|
||||||
dense_act_fn = "gelu_new",
|
|
||||||
dropout_rate = 0.1,
|
|
||||||
eos_token_id = 1,
|
|
||||||
feed_forward_proj = "gated-gelu",
|
|
||||||
initializer_factor = 1.0,
|
|
||||||
is_encoder_decoder = True,
|
|
||||||
is_gated_act = True,
|
|
||||||
layer_norm_epsilon = 1e-06,
|
|
||||||
model_type = "t5",
|
|
||||||
num_decoder_layers = 24,
|
|
||||||
num_heads = 32,
|
|
||||||
num_layers = 24,
|
|
||||||
output_past = True,
|
|
||||||
pad_token_id = 0,
|
|
||||||
relative_attention_max_distance = 128,
|
|
||||||
relative_attention_num_buckets = 32,
|
|
||||||
tie_word_embeddings = False,
|
|
||||||
tokenizer_class = "T5Tokenizer",
|
|
||||||
transformers_version = "4.37.2",
|
|
||||||
use_cache = True,
|
|
||||||
vocab_size = 250112
|
|
||||||
)
|
|
||||||
super().__init__(config)
|
|
||||||
self.eval()
|
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask, clip_skip=1):
|
|
||||||
outputs = super().forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
output_hidden_states=True,
|
|
||||||
)
|
|
||||||
prompt_emb = outputs.hidden_states[-clip_skip]
|
|
||||||
if clip_skip > 1:
|
|
||||||
mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std()
|
|
||||||
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
|
||||||
return prompt_emb
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return HunyuanDiTT5TextEncoderStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanDiTCLIPTextEncoderStateDictConverter():
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")}
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return self.from_diffusers(state_dict)
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanDiTT5TextEncoderStateDictConverter():
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")}
|
|
||||||
state_dict_["shared.weight"] = state_dict["shared.weight"]
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return self.from_diffusers(state_dict)
|
|
||||||
File diff suppressed because it is too large
Load Diff
902
diffsynth/models/longcat_video_dit.py
Normal file
902
diffsynth/models/longcat_video_dit.py
Normal file
@@ -0,0 +1,902 @@
|
|||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.amp as amp
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from .wan_video_dit import flash_attention
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
|
from ..core.gradient import gradient_checkpoint_forward
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm_FP32(torch.nn.Module):
|
||||||
|
def __init__(self, dim: int, eps: float):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
def _norm(self, x):
|
||||||
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = self._norm(x.float()).type_as(x)
|
||||||
|
return output * self.weight
|
||||||
|
|
||||||
|
|
||||||
|
def broadcat(tensors, dim=-1):
|
||||||
|
num_tensors = len(tensors)
|
||||||
|
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
||||||
|
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
||||||
|
shape_len = list(shape_lens)[0]
|
||||||
|
dim = (dim + shape_len) if dim < 0 else dim
|
||||||
|
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
||||||
|
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
||||||
|
assert all(
|
||||||
|
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
|
||||||
|
), "invalid dimensions for broadcastable concatentation"
|
||||||
|
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
||||||
|
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
||||||
|
expanded_dims.insert(dim, (dim, dims[dim]))
|
||||||
|
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
||||||
|
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
||||||
|
return torch.cat(tensors, dim=dim)
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
||||||
|
x1, x2 = x.unbind(dim=-1)
|
||||||
|
x = torch.stack((-x2, x1), dim=-1)
|
||||||
|
return rearrange(x, "... d r -> ... (d r)")
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryPositionalEmbedding(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
head_dim,
|
||||||
|
cp_split_hw=None
|
||||||
|
):
|
||||||
|
"""Rotary positional embedding for 3D
|
||||||
|
Reference : https://blog.eleuther.ai/rotary-embeddings/
|
||||||
|
Paper: https://arxiv.org/pdf/2104.09864.pdf
|
||||||
|
Args:
|
||||||
|
dim: Dimension of embedding
|
||||||
|
base: Base value for exponential
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.head_dim = head_dim
|
||||||
|
assert self.head_dim % 8 == 0, 'Dim must be a multiply of 8 for 3D RoPE.'
|
||||||
|
self.cp_split_hw = cp_split_hw
|
||||||
|
# We take the assumption that the longest side of grid will not larger than 512, i.e, 512 * 8 = 4098 input pixels
|
||||||
|
self.base = 10000
|
||||||
|
self.freqs_dict = {}
|
||||||
|
|
||||||
|
def register_grid_size(self, grid_size):
|
||||||
|
if grid_size not in self.freqs_dict:
|
||||||
|
self.freqs_dict.update({
|
||||||
|
grid_size: self.precompute_freqs_cis_3d(grid_size)
|
||||||
|
})
|
||||||
|
|
||||||
|
def precompute_freqs_cis_3d(self, grid_size):
|
||||||
|
num_frames, height, width = grid_size
|
||||||
|
dim_t = self.head_dim - 4 * (self.head_dim // 6)
|
||||||
|
dim_h = 2 * (self.head_dim // 6)
|
||||||
|
dim_w = 2 * (self.head_dim // 6)
|
||||||
|
freqs_t = 1.0 / (self.base ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
|
||||||
|
freqs_h = 1.0 / (self.base ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
|
||||||
|
freqs_w = 1.0 / (self.base ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
|
||||||
|
grid_t = np.linspace(0, num_frames, num_frames, endpoint=False, dtype=np.float32)
|
||||||
|
grid_h = np.linspace(0, height, height, endpoint=False, dtype=np.float32)
|
||||||
|
grid_w = np.linspace(0, width, width, endpoint=False, dtype=np.float32)
|
||||||
|
grid_t = torch.from_numpy(grid_t).float()
|
||||||
|
grid_h = torch.from_numpy(grid_h).float()
|
||||||
|
grid_w = torch.from_numpy(grid_w).float()
|
||||||
|
freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t)
|
||||||
|
freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h)
|
||||||
|
freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w)
|
||||||
|
freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2)
|
||||||
|
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
|
||||||
|
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
|
||||||
|
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
||||||
|
# (T H W D)
|
||||||
|
freqs = rearrange(freqs, "T H W D -> (T H W) D")
|
||||||
|
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
|
||||||
|
# with torch.no_grad():
|
||||||
|
# freqs = rearrange(freqs, "(T H W) D -> T H W D", T=num_frames, H=height, W=width)
|
||||||
|
# freqs = context_parallel_util.split_cp_2d(freqs, seq_dim_hw=(1, 2), split_hw=self.cp_split_hw)
|
||||||
|
# freqs = rearrange(freqs, "T H W D -> (T H W) D")
|
||||||
|
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
def forward(self, q, k, grid_size):
|
||||||
|
"""3D RoPE.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: [B, head, seq, head_dim]
|
||||||
|
key: [B, head, seq, head_dim]
|
||||||
|
Returns:
|
||||||
|
query and key with the same shape as input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if grid_size not in self.freqs_dict:
|
||||||
|
self.register_grid_size(grid_size)
|
||||||
|
|
||||||
|
freqs_cis = self.freqs_dict[grid_size].to(q.device)
|
||||||
|
q_, k_ = q.float(), k.float()
|
||||||
|
freqs_cis = freqs_cis.float().to(q.device)
|
||||||
|
cos, sin = freqs_cis.cos(), freqs_cis.sin()
|
||||||
|
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
|
||||||
|
q_ = (q_ * cos) + (rotate_half(q_) * sin)
|
||||||
|
k_ = (k_ * cos) + (rotate_half(k_) * sin)
|
||||||
|
|
||||||
|
return q_.type_as(q), k_.type_as(k)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
enable_flashattn3: bool = False,
|
||||||
|
enable_flashattn2: bool = False,
|
||||||
|
enable_xformers: bool = False,
|
||||||
|
enable_bsa: bool = False,
|
||||||
|
bsa_params: dict = None,
|
||||||
|
cp_split_hw: Optional[List[int]] = None
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
self.enable_flashattn3 = enable_flashattn3
|
||||||
|
self.enable_flashattn2 = enable_flashattn2
|
||||||
|
self.enable_xformers = enable_xformers
|
||||||
|
self.enable_bsa = enable_bsa
|
||||||
|
self.bsa_params = bsa_params
|
||||||
|
self.cp_split_hw = cp_split_hw
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
||||||
|
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
|
||||||
|
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
self.rope_3d = RotaryPositionalEmbedding(
|
||||||
|
self.head_dim,
|
||||||
|
cp_split_hw=cp_split_hw
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_attn(self, q, k, v, shape):
|
||||||
|
q = rearrange(q, "B H S D -> B S (H D)")
|
||||||
|
k = rearrange(k, "B H S D -> B S (H D)")
|
||||||
|
v = rearrange(v, "B H S D -> B S (H D)")
|
||||||
|
x = flash_attention(q, k, v, num_heads=self.num_heads)
|
||||||
|
x = rearrange(x, "B S (H D) -> B H S D", H=self.num_heads)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, shape=None, num_cond_latents=None, return_kv=False) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
B, N, C = x.shape
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
|
||||||
|
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
|
||||||
|
qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]
|
||||||
|
q, k, v = qkv.unbind(0)
|
||||||
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
|
if return_kv:
|
||||||
|
k_cache, v_cache = k.clone(), v.clone()
|
||||||
|
|
||||||
|
q, k = self.rope_3d(q, k, shape)
|
||||||
|
|
||||||
|
# cond mode
|
||||||
|
if num_cond_latents is not None and num_cond_latents > 0:
|
||||||
|
num_cond_latents_thw = num_cond_latents * (N // shape[0])
|
||||||
|
# process the condition tokens
|
||||||
|
q_cond = q[:, :, :num_cond_latents_thw].contiguous()
|
||||||
|
k_cond = k[:, :, :num_cond_latents_thw].contiguous()
|
||||||
|
v_cond = v[:, :, :num_cond_latents_thw].contiguous()
|
||||||
|
x_cond = self._process_attn(q_cond, k_cond, v_cond, shape)
|
||||||
|
# process the noise tokens
|
||||||
|
q_noise = q[:, :, num_cond_latents_thw:].contiguous()
|
||||||
|
x_noise = self._process_attn(q_noise, k, v, shape)
|
||||||
|
# merge x_cond and x_noise
|
||||||
|
x = torch.cat([x_cond, x_noise], dim=2).contiguous()
|
||||||
|
else:
|
||||||
|
x = self._process_attn(q, k, v, shape)
|
||||||
|
|
||||||
|
x_output_shape = (B, N, C)
|
||||||
|
x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]
|
||||||
|
x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]
|
||||||
|
x = self.proj(x)
|
||||||
|
|
||||||
|
if return_kv:
|
||||||
|
return x, (k_cache, v_cache)
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward_with_kv_cache(self, x: torch.Tensor, shape=None, num_cond_latents=None, kv_cache=None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
B, N, C = x.shape
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
|
||||||
|
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
|
||||||
|
qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]
|
||||||
|
q, k, v = qkv.unbind(0)
|
||||||
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
|
T, H, W = shape
|
||||||
|
k_cache, v_cache = kv_cache
|
||||||
|
assert k_cache.shape[0] == v_cache.shape[0] and k_cache.shape[0] in [1, B]
|
||||||
|
if k_cache.shape[0] == 1:
|
||||||
|
k_cache = k_cache.repeat(B, 1, 1, 1)
|
||||||
|
v_cache = v_cache.repeat(B, 1, 1, 1)
|
||||||
|
|
||||||
|
if num_cond_latents is not None and num_cond_latents > 0:
|
||||||
|
k_full = torch.cat([k_cache, k], dim=2).contiguous()
|
||||||
|
v_full = torch.cat([v_cache, v], dim=2).contiguous()
|
||||||
|
q_padding = torch.cat([torch.empty_like(k_cache), q], dim=2).contiguous()
|
||||||
|
q_padding, k_full = self.rope_3d(q_padding, k_full, (T + num_cond_latents, H, W))
|
||||||
|
q = q_padding[:, :, -N:].contiguous()
|
||||||
|
|
||||||
|
x = self._process_attn(q, k_full, v_full, shape)
|
||||||
|
|
||||||
|
x_output_shape = (B, N, C)
|
||||||
|
x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]
|
||||||
|
x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]
|
||||||
|
x = self.proj(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadCrossAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_heads,
|
||||||
|
enable_flashattn3=False,
|
||||||
|
enable_flashattn2=False,
|
||||||
|
enable_xformers=False,
|
||||||
|
):
|
||||||
|
super(MultiHeadCrossAttention, self).__init__()
|
||||||
|
assert dim % num_heads == 0, "d_model must be divisible by num_heads"
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
|
||||||
|
self.q_linear = nn.Linear(dim, dim)
|
||||||
|
self.kv_linear = nn.Linear(dim, dim * 2)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
|
||||||
|
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
|
||||||
|
|
||||||
|
self.enable_flashattn3 = enable_flashattn3
|
||||||
|
self.enable_flashattn2 = enable_flashattn2
|
||||||
|
self.enable_xformers = enable_xformers
|
||||||
|
|
||||||
|
def _process_cross_attn(self, x, cond, kv_seqlen):
|
||||||
|
B, N, C = x.shape
|
||||||
|
assert C == self.dim and cond.shape[2] == self.dim
|
||||||
|
|
||||||
|
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
|
||||||
|
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
||||||
|
k, v = kv.unbind(2)
|
||||||
|
|
||||||
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
|
q = rearrange(q, "B S H D -> B S (H D)")
|
||||||
|
k = rearrange(k, "B S H D -> B S (H D)")
|
||||||
|
v = rearrange(v, "B S H D -> B S (H D)")
|
||||||
|
x = flash_attention(q, k, v, num_heads=self.num_heads)
|
||||||
|
|
||||||
|
x = x.view(B, -1, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None):
|
||||||
|
"""
|
||||||
|
x: [B, N, C]
|
||||||
|
cond: [B, M, C]
|
||||||
|
"""
|
||||||
|
if num_cond_latents is None or num_cond_latents == 0:
|
||||||
|
return self._process_cross_attn(x, cond, kv_seqlen)
|
||||||
|
else:
|
||||||
|
B, N, C = x.shape
|
||||||
|
if num_cond_latents is not None and num_cond_latents > 0:
|
||||||
|
assert shape is not None, "SHOULD pass in the shape"
|
||||||
|
num_cond_latents_thw = num_cond_latents * (N // shape[0])
|
||||||
|
x_noise = x[:, num_cond_latents_thw:] # [B, N_noise, C]
|
||||||
|
output_noise = self._process_cross_attn(x_noise, cond, kv_seqlen) # [B, N_noise, C]
|
||||||
|
output = torch.cat([
|
||||||
|
torch.zeros((B, num_cond_latents_thw, C), dtype=output_noise.dtype, device=output_noise.device),
|
||||||
|
output_noise
|
||||||
|
], dim=1).contiguous()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm_FP32(nn.LayerNorm):
|
||||||
|
def __init__(self, dim, eps, elementwise_affine):
|
||||||
|
super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||||
|
|
||||||
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||||
|
origin_dtype = inputs.dtype
|
||||||
|
out = F.layer_norm(
|
||||||
|
inputs.float(),
|
||||||
|
self.normalized_shape,
|
||||||
|
None if self.weight is None else self.weight.float(),
|
||||||
|
None if self.bias is None else self.bias.float() ,
|
||||||
|
self.eps
|
||||||
|
).to(origin_dtype)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def modulate_fp32(norm_func, x, shift, scale):
|
||||||
|
# Suppose x is (B, N, D), shift is (B, -1, D), scale is (B, -1, D)
|
||||||
|
# ensure the modulation params be fp32
|
||||||
|
assert shift.dtype == torch.float32, scale.dtype == torch.float32
|
||||||
|
dtype = x.dtype
|
||||||
|
x = norm_func(x.to(torch.float32))
|
||||||
|
x = x * (scale + 1) + shift
|
||||||
|
x = x.to(dtype)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FinalLayer_FP32(nn.Module):
|
||||||
|
"""
|
||||||
|
The final layer of DiT.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, num_patch, out_channels, adaln_tembed_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_patch = num_patch
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.adaln_tembed_dim = adaln_tembed_dim
|
||||||
|
|
||||||
|
self.norm_final = LayerNorm_FP32(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
|
||||||
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(adaln_tembed_dim, 2 * hidden_size, bias=True))
|
||||||
|
|
||||||
|
def forward(self, x, t, latent_shape):
|
||||||
|
# timestep shape: [B, T, C]
|
||||||
|
assert t.dtype == torch.float32
|
||||||
|
B, N, C = x.shape
|
||||||
|
T, _, _ = latent_shape
|
||||||
|
|
||||||
|
with amp.autocast(get_device_type(), dtype=torch.float32):
|
||||||
|
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
|
||||||
|
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForwardSwiGLU(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
multiple_of: int = 256,
|
||||||
|
ffn_dim_multiplier: Optional[float] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
hidden_dim = int(2 * hidden_dim / 3)
|
||||||
|
# custom dim factor multiplier
|
||||||
|
if ffn_dim_multiplier is not None:
|
||||||
|
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||||
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||||
|
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedder(nn.Module):
|
||||||
|
"""
|
||||||
|
Embeds scalar timesteps into vector representations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, t_embed_dim, frequency_embedding_size=256):
|
||||||
|
super().__init__()
|
||||||
|
self.t_embed_dim = t_embed_dim
|
||||||
|
self.frequency_embedding_size = frequency_embedding_size
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(frequency_embedding_size, t_embed_dim, bias=True),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(t_embed_dim, t_embed_dim, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def timestep_embedding(t, dim, max_period=10000):
|
||||||
|
"""
|
||||||
|
Create sinusoidal timestep embeddings.
|
||||||
|
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||||
|
These may be fractional.
|
||||||
|
:param dim: the dimension of the output.
|
||||||
|
:param max_period: controls the minimum frequency of the embeddings.
|
||||||
|
:return: an (N, D) Tensor of positional embeddings.
|
||||||
|
"""
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
|
||||||
|
freqs = freqs.to(device=t.device)
|
||||||
|
args = t[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def forward(self, t, dtype):
|
||||||
|
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||||
|
if t_freq.dtype != dtype:
|
||||||
|
t_freq = t_freq.to(dtype)
|
||||||
|
t_emb = self.mlp(t_freq)
|
||||||
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
|
class CaptionEmbedder(nn.Module):
|
||||||
|
"""
|
||||||
|
Embeds class labels into vector representations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, hidden_size):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.y_proj = nn.Sequential(
|
||||||
|
nn.Linear(in_channels, hidden_size, bias=True),
|
||||||
|
nn.GELU(approximate="tanh"),
|
||||||
|
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, caption):
|
||||||
|
B, _, N, C = caption.shape
|
||||||
|
caption = self.y_proj(caption)
|
||||||
|
return caption
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed3D(nn.Module):
|
||||||
|
"""Video to Patch Embedding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
patch_size (int): Patch token size. Default: (2,4,4).
|
||||||
|
in_chans (int): Number of input video channels. Default: 3.
|
||||||
|
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||||
|
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size=(2, 4, 4),
|
||||||
|
in_chans=3,
|
||||||
|
embed_dim=96,
|
||||||
|
norm_layer=None,
|
||||||
|
flatten=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.flatten = flatten
|
||||||
|
|
||||||
|
self.in_chans = in_chans
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
|
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||||
|
if norm_layer is not None:
|
||||||
|
self.norm = norm_layer(embed_dim)
|
||||||
|
else:
|
||||||
|
self.norm = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward function."""
|
||||||
|
# padding
|
||||||
|
_, _, D, H, W = x.size()
|
||||||
|
if W % self.patch_size[2] != 0:
|
||||||
|
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
|
||||||
|
if H % self.patch_size[1] != 0:
|
||||||
|
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
|
||||||
|
if D % self.patch_size[0] != 0:
|
||||||
|
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
|
||||||
|
|
||||||
|
B, C, T, H, W = x.shape
|
||||||
|
x = self.proj(x) # (B C T H W)
|
||||||
|
if self.norm is not None:
|
||||||
|
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
|
||||||
|
if self.flatten:
|
||||||
|
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LongCatSingleStreamBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: int,
|
||||||
|
adaln_tembed_dim: int,
|
||||||
|
enable_flashattn3: bool = False,
|
||||||
|
enable_flashattn2: bool = False,
|
||||||
|
enable_xformers: bool = False,
|
||||||
|
enable_bsa: bool = False,
|
||||||
|
bsa_params=None,
|
||||||
|
cp_split_hw=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
|
# scale and gate modulation
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(adaln_tembed_dim, 6 * hidden_size, bias=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mod_norm_attn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)
|
||||||
|
self.mod_norm_ffn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)
|
||||||
|
self.pre_crs_attn_norm = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=True)
|
||||||
|
|
||||||
|
self.attn = Attention(
|
||||||
|
dim=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
enable_flashattn3=enable_flashattn3,
|
||||||
|
enable_flashattn2=enable_flashattn2,
|
||||||
|
enable_xformers=enable_xformers,
|
||||||
|
enable_bsa=enable_bsa,
|
||||||
|
bsa_params=bsa_params,
|
||||||
|
cp_split_hw=cp_split_hw
|
||||||
|
)
|
||||||
|
self.cross_attn = MultiHeadCrossAttention(
|
||||||
|
dim=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
enable_flashattn3=enable_flashattn3,
|
||||||
|
enable_flashattn2=enable_flashattn2,
|
||||||
|
enable_xformers=enable_xformers,
|
||||||
|
)
|
||||||
|
self.ffn = FeedForwardSwiGLU(dim=hidden_size, hidden_dim=int(hidden_size * mlp_ratio))
|
||||||
|
|
||||||
|
def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return_kv=False, kv_cache=None, skip_crs_attn=False):
|
||||||
|
"""
|
||||||
|
x: [B, N, C]
|
||||||
|
y: [1, N_valid_tokens, C]
|
||||||
|
t: [B, T, C_t]
|
||||||
|
y_seqlen: [B]; type of a list
|
||||||
|
latent_shape: latent shape of a single item
|
||||||
|
"""
|
||||||
|
x_dtype = x.dtype
|
||||||
|
|
||||||
|
B, N, C = x.shape
|
||||||
|
T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.
|
||||||
|
|
||||||
|
# compute modulation params in fp32
|
||||||
|
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||||
|
shift_msa, scale_msa, gate_msa, \
|
||||||
|
shift_mlp, scale_mlp, gate_mlp = \
|
||||||
|
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
|
||||||
|
|
||||||
|
# self attn with modulation
|
||||||
|
x_m = modulate_fp32(self.mod_norm_attn, x.view(B, T, -1, C), shift_msa, scale_msa).view(B, N, C)
|
||||||
|
|
||||||
|
if kv_cache is not None:
|
||||||
|
kv_cache = (kv_cache[0].to(x.device), kv_cache[1].to(x.device))
|
||||||
|
attn_outputs = self.attn.forward_with_kv_cache(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, kv_cache=kv_cache)
|
||||||
|
else:
|
||||||
|
attn_outputs = self.attn(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, return_kv=return_kv)
|
||||||
|
|
||||||
|
if return_kv:
|
||||||
|
x_s, kv_cache = attn_outputs
|
||||||
|
else:
|
||||||
|
x_s = attn_outputs
|
||||||
|
|
||||||
|
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||||
|
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
||||||
|
x = x.to(x_dtype)
|
||||||
|
|
||||||
|
# cross attn
|
||||||
|
if not skip_crs_attn:
|
||||||
|
if kv_cache is not None:
|
||||||
|
num_cond_latents = None
|
||||||
|
x = x + self.cross_attn(self.pre_crs_attn_norm(x), y, y_seqlen, num_cond_latents=num_cond_latents, shape=latent_shape)
|
||||||
|
|
||||||
|
# ffn with modulation
|
||||||
|
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
|
||||||
|
x_s = self.ffn(x_m)
|
||||||
|
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||||
|
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
||||||
|
x = x.to(x_dtype)
|
||||||
|
|
||||||
|
if return_kv:
|
||||||
|
return x, kv_cache
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LongCatVideoTransformer3DModel(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 16,
|
||||||
|
out_channels: int = 16,
|
||||||
|
hidden_size: int = 4096,
|
||||||
|
depth: int = 48,
|
||||||
|
num_heads: int = 32,
|
||||||
|
caption_channels: int = 4096,
|
||||||
|
mlp_ratio: int = 4,
|
||||||
|
adaln_tembed_dim: int = 512,
|
||||||
|
frequency_embedding_size: int = 256,
|
||||||
|
# default params
|
||||||
|
patch_size: Tuple[int] = (1, 2, 2),
|
||||||
|
# attention config
|
||||||
|
enable_flashattn3: bool = False,
|
||||||
|
enable_flashattn2: bool = True,
|
||||||
|
enable_xformers: bool = False,
|
||||||
|
enable_bsa: bool = False,
|
||||||
|
bsa_params: dict = {'sparsity': 0.9375, 'chunk_3d_shape_q': [4, 4, 4], 'chunk_3d_shape_k': [4, 4, 4]},
|
||||||
|
cp_split_hw: Optional[List[int]] = [1, 1],
|
||||||
|
text_tokens_zero_pad: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.cp_split_hw = cp_split_hw
|
||||||
|
|
||||||
|
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
|
||||||
|
self.t_embedder = TimestepEmbedder(t_embed_dim=adaln_tembed_dim, frequency_embedding_size=frequency_embedding_size)
|
||||||
|
self.y_embedder = CaptionEmbedder(
|
||||||
|
in_channels=caption_channels,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
LongCatSingleStreamBlock(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
adaln_tembed_dim=adaln_tembed_dim,
|
||||||
|
enable_flashattn3=enable_flashattn3,
|
||||||
|
enable_flashattn2=enable_flashattn2,
|
||||||
|
enable_xformers=enable_xformers,
|
||||||
|
enable_bsa=enable_bsa,
|
||||||
|
bsa_params=bsa_params,
|
||||||
|
cp_split_hw=cp_split_hw
|
||||||
|
)
|
||||||
|
for i in range(depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_layer = FinalLayer_FP32(
|
||||||
|
hidden_size,
|
||||||
|
np.prod(self.patch_size),
|
||||||
|
out_channels,
|
||||||
|
adaln_tembed_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
self.text_tokens_zero_pad = text_tokens_zero_pad
|
||||||
|
|
||||||
|
self.lora_dict = {}
|
||||||
|
self.active_loras = []
|
||||||
|
|
||||||
|
def enable_loras(self, lora_key_list=[]):
|
||||||
|
self.disable_all_loras()
|
||||||
|
|
||||||
|
module_loras = {} # {module_name: [lora1, lora2, ...]}
|
||||||
|
model_device = next(self.parameters()).device
|
||||||
|
model_dtype = next(self.parameters()).dtype
|
||||||
|
|
||||||
|
for lora_key in lora_key_list:
|
||||||
|
if lora_key in self.lora_dict:
|
||||||
|
for lora in self.lora_dict[lora_key].loras:
|
||||||
|
lora.to(model_device, dtype=model_dtype, non_blocking=True)
|
||||||
|
module_name = lora.lora_name.replace("lora___lorahyphen___", "").replace("___lorahyphen___", ".")
|
||||||
|
if module_name not in module_loras:
|
||||||
|
module_loras[module_name] = []
|
||||||
|
module_loras[module_name].append(lora)
|
||||||
|
self.active_loras.append(lora_key)
|
||||||
|
|
||||||
|
for module_name, loras in module_loras.items():
|
||||||
|
module = self._get_module_by_name(module_name)
|
||||||
|
if not hasattr(module, 'org_forward'):
|
||||||
|
module.org_forward = module.forward
|
||||||
|
module.forward = self._create_multi_lora_forward(module, loras)
|
||||||
|
|
||||||
|
def _create_multi_lora_forward(self, module, loras):
|
||||||
|
def multi_lora_forward(x, *args, **kwargs):
|
||||||
|
weight_dtype = x.dtype
|
||||||
|
org_output = module.org_forward(x, *args, **kwargs)
|
||||||
|
|
||||||
|
total_lora_output = 0
|
||||||
|
for lora in loras:
|
||||||
|
if lora.use_lora:
|
||||||
|
lx = lora.lora_down(x.to(lora.lora_down.weight.dtype))
|
||||||
|
lx = lora.lora_up(lx)
|
||||||
|
lora_output = lx.to(weight_dtype) * lora.multiplier * lora.alpha_scale
|
||||||
|
total_lora_output += lora_output
|
||||||
|
|
||||||
|
return org_output + total_lora_output
|
||||||
|
|
||||||
|
return multi_lora_forward
|
||||||
|
|
||||||
|
def _get_module_by_name(self, module_name):
|
||||||
|
try:
|
||||||
|
module = self
|
||||||
|
for part in module_name.split('.'):
|
||||||
|
module = getattr(module, part)
|
||||||
|
return module
|
||||||
|
except AttributeError as e:
|
||||||
|
raise ValueError(f"Cannot find module: {module_name}, error: {e}")
|
||||||
|
|
||||||
|
def disable_all_loras(self):
|
||||||
|
for name, module in self.named_modules():
|
||||||
|
if hasattr(module, 'org_forward'):
|
||||||
|
module.forward = module.org_forward
|
||||||
|
delattr(module, 'org_forward')
|
||||||
|
|
||||||
|
for lora_key, lora_network in self.lora_dict.items():
|
||||||
|
for lora in lora_network.loras:
|
||||||
|
lora.to("cpu")
|
||||||
|
|
||||||
|
self.active_loras.clear()
|
||||||
|
|
||||||
|
def enable_bsa(self,):
|
||||||
|
for block in self.blocks:
|
||||||
|
block.attn.enable_bsa = True
|
||||||
|
|
||||||
|
def disable_bsa(self,):
|
||||||
|
for block in self.blocks:
|
||||||
|
block.attn.enable_bsa = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
timestep,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
num_cond_latents=0,
|
||||||
|
return_kv=False,
|
||||||
|
kv_cache_dict={},
|
||||||
|
skip_crs_attn=False,
|
||||||
|
offload_kv_cache=False,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
):
|
||||||
|
|
||||||
|
B, _, T, H, W = hidden_states.shape
|
||||||
|
|
||||||
|
N_t = T // self.patch_size[0]
|
||||||
|
N_h = H // self.patch_size[1]
|
||||||
|
N_w = W // self.patch_size[2]
|
||||||
|
|
||||||
|
assert self.patch_size[0]==1, "Currently, 3D x_embedder should not compress the temporal dimension."
|
||||||
|
|
||||||
|
# expand the shape of timestep from [B] to [B, T]
|
||||||
|
if len(timestep.shape) == 1:
|
||||||
|
timestep = timestep.unsqueeze(1).expand(-1, N_t).clone() # [B, T]
|
||||||
|
timestep[:, :num_cond_latents] = 0
|
||||||
|
|
||||||
|
dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(dtype)
|
||||||
|
timestep = timestep.to(dtype)
|
||||||
|
encoder_hidden_states = encoder_hidden_states.to(dtype)
|
||||||
|
|
||||||
|
hidden_states = self.x_embedder(hidden_states) # [B, N, C]
|
||||||
|
|
||||||
|
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||||
|
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
|
||||||
|
|
||||||
|
encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]
|
||||||
|
|
||||||
|
if self.text_tokens_zero_pad and encoder_attention_mask is not None:
|
||||||
|
encoder_hidden_states = encoder_hidden_states * encoder_attention_mask[:, None, :, None]
|
||||||
|
encoder_attention_mask = (encoder_attention_mask * 0 + 1).to(encoder_attention_mask.dtype)
|
||||||
|
|
||||||
|
if encoder_attention_mask is not None:
|
||||||
|
encoder_attention_mask = encoder_attention_mask.squeeze(1).squeeze(1)
|
||||||
|
encoder_hidden_states = encoder_hidden_states.squeeze(1).masked_select(encoder_attention_mask.unsqueeze(-1) != 0).view(1, -1, hidden_states.shape[-1]) # [1, N_valid_tokens, C]
|
||||||
|
y_seqlens = encoder_attention_mask.sum(dim=1).tolist() # [B]
|
||||||
|
else:
|
||||||
|
y_seqlens = [encoder_hidden_states.shape[2]] * encoder_hidden_states.shape[0]
|
||||||
|
encoder_hidden_states = encoder_hidden_states.squeeze(1).view(1, -1, hidden_states.shape[-1])
|
||||||
|
|
||||||
|
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
|
||||||
|
# hidden_states = rearrange(hidden_states, "B (T H W) C -> B T H W C", T=N_t, H=N_h, W=N_w)
|
||||||
|
# hidden_states = context_parallel_util.split_cp_2d(hidden_states, seq_dim_hw=(2, 3), split_hw=self.cp_split_hw)
|
||||||
|
# hidden_states = rearrange(hidden_states, "B T H W C -> B (T H W) C")
|
||||||
|
|
||||||
|
# blocks
|
||||||
|
kv_cache_dict_ret = {}
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
block_outputs = gradient_checkpoint_forward(
|
||||||
|
block,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
x=hidden_states,
|
||||||
|
y=encoder_hidden_states,
|
||||||
|
t=t,
|
||||||
|
y_seqlen=y_seqlens,
|
||||||
|
latent_shape=(N_t, N_h, N_w),
|
||||||
|
num_cond_latents=num_cond_latents,
|
||||||
|
return_kv=return_kv,
|
||||||
|
kv_cache=kv_cache_dict.get(i, None),
|
||||||
|
skip_crs_attn=skip_crs_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
if return_kv:
|
||||||
|
hidden_states, kv_cache = block_outputs
|
||||||
|
if offload_kv_cache:
|
||||||
|
kv_cache_dict_ret[i] = (kv_cache[0].cpu(), kv_cache[1].cpu())
|
||||||
|
else:
|
||||||
|
kv_cache_dict_ret[i] = (kv_cache[0].contiguous(), kv_cache[1].contiguous())
|
||||||
|
else:
|
||||||
|
hidden_states = block_outputs
|
||||||
|
|
||||||
|
hidden_states = self.final_layer(hidden_states, t, (N_t, N_h, N_w)) # [B, N, C=T_p*H_p*W_p*C_out]
|
||||||
|
|
||||||
|
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
|
||||||
|
# hidden_states = context_parallel_util.gather_cp_2d(hidden_states, shape=(N_t, N_h, N_w), split_hw=self.cp_split_hw)
|
||||||
|
|
||||||
|
hidden_states = self.unpatchify(hidden_states, N_t, N_h, N_w) # [B, C_out, H, W]
|
||||||
|
|
||||||
|
# cast to float32 for better accuracy
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
|
||||||
|
if return_kv:
|
||||||
|
return hidden_states, kv_cache_dict_ret
|
||||||
|
else:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def unpatchify(self, x, N_t, N_h, N_w):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): of shape [B, N, C]
|
||||||
|
|
||||||
|
Return:
|
||||||
|
x (torch.Tensor): of shape [B, C_out, T, H, W]
|
||||||
|
"""
|
||||||
|
T_p, H_p, W_p = self.patch_size
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
|
||||||
|
N_t=N_t,
|
||||||
|
N_h=N_h,
|
||||||
|
N_w=N_w,
|
||||||
|
T_p=T_p,
|
||||||
|
H_p=H_p,
|
||||||
|
W_p=W_p,
|
||||||
|
C_out=self.out_channels,
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return LongCatVideoTransformer3DModelDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class LongCatVideoTransformer3DModelDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
@@ -1,195 +0,0 @@
|
|||||||
import torch
|
|
||||||
from .sd_unet import SDUNet
|
|
||||||
from .sdxl_unet import SDXLUNet
|
|
||||||
from .sd_text_encoder import SDTextEncoder
|
|
||||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
|
||||||
from .sd3_dit import SD3DiT
|
|
||||||
from .hunyuan_dit import HunyuanDiT
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAFromCivitai:
|
|
||||||
def __init__(self):
|
|
||||||
self.supported_model_classes = []
|
|
||||||
self.lora_prefix = []
|
|
||||||
self.renamed_lora_prefix = {}
|
|
||||||
self.special_keys = {}
|
|
||||||
|
|
||||||
|
|
||||||
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
|
||||||
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
|
|
||||||
state_dict_ = {}
|
|
||||||
for key in state_dict:
|
|
||||||
if ".lora_up" not in key:
|
|
||||||
continue
|
|
||||||
if not key.startswith(lora_prefix):
|
|
||||||
continue
|
|
||||||
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
|
|
||||||
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
|
|
||||||
if len(weight_up.shape) == 4:
|
|
||||||
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
|
|
||||||
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
|
||||||
else:
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
|
||||||
target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight"
|
|
||||||
for special_key in self.special_keys:
|
|
||||||
target_name = target_name.replace(special_key, self.special_keys[special_key])
|
|
||||||
state_dict_[target_name] = lora_weight.cpu()
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
|
|
||||||
state_dict_model = model.state_dict()
|
|
||||||
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
|
|
||||||
if model_resource == "diffusers":
|
|
||||||
state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora)
|
|
||||||
elif model_resource == "civitai":
|
|
||||||
state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora)
|
|
||||||
if len(state_dict_lora) > 0:
|
|
||||||
print(f" {len(state_dict_lora)} tensors are updated.")
|
|
||||||
for name in state_dict_lora:
|
|
||||||
state_dict_model[name] += state_dict_lora[name].to(
|
|
||||||
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
|
|
||||||
model.load_state_dict(state_dict_model)
|
|
||||||
|
|
||||||
|
|
||||||
def match(self, model, state_dict_lora):
|
|
||||||
for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes):
|
|
||||||
if not isinstance(model, model_class):
|
|
||||||
continue
|
|
||||||
state_dict_model = model.state_dict()
|
|
||||||
for model_resource in ["diffusers", "civitai"]:
|
|
||||||
try:
|
|
||||||
state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
|
|
||||||
converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
|
|
||||||
else model.__class__.state_dict_converter().from_civitai
|
|
||||||
state_dict_lora_ = converter_fn(state_dict_lora_)
|
|
||||||
if len(state_dict_lora_) == 0:
|
|
||||||
continue
|
|
||||||
for name in state_dict_lora_:
|
|
||||||
if name not in state_dict_model:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
return lora_prefix, model_resource
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SDLoRAFromCivitai(LoRAFromCivitai):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.supported_model_classes = [SDUNet, SDTextEncoder]
|
|
||||||
self.lora_prefix = ["lora_unet_", "lora_te_"]
|
|
||||||
self.special_keys = {
|
|
||||||
"down.blocks": "down_blocks",
|
|
||||||
"up.blocks": "up_blocks",
|
|
||||||
"mid.block": "mid_block",
|
|
||||||
"proj.in": "proj_in",
|
|
||||||
"proj.out": "proj_out",
|
|
||||||
"transformer.blocks": "transformer_blocks",
|
|
||||||
"to.q": "to_q",
|
|
||||||
"to.k": "to_k",
|
|
||||||
"to.v": "to_v",
|
|
||||||
"to.out": "to_out",
|
|
||||||
"text.model": "text_model",
|
|
||||||
"self.attn.q.proj": "self_attn.q_proj",
|
|
||||||
"self.attn.k.proj": "self_attn.k_proj",
|
|
||||||
"self.attn.v.proj": "self_attn.v_proj",
|
|
||||||
"self.attn.out.proj": "self_attn.out_proj",
|
|
||||||
"input.blocks": "model.diffusion_model.input_blocks",
|
|
||||||
"middle.block": "model.diffusion_model.middle_block",
|
|
||||||
"output.blocks": "model.diffusion_model.output_blocks",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class SDXLLoRAFromCivitai(LoRAFromCivitai):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.supported_model_classes = [SDXLUNet, SDXLTextEncoder, SDXLTextEncoder2]
|
|
||||||
self.lora_prefix = ["lora_unet_", "lora_te1_", "lora_te2_"]
|
|
||||||
self.renamed_lora_prefix = {"lora_te2_": "2"}
|
|
||||||
self.special_keys = {
|
|
||||||
"down.blocks": "down_blocks",
|
|
||||||
"up.blocks": "up_blocks",
|
|
||||||
"mid.block": "mid_block",
|
|
||||||
"proj.in": "proj_in",
|
|
||||||
"proj.out": "proj_out",
|
|
||||||
"transformer.blocks": "transformer_blocks",
|
|
||||||
"to.q": "to_q",
|
|
||||||
"to.k": "to_k",
|
|
||||||
"to.v": "to_v",
|
|
||||||
"to.out": "to_out",
|
|
||||||
"text.model": "conditioner.embedders.0.transformer.text_model",
|
|
||||||
"self.attn.q.proj": "self_attn.q_proj",
|
|
||||||
"self.attn.k.proj": "self_attn.k_proj",
|
|
||||||
"self.attn.v.proj": "self_attn.v_proj",
|
|
||||||
"self.attn.out.proj": "self_attn.out_proj",
|
|
||||||
"input.blocks": "model.diffusion_model.input_blocks",
|
|
||||||
"middle.block": "model.diffusion_model.middle_block",
|
|
||||||
"output.blocks": "model.diffusion_model.output_blocks",
|
|
||||||
"2conditioner.embedders.0.transformer.text_model.encoder.layers": "text_model.encoder.layers"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class GeneralLoRAFromPeft:
|
|
||||||
def __init__(self):
|
|
||||||
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT]
|
|
||||||
|
|
||||||
|
|
||||||
def convert_state_dict(self, state_dict, alpha=1.0, device="cuda", torch_dtype=torch.float16):
|
|
||||||
state_dict_ = {}
|
|
||||||
for key in state_dict:
|
|
||||||
if ".lora_B." not in key:
|
|
||||||
continue
|
|
||||||
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
|
||||||
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
|
||||||
if len(weight_up.shape) == 4:
|
|
||||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
|
||||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
|
||||||
else:
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
|
||||||
keys = key.split(".")
|
|
||||||
keys.pop(keys.index("lora_B") + 1)
|
|
||||||
keys.pop(keys.index("lora_B"))
|
|
||||||
target_name = ".".join(keys)
|
|
||||||
state_dict_[target_name] = lora_weight.cpu()
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
|
|
||||||
state_dict_model = model.state_dict()
|
|
||||||
for name, param in state_dict_model.items():
|
|
||||||
torch_dtype = param.dtype
|
|
||||||
device = param.device
|
|
||||||
break
|
|
||||||
state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, device=device, torch_dtype=torch_dtype)
|
|
||||||
if len(state_dict_lora) > 0:
|
|
||||||
print(f" {len(state_dict_lora)} tensors are updated.")
|
|
||||||
for name in state_dict_lora:
|
|
||||||
state_dict_model[name] += state_dict_lora[name].to(
|
|
||||||
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
|
|
||||||
model.load_state_dict(state_dict_model)
|
|
||||||
|
|
||||||
|
|
||||||
def match(self, model, state_dict_lora):
|
|
||||||
for model_class in self.supported_model_classes:
|
|
||||||
if not isinstance(model, model_class):
|
|
||||||
continue
|
|
||||||
state_dict_model = model.state_dict()
|
|
||||||
try:
|
|
||||||
state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0)
|
|
||||||
if len(state_dict_lora_) == 0:
|
|
||||||
continue
|
|
||||||
for name in state_dict_lora_:
|
|
||||||
if name not in state_dict_model:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
return "", ""
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
return None
|
|
||||||
1408
diffsynth/models/ltx2_audio_vae.py
Normal file
1408
diffsynth/models/ltx2_audio_vae.py
Normal file
File diff suppressed because it is too large
Load Diff
371
diffsynth/models/ltx2_common.py
Normal file
371
diffsynth/models/ltx2_common.py
Normal file
@@ -0,0 +1,371 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import NamedTuple, Protocol, Tuple
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class VideoPixelShape(NamedTuple):
|
||||||
|
"""
|
||||||
|
Shape of the tensor representing the video pixel array. Assumes BGR channel format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch: int
|
||||||
|
frames: int
|
||||||
|
height: int
|
||||||
|
width: int
|
||||||
|
fps: float
|
||||||
|
|
||||||
|
|
||||||
|
class SpatioTemporalScaleFactors(NamedTuple):
|
||||||
|
"""
|
||||||
|
Describes the spatiotemporal downscaling between decoded video space and
|
||||||
|
the corresponding VAE latent grid.
|
||||||
|
"""
|
||||||
|
|
||||||
|
time: int
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default(cls) -> "SpatioTemporalScaleFactors":
|
||||||
|
return cls(time=8, width=32, height=32)
|
||||||
|
|
||||||
|
|
||||||
|
VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default()
|
||||||
|
|
||||||
|
|
||||||
|
class VideoLatentShape(NamedTuple):
|
||||||
|
"""
|
||||||
|
Shape of the tensor representing video in VAE latent space.
|
||||||
|
The latent representation is a 5D tensor with dimensions ordered as
|
||||||
|
(batch, channels, frames, height, width). Spatial and temporal dimensions
|
||||||
|
are downscaled relative to pixel space according to the VAE's scale factors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch: int
|
||||||
|
channels: int
|
||||||
|
frames: int
|
||||||
|
height: int
|
||||||
|
width: int
|
||||||
|
|
||||||
|
def to_torch_shape(self) -> torch.Size:
|
||||||
|
return torch.Size([self.batch, self.channels, self.frames, self.height, self.width])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_torch_shape(shape: torch.Size) -> "VideoLatentShape":
|
||||||
|
return VideoLatentShape(
|
||||||
|
batch=shape[0],
|
||||||
|
channels=shape[1],
|
||||||
|
frames=shape[2],
|
||||||
|
height=shape[3],
|
||||||
|
width=shape[4],
|
||||||
|
)
|
||||||
|
|
||||||
|
def mask_shape(self) -> "VideoLatentShape":
|
||||||
|
return self._replace(channels=1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pixel_shape(
|
||||||
|
shape: VideoPixelShape,
|
||||||
|
latent_channels: int = 128,
|
||||||
|
scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS,
|
||||||
|
) -> "VideoLatentShape":
|
||||||
|
frames = (shape.frames - 1) // scale_factors[0] + 1
|
||||||
|
height = shape.height // scale_factors[1]
|
||||||
|
width = shape.width // scale_factors[2]
|
||||||
|
|
||||||
|
return VideoLatentShape(
|
||||||
|
batch=shape.batch,
|
||||||
|
channels=latent_channels,
|
||||||
|
frames=frames,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
)
|
||||||
|
|
||||||
|
def upscale(self, scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS) -> "VideoLatentShape":
|
||||||
|
return self._replace(
|
||||||
|
channels=3,
|
||||||
|
frames=(self.frames - 1) * scale_factors.time + 1,
|
||||||
|
height=self.height * scale_factors.height,
|
||||||
|
width=self.width * scale_factors.width,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioLatentShape(NamedTuple):
|
||||||
|
"""
|
||||||
|
Shape of audio in VAE latent space: (batch, channels, frames, mel_bins).
|
||||||
|
mel_bins is the number of frequency bins from the mel-spectrogram encoding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch: int
|
||||||
|
channels: int
|
||||||
|
frames: int
|
||||||
|
mel_bins: int
|
||||||
|
|
||||||
|
def to_torch_shape(self) -> torch.Size:
|
||||||
|
return torch.Size([self.batch, self.channels, self.frames, self.mel_bins])
|
||||||
|
|
||||||
|
def mask_shape(self) -> "AudioLatentShape":
|
||||||
|
return self._replace(channels=1, mel_bins=1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_torch_shape(shape: torch.Size) -> "AudioLatentShape":
|
||||||
|
return AudioLatentShape(
|
||||||
|
batch=shape[0],
|
||||||
|
channels=shape[1],
|
||||||
|
frames=shape[2],
|
||||||
|
mel_bins=shape[3],
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_duration(
|
||||||
|
batch: int,
|
||||||
|
duration: float,
|
||||||
|
channels: int = 8,
|
||||||
|
mel_bins: int = 16,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
hop_length: int = 160,
|
||||||
|
audio_latent_downsample_factor: int = 4,
|
||||||
|
) -> "AudioLatentShape":
|
||||||
|
latents_per_second = float(sample_rate) / float(hop_length) / float(audio_latent_downsample_factor)
|
||||||
|
|
||||||
|
return AudioLatentShape(
|
||||||
|
batch=batch,
|
||||||
|
channels=channels,
|
||||||
|
frames=round(duration * latents_per_second),
|
||||||
|
mel_bins=mel_bins,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_video_pixel_shape(
|
||||||
|
shape: VideoPixelShape,
|
||||||
|
channels: int = 8,
|
||||||
|
mel_bins: int = 16,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
hop_length: int = 160,
|
||||||
|
audio_latent_downsample_factor: int = 4,
|
||||||
|
) -> "AudioLatentShape":
|
||||||
|
return AudioLatentShape.from_duration(
|
||||||
|
batch=shape.batch,
|
||||||
|
duration=float(shape.frames) / float(shape.fps),
|
||||||
|
channels=channels,
|
||||||
|
mel_bins=mel_bins,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
hop_length=hop_length,
|
||||||
|
audio_latent_downsample_factor=audio_latent_downsample_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LatentState:
|
||||||
|
"""
|
||||||
|
State of latents during the diffusion denoising process.
|
||||||
|
Attributes:
|
||||||
|
latent: The current noisy latent tensor being denoised.
|
||||||
|
denoise_mask: Mask encoding the denoising strength for each token (1 = full denoising, 0 = no denoising).
|
||||||
|
positions: Positional indices for each latent element, used for positional embeddings.
|
||||||
|
clean_latent: Initial state of the latent before denoising, may include conditioning latents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
latent: torch.Tensor
|
||||||
|
denoise_mask: torch.Tensor
|
||||||
|
positions: torch.Tensor
|
||||||
|
clean_latent: torch.Tensor
|
||||||
|
|
||||||
|
def clone(self) -> "LatentState":
|
||||||
|
return LatentState(
|
||||||
|
latent=self.latent.clone(),
|
||||||
|
denoise_mask=self.denoise_mask.clone(),
|
||||||
|
positions=self.positions.clone(),
|
||||||
|
clean_latent=self.clean_latent.clone(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NormType(Enum):
|
||||||
|
"""Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm)."""
|
||||||
|
|
||||||
|
GROUP = "group"
|
||||||
|
PIXEL = "pixel"
|
||||||
|
|
||||||
|
|
||||||
|
class PixelNorm(nn.Module):
|
||||||
|
"""
|
||||||
|
Per-pixel (per-location) RMS normalization layer.
|
||||||
|
For each element along the chosen dimension, this layer normalizes the tensor
|
||||||
|
by the root-mean-square of its values across that dimension:
|
||||||
|
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
dim: Dimension along which to compute the RMS (typically channels).
|
||||||
|
eps: Small constant added for numerical stability.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply RMS normalization along the configured dimension.
|
||||||
|
"""
|
||||||
|
# Compute mean of squared values along `dim`, keep dimensions for broadcasting.
|
||||||
|
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
|
||||||
|
# Normalize by the root-mean-square (RMS).
|
||||||
|
rms = torch.sqrt(mean_sq + self.eps)
|
||||||
|
return x / rms
|
||||||
|
|
||||||
|
|
||||||
|
def build_normalization_layer(
|
||||||
|
in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP
|
||||||
|
) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Create a normalization layer based on the normalization type.
|
||||||
|
Args:
|
||||||
|
in_channels: Number of input channels
|
||||||
|
num_groups: Number of groups for group normalization
|
||||||
|
normtype: Type of normalization: "group" or "pixel"
|
||||||
|
Returns:
|
||||||
|
A normalization layer
|
||||||
|
"""
|
||||||
|
if normtype == NormType.GROUP:
|
||||||
|
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
if normtype == NormType.PIXEL:
|
||||||
|
return PixelNorm(dim=1, eps=1e-6)
|
||||||
|
raise ValueError(f"Invalid normalization type: {normtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor:
|
||||||
|
"""Root-mean-square (RMS) normalize `x` over its last dimension.
|
||||||
|
Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized
|
||||||
|
shape and forwards `weight` and `eps`.
|
||||||
|
"""
|
||||||
|
return torch.nn.functional.rms_norm(x, (x.shape[-1],), weight=weight, eps=eps)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Modality:
|
||||||
|
"""
|
||||||
|
Input data for a single modality (video or audio) in the transformer.
|
||||||
|
Bundles the latent tokens, timestep embeddings, positional information,
|
||||||
|
and text conditioning context for processing by the diffusion transformer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
latent: (
|
||||||
|
torch.Tensor
|
||||||
|
) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
|
||||||
|
timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps
|
||||||
|
positions: (
|
||||||
|
torch.Tensor
|
||||||
|
) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens
|
||||||
|
context: torch.Tensor
|
||||||
|
enabled: bool = True
|
||||||
|
context_mask: torch.Tensor | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def to_denoised(
|
||||||
|
sample: torch.Tensor,
|
||||||
|
velocity: torch.Tensor,
|
||||||
|
sigma: float | torch.Tensor,
|
||||||
|
calc_dtype: torch.dtype = torch.float32,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Convert the sample and its denoising velocity to denoised sample.
|
||||||
|
Returns:
|
||||||
|
Denoised sample
|
||||||
|
"""
|
||||||
|
if isinstance(sigma, torch.Tensor):
|
||||||
|
sigma = sigma.to(calc_dtype)
|
||||||
|
return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Patchifier(Protocol):
|
||||||
|
"""
|
||||||
|
Protocol for patchifiers that convert latent tensors into patches and assemble them back.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def patchify(
|
||||||
|
self,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
Convert latent tensors into flattened patch tokens.
|
||||||
|
Args:
|
||||||
|
latents: Latent tensor to patchify.
|
||||||
|
Returns:
|
||||||
|
Flattened patch tokens tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def unpatchify(
|
||||||
|
self,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
output_shape: AudioLatentShape | VideoLatentShape,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Converts latent tensors between spatio-temporal formats and flattened sequence representations.
|
||||||
|
Args:
|
||||||
|
latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`.
|
||||||
|
output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or
|
||||||
|
VideoLatentShape.
|
||||||
|
Returns:
|
||||||
|
Dense latent tensor restored from the flattened representation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def patch_size(self) -> Tuple[int, int, int]:
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
Returns the patch size as a tuple of (temporal, height, width) dimensions
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_patch_grid_bounds(
|
||||||
|
self,
|
||||||
|
output_shape: AudioLatentShape | VideoLatentShape,
|
||||||
|
device: torch.device | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
Compute metadata describing where each latent patch resides within the
|
||||||
|
grid specified by `output_shape`.
|
||||||
|
Args:
|
||||||
|
output_shape: Target grid layout for the patches.
|
||||||
|
device: Target device for the returned tensor.
|
||||||
|
Returns:
|
||||||
|
Tensor containing patch coordinate metadata such as spatial or temporal intervals.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_pixel_coords(
|
||||||
|
latent_coords: torch.Tensor,
|
||||||
|
scale_factors: SpatioTemporalScaleFactors,
|
||||||
|
causal_fix: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
|
||||||
|
each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
|
||||||
|
Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
|
||||||
|
Args:
|
||||||
|
latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
|
||||||
|
scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
|
||||||
|
per axis.
|
||||||
|
causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
|
||||||
|
that treat frame zero differently still yield non-negative timestamps.
|
||||||
|
"""
|
||||||
|
# Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
|
||||||
|
broadcast_shape = [1] * latent_coords.ndim
|
||||||
|
broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
|
||||||
|
scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
|
||||||
|
|
||||||
|
# Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
|
||||||
|
pixel_coords = latent_coords * scale_tensor
|
||||||
|
|
||||||
|
if causal_fix:
|
||||||
|
# VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
|
||||||
|
# Shift and clamp to keep the first-frame timestamps causal and non-negative.
|
||||||
|
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
|
||||||
|
|
||||||
|
return pixel_coords
|
||||||
1451
diffsynth/models/ltx2_dit.py
Normal file
1451
diffsynth/models/ltx2_dit.py
Normal file
File diff suppressed because it is too large
Load Diff
366
diffsynth/models/ltx2_text_encoder.py
Normal file
366
diffsynth/models/ltx2_text_encoder.py
Normal file
@@ -0,0 +1,366 @@
|
|||||||
|
import torch
|
||||||
|
from transformers import Gemma3ForConditionalGeneration, Gemma3Config, AutoTokenizer
|
||||||
|
from .ltx2_dit import (LTXRopeType, generate_freq_grid_np, generate_freq_grid_pytorch, precompute_freqs_cis, Attention,
|
||||||
|
FeedForward)
|
||||||
|
from .ltx2_common import rms_norm
|
||||||
|
|
||||||
|
|
||||||
|
class LTX2TextEncoder(Gemma3ForConditionalGeneration):
|
||||||
|
def __init__(self):
|
||||||
|
config = Gemma3Config(
|
||||||
|
**{
|
||||||
|
"architectures": ["Gemma3ForConditionalGeneration"],
|
||||||
|
"boi_token_index": 255999,
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"eoi_token_index": 256000,
|
||||||
|
"eos_token_id": [1, 106],
|
||||||
|
"image_token_index": 262144,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"mm_tokens_per_image": 256,
|
||||||
|
"model_type": "gemma3",
|
||||||
|
"text_config": {
|
||||||
|
"_sliding_window_pattern": 6,
|
||||||
|
"attention_bias": False,
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"attn_logit_softcapping": None,
|
||||||
|
"cache_implementation": "hybrid",
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"final_logit_softcapping": None,
|
||||||
|
"head_dim": 256,
|
||||||
|
"hidden_activation": "gelu_pytorch_tanh",
|
||||||
|
"hidden_size": 3840,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 15360,
|
||||||
|
"layer_types": [
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
|
||||||
|
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
|
||||||
|
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
|
||||||
|
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
|
||||||
|
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention"
|
||||||
|
],
|
||||||
|
"max_position_embeddings": 131072,
|
||||||
|
"model_type": "gemma3_text",
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_hidden_layers": 48,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"query_pre_attn_scalar": 256,
|
||||||
|
"rms_norm_eps": 1e-06,
|
||||||
|
"rope_local_base_freq": 10000,
|
||||||
|
"rope_scaling": {
|
||||||
|
"factor": 8.0,
|
||||||
|
"rope_type": "linear"
|
||||||
|
},
|
||||||
|
"rope_theta": 1000000,
|
||||||
|
"sliding_window": 1024,
|
||||||
|
"sliding_window_pattern": 6,
|
||||||
|
"use_bidirectional_attention": False,
|
||||||
|
"use_cache": True,
|
||||||
|
"vocab_size": 262208
|
||||||
|
},
|
||||||
|
"transformers_version": "4.57.3",
|
||||||
|
"vision_config": {
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"hidden_act": "gelu_pytorch_tanh",
|
||||||
|
"hidden_size": 1152,
|
||||||
|
"image_size": 896,
|
||||||
|
"intermediate_size": 4304,
|
||||||
|
"layer_norm_eps": 1e-06,
|
||||||
|
"model_type": "siglip_vision_model",
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_channels": 3,
|
||||||
|
"num_hidden_layers": 27,
|
||||||
|
"patch_size": 14,
|
||||||
|
"vision_use_head": False
|
||||||
|
}
|
||||||
|
})
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVGemmaTokenizer:
|
||||||
|
"""
|
||||||
|
Tokenizer wrapper for Gemma models compatible with LTXV processes.
|
||||||
|
This class wraps HuggingFace's `AutoTokenizer` for use with Gemma text encoders,
|
||||||
|
ensuring correct settings and output formatting for downstream consumption.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer_path: str, max_length: int = 1024):
|
||||||
|
"""
|
||||||
|
Initialize the tokenizer.
|
||||||
|
Args:
|
||||||
|
tokenizer_path (str): Path to the pretrained tokenizer files or model directory.
|
||||||
|
max_length (int, optional): Max sequence length for encoding. Defaults to 256.
|
||||||
|
"""
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
tokenizer_path, local_files_only=True, model_max_length=max_length
|
||||||
|
)
|
||||||
|
# Gemma expects left padding for chat-style prompts; for plain text it doesn't matter much.
|
||||||
|
self.tokenizer.padding_side = "left"
|
||||||
|
if self.tokenizer.pad_token is None:
|
||||||
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
|
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> dict[str, list[tuple[int, int]]]:
|
||||||
|
"""
|
||||||
|
Tokenize the given text and return token IDs and attention weights.
|
||||||
|
Args:
|
||||||
|
text (str): The input string to tokenize.
|
||||||
|
return_word_ids (bool, optional): If True, includes the token's position (index) in the output tuples.
|
||||||
|
If False (default), omits the indices.
|
||||||
|
Returns:
|
||||||
|
dict[str, list[tuple[int, int]]] OR dict[str, list[tuple[int, int, int]]]:
|
||||||
|
A dictionary with a "gemma" key mapping to:
|
||||||
|
- a list of (token_id, attention_mask) tuples if return_word_ids is False;
|
||||||
|
- a list of (token_id, attention_mask, index) tuples if return_word_ids is True.
|
||||||
|
Example:
|
||||||
|
>>> tokenizer = LTXVGemmaTokenizer("path/to/tokenizer", max_length=8)
|
||||||
|
>>> tokenizer.tokenize_with_weights("hello world")
|
||||||
|
{'gemma': [(1234, 1), (5678, 1), (2, 0), ...]}
|
||||||
|
"""
|
||||||
|
text = text.strip()
|
||||||
|
encoded = self.tokenizer(
|
||||||
|
text,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=self.max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
input_ids = encoded.input_ids
|
||||||
|
attention_mask = encoded.attention_mask
|
||||||
|
tuples = [
|
||||||
|
(token_id, attn, i) for i, (token_id, attn) in enumerate(zip(input_ids[0], attention_mask[0], strict=True))
|
||||||
|
]
|
||||||
|
out = {"gemma": tuples}
|
||||||
|
|
||||||
|
if not return_word_ids:
|
||||||
|
# Return only (token_id, attention_mask) pairs, omitting token position
|
||||||
|
out = {k: [(t, w) for t, w, _ in v] for k, v in out.items()}
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class GemmaFeaturesExtractorProjLinear(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Feature extractor module for Gemma models.
|
||||||
|
This module applies a single linear projection to the input tensor.
|
||||||
|
It expects a flattened feature tensor of shape (batch_size, 3840*49).
|
||||||
|
The linear layer maps this to a (batch_size, 3840) embedding.
|
||||||
|
Attributes:
|
||||||
|
aggregate_embed (torch.nn.Linear): Linear projection layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the GemmaFeaturesExtractorProjLinear module.
|
||||||
|
The input dimension is expected to be 3840 * 49, and the output is 3840.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.aggregate_embed = torch.nn.Linear(3840 * 49, 3840, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass for the feature extractor.
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor of shape (batch_size, 3840 * 49).
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Output tensor of shape (batch_size, 3840).
|
||||||
|
"""
|
||||||
|
return self.aggregate_embed(x)
|
||||||
|
|
||||||
|
|
||||||
|
class _BasicTransformerBlock1D(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
heads: int,
|
||||||
|
dim_head: int,
|
||||||
|
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attn1 = Attention(
|
||||||
|
query_dim=dim,
|
||||||
|
heads=heads,
|
||||||
|
dim_head=dim_head,
|
||||||
|
rope_type=rope_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ff = FeedForward(
|
||||||
|
dim,
|
||||||
|
dim_out=dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor | None = None,
|
||||||
|
pe: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||||
|
|
||||||
|
# 1. Normalization Before Self-Attention
|
||||||
|
norm_hidden_states = rms_norm(hidden_states)
|
||||||
|
|
||||||
|
norm_hidden_states = norm_hidden_states.squeeze(1)
|
||||||
|
|
||||||
|
# 2. Self-Attention
|
||||||
|
attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe)
|
||||||
|
|
||||||
|
hidden_states = attn_output + hidden_states
|
||||||
|
if hidden_states.ndim == 4:
|
||||||
|
hidden_states = hidden_states.squeeze(1)
|
||||||
|
|
||||||
|
# 3. Normalization before Feed-Forward
|
||||||
|
norm_hidden_states = rms_norm(hidden_states)
|
||||||
|
|
||||||
|
# 4. Feed-forward
|
||||||
|
ff_output = self.ff(norm_hidden_states)
|
||||||
|
|
||||||
|
hidden_states = ff_output + hidden_states
|
||||||
|
if hidden_states.ndim == 4:
|
||||||
|
hidden_states = hidden_states.squeeze(1)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Embeddings1DConnector(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or
|
||||||
|
other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can
|
||||||
|
substitute padded positions with learnable registers. The module is highly configurable for head size, number of
|
||||||
|
layers, and register usage.
|
||||||
|
Args:
|
||||||
|
attention_head_dim (int): Dimension of each attention head (default=128).
|
||||||
|
num_attention_heads (int): Number of attention heads (default=30).
|
||||||
|
num_layers (int): Number of transformer layers (default=2).
|
||||||
|
positional_embedding_theta (float): Scaling factor for position embedding (default=10000.0).
|
||||||
|
positional_embedding_max_pos (list[int] | None): Max positions for positional embeddings (default=[1]).
|
||||||
|
causal_temporal_positioning (bool): If True, uses causal attention (default=False).
|
||||||
|
num_learnable_registers (int | None): Number of learnable registers to replace padded tokens. If None, disables
|
||||||
|
register replacement. (default=128)
|
||||||
|
rope_type (LTXRopeType): The RoPE variant to use (default=DEFAULT_ROPE_TYPE).
|
||||||
|
double_precision_rope (bool): Use double precision rope calculation (default=False).
|
||||||
|
"""
|
||||||
|
|
||||||
|
_supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
attention_head_dim: int = 128,
|
||||||
|
num_attention_heads: int = 30,
|
||||||
|
num_layers: int = 2,
|
||||||
|
positional_embedding_theta: float = 10000.0,
|
||||||
|
positional_embedding_max_pos: list[int] | None = [4096],
|
||||||
|
causal_temporal_positioning: bool = False,
|
||||||
|
num_learnable_registers: int | None = 128,
|
||||||
|
rope_type: LTXRopeType = LTXRopeType.SPLIT,
|
||||||
|
double_precision_rope: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
self.causal_temporal_positioning = causal_temporal_positioning
|
||||||
|
self.positional_embedding_theta = positional_embedding_theta
|
||||||
|
self.positional_embedding_max_pos = (
|
||||||
|
positional_embedding_max_pos if positional_embedding_max_pos is not None else [1]
|
||||||
|
)
|
||||||
|
self.rope_type = rope_type
|
||||||
|
self.double_precision_rope = double_precision_rope
|
||||||
|
self.transformer_1d_blocks = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
_BasicTransformerBlock1D(
|
||||||
|
dim=self.inner_dim,
|
||||||
|
heads=num_attention_heads,
|
||||||
|
dim_head=attention_head_dim,
|
||||||
|
rope_type=rope_type,
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_learnable_registers = num_learnable_registers
|
||||||
|
if self.num_learnable_registers:
|
||||||
|
self.learnable_registers = torch.nn.Parameter(
|
||||||
|
torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
def _replace_padded_with_learnable_registers(
|
||||||
|
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert hidden_states.shape[1] % self.num_learnable_registers == 0, (
|
||||||
|
f"Hidden states sequence length {hidden_states.shape[1]} must be divisible by num_learnable_registers "
|
||||||
|
f"{self.num_learnable_registers}."
|
||||||
|
)
|
||||||
|
|
||||||
|
num_registers_duplications = hidden_states.shape[1] // self.num_learnable_registers
|
||||||
|
learnable_registers = torch.tile(self.learnable_registers, (num_registers_duplications, 1))
|
||||||
|
attention_mask_binary = (attention_mask.squeeze(1).squeeze(1).unsqueeze(-1) >= -9000.0).int()
|
||||||
|
|
||||||
|
non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :]
|
||||||
|
non_zero_nums = non_zero_hidden_states.shape[1]
|
||||||
|
pad_length = hidden_states.shape[1] - non_zero_nums
|
||||||
|
adjusted_hidden_states = torch.nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0)
|
||||||
|
flipped_mask = torch.flip(attention_mask_binary, dims=[1])
|
||||||
|
hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers
|
||||||
|
|
||||||
|
attention_mask = torch.full_like(
|
||||||
|
attention_mask,
|
||||||
|
0.0,
|
||||||
|
dtype=attention_mask.dtype,
|
||||||
|
device=attention_mask.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
return hidden_states, attention_mask
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor | None = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Forward pass of Embeddings1DConnector.
|
||||||
|
Args:
|
||||||
|
hidden_states (torch.Tensor): Input tensor of embeddings (shape [batch, seq_len, feature_dim]).
|
||||||
|
attention_mask (torch.Tensor|None): Optional mask for valid tokens (shape compatible with hidden_states).
|
||||||
|
Returns:
|
||||||
|
tuple[torch.Tensor, torch.Tensor]: Processed features and the corresponding (possibly modified) mask.
|
||||||
|
"""
|
||||||
|
if self.num_learnable_registers:
|
||||||
|
hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask)
|
||||||
|
|
||||||
|
indices_grid = torch.arange(hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device)
|
||||||
|
indices_grid = indices_grid[None, None, :]
|
||||||
|
freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch
|
||||||
|
freqs_cis = precompute_freqs_cis(
|
||||||
|
indices_grid=indices_grid,
|
||||||
|
dim=self.inner_dim,
|
||||||
|
out_dtype=hidden_states.dtype,
|
||||||
|
theta=self.positional_embedding_theta,
|
||||||
|
max_pos=self.positional_embedding_max_pos,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
rope_type=self.rope_type,
|
||||||
|
freq_grid_generator=freq_grid_generator,
|
||||||
|
)
|
||||||
|
|
||||||
|
for block in self.transformer_1d_blocks:
|
||||||
|
hidden_states = block(hidden_states, attention_mask=attention_mask, pe=freqs_cis)
|
||||||
|
|
||||||
|
hidden_states = rms_norm(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
class LTX2TextEncoderPostModules(torch.nn.Module):
|
||||||
|
def __init__(self,):
|
||||||
|
super().__init__()
|
||||||
|
self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear()
|
||||||
|
self.embeddings_connector = Embeddings1DConnector()
|
||||||
|
self.audio_embeddings_connector = Embeddings1DConnector()
|
||||||
313
diffsynth/models/ltx2_upsampler.py
Normal file
313
diffsynth/models/ltx2_upsampler.py
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
import math
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from .ltx2_video_vae import LTX2VideoEncoder
|
||||||
|
|
||||||
|
class PixelShuffleND(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
N-dimensional pixel shuffle operation for upsampling tensors.
|
||||||
|
Args:
|
||||||
|
dims (int): Number of dimensions to apply pixel shuffle to.
|
||||||
|
- 1: Temporal (e.g., frames)
|
||||||
|
- 2: Spatial (e.g., height and width)
|
||||||
|
- 3: Spatiotemporal (e.g., depth, height, width)
|
||||||
|
upscale_factors (tuple[int, int, int], optional): Upscaling factors for each dimension.
|
||||||
|
For dims=1, only the first value is used.
|
||||||
|
For dims=2, the first two values are used.
|
||||||
|
For dims=3, all three values are used.
|
||||||
|
The input tensor is rearranged so that the channel dimension is split into
|
||||||
|
smaller channels and upscaling factors, and the upscaling factors are moved
|
||||||
|
into the corresponding spatial/temporal dimensions.
|
||||||
|
Note:
|
||||||
|
This operation is equivalent to the patchifier operation in for the models. Consider
|
||||||
|
using this class instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dims: int, upscale_factors: tuple[int, int, int] = (2, 2, 2)):
|
||||||
|
super().__init__()
|
||||||
|
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
|
||||||
|
self.dims = dims
|
||||||
|
self.upscale_factors = upscale_factors
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.dims == 3:
|
||||||
|
return rearrange(
|
||||||
|
x,
|
||||||
|
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||||
|
p1=self.upscale_factors[0],
|
||||||
|
p2=self.upscale_factors[1],
|
||||||
|
p3=self.upscale_factors[2],
|
||||||
|
)
|
||||||
|
elif self.dims == 2:
|
||||||
|
return rearrange(
|
||||||
|
x,
|
||||||
|
"b (c p1 p2) h w -> b c (h p1) (w p2)",
|
||||||
|
p1=self.upscale_factors[0],
|
||||||
|
p2=self.upscale_factors[1],
|
||||||
|
)
|
||||||
|
elif self.dims == 1:
|
||||||
|
return rearrange(
|
||||||
|
x,
|
||||||
|
"b (c p1) f h w -> b c (f p1) h w",
|
||||||
|
p1=self.upscale_factors[0],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported dims: {self.dims}")
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Residual block with two convolutional layers, group normalization, and SiLU activation.
|
||||||
|
Args:
|
||||||
|
channels (int): Number of input and output channels.
|
||||||
|
mid_channels (Optional[int]): Number of channels in the intermediate convolution layer. Defaults to `channels`
|
||||||
|
if not specified.
|
||||||
|
dims (int): Dimensionality of the convolution (2 for Conv2d, 3 for Conv3d). Defaults to 3.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3):
|
||||||
|
super().__init__()
|
||||||
|
if mid_channels is None:
|
||||||
|
mid_channels = channels
|
||||||
|
|
||||||
|
conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
|
||||||
|
|
||||||
|
self.conv1 = conv(channels, mid_channels, kernel_size=3, padding=1)
|
||||||
|
self.norm1 = torch.nn.GroupNorm(32, mid_channels)
|
||||||
|
self.conv2 = conv(mid_channels, channels, kernel_size=3, padding=1)
|
||||||
|
self.norm2 = torch.nn.GroupNorm(32, channels)
|
||||||
|
self.activation = torch.nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
residual = x
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = self.activation(x + residual)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BlurDownsample(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel.
|
||||||
|
Applies only on H,W. Works for dims=2 or dims=3 (per-frame).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None:
|
||||||
|
super().__init__()
|
||||||
|
assert dims in (2, 3)
|
||||||
|
assert isinstance(stride, int)
|
||||||
|
assert stride >= 1
|
||||||
|
assert kernel_size >= 3
|
||||||
|
assert kernel_size % 2 == 1
|
||||||
|
self.dims = dims
|
||||||
|
self.stride = stride
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
|
||||||
|
# 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from
|
||||||
|
# the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and
|
||||||
|
# provides a smooth approximation of a Gaussian filter (often called a "binomial filter").
|
||||||
|
# The 2D kernel is constructed as the outer product and normalized.
|
||||||
|
k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)])
|
||||||
|
k2d = k[:, None] @ k[None, :]
|
||||||
|
k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size)
|
||||||
|
self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.stride == 1:
|
||||||
|
return x
|
||||||
|
|
||||||
|
if self.dims == 2:
|
||||||
|
return self._apply_2d(x)
|
||||||
|
else:
|
||||||
|
# dims == 3: apply per-frame on H,W
|
||||||
|
b, _, f, _, _ = x.shape
|
||||||
|
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||||
|
x = self._apply_2d(x)
|
||||||
|
h2, w2 = x.shape[-2:]
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _apply_2d(self, x2d: torch.Tensor) -> torch.Tensor:
|
||||||
|
c = x2d.shape[1]
|
||||||
|
weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
|
||||||
|
x2d = F.conv2d(x2d, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
|
||||||
|
return x2d
|
||||||
|
|
||||||
|
|
||||||
|
def _rational_for_scale(scale: float) -> Tuple[int, int]:
|
||||||
|
mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)}
|
||||||
|
if float(scale) not in mapping:
|
||||||
|
raise ValueError(f"Unsupported scale {scale}. Choose from {list(mapping.keys())}")
|
||||||
|
return mapping[float(scale)]
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialRationalResampler(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
|
||||||
|
downsample by 'den' using fixed blur + stride. Operates on H,W only.
|
||||||
|
For dims==3, work per-frame for spatial scaling (temporal axis untouched).
|
||||||
|
Args:
|
||||||
|
mid_channels (`int`): Number of intermediate channels for the convolution layer
|
||||||
|
scale (`float`): Spatial scaling factor. Supported values are:
|
||||||
|
- 0.75: Downsample by 3/4 (reduce spatial size)
|
||||||
|
- 1.5: Upsample by 3/2 (increase spatial size)
|
||||||
|
- 2.0: Upsample by 2x (double spatial size)
|
||||||
|
- 4.0: Upsample by 4x (quadruple spatial size)
|
||||||
|
Any other value will raise a ValueError.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mid_channels: int, scale: float):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = float(scale)
|
||||||
|
self.num, self.den = _rational_for_scale(self.scale)
|
||||||
|
self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1)
|
||||||
|
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
|
||||||
|
self.blur_down = BlurDownsample(dims=2, stride=self.den)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
b, _, f, _, _ = x.shape
|
||||||
|
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.pixel_shuffle(x)
|
||||||
|
x = self.blur_down(x)
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LTX2LatentUpsampler(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Model to upsample VAE latents spatially and/or temporally.
|
||||||
|
Args:
|
||||||
|
in_channels (`int`): Number of channels in the input latent
|
||||||
|
mid_channels (`int`): Number of channels in the middle layers
|
||||||
|
num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
|
||||||
|
dims (`int`): Number of dimensions for convolutions (2 or 3)
|
||||||
|
spatial_upsample (`bool`): Whether to spatially upsample the latent
|
||||||
|
temporal_upsample (`bool`): Whether to temporally upsample the latent
|
||||||
|
spatial_scale (`float`): Scale factor for spatial upsampling
|
||||||
|
rational_resampler (`bool`): Whether to use a rational resampler for spatial upsampling
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 128,
|
||||||
|
mid_channels: int = 1024,
|
||||||
|
num_blocks_per_stage: int = 4,
|
||||||
|
dims: int = 3,
|
||||||
|
spatial_upsample: bool = True,
|
||||||
|
temporal_upsample: bool = False,
|
||||||
|
spatial_scale: float = 2.0,
|
||||||
|
rational_resampler: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.mid_channels = mid_channels
|
||||||
|
self.num_blocks_per_stage = num_blocks_per_stage
|
||||||
|
self.dims = dims
|
||||||
|
self.spatial_upsample = spatial_upsample
|
||||||
|
self.temporal_upsample = temporal_upsample
|
||||||
|
self.spatial_scale = float(spatial_scale)
|
||||||
|
self.rational_resampler = rational_resampler
|
||||||
|
|
||||||
|
conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
|
||||||
|
|
||||||
|
self.initial_conv = conv(in_channels, mid_channels, kernel_size=3, padding=1)
|
||||||
|
self.initial_norm = torch.nn.GroupNorm(32, mid_channels)
|
||||||
|
self.initial_activation = torch.nn.SiLU()
|
||||||
|
|
||||||
|
self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])
|
||||||
|
|
||||||
|
if spatial_upsample and temporal_upsample:
|
||||||
|
self.upsampler = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
|
||||||
|
PixelShuffleND(3),
|
||||||
|
)
|
||||||
|
elif spatial_upsample:
|
||||||
|
if rational_resampler:
|
||||||
|
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=self.spatial_scale)
|
||||||
|
else:
|
||||||
|
self.upsampler = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
|
||||||
|
PixelShuffleND(2),
|
||||||
|
)
|
||||||
|
elif temporal_upsample:
|
||||||
|
self.upsampler = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
|
||||||
|
PixelShuffleND(1),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Either spatial_upsample or temporal_upsample must be True")
|
||||||
|
|
||||||
|
self.post_upsample_res_blocks = torch.nn.ModuleList(
|
||||||
|
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_conv = conv(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, latent: torch.Tensor) -> torch.Tensor:
|
||||||
|
b, _, f, _, _ = latent.shape
|
||||||
|
|
||||||
|
if self.dims == 2:
|
||||||
|
x = rearrange(latent, "b c f h w -> (b f) c h w")
|
||||||
|
x = self.initial_conv(x)
|
||||||
|
x = self.initial_norm(x)
|
||||||
|
x = self.initial_activation(x)
|
||||||
|
|
||||||
|
for block in self.res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
x = self.upsampler(x)
|
||||||
|
|
||||||
|
for block in self.post_upsample_res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
x = self.final_conv(x)
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||||
|
else:
|
||||||
|
x = self.initial_conv(latent)
|
||||||
|
x = self.initial_norm(x)
|
||||||
|
x = self.initial_activation(x)
|
||||||
|
|
||||||
|
for block in self.res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
if self.temporal_upsample:
|
||||||
|
x = self.upsampler(x)
|
||||||
|
# remove the first frame after upsampling.
|
||||||
|
# This is done because the first frame encodes one pixel frame.
|
||||||
|
x = x[:, :, 1:, :, :]
|
||||||
|
elif isinstance(self.upsampler, SpatialRationalResampler):
|
||||||
|
x = self.upsampler(x)
|
||||||
|
else:
|
||||||
|
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||||
|
x = self.upsampler(x)
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||||
|
|
||||||
|
for block in self.post_upsample_res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
x = self.final_conv(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def upsample_video(latent: torch.Tensor, video_encoder: LTX2VideoEncoder, upsampler: "LTX2LatentUpsampler") -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply upsampling to the latent representation using the provided upsampler,
|
||||||
|
with normalization and un-normalization based on the video encoder's per-channel statistics.
|
||||||
|
Args:
|
||||||
|
latent: Input latent tensor of shape [B, C, F, H, W].
|
||||||
|
video_encoder: VideoEncoder with per_channel_statistics for normalization.
|
||||||
|
upsampler: LTX2LatentUpsampler module to perform upsampling.
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Upsampled and re-normalized latent tensor.
|
||||||
|
"""
|
||||||
|
latent = video_encoder.per_channel_statistics.un_normalize(latent)
|
||||||
|
latent = upsampler(latent)
|
||||||
|
latent = video_encoder.per_channel_statistics.normalize(latent)
|
||||||
|
return latent
|
||||||
2317
diffsynth/models/ltx2_video_vae.py
Normal file
2317
diffsynth/models/ltx2_video_vae.py
Normal file
File diff suppressed because it is too large
Load Diff
112
diffsynth/models/model_loader.py
Normal file
112
diffsynth/models/model_loader.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
from ..core.loader import load_model, hash_model_file
|
||||||
|
from ..core.vram import AutoWrappedModule
|
||||||
|
from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS
|
||||||
|
import importlib, json, torch
|
||||||
|
|
||||||
|
|
||||||
|
class ModelPool:
|
||||||
|
def __init__(self):
|
||||||
|
self.model = []
|
||||||
|
self.model_name = []
|
||||||
|
self.model_path = []
|
||||||
|
|
||||||
|
def import_model_class(self, model_class):
|
||||||
|
split = model_class.rfind(".")
|
||||||
|
model_resource, model_class = model_class[:split], model_class[split+1:]
|
||||||
|
model_class = importlib.import_module(model_resource).__getattribute__(model_class)
|
||||||
|
return model_class
|
||||||
|
|
||||||
|
def need_to_enable_vram_management(self, vram_config):
|
||||||
|
return vram_config["offload_dtype"] is not None and vram_config["offload_device"] is not None
|
||||||
|
|
||||||
|
def fetch_module_map(self, model_class, vram_config):
|
||||||
|
if self.need_to_enable_vram_management(vram_config):
|
||||||
|
if model_class in VRAM_MANAGEMENT_MODULE_MAPS:
|
||||||
|
module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in VRAM_MANAGEMENT_MODULE_MAPS[model_class].items()}
|
||||||
|
else:
|
||||||
|
module_map = {self.import_model_class(model_class): AutoWrappedModule}
|
||||||
|
else:
|
||||||
|
module_map = None
|
||||||
|
return module_map
|
||||||
|
|
||||||
|
def load_model_file(self, config, path, vram_config, vram_limit=None, state_dict=None):
|
||||||
|
model_class = self.import_model_class(config["model_class"])
|
||||||
|
model_config = config.get("extra_kwargs", {})
|
||||||
|
if "state_dict_converter" in config:
|
||||||
|
state_dict_converter = self.import_model_class(config["state_dict_converter"])
|
||||||
|
else:
|
||||||
|
state_dict_converter = None
|
||||||
|
module_map = self.fetch_module_map(config["model_class"], vram_config)
|
||||||
|
model = load_model(
|
||||||
|
model_class, path, model_config,
|
||||||
|
vram_config["computation_dtype"], vram_config["computation_device"],
|
||||||
|
state_dict_converter,
|
||||||
|
use_disk_map=True,
|
||||||
|
vram_config=vram_config, module_map=module_map, vram_limit=vram_limit,
|
||||||
|
state_dict=state_dict,
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def default_vram_config(self):
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": None,
|
||||||
|
"offload_device": None,
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cpu",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cpu",
|
||||||
|
}
|
||||||
|
return vram_config
|
||||||
|
|
||||||
|
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False, state_dict=None):
|
||||||
|
print(f"Loading models from: {json.dumps(path, indent=4)}")
|
||||||
|
if vram_config is None:
|
||||||
|
vram_config = self.default_vram_config()
|
||||||
|
model_hash = hash_model_file(path)
|
||||||
|
loaded = False
|
||||||
|
for config in MODEL_CONFIGS:
|
||||||
|
if config["model_hash"] == model_hash:
|
||||||
|
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit, state_dict=state_dict)
|
||||||
|
if clear_parameters: self.clear_parameters(model)
|
||||||
|
self.model.append(model)
|
||||||
|
model_name = config["model_name"]
|
||||||
|
self.model_name.append(model_name)
|
||||||
|
self.model_path.append(path)
|
||||||
|
model_info = {"model_name": model_name, "model_class": config["model_class"], "extra_kwargs": config.get("extra_kwargs")}
|
||||||
|
print(f"Loaded model: {json.dumps(model_info, indent=4)}")
|
||||||
|
loaded = True
|
||||||
|
if not loaded:
|
||||||
|
raise ValueError(f"Cannot detect the model type. File: {path}. Model hash: {model_hash}")
|
||||||
|
|
||||||
|
def fetch_model(self, model_name, index=None):
|
||||||
|
fetched_models = []
|
||||||
|
fetched_model_paths = []
|
||||||
|
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
|
||||||
|
if model_name == model_name_:
|
||||||
|
fetched_models.append(model)
|
||||||
|
fetched_model_paths.append(model_path)
|
||||||
|
if len(fetched_models) == 0:
|
||||||
|
print(f"No {model_name} models available. This is not an error.")
|
||||||
|
model = None
|
||||||
|
elif len(fetched_models) == 1:
|
||||||
|
print(f"Using {model_name} from {json.dumps(fetched_model_paths[0], indent=4)}.")
|
||||||
|
model = fetched_models[0]
|
||||||
|
else:
|
||||||
|
if index is None:
|
||||||
|
model = fetched_models[0]
|
||||||
|
print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths[0], indent=4)}.")
|
||||||
|
elif isinstance(index, int):
|
||||||
|
model = fetched_models[:index]
|
||||||
|
print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths[:index], indent=4)}.")
|
||||||
|
else:
|
||||||
|
model = fetched_models
|
||||||
|
print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths, indent=4)}.")
|
||||||
|
return model
|
||||||
|
|
||||||
|
def clear_parameters(self, model: torch.nn.Module):
|
||||||
|
for name, module in model.named_children():
|
||||||
|
self.clear_parameters(module)
|
||||||
|
for name, param in model.named_parameters(recurse=False):
|
||||||
|
setattr(model, name, None)
|
||||||
@@ -1,536 +0,0 @@
|
|||||||
import os, torch, hashlib, json, importlib
|
|
||||||
from safetensors import safe_open
|
|
||||||
from torch import Tensor
|
|
||||||
from typing_extensions import Literal, TypeAlias
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from .downloader import download_models, Preset_model_id, Preset_model_website
|
|
||||||
|
|
||||||
from .sd_text_encoder import SDTextEncoder
|
|
||||||
from .sd_unet import SDUNet
|
|
||||||
from .sd_vae_encoder import SDVAEEncoder
|
|
||||||
from .sd_vae_decoder import SDVAEDecoder
|
|
||||||
from .lora import SDLoRAFromCivitai, SDXLLoRAFromCivitai, GeneralLoRAFromPeft
|
|
||||||
|
|
||||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
|
||||||
from .sdxl_unet import SDXLUNet
|
|
||||||
from .sdxl_vae_decoder import SDXLVAEDecoder
|
|
||||||
from .sdxl_vae_encoder import SDXLVAEEncoder
|
|
||||||
|
|
||||||
from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
|
||||||
from .sd3_dit import SD3DiT
|
|
||||||
from .sd3_vae_decoder import SD3VAEDecoder
|
|
||||||
from .sd3_vae_encoder import SD3VAEEncoder
|
|
||||||
|
|
||||||
from .sd_controlnet import SDControlNet
|
|
||||||
|
|
||||||
from .sd_motion import SDMotionModel
|
|
||||||
from .sdxl_motion import SDXLMotionModel
|
|
||||||
|
|
||||||
from .svd_image_encoder import SVDImageEncoder
|
|
||||||
from .svd_unet import SVDUNet
|
|
||||||
from .svd_vae_decoder import SVDVAEDecoder
|
|
||||||
from .svd_vae_encoder import SVDVAEEncoder
|
|
||||||
|
|
||||||
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
|
||||||
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
|
||||||
|
|
||||||
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
|
||||||
from .hunyuan_dit import HunyuanDiT
|
|
||||||
|
|
||||||
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(file_path, torch_dtype=None):
|
|
||||||
if file_path.endswith(".safetensors"):
|
|
||||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
|
||||||
else:
|
|
||||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
|
||||||
state_dict = {}
|
|
||||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
|
||||||
for k in f.keys():
|
|
||||||
state_dict[k] = f.get_tensor(k)
|
|
||||||
if torch_dtype is not None:
|
|
||||||
state_dict[k] = state_dict[k].to(torch_dtype)
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
|
||||||
state_dict = torch.load(file_path, map_location="cpu")
|
|
||||||
if torch_dtype is not None:
|
|
||||||
for i in state_dict:
|
|
||||||
if isinstance(state_dict[i], torch.Tensor):
|
|
||||||
state_dict[i] = state_dict[i].to(torch_dtype)
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def search_for_embeddings(state_dict):
|
|
||||||
embeddings = []
|
|
||||||
for k in state_dict:
|
|
||||||
if isinstance(state_dict[k], torch.Tensor):
|
|
||||||
embeddings.append(state_dict[k])
|
|
||||||
elif isinstance(state_dict[k], dict):
|
|
||||||
embeddings += search_for_embeddings(state_dict[k])
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
def search_parameter(param, state_dict):
|
|
||||||
for name, param_ in state_dict.items():
|
|
||||||
if param.numel() == param_.numel():
|
|
||||||
if param.shape == param_.shape:
|
|
||||||
if torch.dist(param, param_) < 1e-6:
|
|
||||||
return name
|
|
||||||
else:
|
|
||||||
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
|
|
||||||
return name
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
|
|
||||||
matched_keys = set()
|
|
||||||
with torch.no_grad():
|
|
||||||
for name in source_state_dict:
|
|
||||||
rename = search_parameter(source_state_dict[name], target_state_dict)
|
|
||||||
if rename is not None:
|
|
||||||
print(f'"{name}": "{rename}",')
|
|
||||||
matched_keys.add(rename)
|
|
||||||
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
|
|
||||||
length = source_state_dict[name].shape[0] // 3
|
|
||||||
rename = []
|
|
||||||
for i in range(3):
|
|
||||||
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
|
|
||||||
if None not in rename:
|
|
||||||
print(f'"{name}": {rename},')
|
|
||||||
for rename_ in rename:
|
|
||||||
matched_keys.add(rename_)
|
|
||||||
for name in target_state_dict:
|
|
||||||
if name not in matched_keys:
|
|
||||||
print("Cannot find", name, target_state_dict[name].shape)
|
|
||||||
|
|
||||||
|
|
||||||
def search_for_files(folder, extensions):
|
|
||||||
files = []
|
|
||||||
if os.path.isdir(folder):
|
|
||||||
for file in sorted(os.listdir(folder)):
|
|
||||||
files += search_for_files(os.path.join(folder, file), extensions)
|
|
||||||
elif os.path.isfile(folder):
|
|
||||||
for extension in extensions:
|
|
||||||
if folder.endswith(extension):
|
|
||||||
files.append(folder)
|
|
||||||
break
|
|
||||||
return files
|
|
||||||
|
|
||||||
|
|
||||||
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
|
||||||
keys = []
|
|
||||||
for key, value in state_dict.items():
|
|
||||||
if isinstance(key, str):
|
|
||||||
if isinstance(value, Tensor):
|
|
||||||
if with_shape:
|
|
||||||
shape = "_".join(map(str, list(value.shape)))
|
|
||||||
keys.append(key + ":" + shape)
|
|
||||||
keys.append(key)
|
|
||||||
elif isinstance(value, dict):
|
|
||||||
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
|
||||||
keys.sort()
|
|
||||||
keys_str = ",".join(keys)
|
|
||||||
return keys_str
|
|
||||||
|
|
||||||
|
|
||||||
def split_state_dict_with_prefix(state_dict):
|
|
||||||
keys = sorted([key for key in state_dict if isinstance(key, str)])
|
|
||||||
prefix_dict = {}
|
|
||||||
for key in keys:
|
|
||||||
prefix = key if "." not in key else key.split(".")[0]
|
|
||||||
if prefix not in prefix_dict:
|
|
||||||
prefix_dict[prefix] = []
|
|
||||||
prefix_dict[prefix].append(key)
|
|
||||||
state_dicts = []
|
|
||||||
for prefix, keys in prefix_dict.items():
|
|
||||||
sub_state_dict = {key: state_dict[key] for key in keys}
|
|
||||||
state_dicts.append(sub_state_dict)
|
|
||||||
return state_dicts
|
|
||||||
|
|
||||||
|
|
||||||
def hash_state_dict_keys(state_dict, with_shape=True):
|
|
||||||
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
|
||||||
keys_str = keys_str.encode(encoding="UTF-8")
|
|
||||||
return hashlib.md5(keys_str).hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
|
|
||||||
loaded_model_names, loaded_models = [], []
|
|
||||||
for model_name, model_class in zip(model_names, model_classes):
|
|
||||||
print(f" model_name: {model_name} model_class: {model_class.__name__}")
|
|
||||||
state_dict_converter = model_class.state_dict_converter()
|
|
||||||
if model_resource == "civitai":
|
|
||||||
state_dict_results = state_dict_converter.from_civitai(state_dict)
|
|
||||||
elif model_resource == "diffusers":
|
|
||||||
state_dict_results = state_dict_converter.from_diffusers(state_dict)
|
|
||||||
if isinstance(state_dict_results, tuple):
|
|
||||||
model_state_dict, extra_kwargs = state_dict_results
|
|
||||||
print(f" This model is initialized with extra kwargs: {extra_kwargs}")
|
|
||||||
else:
|
|
||||||
model_state_dict, extra_kwargs = state_dict_results, {}
|
|
||||||
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
|
|
||||||
model = model_class(**extra_kwargs).to(dtype=torch_dtype, device=device)
|
|
||||||
model.load_state_dict(model_state_dict)
|
|
||||||
loaded_model_names.append(model_name)
|
|
||||||
loaded_models.append(model)
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
|
|
||||||
loaded_model_names, loaded_models = [], []
|
|
||||||
for model_name, model_class in zip(model_names, model_classes):
|
|
||||||
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
|
|
||||||
if torch_dtype == torch.float16 and hasattr(model, "half"):
|
|
||||||
model = model.half()
|
|
||||||
model = model.to(device=device)
|
|
||||||
loaded_model_names.append(model_name)
|
|
||||||
loaded_models.append(model)
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
|
|
||||||
print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
|
|
||||||
base_state_dict = base_model.state_dict()
|
|
||||||
base_model.to("cpu")
|
|
||||||
del base_model
|
|
||||||
model = model_class(**extra_kwargs)
|
|
||||||
model.load_state_dict(base_state_dict, strict=False)
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
|
||||||
model.to(dtype=torch_dtype, device=device)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
|
|
||||||
loaded_model_names, loaded_models = [], []
|
|
||||||
for model_name, model_class in zip(model_names, model_classes):
|
|
||||||
while True:
|
|
||||||
for model_id in range(len(model_manager.model)):
|
|
||||||
base_model_name = model_manager.model_name[model_id]
|
|
||||||
if base_model_name == model_name:
|
|
||||||
base_model_path = model_manager.model_path[model_id]
|
|
||||||
base_model = model_manager.model[model_id]
|
|
||||||
print(f" Adding patch model to {base_model_name} ({base_model_path})")
|
|
||||||
patched_model = load_single_patch_model_from_single_file(
|
|
||||||
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
|
|
||||||
loaded_model_names.append(base_model_name)
|
|
||||||
loaded_models.append(patched_model)
|
|
||||||
model_manager.model.pop(model_id)
|
|
||||||
model_manager.model_path.pop(model_id)
|
|
||||||
model_manager.model_name.pop(model_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDetectorTemplate:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def match(self, file_path="", state_dict={}):
|
|
||||||
return False
|
|
||||||
|
|
||||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDetectorFromSingleFile:
|
|
||||||
def __init__(self, model_loader_configs=[]):
|
|
||||||
self.keys_hash_with_shape_dict = {}
|
|
||||||
self.keys_hash_dict = {}
|
|
||||||
for metadata in model_loader_configs:
|
|
||||||
self.add_model_metadata(*metadata)
|
|
||||||
|
|
||||||
|
|
||||||
def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
|
|
||||||
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
|
|
||||||
if keys_hash is not None:
|
|
||||||
self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
|
|
||||||
|
|
||||||
|
|
||||||
def match(self, file_path="", state_dict={}):
|
|
||||||
if os.path.isdir(file_path):
|
|
||||||
return False
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
|
||||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
|
||||||
return True
|
|
||||||
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
|
||||||
if keys_hash in self.keys_hash_dict:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
|
|
||||||
# Load models with strict matching
|
|
||||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
|
||||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
|
||||||
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
|
||||||
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
# Load models without strict matching
|
|
||||||
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
|
|
||||||
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
|
||||||
if keys_hash in self.keys_hash_dict:
|
|
||||||
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
|
|
||||||
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
|
|
||||||
def __init__(self, model_loader_configs=[]):
|
|
||||||
super().__init__(model_loader_configs)
|
|
||||||
|
|
||||||
|
|
||||||
def match(self, file_path="", state_dict={}):
|
|
||||||
if os.path.isdir(file_path):
|
|
||||||
return False
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
|
||||||
for sub_state_dict in splited_state_dict:
|
|
||||||
if super().match(file_path, sub_state_dict):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
|
||||||
# Split the state_dict and load from each component
|
|
||||||
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
|
||||||
valid_state_dict = {}
|
|
||||||
for sub_state_dict in splited_state_dict:
|
|
||||||
if super().match(file_path, sub_state_dict):
|
|
||||||
valid_state_dict.update(sub_state_dict)
|
|
||||||
if super().match(file_path, valid_state_dict):
|
|
||||||
loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
|
|
||||||
else:
|
|
||||||
loaded_model_names, loaded_models = [], []
|
|
||||||
for sub_state_dict in splited_state_dict:
|
|
||||||
if super().match(file_path, sub_state_dict):
|
|
||||||
loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
|
|
||||||
loaded_model_names += loaded_model_names_
|
|
||||||
loaded_models += loaded_models_
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDetectorFromHuggingfaceFolder:
|
|
||||||
def __init__(self, model_loader_configs=[]):
|
|
||||||
self.architecture_dict = {}
|
|
||||||
for metadata in model_loader_configs:
|
|
||||||
self.add_model_metadata(*metadata)
|
|
||||||
|
|
||||||
|
|
||||||
def add_model_metadata(self, architecture, huggingface_lib, model_name):
|
|
||||||
self.architecture_dict[architecture] = (huggingface_lib, model_name)
|
|
||||||
|
|
||||||
|
|
||||||
def match(self, file_path="", state_dict={}):
|
|
||||||
if os.path.isfile(file_path):
|
|
||||||
return False
|
|
||||||
file_list = os.listdir(file_path)
|
|
||||||
if "config.json" not in file_list:
|
|
||||||
return False
|
|
||||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
if "architectures" not in config:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
|
||||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
loaded_model_names, loaded_models = [], []
|
|
||||||
for architecture in config["architectures"]:
|
|
||||||
huggingface_lib, model_name = self.architecture_dict[architecture]
|
|
||||||
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
|
|
||||||
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
|
|
||||||
loaded_model_names += loaded_model_names_
|
|
||||||
loaded_models += loaded_models_
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDetectorFromPatchedSingleFile:
|
|
||||||
def __init__(self, model_loader_configs=[]):
|
|
||||||
self.keys_hash_with_shape_dict = {}
|
|
||||||
for metadata in model_loader_configs:
|
|
||||||
self.add_model_metadata(*metadata)
|
|
||||||
|
|
||||||
|
|
||||||
def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
|
|
||||||
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def match(self, file_path="", state_dict={}):
|
|
||||||
if os.path.isdir(file_path):
|
|
||||||
return False
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
|
||||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
|
|
||||||
# Load models with strict matching
|
|
||||||
loaded_model_names, loaded_models = [], []
|
|
||||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
|
||||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
|
||||||
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
|
||||||
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
|
|
||||||
state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
|
|
||||||
loaded_model_names += loaded_model_names_
|
|
||||||
loaded_models += loaded_models_
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
device="cuda",
|
|
||||||
model_id_list: List[Preset_model_id] = [],
|
|
||||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
|
||||||
file_path_list: List[str] = [],
|
|
||||||
):
|
|
||||||
self.torch_dtype = torch_dtype
|
|
||||||
self.device = device
|
|
||||||
self.model = []
|
|
||||||
self.model_path = []
|
|
||||||
self.model_name = []
|
|
||||||
downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else []
|
|
||||||
self.model_detector = [
|
|
||||||
ModelDetectorFromSingleFile(model_loader_configs),
|
|
||||||
ModelDetectorFromSplitedSingleFile(model_loader_configs),
|
|
||||||
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
|
|
||||||
ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
|
|
||||||
]
|
|
||||||
self.load_models(downloaded_files + file_path_list)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
|
|
||||||
print(f"Loading models from file: {file_path}")
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
|
|
||||||
for model_name, model in zip(model_names, models):
|
|
||||||
self.model.append(model)
|
|
||||||
self.model_path.append(file_path)
|
|
||||||
self.model_name.append(model_name)
|
|
||||||
print(f" The following models are loaded: {model_names}.")
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
|
|
||||||
print(f"Loading models from folder: {file_path}")
|
|
||||||
model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
|
|
||||||
for model_name, model in zip(model_names, models):
|
|
||||||
self.model.append(model)
|
|
||||||
self.model_path.append(file_path)
|
|
||||||
self.model_name.append(model_name)
|
|
||||||
print(f" The following models are loaded: {model_names}.")
|
|
||||||
|
|
||||||
|
|
||||||
def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
|
|
||||||
print(f"Loading patch models from file: {file_path}")
|
|
||||||
model_names, models = load_patch_model_from_single_file(
|
|
||||||
state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
|
|
||||||
for model_name, model in zip(model_names, models):
|
|
||||||
self.model.append(model)
|
|
||||||
self.model_path.append(file_path)
|
|
||||||
self.model_name.append(model_name)
|
|
||||||
print(f" The following patched models are loaded: {model_names}.")
|
|
||||||
|
|
||||||
|
|
||||||
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
|
|
||||||
print(f"Loading LoRA models from file: {file_path}")
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
|
||||||
for lora in [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft()]:
|
|
||||||
match_results = lora.match(model, state_dict)
|
|
||||||
if match_results is not None:
|
|
||||||
print(f" Adding LoRA to {model_name} ({model_path}).")
|
|
||||||
lora_prefix, model_resource = match_results
|
|
||||||
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(self, file_path, model_names=None):
|
|
||||||
print(f"Loading models from: {file_path}")
|
|
||||||
if os.path.isfile(file_path):
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
else:
|
|
||||||
state_dict = None
|
|
||||||
for model_detector in self.model_detector:
|
|
||||||
if model_detector.match(file_path, state_dict):
|
|
||||||
model_names, models = model_detector.load(
|
|
||||||
file_path, state_dict,
|
|
||||||
device=self.device, torch_dtype=self.torch_dtype,
|
|
||||||
allowed_model_names=model_names, model_manager=self
|
|
||||||
)
|
|
||||||
for model_name, model in zip(model_names, models):
|
|
||||||
self.model.append(model)
|
|
||||||
self.model_path.append(file_path)
|
|
||||||
self.model_name.append(model_name)
|
|
||||||
print(f" The following models are loaded: {model_names}.")
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
print(f" We cannot detect the model type. No models are loaded.")
|
|
||||||
|
|
||||||
|
|
||||||
def load_models(self, file_path_list, model_names=None):
|
|
||||||
for file_path in file_path_list:
|
|
||||||
self.load_model(file_path, model_names)
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_model(self, model_name, file_path=None, require_model_path=False):
|
|
||||||
fetched_models = []
|
|
||||||
fetched_model_paths = []
|
|
||||||
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
|
|
||||||
if file_path is not None and file_path != model_path:
|
|
||||||
continue
|
|
||||||
if model_name == model_name_:
|
|
||||||
fetched_models.append(model)
|
|
||||||
fetched_model_paths.append(model_path)
|
|
||||||
if len(fetched_models) == 0:
|
|
||||||
print(f"No {model_name} models available.")
|
|
||||||
return None
|
|
||||||
if len(fetched_models) == 1:
|
|
||||||
print(f"Using {model_name} from {fetched_model_paths[0]}.")
|
|
||||||
else:
|
|
||||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
|
|
||||||
if require_model_path:
|
|
||||||
return fetched_models[0], fetched_model_paths[0]
|
|
||||||
else:
|
|
||||||
return fetched_models[0]
|
|
||||||
|
|
||||||
|
|
||||||
def to(self, device):
|
|
||||||
for model in self.model:
|
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
161
diffsynth/models/nexus_gen.py
Normal file
161
diffsynth/models/nexus_gen.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class NexusGenAutoregressiveModel(torch.nn.Module):
|
||||||
|
def __init__(self, max_length=1024, max_pixels=262640):
|
||||||
|
super(NexusGenAutoregressiveModel, self).__init__()
|
||||||
|
from .nexus_gen_ar_model import Qwen2_5_VLForConditionalGeneration
|
||||||
|
from transformers import Qwen2_5_VLConfig
|
||||||
|
self.max_length = max_length
|
||||||
|
self.max_pixels = max_pixels
|
||||||
|
model_config = Qwen2_5_VLConfig(**{
|
||||||
|
"_name_or_path": "DiffSynth-Studio/Nexus-GenV2",
|
||||||
|
"architectures": [
|
||||||
|
"Qwen2_5_VLForConditionalGeneration"
|
||||||
|
],
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"auto_map": {
|
||||||
|
"AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig",
|
||||||
|
"AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel",
|
||||||
|
"AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration"
|
||||||
|
},
|
||||||
|
"bos_token_id": 151643,
|
||||||
|
"eos_token_id": 151645,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 3584,
|
||||||
|
"image_token_id": 151655,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 18944,
|
||||||
|
"max_position_embeddings": 128000,
|
||||||
|
"max_window_layers": 28,
|
||||||
|
"model_type": "qwen2_5_vl",
|
||||||
|
"num_attention_heads": 28,
|
||||||
|
"num_hidden_layers": 28,
|
||||||
|
"num_key_value_heads": 4,
|
||||||
|
"pad_token_id": 151643,
|
||||||
|
"rms_norm_eps": 1e-06,
|
||||||
|
"rope_scaling": {
|
||||||
|
"mrope_section": [
|
||||||
|
16,
|
||||||
|
24,
|
||||||
|
24
|
||||||
|
],
|
||||||
|
"rope_type": "default",
|
||||||
|
"type": "default"
|
||||||
|
},
|
||||||
|
"rope_theta": 1000000.0,
|
||||||
|
"sliding_window": 32768,
|
||||||
|
"tie_word_embeddings": False,
|
||||||
|
"torch_dtype": "bfloat16",
|
||||||
|
"transformers_version": "4.49.0",
|
||||||
|
"use_cache": False,
|
||||||
|
"use_sliding_window": False,
|
||||||
|
"video_token_id": 151656,
|
||||||
|
"vision_config": {
|
||||||
|
"hidden_size": 1280,
|
||||||
|
"in_chans": 3,
|
||||||
|
"model_type": "qwen2_5_vl",
|
||||||
|
"spatial_patch_size": 14,
|
||||||
|
"tokens_per_second": 2,
|
||||||
|
"torch_dtype": "bfloat16"
|
||||||
|
},
|
||||||
|
"vision_end_token_id": 151653,
|
||||||
|
"vision_start_token_id": 151652,
|
||||||
|
"vision_token_id": 151654,
|
||||||
|
"vocab_size": 152064
|
||||||
|
})
|
||||||
|
self.model = Qwen2_5_VLForConditionalGeneration(model_config)
|
||||||
|
self.processor = None
|
||||||
|
|
||||||
|
|
||||||
|
def load_processor(self, path):
|
||||||
|
from .nexus_gen_ar_model import Qwen2_5_VLProcessor
|
||||||
|
self.processor = Qwen2_5_VLProcessor.from_pretrained(path)
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return NexusGenAutoregressiveModelStateDictConverter()
|
||||||
|
|
||||||
|
def bound_image(self, image, max_pixels=262640):
|
||||||
|
from qwen_vl_utils import smart_resize
|
||||||
|
resized_height, resized_width = smart_resize(
|
||||||
|
image.height,
|
||||||
|
image.width,
|
||||||
|
max_pixels=max_pixels,
|
||||||
|
)
|
||||||
|
return image.resize((resized_width, resized_height))
|
||||||
|
|
||||||
|
def get_editing_msg(self, instruction):
|
||||||
|
if '<image>' not in instruction:
|
||||||
|
instruction = '<image> ' + instruction
|
||||||
|
messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is the image: <image>"}]
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def get_generation_msg(self, instruction):
|
||||||
|
instruction = "Generate an image according to the following description: {}".format(instruction)
|
||||||
|
messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is an image based on the description: <image>"}]
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def forward(self, instruction, ref_image=None, num_img_tokens=81):
|
||||||
|
"""
|
||||||
|
Generate target embeddings for the given instruction and reference image.
|
||||||
|
"""
|
||||||
|
if ref_image is not None:
|
||||||
|
messages = self.get_editing_msg(instruction)
|
||||||
|
images = [self.bound_image(ref_image)] + [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))]
|
||||||
|
output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens)
|
||||||
|
else:
|
||||||
|
messages = self.get_generation_msg(instruction)
|
||||||
|
images = [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))]
|
||||||
|
output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens)
|
||||||
|
|
||||||
|
return output_image_embeddings
|
||||||
|
|
||||||
|
def get_target_embeddings(self, images, messages, processor, model, num_img_tokens=81):
|
||||||
|
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
|
||||||
|
text = text.replace('<image>', '<|vision_start|><|image_pad|><|vision_end|>')
|
||||||
|
inputs = processor(
|
||||||
|
text=[text],
|
||||||
|
images=images,
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
inputs = inputs.to(model.device)
|
||||||
|
|
||||||
|
input_embeds = model.model.embed_tokens(inputs['input_ids'])
|
||||||
|
image_embeds = model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw'])
|
||||||
|
ground_truth_image_embeds = image_embeds[-num_img_tokens:]
|
||||||
|
input_image_embeds = image_embeds[:-num_img_tokens]
|
||||||
|
|
||||||
|
image_mask = inputs['input_ids'] == model.config.image_token_id
|
||||||
|
indices = image_mask.cumsum(dim=1)
|
||||||
|
input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask)
|
||||||
|
gt_image_mask = torch.logical_and(image_mask, ~input_image_mask)
|
||||||
|
input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds)
|
||||||
|
input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds)
|
||||||
|
|
||||||
|
image_prefill_embeds = model.image_prefill_embeds(
|
||||||
|
torch.arange(81, device=model.device).long()
|
||||||
|
)
|
||||||
|
input_embeds = input_embeds.masked_scatter(gt_image_mask.unsqueeze(-1).expand_as(input_embeds), image_prefill_embeds)
|
||||||
|
|
||||||
|
position_ids, _ = model.get_rope_index(
|
||||||
|
inputs['input_ids'],
|
||||||
|
inputs['image_grid_thw'],
|
||||||
|
attention_mask=inputs['attention_mask'])
|
||||||
|
position_ids = position_ids.contiguous()
|
||||||
|
outputs = model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True)
|
||||||
|
output_image_embeddings = outputs.image_embeddings[:, :-1, :]
|
||||||
|
output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]]
|
||||||
|
return output_image_embeddings, input_image_embeds, inputs['image_grid_thw']
|
||||||
|
|
||||||
|
|
||||||
|
class NexusGenAutoregressiveModelStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
state_dict = {"model." + key: value for key, value in state_dict.items()}
|
||||||
|
return state_dict
|
||||||
1143
diffsynth/models/nexus_gen_ar_model.py
Normal file
1143
diffsynth/models/nexus_gen_ar_model.py
Normal file
File diff suppressed because it is too large
Load Diff
417
diffsynth/models/nexus_gen_projector.py
Normal file
417
diffsynth/models/nexus_gen_projector.py
Normal file
@@ -0,0 +1,417 @@
|
|||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
"""Rotates half the hidden dims of the input."""
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
||||||
|
mrope_section = mrope_section * 2
|
||||||
|
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||||
|
unsqueeze_dim
|
||||||
|
)
|
||||||
|
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||||
|
unsqueeze_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5_VLRotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, config, device=None):
|
||||||
|
super().__init__()
|
||||||
|
# BC: "rope_type" was originally "type"
|
||||||
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
|
else:
|
||||||
|
self.rope_type = "default"
|
||||||
|
self.max_seq_len_cached = config.max_position_embeddings
|
||||||
|
self.original_max_seq_len = config.max_position_embeddings
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
from transformers.modeling_rope_utils import _compute_default_rope_parameters
|
||||||
|
self.rope_init_fn = _compute_default_rope_parameters
|
||||||
|
|
||||||
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
|
|
||||||
|
def _dynamic_frequency_update(self, position_ids, device):
|
||||||
|
"""
|
||||||
|
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
||||||
|
1 - growing beyond the cached sequence length (allow scaling)
|
||||||
|
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||||
|
"""
|
||||||
|
seq_len = torch.max(position_ids) + 1
|
||||||
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
|
inv_freq, self.attention_scaling = self.rope_init_fn(
|
||||||
|
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
||||||
|
)
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, x, position_ids):
|
||||||
|
if "dynamic" in self.rope_type:
|
||||||
|
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||||
|
|
||||||
|
# Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for the grids
|
||||||
|
# So we expand the inv_freq to shape (3, ...)
|
||||||
|
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
|
||||||
|
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
||||||
|
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||||
|
device_type = x.device.type
|
||||||
|
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||||
|
with torch.autocast(device_type=device_type, enabled=False):
|
||||||
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
cos = emb.cos()
|
||||||
|
sin = emb.sin()
|
||||||
|
|
||||||
|
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
||||||
|
cos = cos * self.attention_scaling
|
||||||
|
sin = sin * self.attention_scaling
|
||||||
|
|
||||||
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||||
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||||
|
"""
|
||||||
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return hidden_states
|
||||||
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||||
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5_VLAttention(nn.Module):
|
||||||
|
def __init__(self, config, layer_idx: Optional[int] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.is_causal = True
|
||||||
|
self.attention_dropout = config.attention_dropout
|
||||||
|
self.rope_scaling = config.rope_scaling
|
||||||
|
|
||||||
|
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||||
|
f" and `num_heads`: {self.num_heads})."
|
||||||
|
)
|
||||||
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
||||||
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||||
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||||
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
# Fix precision issues in Qwen2-VL float16 inference
|
||||||
|
# Replace inf values with zeros in attention weights to prevent NaN propagation
|
||||||
|
if query_states.dtype == torch.float16:
|
||||||
|
attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights)
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2MLP(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.intermediate_size = config.intermediate_size
|
||||||
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||||
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
return down_proj
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2RMSNorm(nn.Module):
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
"""
|
||||||
|
Qwen2RMSNorm is equivalent to T5LayerNorm
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5_VLDecoderLayer(nn.Module):
|
||||||
|
def __init__(self, config, layer_idx):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
|
self.self_attn = Qwen2_5_VLAttention(config, layer_idx)
|
||||||
|
|
||||||
|
self.mlp = Qwen2MLP(config)
|
||||||
|
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class NexusGenImageEmbeddingMerger(nn.Module):
|
||||||
|
def __init__(self, num_layers=1, out_channel=4096, expand_ratio=4, device='cpu'):
|
||||||
|
super().__init__()
|
||||||
|
from transformers import Qwen2_5_VLConfig
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
config = Qwen2_5_VLConfig(**{
|
||||||
|
"_name_or_path": "DiffSynth-Studio/Nexus-GenV2",
|
||||||
|
"architectures": [
|
||||||
|
"Qwen2_5_VLForConditionalGeneration"
|
||||||
|
],
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"auto_map": {
|
||||||
|
"AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig",
|
||||||
|
"AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel",
|
||||||
|
"AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration"
|
||||||
|
},
|
||||||
|
"bos_token_id": 151643,
|
||||||
|
"eos_token_id": 151645,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 3584,
|
||||||
|
"image_token_id": 151655,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 18944,
|
||||||
|
"max_position_embeddings": 128000,
|
||||||
|
"max_window_layers": 28,
|
||||||
|
"model_type": "qwen2_5_vl",
|
||||||
|
"num_attention_heads": 28,
|
||||||
|
"num_hidden_layers": 28,
|
||||||
|
"num_key_value_heads": 4,
|
||||||
|
"pad_token_id": 151643,
|
||||||
|
"rms_norm_eps": 1e-06,
|
||||||
|
"rope_scaling": {
|
||||||
|
"mrope_section": [
|
||||||
|
16,
|
||||||
|
24,
|
||||||
|
24
|
||||||
|
],
|
||||||
|
"rope_type": "default",
|
||||||
|
"type": "default"
|
||||||
|
},
|
||||||
|
"rope_theta": 1000000.0,
|
||||||
|
"sliding_window": 32768,
|
||||||
|
"tie_word_embeddings": False,
|
||||||
|
"torch_dtype": "bfloat16",
|
||||||
|
"transformers_version": "4.49.0",
|
||||||
|
"use_cache": False,
|
||||||
|
"use_sliding_window": False,
|
||||||
|
"video_token_id": 151656,
|
||||||
|
"vision_config": {
|
||||||
|
"hidden_size": 1280,
|
||||||
|
"in_chans": 3,
|
||||||
|
"model_type": "qwen2_5_vl",
|
||||||
|
"spatial_patch_size": 14,
|
||||||
|
"tokens_per_second": 2,
|
||||||
|
"torch_dtype": "bfloat16"
|
||||||
|
},
|
||||||
|
"vision_end_token_id": 151653,
|
||||||
|
"vision_start_token_id": 151652,
|
||||||
|
"vision_token_id": 151654,
|
||||||
|
"vocab_size": 152064
|
||||||
|
})
|
||||||
|
self.config = config
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.layers = nn.ModuleList([Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(num_layers)])
|
||||||
|
self.projector = nn.Sequential(Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps),
|
||||||
|
nn.Linear(config.hidden_size, out_channel * expand_ratio),
|
||||||
|
Qwen2RMSNorm(out_channel * expand_ratio, eps=config.rms_norm_eps),
|
||||||
|
ACT2FN[config.hidden_act], nn.Linear(out_channel * expand_ratio, out_channel),
|
||||||
|
Qwen2RMSNorm(out_channel, eps=config.rms_norm_eps))
|
||||||
|
self.base_grid = torch.tensor([[1, 72, 72]], device=device)
|
||||||
|
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config, device=device)
|
||||||
|
|
||||||
|
def get_position_ids(self, image_grid_thw):
|
||||||
|
"""
|
||||||
|
Generates position ids for the input embeddings grid.
|
||||||
|
modified from the qwen2_vl mrope.
|
||||||
|
"""
|
||||||
|
batch_size = image_grid_thw.shape[0]
|
||||||
|
spatial_merge_size = self.config.vision_config.spatial_merge_size
|
||||||
|
t, h, w = (
|
||||||
|
image_grid_thw[0][0],
|
||||||
|
image_grid_thw[0][1],
|
||||||
|
image_grid_thw[0][2],
|
||||||
|
)
|
||||||
|
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||||
|
t.item(),
|
||||||
|
h.item() // spatial_merge_size,
|
||||||
|
w.item() // spatial_merge_size,
|
||||||
|
)
|
||||||
|
scale_h = self.base_grid[0][1].item() / h.item()
|
||||||
|
scale_w = self.base_grid[0][2].item() / w.item()
|
||||||
|
|
||||||
|
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
|
||||||
|
expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
|
||||||
|
time_tensor = expanded_range * self.config.vision_config.tokens_per_second
|
||||||
|
t_index = time_tensor.long().flatten().to(image_grid_thw.device)
|
||||||
|
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten().to(image_grid_thw.device) * scale_h
|
||||||
|
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten().to(image_grid_thw.device) * scale_w
|
||||||
|
# 3, B, L
|
||||||
|
position_ids = torch.stack([t_index, h_index, w_index]).unsqueeze(0).repeat(batch_size, 1, 1).permute(1, 0, 2)
|
||||||
|
return position_ids
|
||||||
|
|
||||||
|
def forward(self, embeds, embeds_grid, ref_embeds=None, ref_embeds_grid=None):
|
||||||
|
position_ids = self.get_position_ids(embeds_grid)
|
||||||
|
hidden_states = embeds
|
||||||
|
if ref_embeds is not None:
|
||||||
|
position_ids_ref_embeds = self.get_position_ids(ref_embeds_grid)
|
||||||
|
position_ids = torch.cat((position_ids, position_ids_ref_embeds), dim=-1)
|
||||||
|
hidden_states = torch.cat((embeds, ref_embeds), dim=1)
|
||||||
|
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states = layer(hidden_states, position_embeddings)
|
||||||
|
|
||||||
|
hidden_states = self.projector(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return NexusGenMergerStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class NexusGenMergerStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
merger_state_dict = {key.replace("embedding_merger.", ""): value for key, value in state_dict.items() if key.startswith('embedding_merger.')}
|
||||||
|
return merger_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
class NexusGenAdapter(nn.Module):
|
||||||
|
"""
|
||||||
|
Adapter for Nexus-Gen generation decoder.
|
||||||
|
"""
|
||||||
|
def __init__(self, input_dim=3584, output_dim=4096):
|
||||||
|
super(NexusGenAdapter, self).__init__()
|
||||||
|
self.adapter = nn.Sequential(nn.Linear(input_dim, output_dim),
|
||||||
|
nn.LayerNorm(output_dim), nn.ReLU(),
|
||||||
|
nn.Linear(output_dim, output_dim),
|
||||||
|
nn.LayerNorm(output_dim))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.adapter(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return NexusGenAdapterStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class NexusGenAdapterStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
adapter_state_dict = {key: value for key, value in state_dict.items() if key.startswith('adapter.')}
|
||||||
|
return adapter_state_dict
|
||||||
56
diffsynth/models/qwen_image_controlnet.py
Normal file
56
diffsynth/models/qwen_image_controlnet.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from .general_modules import RMSNorm
|
||||||
|
|
||||||
|
|
||||||
|
class BlockWiseControlBlock(torch.nn.Module):
|
||||||
|
# [linear, gelu, linear]
|
||||||
|
def __init__(self, dim: int = 3072):
|
||||||
|
super().__init__()
|
||||||
|
self.x_rms = RMSNorm(dim, eps=1e-6)
|
||||||
|
self.y_rms = RMSNorm(dim, eps=1e-6)
|
||||||
|
self.input_proj = nn.Linear(dim, dim)
|
||||||
|
self.act = nn.GELU()
|
||||||
|
self.output_proj = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
x, y = self.x_rms(x), self.y_rms(y)
|
||||||
|
x = self.input_proj(x + y)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.output_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
# zero initialize output_proj
|
||||||
|
nn.init.zeros_(self.output_proj.weight)
|
||||||
|
nn.init.zeros_(self.output_proj.bias)
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageBlockWiseControlNet(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_layers: int = 60,
|
||||||
|
in_dim: int = 64,
|
||||||
|
additional_in_dim: int = 0,
|
||||||
|
dim: int = 3072,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.img_in = nn.Linear(in_dim + additional_in_dim, dim)
|
||||||
|
self.controlnet_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
BlockWiseControlBlock(dim)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_weight(self):
|
||||||
|
nn.init.zeros_(self.img_in.weight)
|
||||||
|
nn.init.zeros_(self.img_in.bias)
|
||||||
|
for block in self.controlnet_blocks:
|
||||||
|
block.init_weights()
|
||||||
|
|
||||||
|
def process_controlnet_conditioning(self, controlnet_conditioning):
|
||||||
|
return self.img_in(controlnet_conditioning)
|
||||||
|
|
||||||
|
def blockwise_forward(self, img, controlnet_conditioning, block_id):
|
||||||
|
return self.controlnet_blocks[block_id](img, controlnet_conditioning)
|
||||||
685
diffsynth/models/qwen_image_dit.py
Normal file
685
diffsynth/models/qwen_image_dit.py
Normal file
@@ -0,0 +1,685 @@
|
|||||||
|
import torch, math, functools
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import Tuple, Optional, Union, List
|
||||||
|
from einops import rearrange
|
||||||
|
from .general_modules import TimestepEmbeddings, RMSNorm, AdaLayerNorm
|
||||||
|
|
||||||
|
try:
|
||||||
|
import flash_attn_interface
|
||||||
|
FLASH_ATTN_3_AVAILABLE = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
FLASH_ATTN_3_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, attention_mask = None, enable_fp8_attention: bool = False):
|
||||||
|
if FLASH_ATTN_3_AVAILABLE and attention_mask is None:
|
||||||
|
if not enable_fp8_attention:
|
||||||
|
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
|
||||||
|
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
|
||||||
|
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
|
||||||
|
x = flash_attn_interface.flash_attn_func(q, k, v)
|
||||||
|
if isinstance(x, tuple):
|
||||||
|
x = x[0]
|
||||||
|
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
||||||
|
else:
|
||||||
|
origin_dtype = q.dtype
|
||||||
|
q_std, k_std, v_std = q.std(), k.std(), v.std()
|
||||||
|
q, k, v = (q / q_std).to(torch.float8_e4m3fn), (k / k_std).to(torch.float8_e4m3fn), (v / v_std).to(torch.float8_e4m3fn)
|
||||||
|
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
|
||||||
|
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
|
||||||
|
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
|
||||||
|
x = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=q_std * k_std / math.sqrt(q.size(-1)))
|
||||||
|
if isinstance(x, tuple):
|
||||||
|
x = x[0]
|
||||||
|
x = x.to(origin_dtype) * v_std
|
||||||
|
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
||||||
|
else:
|
||||||
|
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
|
||||||
|
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ApproximateGELU(nn.Module):
|
||||||
|
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.proj(x)
|
||||||
|
return x * torch.sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
def apply_rotary_emb_qwen(
|
||||||
|
x: torch.Tensor,
|
||||||
|
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]
|
||||||
|
):
|
||||||
|
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||||
|
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
||||||
|
return x_out.type_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
class QwenEmbedRope(nn.Module):
|
||||||
|
def __init__(self, theta: int, axes_dim: list[int], scale_rope=False):
|
||||||
|
super().__init__()
|
||||||
|
self.theta = theta
|
||||||
|
self.axes_dim = axes_dim
|
||||||
|
pos_index = torch.arange(4096)
|
||||||
|
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
||||||
|
self.pos_freqs = torch.cat([
|
||||||
|
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||||
|
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||||
|
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
||||||
|
], dim=1)
|
||||||
|
self.neg_freqs = torch.cat([
|
||||||
|
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
||||||
|
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
||||||
|
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
||||||
|
], dim=1)
|
||||||
|
self.rope_cache = {}
|
||||||
|
self.scale_rope = scale_rope
|
||||||
|
|
||||||
|
def rope_params(self, index, dim, theta=10000):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
|
||||||
|
"""
|
||||||
|
assert dim % 2 == 0
|
||||||
|
freqs = torch.outer(
|
||||||
|
index,
|
||||||
|
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))
|
||||||
|
)
|
||||||
|
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens):
|
||||||
|
if isinstance(video_fhw, list):
|
||||||
|
video_fhw = tuple(max([i[j] for i in video_fhw]) for j in range(3))
|
||||||
|
_, height, width = video_fhw
|
||||||
|
if self.scale_rope:
|
||||||
|
max_vid_index = max(height // 2, width // 2)
|
||||||
|
else:
|
||||||
|
max_vid_index = max(height, width)
|
||||||
|
required_len = max_vid_index + max(txt_seq_lens)
|
||||||
|
cur_max_len = self.pos_freqs.shape[0]
|
||||||
|
if required_len <= cur_max_len:
|
||||||
|
return
|
||||||
|
|
||||||
|
new_max_len = math.ceil(required_len / 512) * 512
|
||||||
|
pos_index = torch.arange(new_max_len)
|
||||||
|
neg_index = torch.arange(new_max_len).flip(0) * -1 - 1
|
||||||
|
self.pos_freqs = torch.cat([
|
||||||
|
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||||
|
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||||
|
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
||||||
|
], dim=1)
|
||||||
|
self.neg_freqs = torch.cat([
|
||||||
|
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
||||||
|
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
||||||
|
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
||||||
|
], dim=1)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, video_fhw, txt_seq_lens, device):
|
||||||
|
self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)
|
||||||
|
if self.pos_freqs.device != device:
|
||||||
|
self.pos_freqs = self.pos_freqs.to(device)
|
||||||
|
self.neg_freqs = self.neg_freqs.to(device)
|
||||||
|
|
||||||
|
vid_freqs = []
|
||||||
|
max_vid_index = 0
|
||||||
|
for idx, fhw in enumerate(video_fhw):
|
||||||
|
frame, height, width = fhw
|
||||||
|
rope_key = f"{idx}_{height}_{width}"
|
||||||
|
|
||||||
|
if rope_key not in self.rope_cache:
|
||||||
|
seq_lens = frame * height * width
|
||||||
|
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||||
|
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||||
|
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||||
|
if self.scale_rope:
|
||||||
|
freqs_height = torch.cat(
|
||||||
|
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
|
||||||
|
)
|
||||||
|
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||||
|
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||||
|
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||||
|
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||||
|
|
||||||
|
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||||
|
self.rope_cache[rope_key] = freqs.clone().contiguous()
|
||||||
|
vid_freqs.append(self.rope_cache[rope_key])
|
||||||
|
|
||||||
|
if self.scale_rope:
|
||||||
|
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
||||||
|
else:
|
||||||
|
max_vid_index = max(height, width, max_vid_index)
|
||||||
|
|
||||||
|
max_len = max(txt_seq_lens)
|
||||||
|
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||||
|
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||||
|
|
||||||
|
return vid_freqs, txt_freqs
|
||||||
|
|
||||||
|
|
||||||
|
def forward_sampling(self, video_fhw, txt_seq_lens, device):
|
||||||
|
self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)
|
||||||
|
if self.pos_freqs.device != device:
|
||||||
|
self.pos_freqs = self.pos_freqs.to(device)
|
||||||
|
self.neg_freqs = self.neg_freqs.to(device)
|
||||||
|
|
||||||
|
vid_freqs = []
|
||||||
|
max_vid_index = 0
|
||||||
|
for idx, fhw in enumerate(video_fhw):
|
||||||
|
frame, height, width = fhw
|
||||||
|
rope_key = f"{idx}_{height}_{width}"
|
||||||
|
if idx > 0 and f"{0}_{height}_{width}" not in self.rope_cache:
|
||||||
|
frame_0, height_0, width_0 = video_fhw[0]
|
||||||
|
|
||||||
|
rope_key_0 = f"0_{height_0}_{width_0}"
|
||||||
|
spatial_freqs_0 = self.rope_cache[rope_key_0].reshape(frame_0, height_0, width_0, -1)
|
||||||
|
h_indices = torch.linspace(0, height_0 - 1, height).long()
|
||||||
|
w_indices = torch.linspace(0, width_0 - 1, width).long()
|
||||||
|
h_grid, w_grid = torch.meshgrid(h_indices, w_indices, indexing='ij')
|
||||||
|
sampled_rope = spatial_freqs_0[:, h_grid, w_grid, :]
|
||||||
|
|
||||||
|
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||||
|
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||||
|
sampled_rope[:, :, :, :freqs_frame.shape[-1]] = freqs_frame
|
||||||
|
|
||||||
|
seq_lens = frame * height * width
|
||||||
|
self.rope_cache[rope_key] = sampled_rope.reshape(seq_lens, -1).clone()
|
||||||
|
if rope_key not in self.rope_cache:
|
||||||
|
seq_lens = frame * height * width
|
||||||
|
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||||
|
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||||
|
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||||
|
if self.scale_rope:
|
||||||
|
freqs_height = torch.cat(
|
||||||
|
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
|
||||||
|
)
|
||||||
|
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||||
|
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||||
|
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||||
|
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||||
|
|
||||||
|
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||||
|
self.rope_cache[rope_key] = freqs.clone()
|
||||||
|
vid_freqs.append(self.rope_cache[rope_key].contiguous())
|
||||||
|
|
||||||
|
if self.scale_rope:
|
||||||
|
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
||||||
|
else:
|
||||||
|
max_vid_index = max(height, width, max_vid_index)
|
||||||
|
|
||||||
|
max_len = max(txt_seq_lens)
|
||||||
|
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||||
|
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||||
|
|
||||||
|
return vid_freqs, txt_freqs
|
||||||
|
|
||||||
|
|
||||||
|
class QwenEmbedLayer3DRope(nn.Module):
|
||||||
|
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
|
||||||
|
super().__init__()
|
||||||
|
self.theta = theta
|
||||||
|
self.axes_dim = axes_dim
|
||||||
|
pos_index = torch.arange(4096)
|
||||||
|
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
||||||
|
self.pos_freqs = torch.cat(
|
||||||
|
[
|
||||||
|
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||||
|
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||||
|
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
self.neg_freqs = torch.cat(
|
||||||
|
[
|
||||||
|
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
||||||
|
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
||||||
|
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scale_rope = scale_rope
|
||||||
|
|
||||||
|
def rope_params(self, index, dim, theta=10000):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
|
||||||
|
"""
|
||||||
|
assert dim % 2 == 0
|
||||||
|
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
|
||||||
|
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
def forward(self, video_fhw, txt_seq_lens, device):
|
||||||
|
"""
|
||||||
|
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
|
||||||
|
txt_length: [bs] a list of 1 integers representing the length of the text
|
||||||
|
"""
|
||||||
|
if self.pos_freqs.device != device:
|
||||||
|
self.pos_freqs = self.pos_freqs.to(device)
|
||||||
|
self.neg_freqs = self.neg_freqs.to(device)
|
||||||
|
|
||||||
|
video_fhw = [video_fhw]
|
||||||
|
if isinstance(video_fhw, list):
|
||||||
|
video_fhw = video_fhw[0]
|
||||||
|
if not isinstance(video_fhw, list):
|
||||||
|
video_fhw = [video_fhw]
|
||||||
|
|
||||||
|
vid_freqs = []
|
||||||
|
max_vid_index = 0
|
||||||
|
layer_num = len(video_fhw) - 1
|
||||||
|
for idx, fhw in enumerate(video_fhw):
|
||||||
|
frame, height, width = fhw
|
||||||
|
if idx != layer_num:
|
||||||
|
video_freq = self._compute_video_freqs(frame, height, width, idx)
|
||||||
|
else:
|
||||||
|
### For the condition image, we set the layer index to -1
|
||||||
|
video_freq = self._compute_condition_freqs(frame, height, width)
|
||||||
|
video_freq = video_freq.to(device)
|
||||||
|
vid_freqs.append(video_freq)
|
||||||
|
|
||||||
|
if self.scale_rope:
|
||||||
|
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
||||||
|
else:
|
||||||
|
max_vid_index = max(height, width, max_vid_index)
|
||||||
|
|
||||||
|
max_vid_index = max(max_vid_index, layer_num)
|
||||||
|
max_len = max(txt_seq_lens)
|
||||||
|
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||||
|
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||||
|
|
||||||
|
return vid_freqs, txt_freqs
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=None)
|
||||||
|
def _compute_video_freqs(self, frame, height, width, idx=0):
|
||||||
|
seq_lens = frame * height * width
|
||||||
|
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||||
|
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||||
|
|
||||||
|
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||||
|
if self.scale_rope:
|
||||||
|
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
|
||||||
|
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||||
|
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||||
|
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||||
|
else:
|
||||||
|
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||||
|
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||||
|
|
||||||
|
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||||
|
return freqs.clone().contiguous()
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=None)
|
||||||
|
def _compute_condition_freqs(self, frame, height, width):
|
||||||
|
seq_lens = frame * height * width
|
||||||
|
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||||
|
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||||
|
|
||||||
|
freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||||
|
if self.scale_rope:
|
||||||
|
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
|
||||||
|
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||||
|
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||||
|
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||||
|
else:
|
||||||
|
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||||
|
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||||
|
|
||||||
|
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||||
|
return freqs.clone().contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
class QwenFeedForward(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
dim_out: Optional[int] = None,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = int(dim * 4)
|
||||||
|
self.net = nn.ModuleList([])
|
||||||
|
self.net.append(ApproximateGELU(dim, inner_dim))
|
||||||
|
self.net.append(nn.Dropout(dropout))
|
||||||
|
self.net.append(nn.Linear(inner_dim, dim_out))
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||||
|
for module in self.net:
|
||||||
|
hidden_states = module(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
class QwenDoubleStreamAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim_a,
|
||||||
|
dim_b,
|
||||||
|
num_heads,
|
||||||
|
head_dim,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(dim_a, dim_a)
|
||||||
|
self.to_k = nn.Linear(dim_a, dim_a)
|
||||||
|
self.to_v = nn.Linear(dim_a, dim_a)
|
||||||
|
self.norm_q = RMSNorm(head_dim, eps=1e-6)
|
||||||
|
self.norm_k = RMSNorm(head_dim, eps=1e-6)
|
||||||
|
|
||||||
|
self.add_q_proj = nn.Linear(dim_b, dim_b)
|
||||||
|
self.add_k_proj = nn.Linear(dim_b, dim_b)
|
||||||
|
self.add_v_proj = nn.Linear(dim_b, dim_b)
|
||||||
|
self.norm_added_q = RMSNorm(head_dim, eps=1e-6)
|
||||||
|
self.norm_added_k = RMSNorm(head_dim, eps=1e-6)
|
||||||
|
|
||||||
|
self.to_out = torch.nn.Sequential(nn.Linear(dim_a, dim_a))
|
||||||
|
self.to_add_out = nn.Linear(dim_b, dim_b)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
image: torch.FloatTensor,
|
||||||
|
text: torch.FloatTensor,
|
||||||
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
enable_fp8_attention: bool = False,
|
||||||
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||||
|
img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
|
||||||
|
txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
|
||||||
|
seq_txt = txt_q.shape[1]
|
||||||
|
|
||||||
|
img_q = rearrange(img_q, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||||
|
img_k = rearrange(img_k, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||||
|
img_v = rearrange(img_v, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||||
|
|
||||||
|
txt_q = rearrange(txt_q, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||||
|
txt_k = rearrange(txt_k, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||||
|
txt_v = rearrange(txt_v, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||||
|
|
||||||
|
img_q, img_k = self.norm_q(img_q), self.norm_k(img_k)
|
||||||
|
txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k)
|
||||||
|
|
||||||
|
if image_rotary_emb is not None:
|
||||||
|
img_freqs, txt_freqs = image_rotary_emb
|
||||||
|
img_q = apply_rotary_emb_qwen(img_q, img_freqs)
|
||||||
|
img_k = apply_rotary_emb_qwen(img_k, img_freqs)
|
||||||
|
txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs)
|
||||||
|
txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs)
|
||||||
|
|
||||||
|
joint_q = torch.cat([txt_q, img_q], dim=2)
|
||||||
|
joint_k = torch.cat([txt_k, img_k], dim=2)
|
||||||
|
joint_v = torch.cat([txt_v, img_v], dim=2)
|
||||||
|
|
||||||
|
joint_attn_out = qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], attention_mask=attention_mask, enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype)
|
||||||
|
|
||||||
|
txt_attn_output = joint_attn_out[:, :seq_txt, :]
|
||||||
|
img_attn_output = joint_attn_out[:, seq_txt:, :]
|
||||||
|
|
||||||
|
img_attn_output = self.to_out(img_attn_output)
|
||||||
|
txt_attn_output = self.to_add_out(txt_attn_output)
|
||||||
|
|
||||||
|
return img_attn_output, txt_attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageTransformerBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
attention_head_dim: int,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.attention_head_dim = attention_head_dim
|
||||||
|
|
||||||
|
self.img_mod = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(dim, 6 * dim),
|
||||||
|
)
|
||||||
|
self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||||
|
self.attn = QwenDoubleStreamAttention(
|
||||||
|
dim_a=dim,
|
||||||
|
dim_b=dim,
|
||||||
|
num_heads=num_attention_heads,
|
||||||
|
head_dim=attention_head_dim,
|
||||||
|
)
|
||||||
|
self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||||
|
self.img_mlp = QwenFeedForward(dim=dim, dim_out=dim)
|
||||||
|
|
||||||
|
self.txt_mod = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(dim, 6 * dim, bias=True),
|
||||||
|
)
|
||||||
|
self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||||
|
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||||
|
self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim)
|
||||||
|
|
||||||
|
def _modulate(self, x, mod_params, index=None):
|
||||||
|
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
||||||
|
if index is not None:
|
||||||
|
# Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts)
|
||||||
|
# So shift, scale, gate have shape [2*actual_batch, d]
|
||||||
|
actual_batch = shift.size(0) // 2
|
||||||
|
shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d]
|
||||||
|
scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]
|
||||||
|
gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]
|
||||||
|
|
||||||
|
# index: [b, l] where b is actual batch size
|
||||||
|
# Expand to [b, l, 1] to match feature dimension
|
||||||
|
index_expanded = index.unsqueeze(-1) # [b, l, 1]
|
||||||
|
|
||||||
|
# Expand chunks to [b, 1, d] then broadcast to [b, l, d]
|
||||||
|
shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d]
|
||||||
|
shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d]
|
||||||
|
scale_0_exp = scale_0.unsqueeze(1)
|
||||||
|
scale_1_exp = scale_1.unsqueeze(1)
|
||||||
|
gate_0_exp = gate_0.unsqueeze(1)
|
||||||
|
gate_1_exp = gate_1.unsqueeze(1)
|
||||||
|
|
||||||
|
# Use torch.where to select based on index
|
||||||
|
shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp)
|
||||||
|
scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp)
|
||||||
|
gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp)
|
||||||
|
else:
|
||||||
|
shift_result = shift.unsqueeze(1)
|
||||||
|
scale_result = scale.unsqueeze(1)
|
||||||
|
gate_result = gate.unsqueeze(1)
|
||||||
|
|
||||||
|
return x * (1 + scale_result) + shift_result, gate_result
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
image: torch.Tensor,
|
||||||
|
text: torch.Tensor,
|
||||||
|
temb: torch.Tensor,
|
||||||
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
enable_fp8_attention = False,
|
||||||
|
modulate_index: Optional[List[int]] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
||||||
|
if modulate_index is not None:
|
||||||
|
temb = torch.chunk(temb, 2, dim=0)[0]
|
||||||
|
txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
||||||
|
|
||||||
|
img_normed = self.img_norm1(image)
|
||||||
|
img_modulated, img_gate = self._modulate(img_normed, img_mod_attn, index=modulate_index)
|
||||||
|
|
||||||
|
txt_normed = self.txt_norm1(text)
|
||||||
|
txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn)
|
||||||
|
|
||||||
|
img_attn_out, txt_attn_out = self.attn(
|
||||||
|
image=img_modulated,
|
||||||
|
text=txt_modulated,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
enable_fp8_attention=enable_fp8_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = image + img_gate * img_attn_out
|
||||||
|
text = text + txt_gate * txt_attn_out
|
||||||
|
|
||||||
|
img_normed_2 = self.img_norm2(image)
|
||||||
|
img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp, index=modulate_index)
|
||||||
|
|
||||||
|
txt_normed_2 = self.txt_norm2(text)
|
||||||
|
txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp)
|
||||||
|
|
||||||
|
img_mlp_out = self.img_mlp(img_modulated_2)
|
||||||
|
txt_mlp_out = self.txt_mlp(txt_modulated_2)
|
||||||
|
|
||||||
|
image = image + img_gate_2 * img_mlp_out
|
||||||
|
text = text + txt_gate_2 * txt_mlp_out
|
||||||
|
|
||||||
|
return text, image
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageDiT(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_layers: int = 60,
|
||||||
|
use_layer3d_rope: bool = False,
|
||||||
|
use_additional_t_cond: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if not use_layer3d_rope:
|
||||||
|
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
|
||||||
|
else:
|
||||||
|
self.pos_embed = QwenEmbedLayer3DRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
|
||||||
|
|
||||||
|
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=False, use_additional_t_cond=use_additional_t_cond)
|
||||||
|
self.txt_norm = RMSNorm(3584, eps=1e-6)
|
||||||
|
|
||||||
|
self.img_in = nn.Linear(64, 3072)
|
||||||
|
self.txt_in = nn.Linear(3584, 3072)
|
||||||
|
|
||||||
|
self.transformer_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
QwenImageTransformerBlock(
|
||||||
|
dim=3072,
|
||||||
|
num_attention_heads=24,
|
||||||
|
attention_head_dim=128,
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm_out = AdaLayerNorm(3072, single=True)
|
||||||
|
self.proj_out = nn.Linear(3072, 64)
|
||||||
|
|
||||||
|
|
||||||
|
def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, entity_masks, height, width, image, img_shapes):
|
||||||
|
# prompt_emb
|
||||||
|
all_prompt_emb = entity_prompt_emb + [prompt_emb]
|
||||||
|
all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb]
|
||||||
|
all_prompt_emb = torch.cat(all_prompt_emb, dim=1)
|
||||||
|
|
||||||
|
# image_rotary_emb
|
||||||
|
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
||||||
|
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||||
|
entity_seq_lens = [emb_mask.sum(dim=1).tolist() for emb_mask in entity_prompt_emb_mask]
|
||||||
|
entity_rotary_emb = [self.pos_embed(img_shapes, entity_seq_len, device=latents.device)[1] for entity_seq_len in entity_seq_lens]
|
||||||
|
txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)
|
||||||
|
image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)
|
||||||
|
|
||||||
|
# attention_mask
|
||||||
|
repeat_dim = latents.shape[1]
|
||||||
|
max_masks = entity_masks.shape[1]
|
||||||
|
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
||||||
|
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||||
|
global_mask = torch.ones_like(entity_masks[0]).to(device=latents.device, dtype=latents.dtype)
|
||||||
|
entity_masks = entity_masks + [global_mask]
|
||||||
|
|
||||||
|
N = len(entity_masks)
|
||||||
|
batch_size = entity_masks[0].shape[0]
|
||||||
|
seq_lens = [mask_.sum(dim=1).item() for mask_ in entity_prompt_emb_mask] + [prompt_emb_mask.sum(dim=1).item()]
|
||||||
|
total_seq_len = sum(seq_lens) + image.shape[1]
|
||||||
|
patched_masks = []
|
||||||
|
for i in range(N):
|
||||||
|
patched_mask = rearrange(entity_masks[i], "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||||
|
patched_masks.append(patched_mask)
|
||||||
|
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
|
||||||
|
|
||||||
|
# prompt-image attention mask
|
||||||
|
image_start = sum(seq_lens)
|
||||||
|
image_end = total_seq_len
|
||||||
|
cumsum = [0]
|
||||||
|
single_image_seq = image_end - image_start
|
||||||
|
for length in seq_lens:
|
||||||
|
cumsum.append(cumsum[-1] + length)
|
||||||
|
for i in range(N):
|
||||||
|
prompt_start = cumsum[i]
|
||||||
|
prompt_end = cumsum[i+1]
|
||||||
|
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
|
||||||
|
image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1)
|
||||||
|
# repeat image mask to match the single image sequence length
|
||||||
|
repeat_time = single_image_seq // image_mask.shape[-1]
|
||||||
|
image_mask = image_mask.repeat(1, 1, repeat_time)
|
||||||
|
# prompt update with image
|
||||||
|
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
|
||||||
|
# image update with prompt
|
||||||
|
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
|
||||||
|
# prompt-prompt attention mask, let the prompt tokens not attend to each other
|
||||||
|
for i in range(N):
|
||||||
|
for j in range(N):
|
||||||
|
if i == j:
|
||||||
|
continue
|
||||||
|
start_i, end_i = cumsum[i], cumsum[i+1]
|
||||||
|
start_j, end_j = cumsum[j], cumsum[j+1]
|
||||||
|
attention_mask[:, start_i:end_i, start_j:end_j] = False
|
||||||
|
|
||||||
|
attention_mask = attention_mask.float()
|
||||||
|
attention_mask[attention_mask == 0] = float('-inf')
|
||||||
|
attention_mask[attention_mask == 1] = 0
|
||||||
|
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1)
|
||||||
|
|
||||||
|
return all_prompt_emb, image_rotary_emb, attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
latents=None,
|
||||||
|
timestep=None,
|
||||||
|
prompt_emb=None,
|
||||||
|
prompt_emb_mask=None,
|
||||||
|
height=None,
|
||||||
|
width=None,
|
||||||
|
):
|
||||||
|
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
|
||||||
|
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
||||||
|
|
||||||
|
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||||
|
image = self.img_in(image)
|
||||||
|
text = self.txt_in(self.txt_norm(prompt_emb))
|
||||||
|
|
||||||
|
conditioning = self.time_text_embed(timestep, image.dtype)
|
||||||
|
|
||||||
|
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||||
|
|
||||||
|
for block in self.transformer_blocks:
|
||||||
|
text, image = block(
|
||||||
|
image=image,
|
||||||
|
text=text,
|
||||||
|
temb=conditioning,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = self.norm_out(image, conditioning)
|
||||||
|
image = self.proj_out(image)
|
||||||
|
|
||||||
|
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||||
|
return image
|
||||||
128
diffsynth/models/qwen_image_image2lora.py
Normal file
128
diffsynth/models/qwen_image_image2lora.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class CompressedMLP(torch.nn.Module):
|
||||||
|
def __init__(self, in_dim, mid_dim, out_dim, bias=False):
|
||||||
|
super().__init__()
|
||||||
|
self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias)
|
||||||
|
self.proj_out = torch.nn.Linear(mid_dim, out_dim, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x, residual=None):
|
||||||
|
x = self.proj_in(x)
|
||||||
|
if residual is not None: x = x + residual
|
||||||
|
x = self.proj_out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ImageEmbeddingToLoraMatrix(torch.nn.Module):
|
||||||
|
def __init__(self, in_dim, compress_dim, lora_a_dim, lora_b_dim, rank):
|
||||||
|
super().__init__()
|
||||||
|
self.proj_a = CompressedMLP(in_dim, compress_dim, lora_a_dim * rank)
|
||||||
|
self.proj_b = CompressedMLP(in_dim, compress_dim, lora_b_dim * rank)
|
||||||
|
self.lora_a_dim = lora_a_dim
|
||||||
|
self.lora_b_dim = lora_b_dim
|
||||||
|
self.rank = rank
|
||||||
|
|
||||||
|
def forward(self, x, residual=None):
|
||||||
|
lora_a = self.proj_a(x, residual).view(self.rank, self.lora_a_dim)
|
||||||
|
lora_b = self.proj_b(x, residual).view(self.lora_b_dim, self.rank)
|
||||||
|
return lora_a, lora_b
|
||||||
|
|
||||||
|
|
||||||
|
class SequencialMLP(torch.nn.Module):
|
||||||
|
def __init__(self, length, in_dim, mid_dim, out_dim, bias=False):
|
||||||
|
super().__init__()
|
||||||
|
self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias)
|
||||||
|
self.proj_out = torch.nn.Linear(length * mid_dim, out_dim, bias=bias)
|
||||||
|
self.length = length
|
||||||
|
self.in_dim = in_dim
|
||||||
|
self.mid_dim = mid_dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.view(self.length, self.in_dim)
|
||||||
|
x = self.proj_in(x)
|
||||||
|
x = x.view(1, self.length * self.mid_dim)
|
||||||
|
x = self.proj_out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LoRATrainerBlock(torch.nn.Module):
|
||||||
|
def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024):
|
||||||
|
super().__init__()
|
||||||
|
self.lora_patterns = lora_patterns
|
||||||
|
self.block_id = block_id
|
||||||
|
self.layers = []
|
||||||
|
for name, lora_a_dim, lora_b_dim in self.lora_patterns:
|
||||||
|
self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank))
|
||||||
|
self.layers = torch.nn.ModuleList(self.layers)
|
||||||
|
if use_residual:
|
||||||
|
self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim)
|
||||||
|
else:
|
||||||
|
self.proj_residual = None
|
||||||
|
|
||||||
|
def forward(self, x, residual=None):
|
||||||
|
lora = {}
|
||||||
|
if self.proj_residual is not None: residual = self.proj_residual(residual)
|
||||||
|
for lora_pattern, layer in zip(self.lora_patterns, self.layers):
|
||||||
|
name = lora_pattern[0]
|
||||||
|
lora_a, lora_b = layer(x, residual=residual)
|
||||||
|
lora[f"transformer_blocks.{self.block_id}.{name}.lora_A.default.weight"] = lora_a
|
||||||
|
lora[f"transformer_blocks.{self.block_id}.{name}.lora_B.default.weight"] = lora_b
|
||||||
|
return lora
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageImage2LoRAModel(torch.nn.Module):
|
||||||
|
def __init__(self, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024):
|
||||||
|
super().__init__()
|
||||||
|
self.lora_patterns = [
|
||||||
|
[
|
||||||
|
("attn.to_q", 3072, 3072),
|
||||||
|
("attn.to_k", 3072, 3072),
|
||||||
|
("attn.to_v", 3072, 3072),
|
||||||
|
("attn.to_out.0", 3072, 3072),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
("img_mlp.net.2", 3072*4, 3072),
|
||||||
|
("img_mod.1", 3072, 3072*6),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
("attn.add_q_proj", 3072, 3072),
|
||||||
|
("attn.add_k_proj", 3072, 3072),
|
||||||
|
("attn.add_v_proj", 3072, 3072),
|
||||||
|
("attn.to_add_out", 3072, 3072),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
("txt_mlp.net.2", 3072*4, 3072),
|
||||||
|
("txt_mod.1", 3072, 3072*6),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
self.num_blocks = num_blocks
|
||||||
|
self.blocks = []
|
||||||
|
for lora_patterns in self.lora_patterns:
|
||||||
|
for block_id in range(self.num_blocks):
|
||||||
|
self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim))
|
||||||
|
self.blocks = torch.nn.ModuleList(self.blocks)
|
||||||
|
self.residual_scale = 0.05
|
||||||
|
self.use_residual = use_residual
|
||||||
|
|
||||||
|
def forward(self, x, residual=None):
|
||||||
|
if residual is not None:
|
||||||
|
if self.use_residual:
|
||||||
|
residual = residual * self.residual_scale
|
||||||
|
else:
|
||||||
|
residual = None
|
||||||
|
lora = {}
|
||||||
|
for block in self.blocks:
|
||||||
|
lora.update(block(x, residual))
|
||||||
|
return lora
|
||||||
|
|
||||||
|
def initialize_weights(self):
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
for name in state_dict:
|
||||||
|
if ".proj_a." in name:
|
||||||
|
state_dict[name] = state_dict[name] * 0.3
|
||||||
|
elif ".proj_b.proj_out." in name:
|
||||||
|
state_dict[name] = state_dict[name] * 0
|
||||||
|
elif ".proj_residual.proj_out." in name:
|
||||||
|
state_dict[name] = state_dict[name] * 0.3
|
||||||
|
self.load_state_dict(state_dict)
|
||||||
190
diffsynth/models/qwen_image_text_encoder.py
Normal file
190
diffsynth/models/qwen_image_text_encoder.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageTextEncoder(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLModel
|
||||||
|
config = Qwen2_5_VLConfig(**{
|
||||||
|
"architectures": [
|
||||||
|
"Qwen2_5_VLForConditionalGeneration"
|
||||||
|
],
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"bos_token_id": 151643,
|
||||||
|
"eos_token_id": 151645,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 3584,
|
||||||
|
"image_token_id": 151655,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 18944,
|
||||||
|
"max_position_embeddings": 128000,
|
||||||
|
"max_window_layers": 28,
|
||||||
|
"model_type": "qwen2_5_vl",
|
||||||
|
"num_attention_heads": 28,
|
||||||
|
"num_hidden_layers": 28,
|
||||||
|
"num_key_value_heads": 4,
|
||||||
|
"rms_norm_eps": 1e-06,
|
||||||
|
"rope_scaling": {
|
||||||
|
"mrope_section": [
|
||||||
|
16,
|
||||||
|
24,
|
||||||
|
24
|
||||||
|
],
|
||||||
|
"rope_type": "default",
|
||||||
|
"type": "default"
|
||||||
|
},
|
||||||
|
"rope_theta": 1000000.0,
|
||||||
|
"sliding_window": 32768,
|
||||||
|
"text_config": {
|
||||||
|
"architectures": [
|
||||||
|
"Qwen2_5_VLForConditionalGeneration"
|
||||||
|
],
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"bos_token_id": 151643,
|
||||||
|
"eos_token_id": 151645,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 3584,
|
||||||
|
"image_token_id": None,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 18944,
|
||||||
|
"layer_types": [
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention"
|
||||||
|
],
|
||||||
|
"max_position_embeddings": 128000,
|
||||||
|
"max_window_layers": 28,
|
||||||
|
"model_type": "qwen2_5_vl_text",
|
||||||
|
"num_attention_heads": 28,
|
||||||
|
"num_hidden_layers": 28,
|
||||||
|
"num_key_value_heads": 4,
|
||||||
|
"rms_norm_eps": 1e-06,
|
||||||
|
"rope_scaling": {
|
||||||
|
"mrope_section": [
|
||||||
|
16,
|
||||||
|
24,
|
||||||
|
24
|
||||||
|
],
|
||||||
|
"rope_type": "default",
|
||||||
|
"type": "default"
|
||||||
|
},
|
||||||
|
"rope_theta": 1000000.0,
|
||||||
|
"sliding_window": None,
|
||||||
|
"torch_dtype": "float32",
|
||||||
|
"use_cache": True,
|
||||||
|
"use_sliding_window": False,
|
||||||
|
"video_token_id": None,
|
||||||
|
"vision_end_token_id": 151653,
|
||||||
|
"vision_start_token_id": 151652,
|
||||||
|
"vision_token_id": 151654,
|
||||||
|
"vocab_size": 152064
|
||||||
|
},
|
||||||
|
"tie_word_embeddings": False,
|
||||||
|
"torch_dtype": "float32",
|
||||||
|
"transformers_version": "4.54.0",
|
||||||
|
"use_cache": True,
|
||||||
|
"use_sliding_window": False,
|
||||||
|
"video_token_id": 151656,
|
||||||
|
"vision_config": {
|
||||||
|
"depth": 32,
|
||||||
|
"fullatt_block_indexes": [
|
||||||
|
7,
|
||||||
|
15,
|
||||||
|
23,
|
||||||
|
31
|
||||||
|
],
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 1280,
|
||||||
|
"in_channels": 3,
|
||||||
|
"in_chans": 3,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 3420,
|
||||||
|
"model_type": "qwen2_5_vl",
|
||||||
|
"num_heads": 16,
|
||||||
|
"out_hidden_size": 3584,
|
||||||
|
"patch_size": 14,
|
||||||
|
"spatial_merge_size": 2,
|
||||||
|
"spatial_patch_size": 14,
|
||||||
|
"temporal_patch_size": 2,
|
||||||
|
"tokens_per_second": 2,
|
||||||
|
"torch_dtype": "float32",
|
||||||
|
"window_size": 112
|
||||||
|
},
|
||||||
|
"vision_end_token_id": 151653,
|
||||||
|
"vision_start_token_id": 151652,
|
||||||
|
"vision_token_id": 151654,
|
||||||
|
"vocab_size": 152064
|
||||||
|
})
|
||||||
|
self.model = Qwen2_5_VLModel(config)
|
||||||
|
self.lm_head = torch.nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
pixel_values: Optional[torch.Tensor] = None,
|
||||||
|
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
rope_deltas: Optional[torch.LongTensor] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
output_attentions = False
|
||||||
|
output_hidden_states = True
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
pixel_values_videos=pixel_values_videos,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
video_grid_thw=video_grid_thw,
|
||||||
|
second_per_grid_ts=second_per_grid_ts,
|
||||||
|
position_ids=position_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=True,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return outputs.hidden_states
|
||||||
726
diffsynth/models/qwen_image_vae.py
Normal file
726
diffsynth/models/qwen_image_vae.py
Normal file
@@ -0,0 +1,726 @@
|
|||||||
|
import torch
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
CACHE_T = 2
|
||||||
|
|
||||||
|
class QwenImageCausalConv3d(torch.nn.Conv3d):
|
||||||
|
r"""
|
||||||
|
A custom 3D causal convolution layer with feature caching support.
|
||||||
|
|
||||||
|
This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
|
||||||
|
caching for efficient inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of channels in the input image
|
||||||
|
out_channels (int): Number of channels produced by the convolution
|
||||||
|
kernel_size (int or tuple): Size of the convolving kernel
|
||||||
|
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
||||||
|
padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Union[int, Tuple[int, int, int]],
|
||||||
|
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||||
|
padding: Union[int, Tuple[int, int, int]] = 0,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set up causal padding
|
||||||
|
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
|
||||||
|
self.padding = (0, 0, 0)
|
||||||
|
|
||||||
|
def forward(self, x, cache_x=None):
|
||||||
|
padding = list(self._padding)
|
||||||
|
if cache_x is not None and self._padding[4] > 0:
|
||||||
|
cache_x = cache_x.to(x.device)
|
||||||
|
x = torch.cat([cache_x, x], dim=2)
|
||||||
|
padding[4] -= cache_x.shape[2]
|
||||||
|
x = torch.nn.functional.pad(x, padding)
|
||||||
|
return super().forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageRMS_norm(nn.Module):
|
||||||
|
r"""
|
||||||
|
A custom RMS normalization layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): The number of dimensions to normalize over.
|
||||||
|
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
|
||||||
|
Default is True.
|
||||||
|
images (bool, optional): Whether the input represents image data. Default is True.
|
||||||
|
bias (bool, optional): Whether to include a learnable bias term. Default is False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
|
||||||
|
super().__init__()
|
||||||
|
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||||||
|
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||||||
|
|
||||||
|
self.channel_first = channel_first
|
||||||
|
self.scale = dim**0.5
|
||||||
|
self.gamma = nn.Parameter(torch.ones(shape))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.nn.functional.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageResidualBlock(nn.Module):
|
||||||
|
r"""
|
||||||
|
A custom residual block module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_dim (int): Number of input channels.
|
||||||
|
out_dim (int): Number of output channels.
|
||||||
|
dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
|
||||||
|
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_dim: int,
|
||||||
|
out_dim: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
non_linearity: str = "silu",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.in_dim = in_dim
|
||||||
|
self.out_dim = out_dim
|
||||||
|
self.nonlinearity = torch.nn.SiLU()
|
||||||
|
|
||||||
|
# layers
|
||||||
|
self.norm1 = QwenImageRMS_norm(in_dim, images=False)
|
||||||
|
self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1)
|
||||||
|
self.norm2 = QwenImageRMS_norm(out_dim, images=False)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1)
|
||||||
|
self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
|
# Apply shortcut connection
|
||||||
|
h = self.conv_shortcut(x)
|
||||||
|
|
||||||
|
# First normalization and activation
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.nonlinearity(x)
|
||||||
|
|
||||||
|
if feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||||
|
|
||||||
|
x = self.conv1(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = self.conv1(x)
|
||||||
|
|
||||||
|
# Second normalization and activation
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = self.nonlinearity(x)
|
||||||
|
|
||||||
|
# Dropout
|
||||||
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
if feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||||
|
|
||||||
|
x = self.conv2(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = self.conv2(x)
|
||||||
|
|
||||||
|
# Add residual connection
|
||||||
|
return x + h
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageAttentionBlock(nn.Module):
|
||||||
|
r"""
|
||||||
|
Causal self-attention with a single head.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): The number of channels in the input tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
# layers
|
||||||
|
self.norm = QwenImageRMS_norm(dim)
|
||||||
|
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||||
|
self.proj = nn.Conv2d(dim, dim, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
batch_size, channels, time, height, width = x.size()
|
||||||
|
|
||||||
|
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
# compute query, key, value
|
||||||
|
qkv = self.to_qkv(x)
|
||||||
|
qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
|
||||||
|
qkv = qkv.permute(0, 1, 3, 2).contiguous()
|
||||||
|
q, k, v = qkv.chunk(3, dim=-1)
|
||||||
|
|
||||||
|
# apply attention
|
||||||
|
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||||
|
|
||||||
|
x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
|
||||||
|
|
||||||
|
# output projection
|
||||||
|
x = self.proj(x)
|
||||||
|
|
||||||
|
# Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
|
||||||
|
x = x.view(batch_size, time, channels, height, width)
|
||||||
|
x = x.permute(0, 2, 1, 3, 4)
|
||||||
|
|
||||||
|
return x + identity
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageUpsample(nn.Upsample):
|
||||||
|
r"""
|
||||||
|
Perform upsampling while ensuring the output tensor has the same data type as the input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor to be upsampled.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Upsampled tensor with the same data type as the input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return super().forward(x.float()).type_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageResample(nn.Module):
|
||||||
|
r"""
|
||||||
|
A custom resampling module for 2D and 3D data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): The number of input/output channels.
|
||||||
|
mode (str): The resampling mode. Must be one of:
|
||||||
|
- 'none': No resampling (identity operation).
|
||||||
|
- 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
|
||||||
|
- 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
|
||||||
|
- 'downsample2d': 2D downsampling with zero-padding and convolution.
|
||||||
|
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, mode: str) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
# layers
|
||||||
|
if mode == "upsample2d":
|
||||||
|
self.resample = nn.Sequential(
|
||||||
|
QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||||
|
)
|
||||||
|
elif mode == "upsample3d":
|
||||||
|
self.resample = nn.Sequential(
|
||||||
|
QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||||
|
)
|
||||||
|
self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||||
|
|
||||||
|
elif mode == "downsample2d":
|
||||||
|
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||||
|
elif mode == "downsample3d":
|
||||||
|
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||||
|
self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.resample = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
|
b, c, t, h, w = x.size()
|
||||||
|
if self.mode == "upsample3d":
|
||||||
|
if feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
if feat_cache[idx] is None:
|
||||||
|
feat_cache[idx] = "Rep"
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
|
||||||
|
# cache last frame of last two chunk
|
||||||
|
cache_x = torch.cat(
|
||||||
|
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
|
||||||
|
)
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
|
||||||
|
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
|
||||||
|
if feat_cache[idx] == "Rep":
|
||||||
|
x = self.time_conv(x)
|
||||||
|
else:
|
||||||
|
x = self.time_conv(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
|
||||||
|
x = x.reshape(b, 2, c, t, h, w)
|
||||||
|
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
|
||||||
|
x = x.reshape(b, c, t * 2, h, w)
|
||||||
|
t = x.shape[2]
|
||||||
|
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||||
|
x = self.resample(x)
|
||||||
|
x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
|
||||||
|
|
||||||
|
if self.mode == "downsample3d":
|
||||||
|
if feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
if feat_cache[idx] is None:
|
||||||
|
feat_cache[idx] = x.clone()
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
cache_x = x[:, :, -1:, :, :].clone()
|
||||||
|
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageMidBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Middle block for WanVAE encoder and decoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): Number of input/output channels.
|
||||||
|
dropout (float): Dropout rate.
|
||||||
|
non_linearity (str): Type of non-linearity to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
# Create the components
|
||||||
|
resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)]
|
||||||
|
attentions = []
|
||||||
|
for _ in range(num_layers):
|
||||||
|
attentions.append(QwenImageAttentionBlock(dim))
|
||||||
|
resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity))
|
||||||
|
self.attentions = nn.ModuleList(attentions)
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
|
# First residual block
|
||||||
|
x = self.resnets[0](x, feat_cache, feat_idx)
|
||||||
|
|
||||||
|
# Process through attention and residual blocks
|
||||||
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||||
|
if attn is not None:
|
||||||
|
x = attn(x)
|
||||||
|
|
||||||
|
x = resnet(x, feat_cache, feat_idx)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageEncoder3d(nn.Module):
|
||||||
|
r"""
|
||||||
|
A 3D encoder module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): The base number of channels in the first layer.
|
||||||
|
z_dim (int): The dimensionality of the latent space.
|
||||||
|
dim_mult (list of int): Multipliers for the number of channels in each block.
|
||||||
|
num_res_blocks (int): Number of residual blocks in each block.
|
||||||
|
attn_scales (list of float): Scales at which to apply attention mechanisms.
|
||||||
|
temperal_downsample (list of bool): Whether to downsample temporally in each block.
|
||||||
|
dropout (float): Dropout rate for the dropout layers.
|
||||||
|
non_linearity (str): Type of non-linearity to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim=128,
|
||||||
|
z_dim=4,
|
||||||
|
dim_mult=[1, 2, 4, 4],
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_scales=[],
|
||||||
|
temperal_downsample=[True, True, False],
|
||||||
|
dropout=0.0,
|
||||||
|
non_linearity: str = "silu",
|
||||||
|
image_channels=3
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.z_dim = z_dim
|
||||||
|
self.dim_mult = dim_mult
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.attn_scales = attn_scales
|
||||||
|
self.temperal_downsample = temperal_downsample
|
||||||
|
self.nonlinearity = torch.nn.SiLU()
|
||||||
|
|
||||||
|
# dimensions
|
||||||
|
dims = [dim * u for u in [1] + dim_mult]
|
||||||
|
scale = 1.0
|
||||||
|
|
||||||
|
# init block
|
||||||
|
self.conv_in = QwenImageCausalConv3d(image_channels, dims[0], 3, padding=1)
|
||||||
|
|
||||||
|
# downsample blocks
|
||||||
|
self.down_blocks = torch.nn.ModuleList([])
|
||||||
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||||
|
# residual (+attention) blocks
|
||||||
|
for _ in range(num_res_blocks):
|
||||||
|
self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout))
|
||||||
|
if scale in attn_scales:
|
||||||
|
self.down_blocks.append(QwenImageAttentionBlock(out_dim))
|
||||||
|
in_dim = out_dim
|
||||||
|
|
||||||
|
# downsample block
|
||||||
|
if i != len(dim_mult) - 1:
|
||||||
|
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
||||||
|
self.down_blocks.append(QwenImageResample(out_dim, mode=mode))
|
||||||
|
scale /= 2.0
|
||||||
|
|
||||||
|
# middle blocks
|
||||||
|
self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1)
|
||||||
|
|
||||||
|
# output blocks
|
||||||
|
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
|
||||||
|
self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
|
if feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
# cache last frame of last two chunk
|
||||||
|
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||||
|
x = self.conv_in(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = self.conv_in(x)
|
||||||
|
|
||||||
|
## downsamples
|
||||||
|
for layer in self.down_blocks:
|
||||||
|
if feat_cache is not None:
|
||||||
|
x = layer(x, feat_cache, feat_idx)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
## middle
|
||||||
|
x = self.mid_block(x, feat_cache, feat_idx)
|
||||||
|
|
||||||
|
## head
|
||||||
|
x = self.norm_out(x)
|
||||||
|
x = self.nonlinearity(x)
|
||||||
|
if feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
# cache last frame of last two chunk
|
||||||
|
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||||
|
x = self.conv_out(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = self.conv_out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageUpBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
A block that handles upsampling for the WanVAE decoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_dim (int): Input dimension
|
||||||
|
out_dim (int): Output dimension
|
||||||
|
num_res_blocks (int): Number of residual blocks
|
||||||
|
dropout (float): Dropout rate
|
||||||
|
upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
|
||||||
|
non_linearity (str): Type of non-linearity to use
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_dim: int,
|
||||||
|
out_dim: int,
|
||||||
|
num_res_blocks: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
upsample_mode: Optional[str] = None,
|
||||||
|
non_linearity: str = "silu",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_dim = in_dim
|
||||||
|
self.out_dim = out_dim
|
||||||
|
|
||||||
|
# Create layers list
|
||||||
|
resnets = []
|
||||||
|
# Add residual blocks and attention if needed
|
||||||
|
current_dim = in_dim
|
||||||
|
for _ in range(num_res_blocks + 1):
|
||||||
|
resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity))
|
||||||
|
current_dim = out_dim
|
||||||
|
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
# Add upsampling layer if needed
|
||||||
|
self.upsamplers = None
|
||||||
|
if upsample_mode is not None:
|
||||||
|
self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)])
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
|
"""
|
||||||
|
Forward pass through the upsampling block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor
|
||||||
|
feat_cache (list, optional): Feature cache for causal convolutions
|
||||||
|
feat_idx (list, optional): Feature index for cache management
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Output tensor
|
||||||
|
"""
|
||||||
|
for resnet in self.resnets:
|
||||||
|
if feat_cache is not None:
|
||||||
|
x = resnet(x, feat_cache, feat_idx)
|
||||||
|
else:
|
||||||
|
x = resnet(x)
|
||||||
|
|
||||||
|
if self.upsamplers is not None:
|
||||||
|
if feat_cache is not None:
|
||||||
|
x = self.upsamplers[0](x, feat_cache, feat_idx)
|
||||||
|
else:
|
||||||
|
x = self.upsamplers[0](x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageDecoder3d(nn.Module):
|
||||||
|
r"""
|
||||||
|
A 3D decoder module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): The base number of channels in the first layer.
|
||||||
|
z_dim (int): The dimensionality of the latent space.
|
||||||
|
dim_mult (list of int): Multipliers for the number of channels in each block.
|
||||||
|
num_res_blocks (int): Number of residual blocks in each block.
|
||||||
|
attn_scales (list of float): Scales at which to apply attention mechanisms.
|
||||||
|
temperal_upsample (list of bool): Whether to upsample temporally in each block.
|
||||||
|
dropout (float): Dropout rate for the dropout layers.
|
||||||
|
non_linearity (str): Type of non-linearity to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim=128,
|
||||||
|
z_dim=4,
|
||||||
|
dim_mult=[1, 2, 4, 4],
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_scales=[],
|
||||||
|
temperal_upsample=[False, True, True],
|
||||||
|
dropout=0.0,
|
||||||
|
non_linearity: str = "silu",
|
||||||
|
image_channels=3,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.z_dim = z_dim
|
||||||
|
self.dim_mult = dim_mult
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.attn_scales = attn_scales
|
||||||
|
self.temperal_upsample = temperal_upsample
|
||||||
|
|
||||||
|
self.nonlinearity = torch.nn.SiLU()
|
||||||
|
|
||||||
|
# dimensions
|
||||||
|
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||||
|
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
||||||
|
|
||||||
|
# init block
|
||||||
|
self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||||
|
|
||||||
|
# middle blocks
|
||||||
|
self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1)
|
||||||
|
|
||||||
|
# upsample blocks
|
||||||
|
self.up_blocks = nn.ModuleList([])
|
||||||
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||||
|
# residual (+attention) blocks
|
||||||
|
if i > 0:
|
||||||
|
in_dim = in_dim // 2
|
||||||
|
|
||||||
|
# Determine if we need upsampling
|
||||||
|
upsample_mode = None
|
||||||
|
if i != len(dim_mult) - 1:
|
||||||
|
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
|
||||||
|
|
||||||
|
# Create and add the upsampling block
|
||||||
|
up_block = QwenImageUpBlock(
|
||||||
|
in_dim=in_dim,
|
||||||
|
out_dim=out_dim,
|
||||||
|
num_res_blocks=num_res_blocks,
|
||||||
|
dropout=dropout,
|
||||||
|
upsample_mode=upsample_mode,
|
||||||
|
non_linearity=non_linearity,
|
||||||
|
)
|
||||||
|
self.up_blocks.append(up_block)
|
||||||
|
|
||||||
|
# Update scale for next iteration
|
||||||
|
if upsample_mode is not None:
|
||||||
|
scale *= 2.0
|
||||||
|
|
||||||
|
# output blocks
|
||||||
|
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
|
||||||
|
self.conv_out = QwenImageCausalConv3d(out_dim, image_channels, 3, padding=1)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
|
## conv1
|
||||||
|
if feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
# cache last frame of last two chunk
|
||||||
|
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||||
|
x = self.conv_in(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = self.conv_in(x)
|
||||||
|
|
||||||
|
## middle
|
||||||
|
x = self.mid_block(x, feat_cache, feat_idx)
|
||||||
|
|
||||||
|
## upsamples
|
||||||
|
for up_block in self.up_blocks:
|
||||||
|
x = up_block(x, feat_cache, feat_idx)
|
||||||
|
|
||||||
|
## head
|
||||||
|
x = self.norm_out(x)
|
||||||
|
x = self.nonlinearity(x)
|
||||||
|
if feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
# cache last frame of last two chunk
|
||||||
|
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||||
|
x = self.conv_out(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = self.conv_out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageVAE(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_dim: int = 96,
|
||||||
|
z_dim: int = 16,
|
||||||
|
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
||||||
|
num_res_blocks: int = 2,
|
||||||
|
attn_scales: List[float] = [],
|
||||||
|
temperal_downsample: List[bool] = [False, True, True],
|
||||||
|
dropout: float = 0.0,
|
||||||
|
image_channels: int = 3,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.z_dim = z_dim
|
||||||
|
self.temperal_downsample = temperal_downsample
|
||||||
|
self.temperal_upsample = temperal_downsample[::-1]
|
||||||
|
|
||||||
|
self.encoder = QwenImageEncoder3d(
|
||||||
|
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, image_channels=image_channels,
|
||||||
|
)
|
||||||
|
self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||||
|
self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
|
||||||
|
|
||||||
|
self.decoder = QwenImageDecoder3d(
|
||||||
|
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, image_channels=image_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
mean = [
|
||||||
|
-0.7571,
|
||||||
|
-0.7089,
|
||||||
|
-0.9113,
|
||||||
|
0.1075,
|
||||||
|
-0.1745,
|
||||||
|
0.9653,
|
||||||
|
-0.1517,
|
||||||
|
1.5508,
|
||||||
|
0.4134,
|
||||||
|
-0.0715,
|
||||||
|
0.5517,
|
||||||
|
-0.3632,
|
||||||
|
-0.1922,
|
||||||
|
-0.9497,
|
||||||
|
0.2503,
|
||||||
|
-0.2921,
|
||||||
|
]
|
||||||
|
std = [
|
||||||
|
2.8184,
|
||||||
|
1.4541,
|
||||||
|
2.3275,
|
||||||
|
2.6558,
|
||||||
|
1.2196,
|
||||||
|
1.7708,
|
||||||
|
2.6052,
|
||||||
|
2.0743,
|
||||||
|
3.2687,
|
||||||
|
2.1526,
|
||||||
|
2.8652,
|
||||||
|
1.5579,
|
||||||
|
1.6382,
|
||||||
|
1.1253,
|
||||||
|
2.8251,
|
||||||
|
1.9160,
|
||||||
|
]
|
||||||
|
self.mean = torch.tensor(mean).view(1, 16, 1, 1, 1)
|
||||||
|
self.std = 1 / torch.tensor(std).view(1, 16, 1, 1, 1)
|
||||||
|
|
||||||
|
def encode(self, x, **kwargs):
|
||||||
|
x = x.unsqueeze(2)
|
||||||
|
x = self.encoder(x)
|
||||||
|
x = self.quant_conv(x)
|
||||||
|
x = x[:, :16]
|
||||||
|
mean, std = self.mean.to(dtype=x.dtype, device=x.device), self.std.to(dtype=x.dtype, device=x.device)
|
||||||
|
x = (x - mean) * std
|
||||||
|
x = x.squeeze(2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x, **kwargs):
|
||||||
|
x = x.unsqueeze(2)
|
||||||
|
mean, std = self.mean.to(dtype=x.dtype, device=x.device), self.std.to(dtype=x.dtype, device=x.device)
|
||||||
|
x = x / std + mean
|
||||||
|
x = self.post_quant_conv(x)
|
||||||
|
x = self.decoder(x)
|
||||||
|
x = x.squeeze(2)
|
||||||
|
return x
|
||||||
@@ -1,798 +0,0 @@
|
|||||||
import torch
|
|
||||||
from einops import rearrange
|
|
||||||
from .svd_unet import TemporalTimesteps
|
|
||||||
from .tiler import TileWorker
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PatchEmbed(torch.nn.Module):
|
|
||||||
def __init__(self, patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192):
|
|
||||||
super().__init__()
|
|
||||||
self.pos_embed_max_size = pos_embed_max_size
|
|
||||||
self.patch_size = patch_size
|
|
||||||
|
|
||||||
self.proj = torch.nn.Conv2d(in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size)
|
|
||||||
self.pos_embed = torch.nn.Parameter(torch.zeros(1, self.pos_embed_max_size, self.pos_embed_max_size, 1536))
|
|
||||||
|
|
||||||
def cropped_pos_embed(self, height, width):
|
|
||||||
height = height // self.patch_size
|
|
||||||
width = width // self.patch_size
|
|
||||||
top = (self.pos_embed_max_size - height) // 2
|
|
||||||
left = (self.pos_embed_max_size - width) // 2
|
|
||||||
spatial_pos_embed = self.pos_embed[:, top : top + height, left : left + width, :].flatten(1, 2)
|
|
||||||
return spatial_pos_embed
|
|
||||||
|
|
||||||
def forward(self, latent):
|
|
||||||
height, width = latent.shape[-2:]
|
|
||||||
latent = self.proj(latent)
|
|
||||||
latent = latent.flatten(2).transpose(1, 2)
|
|
||||||
pos_embed = self.cropped_pos_embed(height, width)
|
|
||||||
return latent + pos_embed
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TimestepEmbeddings(torch.nn.Module):
|
|
||||||
def __init__(self, dim_in, dim_out):
|
|
||||||
super().__init__()
|
|
||||||
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
|
|
||||||
self.timestep_embedder = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, timestep, dtype):
|
|
||||||
time_emb = self.time_proj(timestep).to(dtype)
|
|
||||||
time_emb = self.timestep_embedder(time_emb)
|
|
||||||
return time_emb
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AdaLayerNorm(torch.nn.Module):
|
|
||||||
def __init__(self, dim, single=False):
|
|
||||||
super().__init__()
|
|
||||||
self.single = single
|
|
||||||
self.linear = torch.nn.Linear(dim, dim * (2 if single else 6))
|
|
||||||
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
||||||
|
|
||||||
def forward(self, x, emb):
|
|
||||||
emb = self.linear(torch.nn.functional.silu(emb))
|
|
||||||
if self.single:
|
|
||||||
scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
|
|
||||||
x = self.norm(x) * (1 + scale) + shift
|
|
||||||
return x
|
|
||||||
else:
|
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
|
|
||||||
x = self.norm(x) * (1 + scale_msa) + shift_msa
|
|
||||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class JointAttention(torch.nn.Module):
|
|
||||||
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = head_dim
|
|
||||||
self.only_out_a = only_out_a
|
|
||||||
|
|
||||||
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
|
||||||
self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
|
|
||||||
|
|
||||||
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
|
|
||||||
if not only_out_a:
|
|
||||||
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
|
|
||||||
|
|
||||||
def forward(self, hidden_states_a, hidden_states_b):
|
|
||||||
batch_size = hidden_states_a.shape[0]
|
|
||||||
|
|
||||||
qkv = torch.concat([self.a_to_qkv(hidden_states_a), self.b_to_qkv(hidden_states_b)], dim=1)
|
|
||||||
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
q, k, v = qkv.chunk(3, dim=1)
|
|
||||||
|
|
||||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
||||||
hidden_states = hidden_states.to(q.dtype)
|
|
||||||
hidden_states_a, hidden_states_b = hidden_states[:, :hidden_states_a.shape[1]], hidden_states[:, hidden_states_a.shape[1]:]
|
|
||||||
hidden_states_a = self.a_to_out(hidden_states_a)
|
|
||||||
if self.only_out_a:
|
|
||||||
return hidden_states_a
|
|
||||||
else:
|
|
||||||
hidden_states_b = self.b_to_out(hidden_states_b)
|
|
||||||
return hidden_states_a, hidden_states_b
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class JointTransformerBlock(torch.nn.Module):
|
|
||||||
def __init__(self, dim, num_attention_heads):
|
|
||||||
super().__init__()
|
|
||||||
self.norm1_a = AdaLayerNorm(dim)
|
|
||||||
self.norm1_b = AdaLayerNorm(dim)
|
|
||||||
|
|
||||||
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
|
|
||||||
|
|
||||||
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
||||||
self.ff_a = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(dim, dim*4),
|
|
||||||
torch.nn.GELU(approximate="tanh"),
|
|
||||||
torch.nn.Linear(dim*4, dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
||||||
self.ff_b = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(dim, dim*4),
|
|
||||||
torch.nn.GELU(approximate="tanh"),
|
|
||||||
torch.nn.Linear(dim*4, dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states_a, hidden_states_b, temb):
|
|
||||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
|
||||||
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
|
||||||
|
|
||||||
# Attention
|
|
||||||
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
|
||||||
|
|
||||||
# Part A
|
|
||||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
|
||||||
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
|
||||||
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
|
||||||
|
|
||||||
# Part B
|
|
||||||
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
|
|
||||||
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
|
|
||||||
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
|
||||||
|
|
||||||
return hidden_states_a, hidden_states_b
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class JointTransformerFinalBlock(torch.nn.Module):
|
|
||||||
def __init__(self, dim, num_attention_heads):
|
|
||||||
super().__init__()
|
|
||||||
self.norm1_a = AdaLayerNorm(dim)
|
|
||||||
self.norm1_b = AdaLayerNorm(dim, single=True)
|
|
||||||
|
|
||||||
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, only_out_a=True)
|
|
||||||
|
|
||||||
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
||||||
self.ff_a = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(dim, dim*4),
|
|
||||||
torch.nn.GELU(approximate="tanh"),
|
|
||||||
torch.nn.Linear(dim*4, dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states_a, hidden_states_b, temb):
|
|
||||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
|
||||||
norm_hidden_states_b = self.norm1_b(hidden_states_b, emb=temb)
|
|
||||||
|
|
||||||
# Attention
|
|
||||||
attn_output_a = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
|
||||||
|
|
||||||
# Part A
|
|
||||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
|
||||||
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
|
||||||
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
|
||||||
|
|
||||||
return hidden_states_a, hidden_states_b
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SD3DiT(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192)
|
|
||||||
self.time_embedder = TimestepEmbeddings(256, 1536)
|
|
||||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(2048, 1536), torch.nn.SiLU(), torch.nn.Linear(1536, 1536))
|
|
||||||
self.context_embedder = torch.nn.Linear(4096, 1536)
|
|
||||||
self.blocks = torch.nn.ModuleList([JointTransformerBlock(1536, 24) for _ in range(23)] + [JointTransformerFinalBlock(1536, 24)])
|
|
||||||
self.norm_out = AdaLayerNorm(1536, single=True)
|
|
||||||
self.proj_out = torch.nn.Linear(1536, 64)
|
|
||||||
|
|
||||||
def tiled_forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size=128, tile_stride=64):
|
|
||||||
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
|
|
||||||
hidden_states = TileWorker().tiled_forward(
|
|
||||||
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb),
|
|
||||||
hidden_states,
|
|
||||||
tile_size,
|
|
||||||
tile_stride,
|
|
||||||
tile_device=hidden_states.device,
|
|
||||||
tile_dtype=hidden_states.dtype
|
|
||||||
)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tiled=False, tile_size=128, tile_stride=64, use_gradient_checkpointing=False):
|
|
||||||
if tiled:
|
|
||||||
return self.tiled_forward(hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size, tile_stride)
|
|
||||||
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
|
||||||
prompt_emb = self.context_embedder(prompt_emb)
|
|
||||||
|
|
||||||
height, width = hidden_states.shape[-2:]
|
|
||||||
hidden_states = self.pos_embedder(hidden_states)
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return module(*inputs)
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
for block in self.blocks:
|
|
||||||
if self.training and use_gradient_checkpointing:
|
|
||||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
hidden_states, prompt_emb, conditioning,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning)
|
|
||||||
|
|
||||||
hidden_states = self.norm_out(hidden_states, conditioning)
|
|
||||||
hidden_states = self.proj_out(hidden_states)
|
|
||||||
hidden_states = rearrange(hidden_states, "B (H W) (P Q C) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return SD3DiTStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SD3DiTStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
rename_dict = {
|
|
||||||
"context_embedder": "context_embedder",
|
|
||||||
"pos_embed.pos_embed": "pos_embedder.pos_embed",
|
|
||||||
"pos_embed.proj": "pos_embedder.proj",
|
|
||||||
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
|
||||||
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
|
||||||
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
|
||||||
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
|
||||||
"norm_out.linear": "norm_out.linear",
|
|
||||||
"proj_out": "proj_out",
|
|
||||||
|
|
||||||
"norm1.linear": "norm1_a.linear",
|
|
||||||
"norm1_context.linear": "norm1_b.linear",
|
|
||||||
"attn.to_q": "attn.a_to_q",
|
|
||||||
"attn.to_k": "attn.a_to_k",
|
|
||||||
"attn.to_v": "attn.a_to_v",
|
|
||||||
"attn.to_out.0": "attn.a_to_out",
|
|
||||||
"attn.add_q_proj": "attn.b_to_q",
|
|
||||||
"attn.add_k_proj": "attn.b_to_k",
|
|
||||||
"attn.add_v_proj": "attn.b_to_v",
|
|
||||||
"attn.to_add_out": "attn.b_to_out",
|
|
||||||
"ff.net.0.proj": "ff_a.0",
|
|
||||||
"ff.net.2": "ff_a.2",
|
|
||||||
"ff_context.net.0.proj": "ff_b.0",
|
|
||||||
"ff_context.net.2": "ff_b.2",
|
|
||||||
}
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
if name in rename_dict:
|
|
||||||
if name == "pos_embed.pos_embed":
|
|
||||||
param = param.reshape((1, 192, 192, 1536))
|
|
||||||
state_dict_[rename_dict[name]] = param
|
|
||||||
elif name.endswith(".weight") or name.endswith(".bias"):
|
|
||||||
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
|
||||||
prefix = name[:-len(suffix)]
|
|
||||||
if prefix in rename_dict:
|
|
||||||
state_dict_[rename_dict[prefix] + suffix] = param
|
|
||||||
elif prefix.startswith("transformer_blocks."):
|
|
||||||
names = prefix.split(".")
|
|
||||||
names[0] = "blocks"
|
|
||||||
middle = ".".join(names[2:])
|
|
||||||
if middle in rename_dict:
|
|
||||||
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
|
||||||
state_dict_[name_] = param
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
rename_dict = {
|
|
||||||
"model.diffusion_model.context_embedder.bias": "context_embedder.bias",
|
|
||||||
"model.diffusion_model.context_embedder.weight": "context_embedder.weight",
|
|
||||||
"model.diffusion_model.final_layer.linear.bias": "proj_out.bias",
|
|
||||||
"model.diffusion_model.final_layer.linear.weight": "proj_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias": "blocks.0.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.weight": "blocks.0.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.0.context_block.attn.proj.bias": "blocks.0.attn.b_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.0.context_block.attn.proj.weight": "blocks.0.attn.b_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.0.context_block.attn.qkv.bias": ['blocks.0.attn.b_to_q.bias', 'blocks.0.attn.b_to_k.bias', 'blocks.0.attn.b_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.0.context_block.attn.qkv.weight": ['blocks.0.attn.b_to_q.weight', 'blocks.0.attn.b_to_k.weight', 'blocks.0.attn.b_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc1.bias": "blocks.0.ff_b.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc1.weight": "blocks.0.ff_b.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc2.bias": "blocks.0.ff_b.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc2.weight": "blocks.0.ff_b.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.0.x_block.adaLN_modulation.1.bias": "blocks.0.norm1_a.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.0.x_block.adaLN_modulation.1.weight": "blocks.0.norm1_a.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.0.x_block.attn.proj.bias": "blocks.0.attn.a_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.0.x_block.attn.proj.weight": "blocks.0.attn.a_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.0.x_block.attn.qkv.bias": ['blocks.0.attn.a_to_q.bias', 'blocks.0.attn.a_to_k.bias', 'blocks.0.attn.a_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.0.x_block.attn.qkv.weight": ['blocks.0.attn.a_to_q.weight', 'blocks.0.attn.a_to_k.weight', 'blocks.0.attn.a_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc1.bias": "blocks.0.ff_a.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc1.weight": "blocks.0.ff_a.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc2.bias": "blocks.0.ff_a.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc2.weight": "blocks.0.ff_a.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.1.context_block.adaLN_modulation.1.bias": "blocks.1.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.1.context_block.adaLN_modulation.1.weight": "blocks.1.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.1.context_block.attn.proj.bias": "blocks.1.attn.b_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.1.context_block.attn.proj.weight": "blocks.1.attn.b_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.1.context_block.attn.qkv.bias": ['blocks.1.attn.b_to_q.bias', 'blocks.1.attn.b_to_k.bias', 'blocks.1.attn.b_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.1.context_block.attn.qkv.weight": ['blocks.1.attn.b_to_q.weight', 'blocks.1.attn.b_to_k.weight', 'blocks.1.attn.b_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc1.bias": "blocks.1.ff_b.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc1.weight": "blocks.1.ff_b.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc2.bias": "blocks.1.ff_b.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc2.weight": "blocks.1.ff_b.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.1.x_block.adaLN_modulation.1.bias": "blocks.1.norm1_a.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.1.x_block.adaLN_modulation.1.weight": "blocks.1.norm1_a.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.1.x_block.attn.proj.bias": "blocks.1.attn.a_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.1.x_block.attn.proj.weight": "blocks.1.attn.a_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.1.x_block.attn.qkv.bias": ['blocks.1.attn.a_to_q.bias', 'blocks.1.attn.a_to_k.bias', 'blocks.1.attn.a_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.1.x_block.attn.qkv.weight": ['blocks.1.attn.a_to_q.weight', 'blocks.1.attn.a_to_k.weight', 'blocks.1.attn.a_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc1.bias": "blocks.1.ff_a.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc1.weight": "blocks.1.ff_a.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc2.bias": "blocks.1.ff_a.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc2.weight": "blocks.1.ff_a.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.10.context_block.adaLN_modulation.1.bias": "blocks.10.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.10.context_block.adaLN_modulation.1.weight": "blocks.10.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.10.context_block.attn.proj.bias": "blocks.10.attn.b_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.10.context_block.attn.proj.weight": "blocks.10.attn.b_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.10.context_block.attn.qkv.bias": ['blocks.10.attn.b_to_q.bias', 'blocks.10.attn.b_to_k.bias', 'blocks.10.attn.b_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.10.context_block.attn.qkv.weight": ['blocks.10.attn.b_to_q.weight', 'blocks.10.attn.b_to_k.weight', 'blocks.10.attn.b_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc1.bias": "blocks.10.ff_b.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc1.weight": "blocks.10.ff_b.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc2.bias": "blocks.10.ff_b.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc2.weight": "blocks.10.ff_b.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.10.x_block.adaLN_modulation.1.bias": "blocks.10.norm1_a.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.10.x_block.adaLN_modulation.1.weight": "blocks.10.norm1_a.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.10.x_block.attn.proj.bias": "blocks.10.attn.a_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.10.x_block.attn.proj.weight": "blocks.10.attn.a_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.10.x_block.attn.qkv.bias": ['blocks.10.attn.a_to_q.bias', 'blocks.10.attn.a_to_k.bias', 'blocks.10.attn.a_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.10.x_block.attn.qkv.weight": ['blocks.10.attn.a_to_q.weight', 'blocks.10.attn.a_to_k.weight', 'blocks.10.attn.a_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc1.bias": "blocks.10.ff_a.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc1.weight": "blocks.10.ff_a.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc2.bias": "blocks.10.ff_a.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc2.weight": "blocks.10.ff_a.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.11.context_block.adaLN_modulation.1.bias": "blocks.11.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.11.context_block.adaLN_modulation.1.weight": "blocks.11.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.11.context_block.attn.proj.bias": "blocks.11.attn.b_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.11.context_block.attn.proj.weight": "blocks.11.attn.b_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.11.context_block.attn.qkv.bias": ['blocks.11.attn.b_to_q.bias', 'blocks.11.attn.b_to_k.bias', 'blocks.11.attn.b_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.11.context_block.attn.qkv.weight": ['blocks.11.attn.b_to_q.weight', 'blocks.11.attn.b_to_k.weight', 'blocks.11.attn.b_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc1.bias": "blocks.11.ff_b.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc1.weight": "blocks.11.ff_b.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc2.bias": "blocks.11.ff_b.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc2.weight": "blocks.11.ff_b.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.11.x_block.adaLN_modulation.1.bias": "blocks.11.norm1_a.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.11.x_block.adaLN_modulation.1.weight": "blocks.11.norm1_a.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.11.x_block.attn.proj.bias": "blocks.11.attn.a_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.11.x_block.attn.proj.weight": "blocks.11.attn.a_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.11.x_block.attn.qkv.bias": ['blocks.11.attn.a_to_q.bias', 'blocks.11.attn.a_to_k.bias', 'blocks.11.attn.a_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.11.x_block.attn.qkv.weight": ['blocks.11.attn.a_to_q.weight', 'blocks.11.attn.a_to_k.weight', 'blocks.11.attn.a_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc1.bias": "blocks.11.ff_a.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc1.weight": "blocks.11.ff_a.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc2.bias": "blocks.11.ff_a.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc2.weight": "blocks.11.ff_a.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.12.context_block.adaLN_modulation.1.bias": "blocks.12.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.12.context_block.adaLN_modulation.1.weight": "blocks.12.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.12.context_block.attn.proj.bias": "blocks.12.attn.b_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.12.context_block.attn.proj.weight": "blocks.12.attn.b_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.12.context_block.attn.qkv.bias": ['blocks.12.attn.b_to_q.bias', 'blocks.12.attn.b_to_k.bias', 'blocks.12.attn.b_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.12.context_block.attn.qkv.weight": ['blocks.12.attn.b_to_q.weight', 'blocks.12.attn.b_to_k.weight', 'blocks.12.attn.b_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc1.bias": "blocks.12.ff_b.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc1.weight": "blocks.12.ff_b.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc2.bias": "blocks.12.ff_b.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc2.weight": "blocks.12.ff_b.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.12.x_block.adaLN_modulation.1.bias": "blocks.12.norm1_a.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.12.x_block.adaLN_modulation.1.weight": "blocks.12.norm1_a.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.12.x_block.attn.proj.bias": "blocks.12.attn.a_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.12.x_block.attn.proj.weight": "blocks.12.attn.a_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.12.x_block.attn.qkv.bias": ['blocks.12.attn.a_to_q.bias', 'blocks.12.attn.a_to_k.bias', 'blocks.12.attn.a_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.12.x_block.attn.qkv.weight": ['blocks.12.attn.a_to_q.weight', 'blocks.12.attn.a_to_k.weight', 'blocks.12.attn.a_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc1.bias": "blocks.12.ff_a.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc1.weight": "blocks.12.ff_a.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc2.bias": "blocks.12.ff_a.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc2.weight": "blocks.12.ff_a.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.13.context_block.adaLN_modulation.1.bias": "blocks.13.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.13.context_block.adaLN_modulation.1.weight": "blocks.13.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.13.context_block.attn.proj.bias": "blocks.13.attn.b_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.13.context_block.attn.proj.weight": "blocks.13.attn.b_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.13.context_block.attn.qkv.bias": ['blocks.13.attn.b_to_q.bias', 'blocks.13.attn.b_to_k.bias', 'blocks.13.attn.b_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.13.context_block.attn.qkv.weight": ['blocks.13.attn.b_to_q.weight', 'blocks.13.attn.b_to_k.weight', 'blocks.13.attn.b_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc1.bias": "blocks.13.ff_b.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc1.weight": "blocks.13.ff_b.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc2.bias": "blocks.13.ff_b.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc2.weight": "blocks.13.ff_b.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.13.x_block.adaLN_modulation.1.bias": "blocks.13.norm1_a.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.13.x_block.adaLN_modulation.1.weight": "blocks.13.norm1_a.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.13.x_block.attn.proj.bias": "blocks.13.attn.a_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.13.x_block.attn.proj.weight": "blocks.13.attn.a_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.13.x_block.attn.qkv.bias": ['blocks.13.attn.a_to_q.bias', 'blocks.13.attn.a_to_k.bias', 'blocks.13.attn.a_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.13.x_block.attn.qkv.weight": ['blocks.13.attn.a_to_q.weight', 'blocks.13.attn.a_to_k.weight', 'blocks.13.attn.a_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc1.bias": "blocks.13.ff_a.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc1.weight": "blocks.13.ff_a.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc2.bias": "blocks.13.ff_a.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc2.weight": "blocks.13.ff_a.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.14.context_block.adaLN_modulation.1.bias": "blocks.14.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.14.context_block.adaLN_modulation.1.weight": "blocks.14.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.14.context_block.attn.proj.bias": "blocks.14.attn.b_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.14.context_block.attn.proj.weight": "blocks.14.attn.b_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.14.context_block.attn.qkv.bias": ['blocks.14.attn.b_to_q.bias', 'blocks.14.attn.b_to_k.bias', 'blocks.14.attn.b_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.14.context_block.attn.qkv.weight": ['blocks.14.attn.b_to_q.weight', 'blocks.14.attn.b_to_k.weight', 'blocks.14.attn.b_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc1.bias": "blocks.14.ff_b.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc1.weight": "blocks.14.ff_b.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc2.bias": "blocks.14.ff_b.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc2.weight": "blocks.14.ff_b.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.14.x_block.adaLN_modulation.1.bias": "blocks.14.norm1_a.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.14.x_block.adaLN_modulation.1.weight": "blocks.14.norm1_a.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.14.x_block.attn.proj.bias": "blocks.14.attn.a_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.14.x_block.attn.proj.weight": "blocks.14.attn.a_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.14.x_block.attn.qkv.bias": ['blocks.14.attn.a_to_q.bias', 'blocks.14.attn.a_to_k.bias', 'blocks.14.attn.a_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.14.x_block.attn.qkv.weight": ['blocks.14.attn.a_to_q.weight', 'blocks.14.attn.a_to_k.weight', 'blocks.14.attn.a_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc1.bias": "blocks.14.ff_a.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc1.weight": "blocks.14.ff_a.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc2.bias": "blocks.14.ff_a.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc2.weight": "blocks.14.ff_a.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.15.context_block.adaLN_modulation.1.bias": "blocks.15.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.15.context_block.adaLN_modulation.1.weight": "blocks.15.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.15.context_block.attn.proj.bias": "blocks.15.attn.b_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.15.context_block.attn.proj.weight": "blocks.15.attn.b_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.15.context_block.attn.qkv.bias": ['blocks.15.attn.b_to_q.bias', 'blocks.15.attn.b_to_k.bias', 'blocks.15.attn.b_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.15.context_block.attn.qkv.weight": ['blocks.15.attn.b_to_q.weight', 'blocks.15.attn.b_to_k.weight', 'blocks.15.attn.b_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc1.bias": "blocks.15.ff_b.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc1.weight": "blocks.15.ff_b.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc2.bias": "blocks.15.ff_b.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc2.weight": "blocks.15.ff_b.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.15.x_block.adaLN_modulation.1.bias": "blocks.15.norm1_a.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.15.x_block.adaLN_modulation.1.weight": "blocks.15.norm1_a.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.15.x_block.attn.proj.bias": "blocks.15.attn.a_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.15.x_block.attn.proj.weight": "blocks.15.attn.a_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.15.x_block.attn.qkv.bias": ['blocks.15.attn.a_to_q.bias', 'blocks.15.attn.a_to_k.bias', 'blocks.15.attn.a_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.15.x_block.attn.qkv.weight": ['blocks.15.attn.a_to_q.weight', 'blocks.15.attn.a_to_k.weight', 'blocks.15.attn.a_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc1.bias": "blocks.15.ff_a.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc1.weight": "blocks.15.ff_a.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc2.bias": "blocks.15.ff_a.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc2.weight": "blocks.15.ff_a.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.16.context_block.adaLN_modulation.1.bias": "blocks.16.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.16.context_block.adaLN_modulation.1.weight": "blocks.16.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.16.context_block.attn.proj.bias": "blocks.16.attn.b_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.16.context_block.attn.proj.weight": "blocks.16.attn.b_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.16.context_block.attn.qkv.bias": ['blocks.16.attn.b_to_q.bias', 'blocks.16.attn.b_to_k.bias', 'blocks.16.attn.b_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.16.context_block.attn.qkv.weight": ['blocks.16.attn.b_to_q.weight', 'blocks.16.attn.b_to_k.weight', 'blocks.16.attn.b_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc1.bias": "blocks.16.ff_b.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc1.weight": "blocks.16.ff_b.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc2.bias": "blocks.16.ff_b.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc2.weight": "blocks.16.ff_b.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.16.x_block.adaLN_modulation.1.bias": "blocks.16.norm1_a.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.16.x_block.adaLN_modulation.1.weight": "blocks.16.norm1_a.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.16.x_block.attn.proj.bias": "blocks.16.attn.a_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.16.x_block.attn.proj.weight": "blocks.16.attn.a_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.16.x_block.attn.qkv.bias": ['blocks.16.attn.a_to_q.bias', 'blocks.16.attn.a_to_k.bias', 'blocks.16.attn.a_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.16.x_block.attn.qkv.weight": ['blocks.16.attn.a_to_q.weight', 'blocks.16.attn.a_to_k.weight', 'blocks.16.attn.a_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc1.bias": "blocks.16.ff_a.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc1.weight": "blocks.16.ff_a.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc2.bias": "blocks.16.ff_a.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc2.weight": "blocks.16.ff_a.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.17.context_block.adaLN_modulation.1.bias": "blocks.17.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.17.context_block.adaLN_modulation.1.weight": "blocks.17.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.17.context_block.attn.proj.bias": "blocks.17.attn.b_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.17.context_block.attn.proj.weight": "blocks.17.attn.b_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.17.context_block.attn.qkv.bias": ['blocks.17.attn.b_to_q.bias', 'blocks.17.attn.b_to_k.bias', 'blocks.17.attn.b_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.17.context_block.attn.qkv.weight": ['blocks.17.attn.b_to_q.weight', 'blocks.17.attn.b_to_k.weight', 'blocks.17.attn.b_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc1.bias": "blocks.17.ff_b.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc1.weight": "blocks.17.ff_b.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc2.bias": "blocks.17.ff_b.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc2.weight": "blocks.17.ff_b.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.17.x_block.adaLN_modulation.1.bias": "blocks.17.norm1_a.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.17.x_block.adaLN_modulation.1.weight": "blocks.17.norm1_a.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.17.x_block.attn.proj.bias": "blocks.17.attn.a_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.17.x_block.attn.proj.weight": "blocks.17.attn.a_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.17.x_block.attn.qkv.bias": ['blocks.17.attn.a_to_q.bias', 'blocks.17.attn.a_to_k.bias', 'blocks.17.attn.a_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.17.x_block.attn.qkv.weight": ['blocks.17.attn.a_to_q.weight', 'blocks.17.attn.a_to_k.weight', 'blocks.17.attn.a_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc1.bias": "blocks.17.ff_a.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc1.weight": "blocks.17.ff_a.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc2.bias": "blocks.17.ff_a.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc2.weight": "blocks.17.ff_a.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.18.context_block.adaLN_modulation.1.bias": "blocks.18.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.18.context_block.adaLN_modulation.1.weight": "blocks.18.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.18.context_block.attn.proj.bias": "blocks.18.attn.b_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.18.context_block.attn.proj.weight": "blocks.18.attn.b_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.18.context_block.attn.qkv.bias": ['blocks.18.attn.b_to_q.bias', 'blocks.18.attn.b_to_k.bias', 'blocks.18.attn.b_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.18.context_block.attn.qkv.weight": ['blocks.18.attn.b_to_q.weight', 'blocks.18.attn.b_to_k.weight', 'blocks.18.attn.b_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc1.bias": "blocks.18.ff_b.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc1.weight": "blocks.18.ff_b.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc2.bias": "blocks.18.ff_b.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc2.weight": "blocks.18.ff_b.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.18.x_block.adaLN_modulation.1.bias": "blocks.18.norm1_a.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.18.x_block.adaLN_modulation.1.weight": "blocks.18.norm1_a.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.18.x_block.attn.proj.bias": "blocks.18.attn.a_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.18.x_block.attn.proj.weight": "blocks.18.attn.a_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.18.x_block.attn.qkv.bias": ['blocks.18.attn.a_to_q.bias', 'blocks.18.attn.a_to_k.bias', 'blocks.18.attn.a_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.18.x_block.attn.qkv.weight": ['blocks.18.attn.a_to_q.weight', 'blocks.18.attn.a_to_k.weight', 'blocks.18.attn.a_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc1.bias": "blocks.18.ff_a.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc1.weight": "blocks.18.ff_a.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc2.bias": "blocks.18.ff_a.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc2.weight": "blocks.18.ff_a.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.19.context_block.adaLN_modulation.1.bias": "blocks.19.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.19.context_block.adaLN_modulation.1.weight": "blocks.19.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.19.context_block.attn.proj.bias": "blocks.19.attn.b_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.19.context_block.attn.proj.weight": "blocks.19.attn.b_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.19.context_block.attn.qkv.bias": ['blocks.19.attn.b_to_q.bias', 'blocks.19.attn.b_to_k.bias', 'blocks.19.attn.b_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.19.context_block.attn.qkv.weight": ['blocks.19.attn.b_to_q.weight', 'blocks.19.attn.b_to_k.weight', 'blocks.19.attn.b_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc1.bias": "blocks.19.ff_b.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc1.weight": "blocks.19.ff_b.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc2.bias": "blocks.19.ff_b.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc2.weight": "blocks.19.ff_b.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.19.x_block.adaLN_modulation.1.bias": "blocks.19.norm1_a.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.19.x_block.adaLN_modulation.1.weight": "blocks.19.norm1_a.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.19.x_block.attn.proj.bias": "blocks.19.attn.a_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.19.x_block.attn.proj.weight": "blocks.19.attn.a_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.19.x_block.attn.qkv.bias": ['blocks.19.attn.a_to_q.bias', 'blocks.19.attn.a_to_k.bias', 'blocks.19.attn.a_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.19.x_block.attn.qkv.weight": ['blocks.19.attn.a_to_q.weight', 'blocks.19.attn.a_to_k.weight', 'blocks.19.attn.a_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc1.bias": "blocks.19.ff_a.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc1.weight": "blocks.19.ff_a.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc2.bias": "blocks.19.ff_a.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc2.weight": "blocks.19.ff_a.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.2.context_block.adaLN_modulation.1.bias": "blocks.2.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.2.context_block.adaLN_modulation.1.weight": "blocks.2.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.2.context_block.attn.proj.bias": "blocks.2.attn.b_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.2.context_block.attn.proj.weight": "blocks.2.attn.b_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.2.context_block.attn.qkv.bias": ['blocks.2.attn.b_to_q.bias', 'blocks.2.attn.b_to_k.bias', 'blocks.2.attn.b_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.2.context_block.attn.qkv.weight": ['blocks.2.attn.b_to_q.weight', 'blocks.2.attn.b_to_k.weight', 'blocks.2.attn.b_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc1.bias": "blocks.2.ff_b.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc1.weight": "blocks.2.ff_b.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc2.bias": "blocks.2.ff_b.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc2.weight": "blocks.2.ff_b.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.2.x_block.adaLN_modulation.1.bias": "blocks.2.norm1_a.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.2.x_block.adaLN_modulation.1.weight": "blocks.2.norm1_a.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.2.x_block.attn.proj.bias": "blocks.2.attn.a_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.2.x_block.attn.proj.weight": "blocks.2.attn.a_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.2.x_block.attn.qkv.bias": ['blocks.2.attn.a_to_q.bias', 'blocks.2.attn.a_to_k.bias', 'blocks.2.attn.a_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.2.x_block.attn.qkv.weight": ['blocks.2.attn.a_to_q.weight', 'blocks.2.attn.a_to_k.weight', 'blocks.2.attn.a_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc1.bias": "blocks.2.ff_a.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc1.weight": "blocks.2.ff_a.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc2.bias": "blocks.2.ff_a.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc2.weight": "blocks.2.ff_a.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.20.context_block.adaLN_modulation.1.bias": "blocks.20.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.20.context_block.adaLN_modulation.1.weight": "blocks.20.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.20.context_block.attn.proj.bias": "blocks.20.attn.b_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.20.context_block.attn.proj.weight": "blocks.20.attn.b_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.20.context_block.attn.qkv.bias": ['blocks.20.attn.b_to_q.bias', 'blocks.20.attn.b_to_k.bias', 'blocks.20.attn.b_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.20.context_block.attn.qkv.weight": ['blocks.20.attn.b_to_q.weight', 'blocks.20.attn.b_to_k.weight', 'blocks.20.attn.b_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc1.bias": "blocks.20.ff_b.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc1.weight": "blocks.20.ff_b.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc2.bias": "blocks.20.ff_b.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc2.weight": "blocks.20.ff_b.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.20.x_block.adaLN_modulation.1.bias": "blocks.20.norm1_a.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.20.x_block.adaLN_modulation.1.weight": "blocks.20.norm1_a.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.20.x_block.attn.proj.bias": "blocks.20.attn.a_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.20.x_block.attn.proj.weight": "blocks.20.attn.a_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.20.x_block.attn.qkv.bias": ['blocks.20.attn.a_to_q.bias', 'blocks.20.attn.a_to_k.bias', 'blocks.20.attn.a_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.20.x_block.attn.qkv.weight": ['blocks.20.attn.a_to_q.weight', 'blocks.20.attn.a_to_k.weight', 'blocks.20.attn.a_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc1.bias": "blocks.20.ff_a.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc1.weight": "blocks.20.ff_a.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc2.bias": "blocks.20.ff_a.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc2.weight": "blocks.20.ff_a.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.21.context_block.adaLN_modulation.1.bias": "blocks.21.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.21.context_block.adaLN_modulation.1.weight": "blocks.21.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.21.context_block.attn.proj.bias": "blocks.21.attn.b_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.21.context_block.attn.proj.weight": "blocks.21.attn.b_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.21.context_block.attn.qkv.bias": ['blocks.21.attn.b_to_q.bias', 'blocks.21.attn.b_to_k.bias', 'blocks.21.attn.b_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.21.context_block.attn.qkv.weight": ['blocks.21.attn.b_to_q.weight', 'blocks.21.attn.b_to_k.weight', 'blocks.21.attn.b_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc1.bias": "blocks.21.ff_b.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc1.weight": "blocks.21.ff_b.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc2.bias": "blocks.21.ff_b.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc2.weight": "blocks.21.ff_b.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.21.x_block.adaLN_modulation.1.bias": "blocks.21.norm1_a.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.21.x_block.adaLN_modulation.1.weight": "blocks.21.norm1_a.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.21.x_block.attn.proj.bias": "blocks.21.attn.a_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.21.x_block.attn.proj.weight": "blocks.21.attn.a_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.21.x_block.attn.qkv.bias": ['blocks.21.attn.a_to_q.bias', 'blocks.21.attn.a_to_k.bias', 'blocks.21.attn.a_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.21.x_block.attn.qkv.weight": ['blocks.21.attn.a_to_q.weight', 'blocks.21.attn.a_to_k.weight', 'blocks.21.attn.a_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc1.bias": "blocks.21.ff_a.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc1.weight": "blocks.21.ff_a.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc2.bias": "blocks.21.ff_a.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc2.weight": "blocks.21.ff_a.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.22.context_block.adaLN_modulation.1.bias": "blocks.22.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.22.context_block.adaLN_modulation.1.weight": "blocks.22.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.22.context_block.attn.proj.bias": "blocks.22.attn.b_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.22.context_block.attn.proj.weight": "blocks.22.attn.b_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.22.context_block.attn.qkv.bias": ['blocks.22.attn.b_to_q.bias', 'blocks.22.attn.b_to_k.bias', 'blocks.22.attn.b_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.22.context_block.attn.qkv.weight": ['blocks.22.attn.b_to_q.weight', 'blocks.22.attn.b_to_k.weight', 'blocks.22.attn.b_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc1.bias": "blocks.22.ff_b.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc1.weight": "blocks.22.ff_b.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc2.bias": "blocks.22.ff_b.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc2.weight": "blocks.22.ff_b.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.22.x_block.adaLN_modulation.1.bias": "blocks.22.norm1_a.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.22.x_block.adaLN_modulation.1.weight": "blocks.22.norm1_a.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.22.x_block.attn.proj.bias": "blocks.22.attn.a_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.22.x_block.attn.proj.weight": "blocks.22.attn.a_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.22.x_block.attn.qkv.bias": ['blocks.22.attn.a_to_q.bias', 'blocks.22.attn.a_to_k.bias', 'blocks.22.attn.a_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.22.x_block.attn.qkv.weight": ['blocks.22.attn.a_to_q.weight', 'blocks.22.attn.a_to_k.weight', 'blocks.22.attn.a_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc1.bias": "blocks.22.ff_a.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc1.weight": "blocks.22.ff_a.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc2.bias": "blocks.22.ff_a.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc2.weight": "blocks.22.ff_a.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.23.context_block.attn.qkv.bias": ['blocks.23.attn.b_to_q.bias', 'blocks.23.attn.b_to_k.bias', 'blocks.23.attn.b_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.23.context_block.attn.qkv.weight": ['blocks.23.attn.b_to_q.weight', 'blocks.23.attn.b_to_k.weight', 'blocks.23.attn.b_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.bias": "blocks.23.norm1_a.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.weight": "blocks.23.norm1_a.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.23.x_block.attn.proj.bias": "blocks.23.attn.a_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.23.x_block.attn.proj.weight": "blocks.23.attn.a_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.23.x_block.attn.qkv.bias": ['blocks.23.attn.a_to_q.bias', 'blocks.23.attn.a_to_k.bias', 'blocks.23.attn.a_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.23.x_block.attn.qkv.weight": ['blocks.23.attn.a_to_q.weight', 'blocks.23.attn.a_to_k.weight', 'blocks.23.attn.a_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc1.bias": "blocks.23.ff_a.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc1.weight": "blocks.23.ff_a.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc2.bias": "blocks.23.ff_a.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc2.weight": "blocks.23.ff_a.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.3.context_block.adaLN_modulation.1.bias": "blocks.3.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.3.context_block.adaLN_modulation.1.weight": "blocks.3.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.3.context_block.attn.proj.bias": "blocks.3.attn.b_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.3.context_block.attn.proj.weight": "blocks.3.attn.b_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.3.context_block.attn.qkv.bias": ['blocks.3.attn.b_to_q.bias', 'blocks.3.attn.b_to_k.bias', 'blocks.3.attn.b_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.3.context_block.attn.qkv.weight": ['blocks.3.attn.b_to_q.weight', 'blocks.3.attn.b_to_k.weight', 'blocks.3.attn.b_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc1.bias": "blocks.3.ff_b.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc1.weight": "blocks.3.ff_b.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc2.bias": "blocks.3.ff_b.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc2.weight": "blocks.3.ff_b.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.3.x_block.adaLN_modulation.1.bias": "blocks.3.norm1_a.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.3.x_block.adaLN_modulation.1.weight": "blocks.3.norm1_a.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.3.x_block.attn.proj.bias": "blocks.3.attn.a_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.3.x_block.attn.proj.weight": "blocks.3.attn.a_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.3.x_block.attn.qkv.bias": ['blocks.3.attn.a_to_q.bias', 'blocks.3.attn.a_to_k.bias', 'blocks.3.attn.a_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.3.x_block.attn.qkv.weight": ['blocks.3.attn.a_to_q.weight', 'blocks.3.attn.a_to_k.weight', 'blocks.3.attn.a_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc1.bias": "blocks.3.ff_a.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc1.weight": "blocks.3.ff_a.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc2.bias": "blocks.3.ff_a.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc2.weight": "blocks.3.ff_a.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.4.context_block.adaLN_modulation.1.bias": "blocks.4.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.4.context_block.adaLN_modulation.1.weight": "blocks.4.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.4.context_block.attn.proj.bias": "blocks.4.attn.b_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.4.context_block.attn.proj.weight": "blocks.4.attn.b_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.4.context_block.attn.qkv.bias": ['blocks.4.attn.b_to_q.bias', 'blocks.4.attn.b_to_k.bias', 'blocks.4.attn.b_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.4.context_block.attn.qkv.weight": ['blocks.4.attn.b_to_q.weight', 'blocks.4.attn.b_to_k.weight', 'blocks.4.attn.b_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc1.bias": "blocks.4.ff_b.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc1.weight": "blocks.4.ff_b.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc2.bias": "blocks.4.ff_b.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc2.weight": "blocks.4.ff_b.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.4.x_block.adaLN_modulation.1.bias": "blocks.4.norm1_a.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.4.x_block.adaLN_modulation.1.weight": "blocks.4.norm1_a.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.4.x_block.attn.proj.bias": "blocks.4.attn.a_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.4.x_block.attn.proj.weight": "blocks.4.attn.a_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.4.x_block.attn.qkv.bias": ['blocks.4.attn.a_to_q.bias', 'blocks.4.attn.a_to_k.bias', 'blocks.4.attn.a_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.4.x_block.attn.qkv.weight": ['blocks.4.attn.a_to_q.weight', 'blocks.4.attn.a_to_k.weight', 'blocks.4.attn.a_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc1.bias": "blocks.4.ff_a.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc1.weight": "blocks.4.ff_a.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc2.bias": "blocks.4.ff_a.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc2.weight": "blocks.4.ff_a.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.5.context_block.adaLN_modulation.1.bias": "blocks.5.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.5.context_block.adaLN_modulation.1.weight": "blocks.5.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.5.context_block.attn.proj.bias": "blocks.5.attn.b_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.5.context_block.attn.proj.weight": "blocks.5.attn.b_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.5.context_block.attn.qkv.bias": ['blocks.5.attn.b_to_q.bias', 'blocks.5.attn.b_to_k.bias', 'blocks.5.attn.b_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.5.context_block.attn.qkv.weight": ['blocks.5.attn.b_to_q.weight', 'blocks.5.attn.b_to_k.weight', 'blocks.5.attn.b_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc1.bias": "blocks.5.ff_b.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc1.weight": "blocks.5.ff_b.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc2.bias": "blocks.5.ff_b.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc2.weight": "blocks.5.ff_b.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.5.x_block.adaLN_modulation.1.bias": "blocks.5.norm1_a.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.5.x_block.adaLN_modulation.1.weight": "blocks.5.norm1_a.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.5.x_block.attn.proj.bias": "blocks.5.attn.a_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.5.x_block.attn.proj.weight": "blocks.5.attn.a_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.5.x_block.attn.qkv.bias": ['blocks.5.attn.a_to_q.bias', 'blocks.5.attn.a_to_k.bias', 'blocks.5.attn.a_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.5.x_block.attn.qkv.weight": ['blocks.5.attn.a_to_q.weight', 'blocks.5.attn.a_to_k.weight', 'blocks.5.attn.a_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc1.bias": "blocks.5.ff_a.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc1.weight": "blocks.5.ff_a.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc2.bias": "blocks.5.ff_a.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc2.weight": "blocks.5.ff_a.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.6.context_block.adaLN_modulation.1.bias": "blocks.6.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.6.context_block.adaLN_modulation.1.weight": "blocks.6.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.6.context_block.attn.proj.bias": "blocks.6.attn.b_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.6.context_block.attn.proj.weight": "blocks.6.attn.b_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.6.context_block.attn.qkv.bias": ['blocks.6.attn.b_to_q.bias', 'blocks.6.attn.b_to_k.bias', 'blocks.6.attn.b_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.6.context_block.attn.qkv.weight": ['blocks.6.attn.b_to_q.weight', 'blocks.6.attn.b_to_k.weight', 'blocks.6.attn.b_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc1.bias": "blocks.6.ff_b.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc1.weight": "blocks.6.ff_b.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc2.bias": "blocks.6.ff_b.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc2.weight": "blocks.6.ff_b.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.6.x_block.adaLN_modulation.1.bias": "blocks.6.norm1_a.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.6.x_block.adaLN_modulation.1.weight": "blocks.6.norm1_a.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.6.x_block.attn.proj.bias": "blocks.6.attn.a_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.6.x_block.attn.proj.weight": "blocks.6.attn.a_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.6.x_block.attn.qkv.bias": ['blocks.6.attn.a_to_q.bias', 'blocks.6.attn.a_to_k.bias', 'blocks.6.attn.a_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.6.x_block.attn.qkv.weight": ['blocks.6.attn.a_to_q.weight', 'blocks.6.attn.a_to_k.weight', 'blocks.6.attn.a_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc1.bias": "blocks.6.ff_a.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc1.weight": "blocks.6.ff_a.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc2.bias": "blocks.6.ff_a.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc2.weight": "blocks.6.ff_a.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.7.context_block.adaLN_modulation.1.bias": "blocks.7.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.7.context_block.adaLN_modulation.1.weight": "blocks.7.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.7.context_block.attn.proj.bias": "blocks.7.attn.b_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.7.context_block.attn.proj.weight": "blocks.7.attn.b_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.7.context_block.attn.qkv.bias": ['blocks.7.attn.b_to_q.bias', 'blocks.7.attn.b_to_k.bias', 'blocks.7.attn.b_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.7.context_block.attn.qkv.weight": ['blocks.7.attn.b_to_q.weight', 'blocks.7.attn.b_to_k.weight', 'blocks.7.attn.b_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc1.bias": "blocks.7.ff_b.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc1.weight": "blocks.7.ff_b.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc2.bias": "blocks.7.ff_b.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc2.weight": "blocks.7.ff_b.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.7.x_block.adaLN_modulation.1.bias": "blocks.7.norm1_a.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.7.x_block.adaLN_modulation.1.weight": "blocks.7.norm1_a.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.7.x_block.attn.proj.bias": "blocks.7.attn.a_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.7.x_block.attn.proj.weight": "blocks.7.attn.a_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.7.x_block.attn.qkv.bias": ['blocks.7.attn.a_to_q.bias', 'blocks.7.attn.a_to_k.bias', 'blocks.7.attn.a_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.7.x_block.attn.qkv.weight": ['blocks.7.attn.a_to_q.weight', 'blocks.7.attn.a_to_k.weight', 'blocks.7.attn.a_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc1.bias": "blocks.7.ff_a.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc1.weight": "blocks.7.ff_a.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc2.bias": "blocks.7.ff_a.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc2.weight": "blocks.7.ff_a.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.8.context_block.adaLN_modulation.1.bias": "blocks.8.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.8.context_block.adaLN_modulation.1.weight": "blocks.8.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.8.context_block.attn.proj.bias": "blocks.8.attn.b_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.8.context_block.attn.proj.weight": "blocks.8.attn.b_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.8.context_block.attn.qkv.bias": ['blocks.8.attn.b_to_q.bias', 'blocks.8.attn.b_to_k.bias', 'blocks.8.attn.b_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.8.context_block.attn.qkv.weight": ['blocks.8.attn.b_to_q.weight', 'blocks.8.attn.b_to_k.weight', 'blocks.8.attn.b_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc1.bias": "blocks.8.ff_b.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc1.weight": "blocks.8.ff_b.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc2.bias": "blocks.8.ff_b.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc2.weight": "blocks.8.ff_b.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.8.x_block.adaLN_modulation.1.bias": "blocks.8.norm1_a.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.8.x_block.adaLN_modulation.1.weight": "blocks.8.norm1_a.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.8.x_block.attn.proj.bias": "blocks.8.attn.a_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.8.x_block.attn.proj.weight": "blocks.8.attn.a_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.8.x_block.attn.qkv.bias": ['blocks.8.attn.a_to_q.bias', 'blocks.8.attn.a_to_k.bias', 'blocks.8.attn.a_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.8.x_block.attn.qkv.weight": ['blocks.8.attn.a_to_q.weight', 'blocks.8.attn.a_to_k.weight', 'blocks.8.attn.a_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc1.bias": "blocks.8.ff_a.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc1.weight": "blocks.8.ff_a.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc2.bias": "blocks.8.ff_a.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc2.weight": "blocks.8.ff_a.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.9.context_block.adaLN_modulation.1.bias": "blocks.9.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.9.context_block.adaLN_modulation.1.weight": "blocks.9.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.9.context_block.attn.proj.bias": "blocks.9.attn.b_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.9.context_block.attn.proj.weight": "blocks.9.attn.b_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.9.context_block.attn.qkv.bias": ['blocks.9.attn.b_to_q.bias', 'blocks.9.attn.b_to_k.bias', 'blocks.9.attn.b_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.9.context_block.attn.qkv.weight": ['blocks.9.attn.b_to_q.weight', 'blocks.9.attn.b_to_k.weight', 'blocks.9.attn.b_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc1.bias": "blocks.9.ff_b.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc1.weight": "blocks.9.ff_b.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc2.bias": "blocks.9.ff_b.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc2.weight": "blocks.9.ff_b.2.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.9.x_block.adaLN_modulation.1.bias": "blocks.9.norm1_a.linear.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.9.x_block.adaLN_modulation.1.weight": "blocks.9.norm1_a.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.9.x_block.attn.proj.bias": "blocks.9.attn.a_to_out.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.9.x_block.attn.proj.weight": "blocks.9.attn.a_to_out.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.9.x_block.attn.qkv.bias": ['blocks.9.attn.a_to_q.bias', 'blocks.9.attn.a_to_k.bias', 'blocks.9.attn.a_to_v.bias'],
|
|
||||||
"model.diffusion_model.joint_blocks.9.x_block.attn.qkv.weight": ['blocks.9.attn.a_to_q.weight', 'blocks.9.attn.a_to_k.weight', 'blocks.9.attn.a_to_v.weight'],
|
|
||||||
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc1.bias": "blocks.9.ff_a.0.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc1.weight": "blocks.9.ff_a.0.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.bias": "blocks.9.ff_a.2.bias",
|
|
||||||
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.weight": "blocks.9.ff_a.2.weight",
|
|
||||||
"model.diffusion_model.pos_embed": "pos_embedder.pos_embed",
|
|
||||||
"model.diffusion_model.t_embedder.mlp.0.bias": "time_embedder.timestep_embedder.0.bias",
|
|
||||||
"model.diffusion_model.t_embedder.mlp.0.weight": "time_embedder.timestep_embedder.0.weight",
|
|
||||||
"model.diffusion_model.t_embedder.mlp.2.bias": "time_embedder.timestep_embedder.2.bias",
|
|
||||||
"model.diffusion_model.t_embedder.mlp.2.weight": "time_embedder.timestep_embedder.2.weight",
|
|
||||||
"model.diffusion_model.x_embedder.proj.bias": "pos_embedder.proj.bias",
|
|
||||||
"model.diffusion_model.x_embedder.proj.weight": "pos_embedder.proj.weight",
|
|
||||||
"model.diffusion_model.y_embedder.mlp.0.bias": "pooled_text_embedder.0.bias",
|
|
||||||
"model.diffusion_model.y_embedder.mlp.0.weight": "pooled_text_embedder.0.weight",
|
|
||||||
"model.diffusion_model.y_embedder.mlp.2.bias": "pooled_text_embedder.2.bias",
|
|
||||||
"model.diffusion_model.y_embedder.mlp.2.weight": "pooled_text_embedder.2.weight",
|
|
||||||
|
|
||||||
"model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.weight": "blocks.23.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.bias": "blocks.23.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
|
|
||||||
"model.diffusion_model.final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
|
|
||||||
}
|
|
||||||
state_dict_ = {}
|
|
||||||
for name in state_dict:
|
|
||||||
if name in rename_dict:
|
|
||||||
param = state_dict[name]
|
|
||||||
if name.startswith("model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1."):
|
|
||||||
param = torch.concat([param[1536:], param[:1536]], axis=0)
|
|
||||||
elif name.startswith("model.diffusion_model.final_layer.adaLN_modulation.1."):
|
|
||||||
param = torch.concat([param[1536:], param[:1536]], axis=0)
|
|
||||||
elif name == "model.diffusion_model.pos_embed":
|
|
||||||
param = param.reshape((1, 192, 192, 1536))
|
|
||||||
if isinstance(rename_dict[name], str):
|
|
||||||
state_dict_[rename_dict[name]] = param
|
|
||||||
else:
|
|
||||||
name_ = rename_dict[name][0].replace(".a_to_q.", ".a_to_qkv.").replace(".b_to_q.", ".b_to_qkv.")
|
|
||||||
state_dict_[name_] = param
|
|
||||||
return state_dict_
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,81 +0,0 @@
|
|||||||
import torch
|
|
||||||
from .sd_vae_decoder import VAEAttentionBlock, SDVAEDecoderStateDictConverter
|
|
||||||
from .sd_unet import ResnetBlock, UpSampler
|
|
||||||
from .tiler import TileWorker
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SD3VAEDecoder(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.scaling_factor = 1.5305 # Different from SD 1.x
|
|
||||||
self.shift_factor = 0.0609 # Different from SD 1.x
|
|
||||||
self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x
|
|
||||||
|
|
||||||
self.blocks = torch.nn.ModuleList([
|
|
||||||
# UNetMidBlock2D
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
# UpDecoderBlock2D
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
UpSampler(512),
|
|
||||||
# UpDecoderBlock2D
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
UpSampler(512),
|
|
||||||
# UpDecoderBlock2D
|
|
||||||
ResnetBlock(512, 256, eps=1e-6),
|
|
||||||
ResnetBlock(256, 256, eps=1e-6),
|
|
||||||
ResnetBlock(256, 256, eps=1e-6),
|
|
||||||
UpSampler(256),
|
|
||||||
# UpDecoderBlock2D
|
|
||||||
ResnetBlock(256, 128, eps=1e-6),
|
|
||||||
ResnetBlock(128, 128, eps=1e-6),
|
|
||||||
ResnetBlock(128, 128, eps=1e-6),
|
|
||||||
])
|
|
||||||
|
|
||||||
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)
|
|
||||||
self.conv_act = torch.nn.SiLU()
|
|
||||||
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
|
||||||
hidden_states = TileWorker().tiled_forward(
|
|
||||||
lambda x: self.forward(x),
|
|
||||||
sample,
|
|
||||||
tile_size,
|
|
||||||
tile_stride,
|
|
||||||
tile_device=sample.device,
|
|
||||||
tile_dtype=sample.dtype
|
|
||||||
)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
|
||||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
|
||||||
if tiled:
|
|
||||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
|
|
||||||
# 1. pre-process
|
|
||||||
hidden_states = sample / self.scaling_factor + self.shift_factor
|
|
||||||
hidden_states = self.conv_in(hidden_states)
|
|
||||||
time_emb = None
|
|
||||||
text_emb = None
|
|
||||||
res_stack = None
|
|
||||||
|
|
||||||
# 2. blocks
|
|
||||||
for i, block in enumerate(self.blocks):
|
|
||||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
|
||||||
|
|
||||||
# 3. output
|
|
||||||
hidden_states = self.conv_norm_out(hidden_states)
|
|
||||||
hidden_states = self.conv_act(hidden_states)
|
|
||||||
hidden_states = self.conv_out(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return SDVAEDecoderStateDictConverter()
|
|
||||||
@@ -1,95 +0,0 @@
|
|||||||
import torch
|
|
||||||
from .sd_unet import ResnetBlock, DownSampler
|
|
||||||
from .sd_vae_encoder import VAEAttentionBlock, SDVAEEncoderStateDictConverter
|
|
||||||
from .tiler import TileWorker
|
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
|
|
||||||
class SD3VAEEncoder(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.scaling_factor = 1.5305 # Different from SD 1.x
|
|
||||||
self.shift_factor = 0.0609 # Different from SD 1.x
|
|
||||||
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
self.blocks = torch.nn.ModuleList([
|
|
||||||
# DownEncoderBlock2D
|
|
||||||
ResnetBlock(128, 128, eps=1e-6),
|
|
||||||
ResnetBlock(128, 128, eps=1e-6),
|
|
||||||
DownSampler(128, padding=0, extra_padding=True),
|
|
||||||
# DownEncoderBlock2D
|
|
||||||
ResnetBlock(128, 256, eps=1e-6),
|
|
||||||
ResnetBlock(256, 256, eps=1e-6),
|
|
||||||
DownSampler(256, padding=0, extra_padding=True),
|
|
||||||
# DownEncoderBlock2D
|
|
||||||
ResnetBlock(256, 512, eps=1e-6),
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
DownSampler(512, padding=0, extra_padding=True),
|
|
||||||
# DownEncoderBlock2D
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
# UNetMidBlock2D
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
])
|
|
||||||
|
|
||||||
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
|
|
||||||
self.conv_act = torch.nn.SiLU()
|
|
||||||
self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
|
||||||
hidden_states = TileWorker().tiled_forward(
|
|
||||||
lambda x: self.forward(x),
|
|
||||||
sample,
|
|
||||||
tile_size,
|
|
||||||
tile_stride,
|
|
||||||
tile_device=sample.device,
|
|
||||||
tile_dtype=sample.dtype
|
|
||||||
)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
|
||||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
|
||||||
if tiled:
|
|
||||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
|
|
||||||
# 1. pre-process
|
|
||||||
hidden_states = self.conv_in(sample)
|
|
||||||
time_emb = None
|
|
||||||
text_emb = None
|
|
||||||
res_stack = None
|
|
||||||
|
|
||||||
# 2. blocks
|
|
||||||
for i, block in enumerate(self.blocks):
|
|
||||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
|
||||||
|
|
||||||
# 3. output
|
|
||||||
hidden_states = self.conv_norm_out(hidden_states)
|
|
||||||
hidden_states = self.conv_act(hidden_states)
|
|
||||||
hidden_states = self.conv_out(hidden_states)
|
|
||||||
hidden_states = hidden_states[:, :16]
|
|
||||||
hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def encode_video(self, sample, batch_size=8):
|
|
||||||
B = sample.shape[0]
|
|
||||||
hidden_states = []
|
|
||||||
|
|
||||||
for i in range(0, sample.shape[2], batch_size):
|
|
||||||
|
|
||||||
j = min(i + batch_size, sample.shape[2])
|
|
||||||
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
|
|
||||||
|
|
||||||
hidden_states_batch = self(sample_batch)
|
|
||||||
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
|
|
||||||
|
|
||||||
hidden_states.append(hidden_states_batch)
|
|
||||||
|
|
||||||
hidden_states = torch.concat(hidden_states, dim=2)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return SDVAEEncoderStateDictConverter()
|
|
||||||
@@ -1,588 +0,0 @@
|
|||||||
import torch
|
|
||||||
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
|
|
||||||
from .tiler import TileWorker
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetConditioningLayer(torch.nn.Module):
|
|
||||||
def __init__(self, channels = (3, 16, 32, 96, 256, 320)):
|
|
||||||
super().__init__()
|
|
||||||
self.blocks = torch.nn.ModuleList([])
|
|
||||||
self.blocks.append(torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1))
|
|
||||||
self.blocks.append(torch.nn.SiLU())
|
|
||||||
for i in range(1, len(channels) - 2):
|
|
||||||
self.blocks.append(torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1))
|
|
||||||
self.blocks.append(torch.nn.SiLU())
|
|
||||||
self.blocks.append(torch.nn.Conv2d(channels[i], channels[i+1], kernel_size=3, padding=1, stride=2))
|
|
||||||
self.blocks.append(torch.nn.SiLU())
|
|
||||||
self.blocks.append(torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1))
|
|
||||||
|
|
||||||
def forward(self, conditioning):
|
|
||||||
for block in self.blocks:
|
|
||||||
conditioning = block(conditioning)
|
|
||||||
return conditioning
|
|
||||||
|
|
||||||
|
|
||||||
class SDControlNet(torch.nn.Module):
|
|
||||||
def __init__(self, global_pool=False):
|
|
||||||
super().__init__()
|
|
||||||
self.time_proj = Timesteps(320)
|
|
||||||
self.time_embedding = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(320, 1280),
|
|
||||||
torch.nn.SiLU(),
|
|
||||||
torch.nn.Linear(1280, 1280)
|
|
||||||
)
|
|
||||||
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
|
|
||||||
|
|
||||||
self.blocks = torch.nn.ModuleList([
|
|
||||||
# CrossAttnDownBlock2D
|
|
||||||
ResnetBlock(320, 320, 1280),
|
|
||||||
AttentionBlock(8, 40, 320, 1, 768),
|
|
||||||
PushBlock(),
|
|
||||||
ResnetBlock(320, 320, 1280),
|
|
||||||
AttentionBlock(8, 40, 320, 1, 768),
|
|
||||||
PushBlock(),
|
|
||||||
DownSampler(320),
|
|
||||||
PushBlock(),
|
|
||||||
# CrossAttnDownBlock2D
|
|
||||||
ResnetBlock(320, 640, 1280),
|
|
||||||
AttentionBlock(8, 80, 640, 1, 768),
|
|
||||||
PushBlock(),
|
|
||||||
ResnetBlock(640, 640, 1280),
|
|
||||||
AttentionBlock(8, 80, 640, 1, 768),
|
|
||||||
PushBlock(),
|
|
||||||
DownSampler(640),
|
|
||||||
PushBlock(),
|
|
||||||
# CrossAttnDownBlock2D
|
|
||||||
ResnetBlock(640, 1280, 1280),
|
|
||||||
AttentionBlock(8, 160, 1280, 1, 768),
|
|
||||||
PushBlock(),
|
|
||||||
ResnetBlock(1280, 1280, 1280),
|
|
||||||
AttentionBlock(8, 160, 1280, 1, 768),
|
|
||||||
PushBlock(),
|
|
||||||
DownSampler(1280),
|
|
||||||
PushBlock(),
|
|
||||||
# DownBlock2D
|
|
||||||
ResnetBlock(1280, 1280, 1280),
|
|
||||||
PushBlock(),
|
|
||||||
ResnetBlock(1280, 1280, 1280),
|
|
||||||
PushBlock(),
|
|
||||||
# UNetMidBlock2DCrossAttn
|
|
||||||
ResnetBlock(1280, 1280, 1280),
|
|
||||||
AttentionBlock(8, 160, 1280, 1, 768),
|
|
||||||
ResnetBlock(1280, 1280, 1280),
|
|
||||||
PushBlock()
|
|
||||||
])
|
|
||||||
|
|
||||||
self.controlnet_blocks = torch.nn.ModuleList([
|
|
||||||
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
|
||||||
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
|
|
||||||
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
|
|
||||||
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
|
|
||||||
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
|
|
||||||
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
|
|
||||||
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
|
|
||||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
|
|
||||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
|
||||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
|
||||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
|
||||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
|
||||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
|
||||||
])
|
|
||||||
|
|
||||||
self.global_pool = global_pool
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
sample, timestep, encoder_hidden_states, conditioning,
|
|
||||||
tiled=False, tile_size=64, tile_stride=32,
|
|
||||||
):
|
|
||||||
# 1. time
|
|
||||||
time_emb = self.time_proj(timestep).to(sample.dtype)
|
|
||||||
time_emb = self.time_embedding(time_emb)
|
|
||||||
time_emb = time_emb.repeat(sample.shape[0], 1)
|
|
||||||
|
|
||||||
# 2. pre-process
|
|
||||||
height, width = sample.shape[2], sample.shape[3]
|
|
||||||
hidden_states = self.conv_in(sample) + self.controlnet_conv_in(conditioning)
|
|
||||||
text_emb = encoder_hidden_states
|
|
||||||
res_stack = [hidden_states]
|
|
||||||
|
|
||||||
# 3. blocks
|
|
||||||
for i, block in enumerate(self.blocks):
|
|
||||||
if tiled and not isinstance(block, PushBlock):
|
|
||||||
_, _, inter_height, _ = hidden_states.shape
|
|
||||||
resize_scale = inter_height / height
|
|
||||||
hidden_states = TileWorker().tiled_forward(
|
|
||||||
lambda x: block(x, time_emb, text_emb, res_stack)[0],
|
|
||||||
hidden_states,
|
|
||||||
int(tile_size * resize_scale),
|
|
||||||
int(tile_stride * resize_scale),
|
|
||||||
tile_device=hidden_states.device,
|
|
||||||
tile_dtype=hidden_states.dtype
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
|
|
||||||
|
|
||||||
# 4. ControlNet blocks
|
|
||||||
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
|
|
||||||
|
|
||||||
# pool
|
|
||||||
if self.global_pool:
|
|
||||||
controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
|
|
||||||
|
|
||||||
return controlnet_res_stack
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return SDControlNetStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
class SDControlNetStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
# architecture
|
|
||||||
block_types = [
|
|
||||||
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
|
|
||||||
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
|
|
||||||
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
|
|
||||||
'ResnetBlock', 'PushBlock', 'ResnetBlock', 'PushBlock',
|
|
||||||
'ResnetBlock', 'AttentionBlock', 'ResnetBlock',
|
|
||||||
'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'UpSampler',
|
|
||||||
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
|
|
||||||
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
|
|
||||||
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock'
|
|
||||||
]
|
|
||||||
|
|
||||||
# controlnet_rename_dict
|
|
||||||
controlnet_rename_dict = {
|
|
||||||
"controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
|
|
||||||
"controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
|
|
||||||
"controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
|
|
||||||
"controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
|
|
||||||
"controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
|
|
||||||
"controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
|
|
||||||
"controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
|
|
||||||
"controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
|
|
||||||
"controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
|
|
||||||
"controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
|
|
||||||
"controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
|
|
||||||
"controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
|
|
||||||
"controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
|
|
||||||
"controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
|
|
||||||
"controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
|
|
||||||
"controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Rename each parameter
|
|
||||||
name_list = sorted([name for name in state_dict])
|
|
||||||
rename_dict = {}
|
|
||||||
block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
|
|
||||||
last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
|
|
||||||
for name in name_list:
|
|
||||||
names = name.split(".")
|
|
||||||
if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
|
|
||||||
pass
|
|
||||||
elif name in controlnet_rename_dict:
|
|
||||||
names = controlnet_rename_dict[name].split(".")
|
|
||||||
elif names[0] == "controlnet_down_blocks":
|
|
||||||
names[0] = "controlnet_blocks"
|
|
||||||
elif names[0] == "controlnet_mid_block":
|
|
||||||
names = ["controlnet_blocks", "12", names[-1]]
|
|
||||||
elif names[0] in ["time_embedding", "add_embedding"]:
|
|
||||||
if names[0] == "add_embedding":
|
|
||||||
names[0] = "add_time_embedding"
|
|
||||||
names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
|
|
||||||
elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
|
|
||||||
if names[0] == "mid_block":
|
|
||||||
names.insert(1, "0")
|
|
||||||
block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
|
|
||||||
block_type_with_id = ".".join(names[:4])
|
|
||||||
if block_type_with_id != last_block_type_with_id[block_type]:
|
|
||||||
block_id[block_type] += 1
|
|
||||||
last_block_type_with_id[block_type] = block_type_with_id
|
|
||||||
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
|
||||||
block_id[block_type] += 1
|
|
||||||
block_type_with_id = ".".join(names[:4])
|
|
||||||
names = ["blocks", str(block_id[block_type])] + names[4:]
|
|
||||||
if "ff" in names:
|
|
||||||
ff_index = names.index("ff")
|
|
||||||
component = ".".join(names[ff_index:ff_index+3])
|
|
||||||
component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
|
|
||||||
names = names[:ff_index] + [component] + names[ff_index+3:]
|
|
||||||
if "to_out" in names:
|
|
||||||
names.pop(names.index("to_out") + 1)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown parameters: {name}")
|
|
||||||
rename_dict[name] = ".".join(names)
|
|
||||||
|
|
||||||
# Convert state_dict
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
if ".proj_in." in name or ".proj_out." in name:
|
|
||||||
param = param.squeeze()
|
|
||||||
if rename_dict[name] in [
|
|
||||||
"controlnet_blocks.1.bias", "controlnet_blocks.2.bias", "controlnet_blocks.3.bias", "controlnet_blocks.5.bias", "controlnet_blocks.6.bias",
|
|
||||||
"controlnet_blocks.8.bias", "controlnet_blocks.9.bias", "controlnet_blocks.10.bias", "controlnet_blocks.11.bias", "controlnet_blocks.12.bias"
|
|
||||||
]:
|
|
||||||
continue
|
|
||||||
state_dict_[rename_dict[name]] = param
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
if "mid_block.resnets.1.time_emb_proj.weight" in state_dict:
|
|
||||||
# For controlnets in diffusers format
|
|
||||||
return self.from_diffusers(state_dict)
|
|
||||||
rename_dict = {
|
|
||||||
"control_model.time_embed.0.weight": "time_embedding.0.weight",
|
|
||||||
"control_model.time_embed.0.bias": "time_embedding.0.bias",
|
|
||||||
"control_model.time_embed.2.weight": "time_embedding.2.weight",
|
|
||||||
"control_model.time_embed.2.bias": "time_embedding.2.bias",
|
|
||||||
"control_model.input_blocks.0.0.weight": "conv_in.weight",
|
|
||||||
"control_model.input_blocks.0.0.bias": "conv_in.bias",
|
|
||||||
"control_model.input_blocks.1.0.in_layers.0.weight": "blocks.0.norm1.weight",
|
|
||||||
"control_model.input_blocks.1.0.in_layers.0.bias": "blocks.0.norm1.bias",
|
|
||||||
"control_model.input_blocks.1.0.in_layers.2.weight": "blocks.0.conv1.weight",
|
|
||||||
"control_model.input_blocks.1.0.in_layers.2.bias": "blocks.0.conv1.bias",
|
|
||||||
"control_model.input_blocks.1.0.emb_layers.1.weight": "blocks.0.time_emb_proj.weight",
|
|
||||||
"control_model.input_blocks.1.0.emb_layers.1.bias": "blocks.0.time_emb_proj.bias",
|
|
||||||
"control_model.input_blocks.1.0.out_layers.0.weight": "blocks.0.norm2.weight",
|
|
||||||
"control_model.input_blocks.1.0.out_layers.0.bias": "blocks.0.norm2.bias",
|
|
||||||
"control_model.input_blocks.1.0.out_layers.3.weight": "blocks.0.conv2.weight",
|
|
||||||
"control_model.input_blocks.1.0.out_layers.3.bias": "blocks.0.conv2.bias",
|
|
||||||
"control_model.input_blocks.1.1.norm.weight": "blocks.1.norm.weight",
|
|
||||||
"control_model.input_blocks.1.1.norm.bias": "blocks.1.norm.bias",
|
|
||||||
"control_model.input_blocks.1.1.proj_in.weight": "blocks.1.proj_in.weight",
|
|
||||||
"control_model.input_blocks.1.1.proj_in.bias": "blocks.1.proj_in.bias",
|
|
||||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "blocks.1.transformer_blocks.0.attn1.to_q.weight",
|
|
||||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "blocks.1.transformer_blocks.0.attn1.to_k.weight",
|
|
||||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "blocks.1.transformer_blocks.0.attn1.to_v.weight",
|
|
||||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.1.transformer_blocks.0.attn1.to_out.weight",
|
|
||||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.1.transformer_blocks.0.attn1.to_out.bias",
|
|
||||||
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.1.transformer_blocks.0.act_fn.proj.weight",
|
|
||||||
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.1.transformer_blocks.0.act_fn.proj.bias",
|
|
||||||
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "blocks.1.transformer_blocks.0.ff.weight",
|
|
||||||
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "blocks.1.transformer_blocks.0.ff.bias",
|
|
||||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "blocks.1.transformer_blocks.0.attn2.to_q.weight",
|
|
||||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "blocks.1.transformer_blocks.0.attn2.to_k.weight",
|
|
||||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "blocks.1.transformer_blocks.0.attn2.to_v.weight",
|
|
||||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.1.transformer_blocks.0.attn2.to_out.weight",
|
|
||||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.1.transformer_blocks.0.attn2.to_out.bias",
|
|
||||||
"control_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "blocks.1.transformer_blocks.0.norm1.weight",
|
|
||||||
"control_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "blocks.1.transformer_blocks.0.norm1.bias",
|
|
||||||
"control_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "blocks.1.transformer_blocks.0.norm2.weight",
|
|
||||||
"control_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "blocks.1.transformer_blocks.0.norm2.bias",
|
|
||||||
"control_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "blocks.1.transformer_blocks.0.norm3.weight",
|
|
||||||
"control_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "blocks.1.transformer_blocks.0.norm3.bias",
|
|
||||||
"control_model.input_blocks.1.1.proj_out.weight": "blocks.1.proj_out.weight",
|
|
||||||
"control_model.input_blocks.1.1.proj_out.bias": "blocks.1.proj_out.bias",
|
|
||||||
"control_model.input_blocks.2.0.in_layers.0.weight": "blocks.3.norm1.weight",
|
|
||||||
"control_model.input_blocks.2.0.in_layers.0.bias": "blocks.3.norm1.bias",
|
|
||||||
"control_model.input_blocks.2.0.in_layers.2.weight": "blocks.3.conv1.weight",
|
|
||||||
"control_model.input_blocks.2.0.in_layers.2.bias": "blocks.3.conv1.bias",
|
|
||||||
"control_model.input_blocks.2.0.emb_layers.1.weight": "blocks.3.time_emb_proj.weight",
|
|
||||||
"control_model.input_blocks.2.0.emb_layers.1.bias": "blocks.3.time_emb_proj.bias",
|
|
||||||
"control_model.input_blocks.2.0.out_layers.0.weight": "blocks.3.norm2.weight",
|
|
||||||
"control_model.input_blocks.2.0.out_layers.0.bias": "blocks.3.norm2.bias",
|
|
||||||
"control_model.input_blocks.2.0.out_layers.3.weight": "blocks.3.conv2.weight",
|
|
||||||
"control_model.input_blocks.2.0.out_layers.3.bias": "blocks.3.conv2.bias",
|
|
||||||
"control_model.input_blocks.2.1.norm.weight": "blocks.4.norm.weight",
|
|
||||||
"control_model.input_blocks.2.1.norm.bias": "blocks.4.norm.bias",
|
|
||||||
"control_model.input_blocks.2.1.proj_in.weight": "blocks.4.proj_in.weight",
|
|
||||||
"control_model.input_blocks.2.1.proj_in.bias": "blocks.4.proj_in.bias",
|
|
||||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "blocks.4.transformer_blocks.0.attn1.to_q.weight",
|
|
||||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "blocks.4.transformer_blocks.0.attn1.to_k.weight",
|
|
||||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "blocks.4.transformer_blocks.0.attn1.to_v.weight",
|
|
||||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.4.transformer_blocks.0.attn1.to_out.weight",
|
|
||||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.4.transformer_blocks.0.attn1.to_out.bias",
|
|
||||||
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.4.transformer_blocks.0.act_fn.proj.weight",
|
|
||||||
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.4.transformer_blocks.0.act_fn.proj.bias",
|
|
||||||
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "blocks.4.transformer_blocks.0.ff.weight",
|
|
||||||
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "blocks.4.transformer_blocks.0.ff.bias",
|
|
||||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "blocks.4.transformer_blocks.0.attn2.to_q.weight",
|
|
||||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "blocks.4.transformer_blocks.0.attn2.to_k.weight",
|
|
||||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "blocks.4.transformer_blocks.0.attn2.to_v.weight",
|
|
||||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.4.transformer_blocks.0.attn2.to_out.weight",
|
|
||||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.4.transformer_blocks.0.attn2.to_out.bias",
|
|
||||||
"control_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "blocks.4.transformer_blocks.0.norm1.weight",
|
|
||||||
"control_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "blocks.4.transformer_blocks.0.norm1.bias",
|
|
||||||
"control_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "blocks.4.transformer_blocks.0.norm2.weight",
|
|
||||||
"control_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "blocks.4.transformer_blocks.0.norm2.bias",
|
|
||||||
"control_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "blocks.4.transformer_blocks.0.norm3.weight",
|
|
||||||
"control_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "blocks.4.transformer_blocks.0.norm3.bias",
|
|
||||||
"control_model.input_blocks.2.1.proj_out.weight": "blocks.4.proj_out.weight",
|
|
||||||
"control_model.input_blocks.2.1.proj_out.bias": "blocks.4.proj_out.bias",
|
|
||||||
"control_model.input_blocks.3.0.op.weight": "blocks.6.conv.weight",
|
|
||||||
"control_model.input_blocks.3.0.op.bias": "blocks.6.conv.bias",
|
|
||||||
"control_model.input_blocks.4.0.in_layers.0.weight": "blocks.8.norm1.weight",
|
|
||||||
"control_model.input_blocks.4.0.in_layers.0.bias": "blocks.8.norm1.bias",
|
|
||||||
"control_model.input_blocks.4.0.in_layers.2.weight": "blocks.8.conv1.weight",
|
|
||||||
"control_model.input_blocks.4.0.in_layers.2.bias": "blocks.8.conv1.bias",
|
|
||||||
"control_model.input_blocks.4.0.emb_layers.1.weight": "blocks.8.time_emb_proj.weight",
|
|
||||||
"control_model.input_blocks.4.0.emb_layers.1.bias": "blocks.8.time_emb_proj.bias",
|
|
||||||
"control_model.input_blocks.4.0.out_layers.0.weight": "blocks.8.norm2.weight",
|
|
||||||
"control_model.input_blocks.4.0.out_layers.0.bias": "blocks.8.norm2.bias",
|
|
||||||
"control_model.input_blocks.4.0.out_layers.3.weight": "blocks.8.conv2.weight",
|
|
||||||
"control_model.input_blocks.4.0.out_layers.3.bias": "blocks.8.conv2.bias",
|
|
||||||
"control_model.input_blocks.4.0.skip_connection.weight": "blocks.8.conv_shortcut.weight",
|
|
||||||
"control_model.input_blocks.4.0.skip_connection.bias": "blocks.8.conv_shortcut.bias",
|
|
||||||
"control_model.input_blocks.4.1.norm.weight": "blocks.9.norm.weight",
|
|
||||||
"control_model.input_blocks.4.1.norm.bias": "blocks.9.norm.bias",
|
|
||||||
"control_model.input_blocks.4.1.proj_in.weight": "blocks.9.proj_in.weight",
|
|
||||||
"control_model.input_blocks.4.1.proj_in.bias": "blocks.9.proj_in.bias",
|
|
||||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.9.transformer_blocks.0.attn1.to_q.weight",
|
|
||||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.9.transformer_blocks.0.attn1.to_k.weight",
|
|
||||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.9.transformer_blocks.0.attn1.to_v.weight",
|
|
||||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.9.transformer_blocks.0.attn1.to_out.weight",
|
|
||||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.9.transformer_blocks.0.attn1.to_out.bias",
|
|
||||||
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.9.transformer_blocks.0.act_fn.proj.weight",
|
|
||||||
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.9.transformer_blocks.0.act_fn.proj.bias",
|
|
||||||
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.9.transformer_blocks.0.ff.weight",
|
|
||||||
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.9.transformer_blocks.0.ff.bias",
|
|
||||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.9.transformer_blocks.0.attn2.to_q.weight",
|
|
||||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.9.transformer_blocks.0.attn2.to_k.weight",
|
|
||||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.9.transformer_blocks.0.attn2.to_v.weight",
|
|
||||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.9.transformer_blocks.0.attn2.to_out.weight",
|
|
||||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.9.transformer_blocks.0.attn2.to_out.bias",
|
|
||||||
"control_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.9.transformer_blocks.0.norm1.weight",
|
|
||||||
"control_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.9.transformer_blocks.0.norm1.bias",
|
|
||||||
"control_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.9.transformer_blocks.0.norm2.weight",
|
|
||||||
"control_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.9.transformer_blocks.0.norm2.bias",
|
|
||||||
"control_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.9.transformer_blocks.0.norm3.weight",
|
|
||||||
"control_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.9.transformer_blocks.0.norm3.bias",
|
|
||||||
"control_model.input_blocks.4.1.proj_out.weight": "blocks.9.proj_out.weight",
|
|
||||||
"control_model.input_blocks.4.1.proj_out.bias": "blocks.9.proj_out.bias",
|
|
||||||
"control_model.input_blocks.5.0.in_layers.0.weight": "blocks.11.norm1.weight",
|
|
||||||
"control_model.input_blocks.5.0.in_layers.0.bias": "blocks.11.norm1.bias",
|
|
||||||
"control_model.input_blocks.5.0.in_layers.2.weight": "blocks.11.conv1.weight",
|
|
||||||
"control_model.input_blocks.5.0.in_layers.2.bias": "blocks.11.conv1.bias",
|
|
||||||
"control_model.input_blocks.5.0.emb_layers.1.weight": "blocks.11.time_emb_proj.weight",
|
|
||||||
"control_model.input_blocks.5.0.emb_layers.1.bias": "blocks.11.time_emb_proj.bias",
|
|
||||||
"control_model.input_blocks.5.0.out_layers.0.weight": "blocks.11.norm2.weight",
|
|
||||||
"control_model.input_blocks.5.0.out_layers.0.bias": "blocks.11.norm2.bias",
|
|
||||||
"control_model.input_blocks.5.0.out_layers.3.weight": "blocks.11.conv2.weight",
|
|
||||||
"control_model.input_blocks.5.0.out_layers.3.bias": "blocks.11.conv2.bias",
|
|
||||||
"control_model.input_blocks.5.1.norm.weight": "blocks.12.norm.weight",
|
|
||||||
"control_model.input_blocks.5.1.norm.bias": "blocks.12.norm.bias",
|
|
||||||
"control_model.input_blocks.5.1.proj_in.weight": "blocks.12.proj_in.weight",
|
|
||||||
"control_model.input_blocks.5.1.proj_in.bias": "blocks.12.proj_in.bias",
|
|
||||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.12.transformer_blocks.0.attn1.to_q.weight",
|
|
||||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.12.transformer_blocks.0.attn1.to_k.weight",
|
|
||||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.12.transformer_blocks.0.attn1.to_v.weight",
|
|
||||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.12.transformer_blocks.0.attn1.to_out.weight",
|
|
||||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.12.transformer_blocks.0.attn1.to_out.bias",
|
|
||||||
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.12.transformer_blocks.0.act_fn.proj.weight",
|
|
||||||
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.12.transformer_blocks.0.act_fn.proj.bias",
|
|
||||||
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.12.transformer_blocks.0.ff.weight",
|
|
||||||
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.12.transformer_blocks.0.ff.bias",
|
|
||||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.12.transformer_blocks.0.attn2.to_q.weight",
|
|
||||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.12.transformer_blocks.0.attn2.to_k.weight",
|
|
||||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.12.transformer_blocks.0.attn2.to_v.weight",
|
|
||||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.12.transformer_blocks.0.attn2.to_out.weight",
|
|
||||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.12.transformer_blocks.0.attn2.to_out.bias",
|
|
||||||
"control_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.12.transformer_blocks.0.norm1.weight",
|
|
||||||
"control_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.12.transformer_blocks.0.norm1.bias",
|
|
||||||
"control_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.12.transformer_blocks.0.norm2.weight",
|
|
||||||
"control_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.12.transformer_blocks.0.norm2.bias",
|
|
||||||
"control_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.12.transformer_blocks.0.norm3.weight",
|
|
||||||
"control_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.12.transformer_blocks.0.norm3.bias",
|
|
||||||
"control_model.input_blocks.5.1.proj_out.weight": "blocks.12.proj_out.weight",
|
|
||||||
"control_model.input_blocks.5.1.proj_out.bias": "blocks.12.proj_out.bias",
|
|
||||||
"control_model.input_blocks.6.0.op.weight": "blocks.14.conv.weight",
|
|
||||||
"control_model.input_blocks.6.0.op.bias": "blocks.14.conv.bias",
|
|
||||||
"control_model.input_blocks.7.0.in_layers.0.weight": "blocks.16.norm1.weight",
|
|
||||||
"control_model.input_blocks.7.0.in_layers.0.bias": "blocks.16.norm1.bias",
|
|
||||||
"control_model.input_blocks.7.0.in_layers.2.weight": "blocks.16.conv1.weight",
|
|
||||||
"control_model.input_blocks.7.0.in_layers.2.bias": "blocks.16.conv1.bias",
|
|
||||||
"control_model.input_blocks.7.0.emb_layers.1.weight": "blocks.16.time_emb_proj.weight",
|
|
||||||
"control_model.input_blocks.7.0.emb_layers.1.bias": "blocks.16.time_emb_proj.bias",
|
|
||||||
"control_model.input_blocks.7.0.out_layers.0.weight": "blocks.16.norm2.weight",
|
|
||||||
"control_model.input_blocks.7.0.out_layers.0.bias": "blocks.16.norm2.bias",
|
|
||||||
"control_model.input_blocks.7.0.out_layers.3.weight": "blocks.16.conv2.weight",
|
|
||||||
"control_model.input_blocks.7.0.out_layers.3.bias": "blocks.16.conv2.bias",
|
|
||||||
"control_model.input_blocks.7.0.skip_connection.weight": "blocks.16.conv_shortcut.weight",
|
|
||||||
"control_model.input_blocks.7.0.skip_connection.bias": "blocks.16.conv_shortcut.bias",
|
|
||||||
"control_model.input_blocks.7.1.norm.weight": "blocks.17.norm.weight",
|
|
||||||
"control_model.input_blocks.7.1.norm.bias": "blocks.17.norm.bias",
|
|
||||||
"control_model.input_blocks.7.1.proj_in.weight": "blocks.17.proj_in.weight",
|
|
||||||
"control_model.input_blocks.7.1.proj_in.bias": "blocks.17.proj_in.bias",
|
|
||||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.17.transformer_blocks.0.attn1.to_q.weight",
|
|
||||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.17.transformer_blocks.0.attn1.to_k.weight",
|
|
||||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.17.transformer_blocks.0.attn1.to_v.weight",
|
|
||||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.17.transformer_blocks.0.attn1.to_out.weight",
|
|
||||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.17.transformer_blocks.0.attn1.to_out.bias",
|
|
||||||
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.17.transformer_blocks.0.act_fn.proj.weight",
|
|
||||||
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.17.transformer_blocks.0.act_fn.proj.bias",
|
|
||||||
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.17.transformer_blocks.0.ff.weight",
|
|
||||||
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.17.transformer_blocks.0.ff.bias",
|
|
||||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.17.transformer_blocks.0.attn2.to_q.weight",
|
|
||||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.17.transformer_blocks.0.attn2.to_k.weight",
|
|
||||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.17.transformer_blocks.0.attn2.to_v.weight",
|
|
||||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.17.transformer_blocks.0.attn2.to_out.weight",
|
|
||||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.17.transformer_blocks.0.attn2.to_out.bias",
|
|
||||||
"control_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.17.transformer_blocks.0.norm1.weight",
|
|
||||||
"control_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.17.transformer_blocks.0.norm1.bias",
|
|
||||||
"control_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.17.transformer_blocks.0.norm2.weight",
|
|
||||||
"control_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.17.transformer_blocks.0.norm2.bias",
|
|
||||||
"control_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.17.transformer_blocks.0.norm3.weight",
|
|
||||||
"control_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.17.transformer_blocks.0.norm3.bias",
|
|
||||||
"control_model.input_blocks.7.1.proj_out.weight": "blocks.17.proj_out.weight",
|
|
||||||
"control_model.input_blocks.7.1.proj_out.bias": "blocks.17.proj_out.bias",
|
|
||||||
"control_model.input_blocks.8.0.in_layers.0.weight": "blocks.19.norm1.weight",
|
|
||||||
"control_model.input_blocks.8.0.in_layers.0.bias": "blocks.19.norm1.bias",
|
|
||||||
"control_model.input_blocks.8.0.in_layers.2.weight": "blocks.19.conv1.weight",
|
|
||||||
"control_model.input_blocks.8.0.in_layers.2.bias": "blocks.19.conv1.bias",
|
|
||||||
"control_model.input_blocks.8.0.emb_layers.1.weight": "blocks.19.time_emb_proj.weight",
|
|
||||||
"control_model.input_blocks.8.0.emb_layers.1.bias": "blocks.19.time_emb_proj.bias",
|
|
||||||
"control_model.input_blocks.8.0.out_layers.0.weight": "blocks.19.norm2.weight",
|
|
||||||
"control_model.input_blocks.8.0.out_layers.0.bias": "blocks.19.norm2.bias",
|
|
||||||
"control_model.input_blocks.8.0.out_layers.3.weight": "blocks.19.conv2.weight",
|
|
||||||
"control_model.input_blocks.8.0.out_layers.3.bias": "blocks.19.conv2.bias",
|
|
||||||
"control_model.input_blocks.8.1.norm.weight": "blocks.20.norm.weight",
|
|
||||||
"control_model.input_blocks.8.1.norm.bias": "blocks.20.norm.bias",
|
|
||||||
"control_model.input_blocks.8.1.proj_in.weight": "blocks.20.proj_in.weight",
|
|
||||||
"control_model.input_blocks.8.1.proj_in.bias": "blocks.20.proj_in.bias",
|
|
||||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.20.transformer_blocks.0.attn1.to_q.weight",
|
|
||||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.20.transformer_blocks.0.attn1.to_k.weight",
|
|
||||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.20.transformer_blocks.0.attn1.to_v.weight",
|
|
||||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.20.transformer_blocks.0.attn1.to_out.weight",
|
|
||||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.20.transformer_blocks.0.attn1.to_out.bias",
|
|
||||||
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.20.transformer_blocks.0.act_fn.proj.weight",
|
|
||||||
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.20.transformer_blocks.0.act_fn.proj.bias",
|
|
||||||
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.20.transformer_blocks.0.ff.weight",
|
|
||||||
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.20.transformer_blocks.0.ff.bias",
|
|
||||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.20.transformer_blocks.0.attn2.to_q.weight",
|
|
||||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.20.transformer_blocks.0.attn2.to_k.weight",
|
|
||||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.20.transformer_blocks.0.attn2.to_v.weight",
|
|
||||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.20.transformer_blocks.0.attn2.to_out.weight",
|
|
||||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.20.transformer_blocks.0.attn2.to_out.bias",
|
|
||||||
"control_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.20.transformer_blocks.0.norm1.weight",
|
|
||||||
"control_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.20.transformer_blocks.0.norm1.bias",
|
|
||||||
"control_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.20.transformer_blocks.0.norm2.weight",
|
|
||||||
"control_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.20.transformer_blocks.0.norm2.bias",
|
|
||||||
"control_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.20.transformer_blocks.0.norm3.weight",
|
|
||||||
"control_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.20.transformer_blocks.0.norm3.bias",
|
|
||||||
"control_model.input_blocks.8.1.proj_out.weight": "blocks.20.proj_out.weight",
|
|
||||||
"control_model.input_blocks.8.1.proj_out.bias": "blocks.20.proj_out.bias",
|
|
||||||
"control_model.input_blocks.9.0.op.weight": "blocks.22.conv.weight",
|
|
||||||
"control_model.input_blocks.9.0.op.bias": "blocks.22.conv.bias",
|
|
||||||
"control_model.input_blocks.10.0.in_layers.0.weight": "blocks.24.norm1.weight",
|
|
||||||
"control_model.input_blocks.10.0.in_layers.0.bias": "blocks.24.norm1.bias",
|
|
||||||
"control_model.input_blocks.10.0.in_layers.2.weight": "blocks.24.conv1.weight",
|
|
||||||
"control_model.input_blocks.10.0.in_layers.2.bias": "blocks.24.conv1.bias",
|
|
||||||
"control_model.input_blocks.10.0.emb_layers.1.weight": "blocks.24.time_emb_proj.weight",
|
|
||||||
"control_model.input_blocks.10.0.emb_layers.1.bias": "blocks.24.time_emb_proj.bias",
|
|
||||||
"control_model.input_blocks.10.0.out_layers.0.weight": "blocks.24.norm2.weight",
|
|
||||||
"control_model.input_blocks.10.0.out_layers.0.bias": "blocks.24.norm2.bias",
|
|
||||||
"control_model.input_blocks.10.0.out_layers.3.weight": "blocks.24.conv2.weight",
|
|
||||||
"control_model.input_blocks.10.0.out_layers.3.bias": "blocks.24.conv2.bias",
|
|
||||||
"control_model.input_blocks.11.0.in_layers.0.weight": "blocks.26.norm1.weight",
|
|
||||||
"control_model.input_blocks.11.0.in_layers.0.bias": "blocks.26.norm1.bias",
|
|
||||||
"control_model.input_blocks.11.0.in_layers.2.weight": "blocks.26.conv1.weight",
|
|
||||||
"control_model.input_blocks.11.0.in_layers.2.bias": "blocks.26.conv1.bias",
|
|
||||||
"control_model.input_blocks.11.0.emb_layers.1.weight": "blocks.26.time_emb_proj.weight",
|
|
||||||
"control_model.input_blocks.11.0.emb_layers.1.bias": "blocks.26.time_emb_proj.bias",
|
|
||||||
"control_model.input_blocks.11.0.out_layers.0.weight": "blocks.26.norm2.weight",
|
|
||||||
"control_model.input_blocks.11.0.out_layers.0.bias": "blocks.26.norm2.bias",
|
|
||||||
"control_model.input_blocks.11.0.out_layers.3.weight": "blocks.26.conv2.weight",
|
|
||||||
"control_model.input_blocks.11.0.out_layers.3.bias": "blocks.26.conv2.bias",
|
|
||||||
"control_model.zero_convs.0.0.weight": "controlnet_blocks.0.weight",
|
|
||||||
"control_model.zero_convs.0.0.bias": "controlnet_blocks.0.bias",
|
|
||||||
"control_model.zero_convs.1.0.weight": "controlnet_blocks.1.weight",
|
|
||||||
"control_model.zero_convs.1.0.bias": "controlnet_blocks.0.bias",
|
|
||||||
"control_model.zero_convs.2.0.weight": "controlnet_blocks.2.weight",
|
|
||||||
"control_model.zero_convs.2.0.bias": "controlnet_blocks.0.bias",
|
|
||||||
"control_model.zero_convs.3.0.weight": "controlnet_blocks.3.weight",
|
|
||||||
"control_model.zero_convs.3.0.bias": "controlnet_blocks.0.bias",
|
|
||||||
"control_model.zero_convs.4.0.weight": "controlnet_blocks.4.weight",
|
|
||||||
"control_model.zero_convs.4.0.bias": "controlnet_blocks.4.bias",
|
|
||||||
"control_model.zero_convs.5.0.weight": "controlnet_blocks.5.weight",
|
|
||||||
"control_model.zero_convs.5.0.bias": "controlnet_blocks.4.bias",
|
|
||||||
"control_model.zero_convs.6.0.weight": "controlnet_blocks.6.weight",
|
|
||||||
"control_model.zero_convs.6.0.bias": "controlnet_blocks.4.bias",
|
|
||||||
"control_model.zero_convs.7.0.weight": "controlnet_blocks.7.weight",
|
|
||||||
"control_model.zero_convs.7.0.bias": "controlnet_blocks.7.bias",
|
|
||||||
"control_model.zero_convs.8.0.weight": "controlnet_blocks.8.weight",
|
|
||||||
"control_model.zero_convs.8.0.bias": "controlnet_blocks.7.bias",
|
|
||||||
"control_model.zero_convs.9.0.weight": "controlnet_blocks.9.weight",
|
|
||||||
"control_model.zero_convs.9.0.bias": "controlnet_blocks.7.bias",
|
|
||||||
"control_model.zero_convs.10.0.weight": "controlnet_blocks.10.weight",
|
|
||||||
"control_model.zero_convs.10.0.bias": "controlnet_blocks.7.bias",
|
|
||||||
"control_model.zero_convs.11.0.weight": "controlnet_blocks.11.weight",
|
|
||||||
"control_model.zero_convs.11.0.bias": "controlnet_blocks.7.bias",
|
|
||||||
"control_model.input_hint_block.0.weight": "controlnet_conv_in.blocks.0.weight",
|
|
||||||
"control_model.input_hint_block.0.bias": "controlnet_conv_in.blocks.0.bias",
|
|
||||||
"control_model.input_hint_block.2.weight": "controlnet_conv_in.blocks.2.weight",
|
|
||||||
"control_model.input_hint_block.2.bias": "controlnet_conv_in.blocks.2.bias",
|
|
||||||
"control_model.input_hint_block.4.weight": "controlnet_conv_in.blocks.4.weight",
|
|
||||||
"control_model.input_hint_block.4.bias": "controlnet_conv_in.blocks.4.bias",
|
|
||||||
"control_model.input_hint_block.6.weight": "controlnet_conv_in.blocks.6.weight",
|
|
||||||
"control_model.input_hint_block.6.bias": "controlnet_conv_in.blocks.6.bias",
|
|
||||||
"control_model.input_hint_block.8.weight": "controlnet_conv_in.blocks.8.weight",
|
|
||||||
"control_model.input_hint_block.8.bias": "controlnet_conv_in.blocks.8.bias",
|
|
||||||
"control_model.input_hint_block.10.weight": "controlnet_conv_in.blocks.10.weight",
|
|
||||||
"control_model.input_hint_block.10.bias": "controlnet_conv_in.blocks.10.bias",
|
|
||||||
"control_model.input_hint_block.12.weight": "controlnet_conv_in.blocks.12.weight",
|
|
||||||
"control_model.input_hint_block.12.bias": "controlnet_conv_in.blocks.12.bias",
|
|
||||||
"control_model.input_hint_block.14.weight": "controlnet_conv_in.blocks.14.weight",
|
|
||||||
"control_model.input_hint_block.14.bias": "controlnet_conv_in.blocks.14.bias",
|
|
||||||
"control_model.middle_block.0.in_layers.0.weight": "blocks.28.norm1.weight",
|
|
||||||
"control_model.middle_block.0.in_layers.0.bias": "blocks.28.norm1.bias",
|
|
||||||
"control_model.middle_block.0.in_layers.2.weight": "blocks.28.conv1.weight",
|
|
||||||
"control_model.middle_block.0.in_layers.2.bias": "blocks.28.conv1.bias",
|
|
||||||
"control_model.middle_block.0.emb_layers.1.weight": "blocks.28.time_emb_proj.weight",
|
|
||||||
"control_model.middle_block.0.emb_layers.1.bias": "blocks.28.time_emb_proj.bias",
|
|
||||||
"control_model.middle_block.0.out_layers.0.weight": "blocks.28.norm2.weight",
|
|
||||||
"control_model.middle_block.0.out_layers.0.bias": "blocks.28.norm2.bias",
|
|
||||||
"control_model.middle_block.0.out_layers.3.weight": "blocks.28.conv2.weight",
|
|
||||||
"control_model.middle_block.0.out_layers.3.bias": "blocks.28.conv2.bias",
|
|
||||||
"control_model.middle_block.1.norm.weight": "blocks.29.norm.weight",
|
|
||||||
"control_model.middle_block.1.norm.bias": "blocks.29.norm.bias",
|
|
||||||
"control_model.middle_block.1.proj_in.weight": "blocks.29.proj_in.weight",
|
|
||||||
"control_model.middle_block.1.proj_in.bias": "blocks.29.proj_in.bias",
|
|
||||||
"control_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "blocks.29.transformer_blocks.0.attn1.to_q.weight",
|
|
||||||
"control_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "blocks.29.transformer_blocks.0.attn1.to_k.weight",
|
|
||||||
"control_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "blocks.29.transformer_blocks.0.attn1.to_v.weight",
|
|
||||||
"control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.29.transformer_blocks.0.attn1.to_out.weight",
|
|
||||||
"control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.29.transformer_blocks.0.attn1.to_out.bias",
|
|
||||||
"control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.29.transformer_blocks.0.act_fn.proj.weight",
|
|
||||||
"control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.29.transformer_blocks.0.act_fn.proj.bias",
|
|
||||||
"control_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "blocks.29.transformer_blocks.0.ff.weight",
|
|
||||||
"control_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "blocks.29.transformer_blocks.0.ff.bias",
|
|
||||||
"control_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "blocks.29.transformer_blocks.0.attn2.to_q.weight",
|
|
||||||
"control_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "blocks.29.transformer_blocks.0.attn2.to_k.weight",
|
|
||||||
"control_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "blocks.29.transformer_blocks.0.attn2.to_v.weight",
|
|
||||||
"control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.29.transformer_blocks.0.attn2.to_out.weight",
|
|
||||||
"control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.29.transformer_blocks.0.attn2.to_out.bias",
|
|
||||||
"control_model.middle_block.1.transformer_blocks.0.norm1.weight": "blocks.29.transformer_blocks.0.norm1.weight",
|
|
||||||
"control_model.middle_block.1.transformer_blocks.0.norm1.bias": "blocks.29.transformer_blocks.0.norm1.bias",
|
|
||||||
"control_model.middle_block.1.transformer_blocks.0.norm2.weight": "blocks.29.transformer_blocks.0.norm2.weight",
|
|
||||||
"control_model.middle_block.1.transformer_blocks.0.norm2.bias": "blocks.29.transformer_blocks.0.norm2.bias",
|
|
||||||
"control_model.middle_block.1.transformer_blocks.0.norm3.weight": "blocks.29.transformer_blocks.0.norm3.weight",
|
|
||||||
"control_model.middle_block.1.transformer_blocks.0.norm3.bias": "blocks.29.transformer_blocks.0.norm3.bias",
|
|
||||||
"control_model.middle_block.1.proj_out.weight": "blocks.29.proj_out.weight",
|
|
||||||
"control_model.middle_block.1.proj_out.bias": "blocks.29.proj_out.bias",
|
|
||||||
"control_model.middle_block.2.in_layers.0.weight": "blocks.30.norm1.weight",
|
|
||||||
"control_model.middle_block.2.in_layers.0.bias": "blocks.30.norm1.bias",
|
|
||||||
"control_model.middle_block.2.in_layers.2.weight": "blocks.30.conv1.weight",
|
|
||||||
"control_model.middle_block.2.in_layers.2.bias": "blocks.30.conv1.bias",
|
|
||||||
"control_model.middle_block.2.emb_layers.1.weight": "blocks.30.time_emb_proj.weight",
|
|
||||||
"control_model.middle_block.2.emb_layers.1.bias": "blocks.30.time_emb_proj.bias",
|
|
||||||
"control_model.middle_block.2.out_layers.0.weight": "blocks.30.norm2.weight",
|
|
||||||
"control_model.middle_block.2.out_layers.0.bias": "blocks.30.norm2.bias",
|
|
||||||
"control_model.middle_block.2.out_layers.3.weight": "blocks.30.conv2.weight",
|
|
||||||
"control_model.middle_block.2.out_layers.3.bias": "blocks.30.conv2.bias",
|
|
||||||
"control_model.middle_block_out.0.weight": "controlnet_blocks.12.weight",
|
|
||||||
"control_model.middle_block_out.0.bias": "controlnet_blocks.7.bias",
|
|
||||||
}
|
|
||||||
state_dict_ = {}
|
|
||||||
for name in state_dict:
|
|
||||||
if name in rename_dict:
|
|
||||||
param = state_dict[name]
|
|
||||||
if ".proj_in." in name or ".proj_out." in name:
|
|
||||||
param = param.squeeze()
|
|
||||||
state_dict_[rename_dict[name]] = param
|
|
||||||
return state_dict_
|
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
from .svd_image_encoder import SVDImageEncoder
|
|
||||||
from .sdxl_ipadapter import IpAdapterImageProjModel, IpAdapterModule, SDXLIpAdapterStateDictConverter
|
|
||||||
from transformers import CLIPImageProcessor
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class IpAdapterCLIPImageEmbedder(SVDImageEncoder):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.image_processor = CLIPImageProcessor()
|
|
||||||
|
|
||||||
def forward(self, image):
|
|
||||||
pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values
|
|
||||||
pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype)
|
|
||||||
return super().forward(pixel_values)
|
|
||||||
|
|
||||||
|
|
||||||
class SDIpAdapter(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
shape_list = [(768, 320)] * 2 + [(768, 640)] * 2 + [(768, 1280)] * 5 + [(768, 640)] * 3 + [(768, 320)] * 3 + [(768, 1280)] * 1
|
|
||||||
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list])
|
|
||||||
self.image_proj = IpAdapterImageProjModel(cross_attention_dim=768, clip_embeddings_dim=1024, clip_extra_context_tokens=4)
|
|
||||||
self.set_full_adapter()
|
|
||||||
|
|
||||||
def set_full_adapter(self):
|
|
||||||
block_ids = [1, 4, 9, 12, 17, 20, 40, 43, 46, 50, 53, 56, 60, 63, 66, 29]
|
|
||||||
self.call_block_id = {(i, 0): j for j, i in enumerate(block_ids)}
|
|
||||||
|
|
||||||
def set_less_adapter(self):
|
|
||||||
# IP-Adapter for SD v1.5 doesn't support this feature.
|
|
||||||
self.set_full_adapter()
|
|
||||||
|
|
||||||
def forward(self, hidden_states, scale=1.0):
|
|
||||||
hidden_states = self.image_proj(hidden_states)
|
|
||||||
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
|
|
||||||
ip_kv_dict = {}
|
|
||||||
for (block_id, transformer_id) in self.call_block_id:
|
|
||||||
ipadapter_id = self.call_block_id[(block_id, transformer_id)]
|
|
||||||
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
|
|
||||||
if block_id not in ip_kv_dict:
|
|
||||||
ip_kv_dict[block_id] = {}
|
|
||||||
ip_kv_dict[block_id][transformer_id] = {
|
|
||||||
"ip_k": ip_k,
|
|
||||||
"ip_v": ip_v,
|
|
||||||
"scale": scale
|
|
||||||
}
|
|
||||||
return ip_kv_dict
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return SDIpAdapterStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
class SDIpAdapterStateDictConverter(SDXLIpAdapterStateDictConverter):
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user